In [0]:
#@title Imports

%install '.package(url: "https://github.com/marcrasi/swift-vec2", .branch("master"))' Vec2

// Clear installation messages from output area.
print("\u{001B}[2J")

import Vec2
typealias Vector2 = Vec2

In [0]:
#@title Helpers

let iterationCount = 100
let maxLr = Float(1e-1)
let minLr = Float(1e-2)
func learningRate(_ i: Int = 0) -> Float {
    return (maxLr - minLr) * (cos(Float.pi * Float(i) / Float(iterationCount)) + 1) / 2 + minLr
}

extension Array where Element == Float {
    @differentiable
    func sum() -> Float {
        differentiableReduce(0, +)
    }
}

In [0]:
#@title State Data Structures

import Vec2

struct Ball: AdditiveArithmetic, Differentiable {
  var position: Vec2
  var velocity: Vec2

  static let ballRadius = Float(1)

  @differentiable
  init(position: Vec2, velocity: Vec2) {
      self.position = position
      self.velocity = velocity
  }

  @differentiable
  func updating(position: Vec2) -> Ball {
    Ball(position: position, velocity: velocity)
  }

  @differentiable
  func updating(velocity: Vec2) -> Ball {
    Ball(position: position, velocity: velocity)
  }

  @differentiable
  func moved(_ delta: Vec2) -> Ball {
    updating(position: position + delta)
  }

  @differentiable
  func impulsed(_ delta: Vec2) -> Ball {
    updating(velocity: velocity + delta)
  }
}

struct BallTup: Differentiable {
  var ball1: Ball
  var ball2: Ball
  var collisionEnergy: Float

  @differentiable
  init(_ ball1: Ball, _ ball2: Ball, _ collisionEnergy: Float) {
    self.ball1 = ball1
    self.ball2 = ball2
    self.collisionEnergy = collisionEnergy
  }
}

struct Wall {
  var p1: Vec2
  var p2: Vec2
}

struct SimulationParameters {
  var dt: Float = 0.02
  var lambda: Float = 0.0
}

extension Ball {
  @differentiable
  func stepped(_ params: SimulationParameters) -> Ball {
    let frictionAcceleration = Float(1)
    let friction = Vec2(magnitude: frictionAcceleration * params.dt, direction: velocity.direction)
    var newVelocity = friction.magnitude > velocity.magnitude ? Vec2(0, 0) : velocity - friction

    if Float.random(in: 0..<1) > exp(-params.lambda * params.dt) {
        newVelocity = newVelocity + newVelocity.magnitude * Vec2(Float.random(in: (-0.5...0.5)), Float.random(in: (-0.5...0.5)))
    }

    return Ball(
      position: position + params.dt * newVelocity,
      velocity: newVelocity)
  }

  func touches(_ other: Ball) -> Bool {
    return (position - other.position).magnitude <= 2 * Ball.ballRadius
  }

  private func projected(to wall: Wall) -> Float {
    (wall.p1.magnitudeSquared + position.dot(wall.p2 - wall.p1) - wall.p1.dot(wall.p2)) / (wall.p2 - wall.p1).magnitudeSquared
  }

  func touches(_ wall: Wall) -> Bool {
    let t = projected(to: wall)
    if t < 0 || t > 1 { return false }
    let projection = (1 - t) * wall.p1 + t * wall.p2
    return (position - projection).magnitudeSquared <= Ball.ballRadius * Ball.ballRadius
  }

  @differentiable
  func bounced(on wall: Wall) -> Ball {
    let tangent = wall.p2 - wall.p1
    let unitTangent = tangent / tangent.magnitude
    let unitNormal = Vec2(-unitTangent.y, unitTangent.x)
    let t = projected(to: wall)
    let projection = (1 - t) * wall.p1 + t * wall.p2
    let displacement = position - projection
    if velocity.dot(displacement) > 0 { return self }
    let newVelocity = velocity.dot(unitTangent) * unitTangent - velocity.dot(unitNormal) * unitNormal
    return Ball(position: position, velocity: newVelocity)
  }

  @differentiable
  static func collided(_ a: Ball, _ b: Ball) -> BallTup {
    updateCollisionVelocities(a, b)
  }

  private static func updateCollisionVelocities(_ ball1: Ball, _ ball2: Ball) -> BallTup {
    // Perfectly elastic collision. This is the impulse along the normal that preserves kinetic energy.
    let p = ball2.position - ball1.position
    let v = ball2.velocity - ball1.velocity
    let vdotp = v.dot(p)
    if vdotp > 0 { return BallTup(ball1, ball2, 0) }
    let impulse = (-vdotp / p.magnitudeSquared) * p
    return BallTup(ball1.impulsed(-1 * impulse), ball2.impulsed(impulse), impulse.magnitudeSquared)
  }
}

In [0]:
#@title World State

struct World: Differentiable {
    /// Current state of all the balls.
    var balls: [Ball]

    /// Minimum ball distance (over all previous states) to each target.
    var targetDistances: [Float]

    /// Amount of time simulation has been running.
    var t: Float

    /// Positions of the targets.
    @noDerivative var targets: [Vec2]

    /// Positions of the walls.
    @noDerivative var walls: [Wall]

    @differentiable
    init(
      balls: [Ball],
      targetDistances: [Float],
      t: Float,
      targets: [Vec2],
      walls: [Wall]
    ) {
        self.balls = balls
        self.targetDistances = targetDistances
        self.t = t
        self.targets = targets
        self.walls = walls
    }
}

In [0]:
#@title Simulation Logic

extension World {
  @differentiable
  init(balls: [Ball], targets: [Vec2], walls: [Wall]) {
      self.init(
          balls: balls,
          targetDistances: Array(repeating: Float.infinity, count: withoutDerivative(at: targets.count)),
          t: 0,
          targets: targets,
          walls: walls
    )
  }

  @differentiable
  init(ball1InitialVelocity: Vec2) {
    self.init(
        balls: [
            Ball(position: Vec2(-20, 0), velocity: ball1InitialVelocity),
            Ball(position: Vec2(-10, 0), velocity: Vec2(0, 0))
        ],
        targets: [
            Vec2(0, 5),
            Vec2(-5, 15),
            Vec2(-2, -20)
        ],
        walls: [
            Wall(p1: Vec2(7, -10), p2: Vec2(7, 20))
        ]
    )
  }
}

extension World {
  var still: Bool {
    if t > 7 { return true }
    for ball in balls {
      if ball.velocity.magnitude > 0 {
        return false
      }
    }
    return true
  }
}

extension World {
  @differentiable
  func stepped(_ params: SimulationParameters = SimulationParameters()) -> World {
    // Integrate the ball velocity.
    var updatedBalls = balls.differentiableMap { $0.stepped(params) }

    // Collide the balls with the walls.
    updatedBalls = updatedBalls.differentiableMap { [walls = walls] (ball: Ball) -> Ball in
      for i in withoutDerivative(at: walls.indices) {
        let wall = walls[i]
        if ball.touches(wall) {
          return ball.bounced(on: wall)
        }
      }
      return ball
    }
    
    // Collide the balls with each other.
    if updatedBalls[0].touches(updatedBalls[1]) {
      let collidedBalls = Ball.collided(updatedBalls[0], updatedBalls[1])
      updatedBalls = [collidedBalls.ball1, collidedBalls.ball2]
    }

    // Update min target distance.
    var newMinTargetDistance: [Float] = []
    for i in withoutDerivative(at: targets.indices) {
        let distTo1 = (updatedBalls[0].position - targets[i]).magnitude
        let distTo2 = (updatedBalls[1].position - targets[i]).magnitude
        var curTargetDistance = distTo1 < distTo2 ? distTo1 : distTo2
        if curTargetDistance < 2 * Ball.ballRadius { curTargetDistance = 2 * Ball.ballRadius }
        if curTargetDistance < targetDistances[i] {
            newMinTargetDistance = newMinTargetDistance + [curTargetDistance]
        } else {
            newMinTargetDistance = newMinTargetDistance + [targetDistances[i]]
        }
    }

    return World(
        balls: updatedBalls,
        targetDistances: newMinTargetDistance.withDerivative { [count = withoutDerivative(at: newMinTargetDistance.count)] (d: inout Array<Float>.DifferentiableView) -> () in
            if d.base.count == 0 {
                d = Array.DifferentiableView(Array(repeating: 0, count: count))
            }
        },
        t: t + params.dt,
        targets: targets,
        walls: walls
    )
  }

  @differentiable
  func steppedUntilStill(_ params: SimulationParameters, _ f: (World) -> () = { _ in }) -> World {
    var state = self
    while !state.still {
      f(state)
      state = state.stepped(params)
    }
    f(state)
    return state
  }
}

struct DrawingArrow {
    var offset: Vec2
    var color: String
    var direction: Vec2
}

func svg(states: [World], params: SimulationParameters = SimulationParameters(), vectors: [[DrawingArrow]] = [], delay: Float = 0) -> String {
  let scale = Float(7)
  let origin = Vec2(175, 70)
  let size = Vec2(350, 224)

  func transformed(_ position: Vec2) -> Vec2 {
      let p2 = Vector2(x: position.y, y: -position.x)
    return scale * p2 + origin
  }

  var r = ""
  r += """
    <svg width="\(size.x)", height="\(size.y)">\n
  """

  let totalDuration = (states.last?.t ?? 0) + delay
  r += """
    <rect>
        <animate
            id="looper"
            begin="0;looper.end"
            attributeName="visibility"
            from="hide"
            to="hide"
            dur="\(totalDuration)s" />
    </rect>
  """

  func target(id: String, cx: String, cy: String, border: String) -> String {
    """
      <circle
        id="\(id)"
        r="\(scale * Ball.ballRadius)"
        cx="\(cx)"
        cy="\(cy)"
        stroke-width="3"
        stroke="\(border)" />\n
    """
  }

  func circle(id: String, cx: String, cy: String, fill: String) -> String {
    """
      <circle
        id="\(id)"
        r="\(scale * Ball.ballRadius)"
        cx="\(cx)"
        cy="\(cy)"
        fill="\(fill)" />\n
    """
  }

  func line(p1: Vec2, p2: Vec2) -> String {
    """
      <line
        x1="\(p1.x)"
        y1="\(p1.y)"
        x2="\(p2.x)"
        y2="\(p2.y)"
        style="stroke:#000;stroke-width:2" />\n
    """
  }

  func arrow(id: String, _ arrow: DrawingArrow, _ base: Vec2) -> String {
      let off2 = Vector2(x: arrow.offset.y, y: -arrow.offset.x)
      let p1 = base + off2
      let dir2 = Vector2(x: arrow.direction.y, y: -arrow.direction.x)
      let p2 = p1 + dir2
      return """
        <line
            id="\(id)"
            x1="\(p1.x)"
            y1="\(p1.y)"
            x2="\(p2.x)"
            y2="\(p2.y)"
            style="stroke:\(arrow.color);stroke-width:2" />\n
      """
  }

for (index, finalPosition) in (states.first?.targets ?? []).enumerated() {
    let position = transformed(finalPosition)
    r += target(id: "target\(index)", cx: "\(position.x)", cy: "\(position.y)", border: "red")
}

  for (index, ball) in (states.first?.balls ?? []).enumerated() {
    let position = transformed(ball.position)
    r += circle(id: "ball\(index)", cx: "\(position.x)", cy: "\(position.y)", fill: "orange")

    if vectors.count > index {
        for (index2, vector) in vectors[index].enumerated() {
            let id = "arrow\(index)_\(index2)"
            r += arrow(id: id, vector, position)
            r += """
                
                <animate
                    xlink:href="#\(id)"
                    attributeName="opacity"
                    from="0"
                    to="1"
                    dur="\(params.dt)s"
                    begin="looper.begin"
                    fill="freeze" />
                <animate
                    xlink:href="#\(id)"
                    attributeName="opacity"
                    from="1"
                    to="0"
                    dur="\(params.dt)s"
                    begin="looper.begin+\(delay)s"
                    fill="freeze" />
            """
        }
    }
  }

  for wall in (states.first?.walls ?? []) {
    r += line(p1: transformed(wall.p1), p2: transformed(wall.p2))
  }

  for (timeIndex, (state, nextState)) in zip(states, states.dropFirst(1)).enumerated() {
    let t = Float(timeIndex) * params.dt + delay
    for (ballIndex, (ballState, nextBallState)) in zip(state.balls, nextState.balls).enumerated() {
      func animate(attributeName: String, from: String, to: String) -> String {
        """
          <animate
            xlink:href="#ball\(ballIndex)"
            attributeName="\(attributeName)"
            from="\(from)"
            to="\(to)"
            dur="\(params.dt)s"
            begin="looper.begin+\(t)s" />\n
        """
      }
      let position = transformed(ballState.position)
      let nextPosition = transformed(nextBallState.position)
      r += animate(attributeName: "cx", from: String(position.x), to: String(nextPosition.x))
      r += animate(attributeName: "cy", from: String(position.y), to: String(nextPosition.y))
    }

    for (targetIndex, (targetDistance, nextTargetDistance)) in zip(state.targetDistances, nextState.targetDistances).enumerated() {
        if targetDistance > 2 && nextTargetDistance <= 2 {
            r += """
                <animate
                    xlink:href="#target\(targetIndex)"
                    attributeName="stroke"
                    from="green"
                    to="red"
                    dur="\(params.dt)s"
                    begin="looper.begin"
                    fill="freeze" />
                """
            r += """
                <animate
                    xlink:href="#target\(targetIndex)"
                    attributeName="stroke"
                    from="red"
                    to="green"
                    dur="\(params.dt)s"
                    begin="looper.begin+\(t)s"
                    fill="freeze" />
                """
        }
    }
  }

  r += "</svg>\n"
  return r
}

import Python
%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")
let display = Python.import("IPython.core.display")

func drawSVG(states: [World], params: SimulationParameters = SimulationParameters(), vectors: [[DrawingArrow]] = [], delay: Float = 0) {
    display[dynamicMember: "display"](display.HTML(svg(states: states, params: params, vectors: vectors, delay: delay)))
}

func drawSimulation(ball1InitialVelocity: Vec2, params: SimulationParameters = SimulationParameters()) {
    let initialState = World(ball1InitialVelocity: ball1InitialVelocity)
    let vectors = initialState.balls.enumerated().map { (index: Int, ball: Ball) -> [DrawingArrow] in
        let scaledVelocity = 2 * ball.velocity
        let vs = [DrawingArrow(offset: Vec2(0, 0), color: "blue", direction: scaledVelocity)]
        return vs
    }
    var allStates: [World] = []
    initialState.steppedUntilStill(params) { allStates.append($0) }
    drawSVG(states: allStates, params: params, vectors: vectors, delay: 1)
}

func drawGradients(ball1InitialVelocity: Vec2, params: SimulationParameters = SimulationParameters(), grad: Vec2) {
    let initialState = World(ball1InitialVelocity: ball1InitialVelocity)
    let vGrads = [-10 * grad]
    let vectors = initialState.balls.enumerated().map { (index: Int, ball: Ball) -> [DrawingArrow] in
        let scaledVelocity = 2 * ball.velocity
        var vs = [DrawingArrow(offset: Vec2(0, 0), color: "blue", direction: scaledVelocity)]
        if vGrads.count > index {
            vs.append(DrawingArrow(offset: scaledVelocity, color: "green", direction: vGrads[index]))
        }
        return vs
    }
    drawSVG(states: [initialState], params: params, vectors: vectors, delay: 3600)
}

In [0]:
@differentiable
func simulate(_ initialState: World) -> World {
    var state = initialState
    while !state.still {
        state = state.stepped()
    }
    return state
}

In [13]:
var v0 = Vector2(20, 0.01)
drawSimulation(ball1InitialVelocity: v0)



In [0]:
@differentiable
func loss(_ v0: Vector2) -> Float {
    // Initialize a world with the given initial velocity.
    let initialState = World(ball1InitialVelocity: v0)

    // Simulate the world forwards to the final state.
    let finalState = simulate(initialState)

    // Sum the closest approaches to the targets.
    return finalState.targetDistances.sum()
}

loss(v0)

In [0]:
let grad = gradient(at: v0, in: loss)
drawGradients(ball1InitialVelocity: v0, grad: grad)
v0 -= learningRate() * grad

In [0]:
drawSimulation(ball1InitialVelocity: v0)

In [0]:
#@title Training Loop

for i in 0..<iterationCount {
    let (loss0, grad) = valueWithGradient(at: v0, in: loss)
    print("\(i): loss \(loss0)")
    v0 -= learningRate(i) * grad
}

In [0]:
drawSimulation(ball1InitialVelocity: v0)

In [0]: