In [1]:
"""
Bonus Challenge!
Write your code in Add (scroll down).
"""
class Node(object):
def __init__(self, inbound_nodes=[]):
# Nodes from which this Node receives values
self.inbound_nodes = inbound_nodes
# Nodes to which this Node passes values
self.outbound_nodes = []
# A calculated value
self.value = None
# Add this node as an outbound node on its inputs.
for n in self.inbound_nodes:
n.outbound_nodes.append(self)
# These will be implemented in a subclass.
def forward(self):
"""
Forward propagation.
Compute the output value based on `inbound_nodes` and
store the result in self.value.
"""
raise NotImplemented
In [2]:
class Input(Node):
def __init__(self):
# An Input Node has no inbound nodes,
# so no need to pass anything to the Node instantiator
Node.__init__(self)
# NOTE: Input Node is the only Node where the value
# may be passed as an argument to forward().
#
# All other Node implementations should get the value
# of the previous nodes from self.inbound_nodes
#
# Example:
# val0 = self.inbound_nodes[0].value
def forward(self, value=None):
# Overwrite the value if one is passed in.
if value is not None:
self.value = value
In [8]:
"""
Can you augment the Add class so that it accepts
any number of nodes as input?
Hint: this may be useful:
https://docs.python.org/3/tutorial/controlflow.html#unpacking-argument-lists
"""
class Add(Node):
# You may need to change this...
def __init__(self, *inputs):
Node.__init__(self, inputs)
def forward(self):
value = 0
for node in range(len(self.inbound_nodes)):
value += self.inbound_nodes[node].value
self.value = value
In [17]:
class Mul(Node):
def __init__(self, *inputs):
Node.__init__(self, inputs)
def forward(self):
value = 1
for node in range(len(self.inbound_nodes)):
value *= self.inbound_nodes[node].value
self.value = value
In [18]:
def topological_sort(feed_dict):
"""
Sort the nodes in topological order using Kahn's Algorithm.
`feed_dict`: A dictionary where the key is a `Input` Node and the value is the respective value feed to that Node.
Returns a list of sorted nodes.
"""
input_nodes = [n for n in feed_dict.keys()]
G = {}
nodes = [n for n in input_nodes]
while len(nodes) > 0:
n = nodes.pop(0)
if n not in G:
G[n] = {'in': set(), 'out': set()}
for m in n.outbound_nodes:
if m not in G:
G[m] = {'in': set(), 'out': set()}
G[n]['out'].add(m)
G[m]['in'].add(n)
nodes.append(m)
L = []
S = set(input_nodes)
while len(S) > 0:
n = S.pop()
if isinstance(n, Input):
n.value = feed_dict[n]
L.append(n)
for m in n.outbound_nodes:
G[n]['out'].remove(m)
G[m]['in'].remove(n)
# if no other incoming edges add to S
if len(G[m]['in']) == 0:
S.add(m)
return L
In [19]:
def forward_pass(output_node, sorted_nodes):
"""
Performs a forward pass through a list of sorted nodes.
Arguments:
`output_node`: A node in the graph, should be the output node (have no outgoing edges).
`sorted_nodes`: A topologically sorted list of nodes.
Returns the output Node's value
"""
for n in sorted_nodes:
n.forward()
return output_node.value
In [20]:
if __name__ == "__main__":
"""
No need to change anything here!
If all goes well, this should work after you
modify the Add class in miniflow.py.
"""
#from miniflow import *
x, y, z = Input(), Input(), Input()
f = Add(x, y, z)
feed_dict = {x: 4, y: 5, z: 10}
graph = topological_sort(feed_dict)
output = forward_pass(f, graph)
# should output 19
print("{} + {} + {} = {} (according to miniflow)".format(feed_dict[x], feed_dict[y], feed_dict[z], output))
In [21]:
x2, y2, z2 = Input(), Input(), Input()
f = Add (x, y, z, x2, y2, z2)
feed_dict = {x: 4, y: 5, z: 10, x2: 4, y2: 5, z2: 10}
graph = topological_sort(feed_dict)
output = forward_pass(f, graph)
# should output 38
print("{} + {} + {} + {} + {} + {} = {} (according to miniflow)".format(feed_dict[x], feed_dict[y], feed_dict[z], feed_dict[x2], feed_dict[y2], feed_dict[z2], output))
In [22]:
f = Mul(x, y, z)
feed_dict = {x: 4, y: 5, z: 10}
graph = topological_sort(feed_dict)
output = forward_pass(f, graph)
# should output 200
print("{} * {} * {} = {} (according to miniflow)".format(feed_dict[x], feed_dict[y], feed_dict[z], output))
In [ ]: