In [0]:
#@title Imports
%install '.package(url: "https://github.com/marcrasi/swift-vec2", .branch("master"))' Vec2
// Clear installation messages from output area.
print("\u{001B}[2J")
import Vec2
typealias Vector2 = Vec2
In [0]:
#@title Helpers
let iterationCount = 100
let maxLr = Float(1e-1)
let minLr = Float(1e-2)
func learningRate(_ i: Int = 0) -> Float {
return (maxLr - minLr) * (cos(Float.pi * Float(i) / Float(iterationCount)) + 1) / 2 + minLr
}
extension Array where Element == Float {
@differentiable
func sum() -> Float {
differentiableReduce(0, +)
}
}
In [0]:
#@title State Data Structures
import Vec2
struct Ball: AdditiveArithmetic, Differentiable {
var position: Vec2
var velocity: Vec2
static let ballRadius = Float(1)
@differentiable
init(position: Vec2, velocity: Vec2) {
self.position = position
self.velocity = velocity
}
@differentiable
func updating(position: Vec2) -> Ball {
Ball(position: position, velocity: velocity)
}
@differentiable
func updating(velocity: Vec2) -> Ball {
Ball(position: position, velocity: velocity)
}
@differentiable
func moved(_ delta: Vec2) -> Ball {
updating(position: position + delta)
}
@differentiable
func impulsed(_ delta: Vec2) -> Ball {
updating(velocity: velocity + delta)
}
}
struct BallTup: Differentiable {
var ball1: Ball
var ball2: Ball
var collisionEnergy: Float
@differentiable
init(_ ball1: Ball, _ ball2: Ball, _ collisionEnergy: Float) {
self.ball1 = ball1
self.ball2 = ball2
self.collisionEnergy = collisionEnergy
}
}
struct Wall {
var p1: Vec2
var p2: Vec2
}
struct SimulationParameters {
var dt: Float = 0.02
var lambda: Float = 0.0
}
extension Ball {
@differentiable
func stepped(_ params: SimulationParameters) -> Ball {
let frictionAcceleration = Float(1)
let friction = Vec2(magnitude: frictionAcceleration * params.dt, direction: velocity.direction)
var newVelocity = friction.magnitude > velocity.magnitude ? Vec2(0, 0) : velocity - friction
if Float.random(in: 0..<1) > exp(-params.lambda * params.dt) {
newVelocity = newVelocity + newVelocity.magnitude * Vec2(Float.random(in: (-0.5...0.5)), Float.random(in: (-0.5...0.5)))
}
return Ball(
position: position + params.dt * newVelocity,
velocity: newVelocity)
}
func touches(_ other: Ball) -> Bool {
return (position - other.position).magnitude <= 2 * Ball.ballRadius
}
private func projected(to wall: Wall) -> Float {
(wall.p1.magnitudeSquared + position.dot(wall.p2 - wall.p1) - wall.p1.dot(wall.p2)) / (wall.p2 - wall.p1).magnitudeSquared
}
func touches(_ wall: Wall) -> Bool {
let t = projected(to: wall)
if t < 0 || t > 1 { return false }
let projection = (1 - t) * wall.p1 + t * wall.p2
return (position - projection).magnitudeSquared <= Ball.ballRadius * Ball.ballRadius
}
@differentiable
func bounced(on wall: Wall) -> Ball {
let tangent = wall.p2 - wall.p1
let unitTangent = tangent / tangent.magnitude
let unitNormal = Vec2(-unitTangent.y, unitTangent.x)
let t = projected(to: wall)
let projection = (1 - t) * wall.p1 + t * wall.p2
let displacement = position - projection
if velocity.dot(displacement) > 0 { return self }
let newVelocity = velocity.dot(unitTangent) * unitTangent - velocity.dot(unitNormal) * unitNormal
return Ball(position: position, velocity: newVelocity)
}
@differentiable
static func collided(_ a: Ball, _ b: Ball) -> BallTup {
updateCollisionVelocities(a, b)
}
private static func updateCollisionVelocities(_ ball1: Ball, _ ball2: Ball) -> BallTup {
// Perfectly elastic collision. This is the impulse along the normal that preserves kinetic energy.
let p = ball2.position - ball1.position
let v = ball2.velocity - ball1.velocity
let vdotp = v.dot(p)
if vdotp > 0 { return BallTup(ball1, ball2, 0) }
let impulse = (-vdotp / p.magnitudeSquared) * p
return BallTup(ball1.impulsed(-1 * impulse), ball2.impulsed(impulse), impulse.magnitudeSquared)
}
}
In [0]:
#@title World State
struct World: Differentiable {
/// Current state of all the balls.
var balls: [Ball]
/// Minimum ball distance (over all previous states) to each target.
var targetDistances: [Float]
/// Amount of time simulation has been running.
var t: Float
/// Positions of the targets.
@noDerivative var targets: [Vec2]
/// Positions of the walls.
@noDerivative var walls: [Wall]
@differentiable
init(
balls: [Ball],
targetDistances: [Float],
t: Float,
targets: [Vec2],
walls: [Wall]
) {
self.balls = balls
self.targetDistances = targetDistances
self.t = t
self.targets = targets
self.walls = walls
}
}
In [0]:
#@title Simulation Logic
extension World {
@differentiable
init(balls: [Ball], targets: [Vec2], walls: [Wall]) {
self.init(
balls: balls,
targetDistances: Array(repeating: Float.infinity, count: withoutDerivative(at: targets.count)),
t: 0,
targets: targets,
walls: walls
)
}
@differentiable
init(ball1InitialVelocity: Vec2) {
self.init(
balls: [
Ball(position: Vec2(-20, 0), velocity: ball1InitialVelocity),
Ball(position: Vec2(-10, 0), velocity: Vec2(0, 0))
],
targets: [
Vec2(0, 5),
Vec2(-5, 15),
Vec2(-2, -20)
],
walls: [
Wall(p1: Vec2(7, -10), p2: Vec2(7, 20))
]
)
}
}
extension World {
var still: Bool {
if t > 7 { return true }
for ball in balls {
if ball.velocity.magnitude > 0 {
return false
}
}
return true
}
}
extension World {
@differentiable
func stepped(_ params: SimulationParameters = SimulationParameters()) -> World {
// Integrate the ball velocity.
var updatedBalls = balls.differentiableMap { $0.stepped(params) }
// Collide the balls with the walls.
updatedBalls = updatedBalls.differentiableMap { [walls = walls] (ball: Ball) -> Ball in
for i in withoutDerivative(at: walls.indices) {
let wall = walls[i]
if ball.touches(wall) {
return ball.bounced(on: wall)
}
}
return ball
}
// Collide the balls with each other.
if updatedBalls[0].touches(updatedBalls[1]) {
let collidedBalls = Ball.collided(updatedBalls[0], updatedBalls[1])
updatedBalls = [collidedBalls.ball1, collidedBalls.ball2]
}
// Update min target distance.
var newMinTargetDistance: [Float] = []
for i in withoutDerivative(at: targets.indices) {
let distTo1 = (updatedBalls[0].position - targets[i]).magnitude
let distTo2 = (updatedBalls[1].position - targets[i]).magnitude
var curTargetDistance = distTo1 < distTo2 ? distTo1 : distTo2
if curTargetDistance < 2 * Ball.ballRadius { curTargetDistance = 2 * Ball.ballRadius }
if curTargetDistance < targetDistances[i] {
newMinTargetDistance = newMinTargetDistance + [curTargetDistance]
} else {
newMinTargetDistance = newMinTargetDistance + [targetDistances[i]]
}
}
return World(
balls: updatedBalls,
targetDistances: newMinTargetDistance.withDerivative { [count = withoutDerivative(at: newMinTargetDistance.count)] (d: inout Array<Float>.DifferentiableView) -> () in
if d.base.count == 0 {
d = Array.DifferentiableView(Array(repeating: 0, count: count))
}
},
t: t + params.dt,
targets: targets,
walls: walls
)
}
@differentiable
func steppedUntilStill(_ params: SimulationParameters, _ f: (World) -> () = { _ in }) -> World {
var state = self
while !state.still {
f(state)
state = state.stepped(params)
}
f(state)
return state
}
}
struct DrawingArrow {
var offset: Vec2
var color: String
var direction: Vec2
}
func svg(states: [World], params: SimulationParameters = SimulationParameters(), vectors: [[DrawingArrow]] = [], delay: Float = 0) -> String {
let scale = Float(7)
let origin = Vec2(175, 70)
let size = Vec2(350, 224)
func transformed(_ position: Vec2) -> Vec2 {
let p2 = Vector2(x: position.y, y: -position.x)
return scale * p2 + origin
}
var r = ""
r += """
<svg width="\(size.x)", height="\(size.y)">\n
"""
let totalDuration = (states.last?.t ?? 0) + delay
r += """
<rect>
<animate
id="looper"
begin="0;looper.end"
attributeName="visibility"
from="hide"
to="hide"
dur="\(totalDuration)s" />
</rect>
"""
func target(id: String, cx: String, cy: String, border: String) -> String {
"""
<circle
id="\(id)"
r="\(scale * Ball.ballRadius)"
cx="\(cx)"
cy="\(cy)"
stroke-width="3"
stroke="\(border)" />\n
"""
}
func circle(id: String, cx: String, cy: String, fill: String) -> String {
"""
<circle
id="\(id)"
r="\(scale * Ball.ballRadius)"
cx="\(cx)"
cy="\(cy)"
fill="\(fill)" />\n
"""
}
func line(p1: Vec2, p2: Vec2) -> String {
"""
<line
x1="\(p1.x)"
y1="\(p1.y)"
x2="\(p2.x)"
y2="\(p2.y)"
style="stroke:#000;stroke-width:2" />\n
"""
}
func arrow(id: String, _ arrow: DrawingArrow, _ base: Vec2) -> String {
let off2 = Vector2(x: arrow.offset.y, y: -arrow.offset.x)
let p1 = base + off2
let dir2 = Vector2(x: arrow.direction.y, y: -arrow.direction.x)
let p2 = p1 + dir2
return """
<line
id="\(id)"
x1="\(p1.x)"
y1="\(p1.y)"
x2="\(p2.x)"
y2="\(p2.y)"
style="stroke:\(arrow.color);stroke-width:2" />\n
"""
}
for (index, finalPosition) in (states.first?.targets ?? []).enumerated() {
let position = transformed(finalPosition)
r += target(id: "target\(index)", cx: "\(position.x)", cy: "\(position.y)", border: "red")
}
for (index, ball) in (states.first?.balls ?? []).enumerated() {
let position = transformed(ball.position)
r += circle(id: "ball\(index)", cx: "\(position.x)", cy: "\(position.y)", fill: "orange")
if vectors.count > index {
for (index2, vector) in vectors[index].enumerated() {
let id = "arrow\(index)_\(index2)"
r += arrow(id: id, vector, position)
r += """
<animate
xlink:href="#\(id)"
attributeName="opacity"
from="0"
to="1"
dur="\(params.dt)s"
begin="looper.begin"
fill="freeze" />
<animate
xlink:href="#\(id)"
attributeName="opacity"
from="1"
to="0"
dur="\(params.dt)s"
begin="looper.begin+\(delay)s"
fill="freeze" />
"""
}
}
}
for wall in (states.first?.walls ?? []) {
r += line(p1: transformed(wall.p1), p2: transformed(wall.p2))
}
for (timeIndex, (state, nextState)) in zip(states, states.dropFirst(1)).enumerated() {
let t = Float(timeIndex) * params.dt + delay
for (ballIndex, (ballState, nextBallState)) in zip(state.balls, nextState.balls).enumerated() {
func animate(attributeName: String, from: String, to: String) -> String {
"""
<animate
xlink:href="#ball\(ballIndex)"
attributeName="\(attributeName)"
from="\(from)"
to="\(to)"
dur="\(params.dt)s"
begin="looper.begin+\(t)s" />\n
"""
}
let position = transformed(ballState.position)
let nextPosition = transformed(nextBallState.position)
r += animate(attributeName: "cx", from: String(position.x), to: String(nextPosition.x))
r += animate(attributeName: "cy", from: String(position.y), to: String(nextPosition.y))
}
for (targetIndex, (targetDistance, nextTargetDistance)) in zip(state.targetDistances, nextState.targetDistances).enumerated() {
if targetDistance > 2 && nextTargetDistance <= 2 {
r += """
<animate
xlink:href="#target\(targetIndex)"
attributeName="stroke"
from="green"
to="red"
dur="\(params.dt)s"
begin="looper.begin"
fill="freeze" />
"""
r += """
<animate
xlink:href="#target\(targetIndex)"
attributeName="stroke"
from="red"
to="green"
dur="\(params.dt)s"
begin="looper.begin+\(t)s"
fill="freeze" />
"""
}
}
}
r += "</svg>\n"
return r
}
import Python
%include "EnableIPythonDisplay.swift"
IPythonDisplay.shell.enable_matplotlib("inline")
let display = Python.import("IPython.core.display")
func drawSVG(states: [World], params: SimulationParameters = SimulationParameters(), vectors: [[DrawingArrow]] = [], delay: Float = 0) {
display[dynamicMember: "display"](display.HTML(svg(states: states, params: params, vectors: vectors, delay: delay)))
}
func drawSimulation(ball1InitialVelocity: Vec2, params: SimulationParameters = SimulationParameters()) {
let initialState = World(ball1InitialVelocity: ball1InitialVelocity)
let vectors = initialState.balls.enumerated().map { (index: Int, ball: Ball) -> [DrawingArrow] in
let scaledVelocity = 2 * ball.velocity
let vs = [DrawingArrow(offset: Vec2(0, 0), color: "blue", direction: scaledVelocity)]
return vs
}
var allStates: [World] = []
initialState.steppedUntilStill(params) { allStates.append($0) }
drawSVG(states: allStates, params: params, vectors: vectors, delay: 1)
}
func drawGradients(ball1InitialVelocity: Vec2, params: SimulationParameters = SimulationParameters(), grad: Vec2) {
let initialState = World(ball1InitialVelocity: ball1InitialVelocity)
let vGrads = [-10 * grad]
let vectors = initialState.balls.enumerated().map { (index: Int, ball: Ball) -> [DrawingArrow] in
let scaledVelocity = 2 * ball.velocity
var vs = [DrawingArrow(offset: Vec2(0, 0), color: "blue", direction: scaledVelocity)]
if vGrads.count > index {
vs.append(DrawingArrow(offset: scaledVelocity, color: "green", direction: vGrads[index]))
}
return vs
}
drawSVG(states: [initialState], params: params, vectors: vectors, delay: 3600)
}
In [0]:
@differentiable
func simulate(_ initialState: World) -> World {
var state = initialState
while !state.still {
state = state.stepped()
}
return state
}
In [13]:
var v0 = Vector2(20, 0.01)
drawSimulation(ball1InitialVelocity: v0)
In [0]:
@differentiable
func loss(_ v0: Vector2) -> Float {
// Initialize a world with the given initial velocity.
let initialState = World(ball1InitialVelocity: v0)
// Simulate the world forwards to the final state.
let finalState = simulate(initialState)
// Sum the closest approaches to the targets.
return finalState.targetDistances.sum()
}
loss(v0)
In [0]:
let grad = gradient(at: v0, in: loss)
drawGradients(ball1InitialVelocity: v0, grad: grad)
v0 -= learningRate() * grad
In [0]:
drawSimulation(ball1InitialVelocity: v0)
In [0]:
#@title Training Loop
for i in 0..<iterationCount {
let (loss0, grad) = valueWithGradient(at: v0, in: loss)
print("\(i): loss \(loss0)")
v0 -= learningRate(i) * grad
}
In [0]:
drawSimulation(ball1InitialVelocity: v0)
In [0]: