In [0]:
#@title ##### License
# Copyright 2018 The GraphNets Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
The purpose of this tutorial is to get practical experience using the Graph Nets library via examples of:
graph_nets.graphs.GraphsTuple
using graph_nets.utils_np
.graph_nets.utils_tf
.graph_nets.modules
.graph_nets.blocks
.For more information about graph networks, see our arXiv paper: Relational inductive biases, deep learning, and graph networks.
In [0]:
#@title ### Install the Graph Nets library on this Colaboratory runtime { form-width: "60%", run: "auto"}
#@markdown <br>1. Connect to a local or hosted Colaboratory runtime by clicking the **Connect** button at the top-right.<br>2. Choose "Yes" below to install the Graph Nets library on the runtime machine with the correct dependencies. Note, this works both with local and hosted Colaboratory runtimes.
install_graph_nets_library = "No" #@param ["Yes", "No"]
if install_graph_nets_library.lower() == "yes":
print("Installing Graph Nets library and dependencies:")
print("Output message from command:\n")
!pip install graph_nets "dm-sonnet<2" "tensorflow_probability<0.9"
else:
print("Skipping installation of Graph Nets library")
If you are running this notebook locally (i.e., not through Colaboratory), you will also need to install a few more dependencies. Run the following on the command line to install the graph networks library, as well as a few other dependencies:
``` pip install graph_nets matplotlib scipy "tensorflow>=1.15,<2" "dm-sonnet<2" "tensorflow_probability<0.9"
In [0]:
#@title #### (Imports)
%tensorflow_version 1.x # For Google Colab only.
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from graph_nets import blocks
from graph_nets import graphs
from graph_nets import modules
from graph_nets import utils_np
from graph_nets import utils_tf
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import sonnet as snt
import tensorflow as tf
Tutorial of the Graph Nets library
How to represent graphs as a graphs.GraphsTuple
Different ways of expressing data as a graph
Working with tensor GraphsTuple's
Creating a constant tensor GraphsTuple from data dicts
Creating a modules.GraphNetwork
Various canonical Graph Net modules
Independent Graph Net (modules.GraphIndependent)
Message-passing neural networks (modules.InteractionNetwork, modules.CommNet)
Non-local neural networks (modules.SelfAttention)
graphs.GraphsTuple
classThe Graph Nets library contains models which operate on graph-structured data, so the first thing to understand is how graph-structured data is represented in the code.
The graph_nets.graphs.GraphsTuple
class, defined in graph_nets/graphs.py
, represents a batches of one or more graphs. All graph network modules take instances of GraphsTuple
as input, and return instances of GraphsTuple
as output. The graphs are directed (one-way edges), attributed (node-, edge-, and graph-level features are allowed), multigraphs (multiple edges can connect any two nodes, and self-edges are allowed). See Box 3, page 11 in our companion arXiv paper for details.
A GraphsTuple
has attributes:
n_node
(shape=[num_graphs]): Number of nodes in each graph in the batch.n_edge
(shape=[num_graphs]): Number of edges in each graph in the batch.globals
(shape=[num_graphs] + global_feature_dimensions): Global features for each graph in the batch.nodes
(shape=[total_num_nodes] + node_feature_dimensions): Node features for each node in the batch of graphs.edges
(shape=[total_num_edges] + edge_feature_dimensions): Edge features for each edge in the batch of graphs.senders
(shape=[total_num_edges]): Indices of the nodes in nodes
, which indicate the source node of each directed edge in edges
.receivers
(shape=[total_num_edges]): Indices of the nodes in nodes
, which indicate the destination node of each directed edge in edges
.The nodes and edges from the different graphs in the batch are concatenated along the first axis of the nodes
and edges
fields, and can be partitioned using the n_node
and n_edge
fields respectively. Note, all but the "n_*
" fields are optional (see examples below).
The attributes of a GraphsTuple
instance are typically either Numpy arrays or TensorFlow tensors. The library contains utilities for manipulating graphs with each of these types of attributes, respectively:
utils_np
(for Numpy arrays)utils_tf
(for TensorFlow tensors)An important method of the GraphsTuple
class is GraphsTuple.replace
: Similarly to collections.namedtuple._replace
(in fact, GraphsTuple
is sub-class of collections.namedtuple
), this method creates a copy of the GraphsTuple
, with references to all of the original attributes, by replacing some of them by the values provided as keyword arguments.
Each graph will have a global feature, several nodes, and several edges. The graphs can have different numbers of nodes and edges, but the lengths of the global, node, and edge attribute vectors must be the same across graphs.
In order to create a graphs.GraphsTuple
instance, we can define a list
whose elements are dict
s, with the following keys, that contain each graph's data:
float
-valued feature vector.float
-valued feature vectors.float
-valued feature vectors.int
-valued node index, to a receiver node.int
-valued node index.Try running the cell below to create some dummy graph data.
In [0]:
# Global features for graph 0.
globals_0 = [1., 2., 3.]
# Node features for graph 0.
nodes_0 = [[10., 20., 30.], # Node 0
[11., 21., 31.], # Node 1
[12., 22., 32.], # Node 2
[13., 23., 33.], # Node 3
[14., 24., 34.]] # Node 4
# Edge features for graph 0.
edges_0 = [[100., 200.], # Edge 0
[101., 201.], # Edge 1
[102., 202.], # Edge 2
[103., 203.], # Edge 3
[104., 204.], # Edge 4
[105., 205.]] # Edge 5
# The sender and receiver nodes associated with each edge for graph 0.
senders_0 = [0, # Index of the sender node for edge 0
1, # Index of the sender node for edge 1
1, # Index of the sender node for edge 2
2, # Index of the sender node for edge 3
2, # Index of the sender node for edge 4
3] # Index of the sender node for edge 5
receivers_0 = [1, # Index of the receiver node for edge 0
2, # Index of the receiver node for edge 1
3, # Index of the receiver node for edge 2
0, # Index of the receiver node for edge 3
3, # Index of the receiver node for edge 4
4] # Index of the receiver node for edge 5
# Global features for graph 1.
globals_1 = [1001., 1002., 1003.]
# Node features for graph 1.
nodes_1 = [[1010., 1020., 1030.], # Node 0
[1011., 1021., 1031.]] # Node 1
# Edge features for graph 1.
edges_1 = [[1100., 1200.], # Edge 0
[1101., 1201.], # Edge 1
[1102., 1202.], # Edge 2
[1103., 1203.]] # Edge 3
# The sender and receiver nodes associated with each edge for graph 1.
senders_1 = [0, # Index of the sender node for edge 0
0, # Index of the sender node for edge 1
1, # Index of the sender node for edge 2
1] # Index of the sender node for edge 3
receivers_1 = [0, # Index of the receiver node for edge 0
1, # Index of the receiver node for edge 1
0, # Index of the receiver node for edge 2
0] # Index of the receiver node for edge 3
data_dict_0 = {
"globals": globals_0,
"nodes": nodes_0,
"edges": edges_0,
"senders": senders_0,
"receivers": receivers_0
}
data_dict_1 = {
"globals": globals_1,
"nodes": nodes_1,
"edges": edges_1,
"senders": senders_1,
"receivers": receivers_1
}
graphs.GraphsTuple
The utils_np
module contains a functions named utils_np.data_dicts_to_graphs_tuple
, which takes a list
of dict
s with the keys specified above, and returns a GraphsTuple
that represents the sequence of graphs.
The data_dicts_to_graphs_tuple
function does three things:
Try running the cell below to put the graph dictionaries into a GraphsTuple
using utils_np.data_dicts_to_graphs_tuple
.
In [0]:
data_dict_list = [data_dict_0, data_dict_1]
graphs_tuple = utils_np.data_dicts_to_graphs_tuple(data_dict_list)
In [0]:
graphs_nx = utils_np.graphs_tuple_to_networkxs(graphs_tuple)
_, axs = plt.subplots(ncols=2, figsize=(6, 3))
for iax, (graph_nx, ax) in enumerate(zip(graphs_nx, axs)):
nx.draw(graph_nx, ax=ax)
ax.set_title("Graph {}".format(iax))
In [0]:
def print_graphs_tuple(graphs_tuple):
print("Shapes of `GraphsTuple`'s fields:")
print(graphs_tuple.map(lambda x: x if x is None else x.shape, fields=graphs.ALL_FIELDS))
print("\nData contained in `GraphsTuple`'s fields:")
print("globals:\n{}".format(graphs_tuple.globals))
print("nodes:\n{}".format(graphs_tuple.nodes))
print("edges:\n{}".format(graphs_tuple.edges))
print("senders:\n{}".format(graphs_tuple.senders))
print("receivers:\n{}".format(graphs_tuple.receivers))
print("n_node:\n{}".format(graphs_tuple.n_node))
print("n_edge:\n{}".format(graphs_tuple.n_edge))
print_graphs_tuple(graphs_tuple)
In [0]:
recovered_data_dict_list = utils_np.graphs_tuple_to_data_dicts(graphs_tuple)
In [0]:
# Number of nodes
n_node = 3
# Three edges connecting the nodes in a cycle
senders = [0, 1, 2] # Indices of nodes sending the edges
receivers = [1, 2, 0] # Indices of nodes receiving the edges
data_dict = {
"n_node": n_node,
"senders": senders,
"receivers": receivers,
}
graphs_tuple = utils_np.data_dicts_to_graphs_tuple([data_dict])
In [0]:
# Node features.
nodes = [[10.], # Node 0
[11.], # Node 1
[12.]] # Node 2
data_dict = {
"nodes": nodes,
}
graphs_tuple = utils_np.data_dicts_to_graphs_tuple([data_dict])
# We can visualize the graph using networkx.
graphs_nx = utils_np.graphs_tuple_to_networkxs(graphs_tuple)
ax = plt.figure(figsize=(3, 3)).gca()
nx.draw(graphs_nx[0], ax=ax)
_ = ax.set_title("Graph without edges")
GraphsTuple
from a networkx
graphnetworkx
is a powerful graph manipulation library in Python. A GraphsTuple
to be built from networkx
graphs as follows:
In [0]:
graph_nx = nx.OrderedMultiDiGraph()
# Globals.
graph_nx.graph["features"] = np.array([0.6, 0.7, 0.8])
# Nodes.
graph_nx.add_node(0, features=np.array([0.3, 1.3]))
graph_nx.add_node(1, features=np.array([0.4, 1.4]))
graph_nx.add_node(2, features=np.array([0.5, 1.5]))
graph_nx.add_node(3, features=np.array([0.6, 1.6]))
# Edges.
graph_nx.add_edge(0, 1, features=np.array([3.6, 3.7]))
graph_nx.add_edge(2, 0, features=np.array([5.6, 5.7]))
graph_nx.add_edge(3, 0, features=np.array([6.6, 6.7]))
ax = plt.figure(figsize=(3, 3)).gca()
nx.draw(graph_nx, ax=ax)
ax.set_title("Graph")
graphs_tuple = utils_np.networkxs_to_graphs_tuple([graph_nx])
print_graphs_tuple(graphs_tuple)
In [0]:
#@title #### (Define functions for generating and plotting graphs)
GLOBAL_SIZE = 4
NODE_SIZE = 5
EDGE_SIZE = 6
def get_graph_data_dict(num_nodes, num_edges):
return {
"globals": np.random.rand(GLOBAL_SIZE).astype(np.float32),
"nodes": np.random.rand(num_nodes, NODE_SIZE).astype(np.float32),
"edges": np.random.rand(num_edges, EDGE_SIZE).astype(np.float32),
"senders": np.random.randint(num_nodes, size=num_edges, dtype=np.int32),
"receivers": np.random.randint(num_nodes, size=num_edges, dtype=np.int32),
}
graph_3_nodes_4_edges = get_graph_data_dict(num_nodes=3, num_edges=4)
graph_5_nodes_8_edges = get_graph_data_dict(num_nodes=5, num_edges=8)
graph_7_nodes_13_edges = get_graph_data_dict(num_nodes=7, num_edges=13)
graph_9_nodes_25_edges = get_graph_data_dict(num_nodes=9, num_edges=25)
graph_dicts = [graph_3_nodes_4_edges, graph_5_nodes_8_edges,
graph_7_nodes_13_edges, graph_9_nodes_25_edges]
def plot_graphs_tuple_np(graphs_tuple):
networkx_graphs = utils_np.graphs_tuple_to_networkxs(graphs_tuple)
num_graphs = len(networkx_graphs)
_, axes = plt.subplots(1, num_graphs, figsize=(5*num_graphs, 5))
if num_graphs == 1:
axes = axes,
for graph, ax in zip(networkx_graphs, axes):
plot_graph_networkx(graph, ax)
def plot_graph_networkx(graph, ax, pos=None):
node_labels = {node: "{:.3g}".format(data["features"][0])
for node, data in graph.nodes(data=True)
if data["features"] is not None}
edge_labels = {(sender, receiver): "{:.3g}".format(data["features"][0])
for sender, receiver, data in graph.edges(data=True)
if data["features"] is not None}
global_label = ("{:.3g}".format(graph.graph["features"][0])
if graph.graph["features"] is not None else None)
if pos is None:
pos = nx.spring_layout(graph)
nx.draw_networkx(graph, pos, ax=ax, labels=node_labels)
if edge_labels:
nx.draw_networkx_edge_labels(graph, pos, edge_labels, ax=ax)
if global_label:
plt.text(0.05, 0.95, global_label, transform=ax.transAxes)
ax.yaxis.set_visible(False)
ax.xaxis.set_visible(False)
return pos
def plot_compare_graphs(graphs_tuples, labels):
pos = None
num_graphs = len(graphs_tuples)
_, axes = plt.subplots(1, num_graphs, figsize=(5*num_graphs, 5))
if num_graphs == 1:
axes = axes,
pos = None
for name, graphs_tuple, ax in zip(labels, graphs_tuples, axes):
graph = utils_np.graphs_tuple_to_networkxs(graphs_tuple)[0]
pos = plot_graph_networkx(graph, ax, pos=pos)
ax.set_title(name)
GraphsTuple
from data dictsSimilar to utils_np.data_dicts_to_graphs_tuple
, the utils_tf
module, which manipulates graphs whose attributes are represented as TensorFlow tensors, contains a function named utils_tf.data_dicts_to_graphs_tuple
, which creates a constant tensor graph from data dicts, containing either numpy arrays of tensors.
In [0]:
tf.reset_default_graph()
graphs_tuple_tf = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)
with tf.Session() as sess:
graphs_tuple_np = sess.run(graphs_tuple_tf)
plot_graphs_tuple_np(graphs_tuple_np)
In [0]:
# If the GraphsTuple has None's we need to make use of `utils_tf.make_runnable_in_session`.
tf.reset_default_graph()
graphs_tuple_tf = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)
# Removing the edges from a graph.
graph_with_nones = graphs_tuple_tf.replace(
edges=None, senders=None, receivers=None, n_edge=graphs_tuple_tf.n_edge*0)
runnable_in_session_graph = utils_tf.make_runnable_in_session(graph_with_nones)
with tf.Session() as sess:
graphs_tuple_np = sess.run(runnable_in_session_graph)
plot_graphs_tuple_np(graphs_tuple_np)
GraphsTuple
placeholdersIn TensorFlow, data is often passed into a session via placeholder tensors. The cell below shows how to create placeholders for graph data.
In [0]:
tf.reset_default_graph()
# Create a placeholder using the first graph in the list as template.
graphs_tuple_ph = utils_tf.placeholders_from_data_dicts(graph_dicts[0:1])
with tf.Session() as sess:
# Feeding a batch of graphs with different sizes, and different
# numbers of nodes and edges through the placeholder.
feed_dict = utils_tf.get_feed_dict(
graphs_tuple_ph, utils_np.data_dicts_to_graphs_tuple(graph_dicts[1:]))
graphs_tuple_np = sess.run(graphs_tuple_ph, feed_dict)
plot_graphs_tuple_np(graphs_tuple_np)
A similar utility is provided to work with networkx
graphs: utils_np.data_dict_to_networkx
.
In [0]:
# If the GraphsTuple has None's we need to make use of `utils_tf.make_runnable_in_session`.
tf.reset_default_graph()
graphs_tuple_tf = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)
first_graph_tf = utils_tf.get_graph(graphs_tuple_tf, 0)
three_graphs_tf = utils_tf.get_graph(graphs_tuple_tf, slice(1, 4))
with tf.Session() as sess:
first_graph_np = sess.run(first_graph_tf)
three_graphs_np = sess.run(three_graphs_tf)
plot_graphs_tuple_np(first_graph_np)
plot_graphs_tuple_np(three_graphs_np)
In [0]:
# Concatenating along the batch dimension
tf.reset_default_graph()
graphs_tuple_1_tf = utils_tf.data_dicts_to_graphs_tuple(graph_dicts[0:1])
graphs_tuple_2_tf = utils_tf.data_dicts_to_graphs_tuple(graph_dicts[1:])
graphs_tuple_tf = utils_tf.concat([graphs_tuple_1_tf, graphs_tuple_2_tf], axis=0)
with tf.Session() as sess:
graphs_tuple_np = sess.run(graphs_tuple_tf)
plot_graphs_tuple_np(graphs_tuple_np)
Similarly, we can concatenate along feature dimensions, assuming all of the batches to be concatenates have the same graph structure/connectivity.
See utils_tf
for more methods to work with GraphsTuple's containing tensors.
modules.GraphNetwork
A graph network has up to three learnable sub-functions: edge ($\phi^e$), node ($\phi^v$), and global ($\phi^u$) in the schematic above. See Section 3.2.2, page 12 in our companion arXiv paper for details.
To instantiate a graph network module in the library, these sub-functions are specified via constructor arguments which are callable
s that return Sonnet modules, such as snt.Linear
or snt.nets.MLP
.
The reason that a callable
is provided, instead of the module/method directly, is so the Graph Net object owns the modules and the variables created by them.
In [0]:
tf.reset_default_graph()
OUTPUT_EDGE_SIZE = 10
OUTPUT_NODE_SIZE = 11
OUTPUT_GLOBAL_SIZE = 12
graph_network = modules.GraphNetwork(
edge_model_fn=lambda: snt.Linear(output_size=OUTPUT_EDGE_SIZE),
node_model_fn=lambda: snt.Linear(output_size=OUTPUT_NODE_SIZE),
global_model_fn=lambda: snt.Linear(output_size=OUTPUT_GLOBAL_SIZE))
GraphsTuple
to a Graph NetA GraphsTuple
can be fed into a graph network, which returns an output graph with the same number of nodes, edges, and edge connectivity, but with updated edge, node and global features. All of the output features are conditioned on the input features according to the graph structure, and are fully differentiable.
In [0]:
input_graphs = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)
output_graphs = graph_network(input_graphs)
print("Output edges size: {}".format(output_graphs.edges.shape[-1])) # Equal to OUTPUT_EDGE_SIZE
print("Output nodes size: {}".format(output_graphs.nodes.shape[-1])) # Equal to OUTPUT_NODE_SIZE
print("Output globals size: {}".format(output_graphs.globals.shape[-1])) # Equal to OUTPUT_GLOBAL_SIZE
GraphNetwork
recurrentlyA Graph Net module can be chained recurrently by matching the output feature sizes to the input feature sizes, and feeding the output back to the input multiple times (arXiv paper, bottom of Fig. 6a).
In [0]:
tf.reset_default_graph()
input_graphs = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)
graph_network = modules.GraphNetwork(
edge_model_fn=lambda: snt.Linear(output_size=EDGE_SIZE),
node_model_fn=lambda: snt.Linear(output_size=NODE_SIZE),
global_model_fn=lambda: snt.Linear(output_size=GLOBAL_SIZE))
num_recurrent_passes = 3
previous_graphs = input_graphs
for unused_pass in range(num_recurrent_passes):
previous_graphs = graph_network(previous_graphs)
output_graphs = previous_graphs
Alternatively, we can process the input graph multiple times with a graph state that gets updated recurrently.
In [0]:
def zeros_graph(sample_graph, edge_size, node_size, global_size):
zeros_graphs = sample_graph.replace(nodes=None, edges=None, globals=None)
zeros_graphs = utils_tf.set_zero_edge_features(zeros_graphs, edge_size)
zeros_graphs = utils_tf.set_zero_node_features(zeros_graphs, node_size)
zeros_graphs = utils_tf.set_zero_global_features(zeros_graphs, global_size)
return zeros_graphs
tf.reset_default_graph()
graph_network = modules.GraphNetwork(
edge_model_fn=lambda: snt.Linear(output_size=OUTPUT_EDGE_SIZE),
node_model_fn=lambda: snt.Linear(output_size=OUTPUT_NODE_SIZE),
global_model_fn=lambda: snt.Linear(output_size=OUTPUT_GLOBAL_SIZE))
input_graphs = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)
initial_state = zeros_graph(
input_graphs, OUTPUT_EDGE_SIZE, OUTPUT_NODE_SIZE, OUTPUT_GLOBAL_SIZE)
num_recurrent_passes = 3
current_state = initial_state
for unused_pass in range(num_recurrent_passes):
input_and_state_graphs = utils_tf.concat(
[input_graphs, current_state], axis=1)
current_state = graph_network(input_and_state_graphs)
output_graphs = current_state
Similarly, recurrent modules with gating, such as an LSTM or GRU, can be applied on the edges, nodes, and globals of the state and input graphs separately.
Other canonical modules discussed in Figure 4 of our arXiv paper are provided in graph_nets.modules
:
modules.GraphIndependent
(updates the global, node, and edge features independently, without message-passing)modules.InteractionNetwork
(an example of a "Message-passing neural network")modules.CommNet
(another example of a "Message-passing neural network")modules.SelfAttention
(an example of a "Non-local neural network")modules.RelationNetwork
modules.DeepSets
See documentation for more details and corresponding references.
Broadcast operations allow to transfer information between different types of elements in the graph:
blocks.broadcast_globals_to_nodes
: Copy/broadcast global features across all nodes.blocks.broadcast_globals_to_edges
: Copy/broadcast global features across all edges.blocks.broadcast_sender_nodes_to_edges
: Copy/broadcast node information from each node, across all edges for which that node is a sender.blocks.broadcast_receiver_nodes_to_edges
: Copy/broadcast node information from each node, across all edges for which that node is a receiver.
In [0]:
tf.reset_default_graph()
graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])
updated_broadcast_globals_to_nodes = graphs_tuple.replace(
nodes=blocks.broadcast_globals_to_nodes(graphs_tuple))
updated_broadcast_globals_to_edges = graphs_tuple.replace(
edges=blocks.broadcast_globals_to_edges(graphs_tuple))
updated_broadcast_sender_nodes_to_edges = graphs_tuple.replace(
edges=blocks.broadcast_sender_nodes_to_edges(graphs_tuple))
updated_broadcast_receiver_nodes_to_edges = graphs_tuple.replace(
edges=blocks.broadcast_receiver_nodes_to_edges(graphs_tuple))
with tf.Session() as sess:
output_graphs = sess.run([
graphs_tuple,
updated_broadcast_globals_to_nodes,
updated_broadcast_globals_to_edges,
updated_broadcast_sender_nodes_to_edges,
updated_broadcast_receiver_nodes_to_edges])
plot_compare_graphs(output_graphs, labels=[
"Input graph",
"blocks.broadcast_globals_to_nodes",
"blocks.broadcast_globals_to_edges",
"blocks.broadcast_sender_nodes_to_edges",
"blocks.broadcast_receiver_nodes_to_edges"])
We can easily use broadcasters to, for example, set the value of each edge to be the sum of the first feature element of: the input edges, the sender nodes, the receiver nodes, and the global feature.
In [0]:
tf.reset_default_graph()
graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])
updated_graphs_tuple = graphs_tuple.replace(
edges=(graphs_tuple.edges[:, :1] +
blocks.broadcast_receiver_nodes_to_edges(graphs_tuple)[:, :1] +
blocks.broadcast_sender_nodes_to_edges(graphs_tuple)[:, :1] +
blocks.broadcast_globals_to_edges(graphs_tuple)[:, :1]))
with tf.Session() as sess:
output_graphs = sess.run([
graphs_tuple,
updated_graphs_tuple])
plot_compare_graphs(output_graphs, labels=[
"Input graph",
"Updated graph"])
Aggregators perform reduce operations between different elements of the graph:
blocks.EdgesToGlobalsAggregator
: Aggregates the sets of features for all edges into a single global set of features.blocks.NodesToGlobalsAggregator
: Aggregates the sets of features for all nodes into a single global set of features.blocks.SentEdgesToNodesAggregator
: Aggregates the sets of features for all edges sent by each node into a single set of features for that node.blocks.ReceivedEdgesToNodesAggregator
: Aggregates the sets of features for all edges received by each node into a single set of features for that.Different types of reduce operations are:
tf.unsorted_segment_sum
: Elementwise sum. Set to 0 for empty sets.tf.unsorted_segment_mean
: Elementwise mean. Set to 0 for empty sets.tf.unsorted_segment_prod
: Elementwise prod. Set to 1 for empty sets.blocks.unsorted_segment_max_or_zero
: Elementwise max. Set to 0 for empty sets.blocks.unsorted_segment_min_or_zero
: Elementwise min. Set to 0 for empty sets.
In [0]:
tf.reset_default_graph()
graphs_tuple = utils_tf.data_dicts_to_graphs_tuple([data_dict_0])
reducer = tf.unsorted_segment_sum
updated_edges_to_globals = graphs_tuple.replace(
globals=blocks.EdgesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_nodes_to_globals = graphs_tuple.replace(
globals=blocks.NodesToGlobalsAggregator(reducer=reducer)(graphs_tuple))
updated_sent_edges_to_nodes = graphs_tuple.replace(
nodes=blocks.SentEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))
updated_received_edges_to_nodes = graphs_tuple.replace(
nodes=blocks.ReceivedEdgesToNodesAggregator(reducer=reducer)(graphs_tuple))
with tf.Session() as sess:
output_graphs = sess.run([
graphs_tuple,
updated_edges_to_globals,
updated_nodes_to_globals,
updated_sent_edges_to_nodes,
updated_received_edges_to_nodes])
plot_compare_graphs(output_graphs, labels=[
"Input graph",
"blocks.EdgesToGlobalsAggregator",
"blocks.NodesToGlobalsAggregator",
"blocks.SentEdgesToNodesAggregator",
"blocks.ReceivedEdgesToNodesAggregator"])
blocks.EdgeBlock
An EdgeBlock consists of applying a function to the concatenation of:
graphs_tuple.edges
blocks.broadcast_sender_nodes_to_edges(graphs_tuple)
blocks.broadcast_receiver_nodes_to_edges(graphs_tuple)
blocks.broadcast_globals_to_edges(graphs_tuple)
The result is a graph with new edge features conditioned on input edges, nodes and global features according to the graph structure.
In [0]:
tf.reset_default_graph()
edge_block = blocks.EdgeBlock(
edge_model_fn=lambda: snt.Linear(output_size=10))
input_graphs = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)
output_graphs = edge_block(input_graphs)
print(("Output edges size: {}".format(output_graphs.edges.shape[-1])))
blocks.NodeBlock
An NodeBlock consists of applying a function to the concatenation of:
graphs_tuple.nodes
blocks.ReceivedEdgesToNodesAggregator(<reducer-function>)(graphs_tuple)
blocks.broadcast_globals_to_nodes(graphs_tuple)
The result is a graph with new node features conditioned on input edges, nodes and global features according to the graph structure.
In [0]:
tf.reset_default_graph()
node_block = blocks.NodeBlock(
node_model_fn=lambda: snt.Linear(output_size=15))
input_graphs = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)
output_graphs = node_block(input_graphs)
print(("Output nodes size: {}".format(output_graphs.nodes.shape[-1])))
blocks.GlobalBlock
An GlobalBlock consists of applying a function to the concatenation of:
graphs_tuple.globals
blocks.EdgesToGlobalsAggregator(<reducer-function>)(graphs_tuple)
blocks.NodesToGlobalsAggregator(<reducer-function>)(graphs_tuple)
The result is a graph with new globals features conditioned on input edges, nodes and global features.
In [0]:
tf.reset_default_graph()
global_block = blocks.GlobalBlock(
global_model_fn=lambda: snt.Linear(output_size=20))
input_graphs = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)
output_graphs = global_block(input_graphs)
print(("Output globals size: {}".format(output_graphs.globals.shape[-1])))
In [0]:
tf.reset_default_graph()
graph_network = modules.GraphNetwork(
edge_model_fn=lambda: snt.Linear(output_size=10),
node_model_fn=lambda: snt.Linear(output_size=15),
global_model_fn=lambda: snt.Linear(output_size=20))
input_graphs = utils_tf.data_dicts_to_graphs_tuple(graph_dicts)
output_graphs = graph_network(input_graphs)
for var in graph_network.variables:
print(var)
Most of the existing neural networks operating on graphs can be built upon this set of building blocks using their different configuration options. See graph_nets.modules
for some examples.