Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a few typos in basic.py and add missing docstring. #28

Merged
merged 1 commit into from
May 10, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions jraph/examples/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,10 @@ def run():
senders=np.array([0, 1]), receivers=np.array([2, 2]))
logging.info("Nested graph %r", nested_graph)

# Creates a GraphsTuple from scratch containing a 2 graphs using an implicit
# Creates a GraphsTuple from scratch containing 2 graphs using an implicit
# batch dimension.
# The first graph has 3 nodes and 2 edges.
# The second graph has 1 nodes and 1 edges.
# The second graph has 1 node and 1 edge.
# Each node has a 4-dimensional feature vector.
# Each edge has a 5-dimensional feature vector.
# The graph itself has a 6-dimensional feature vector.
Expand Down Expand Up @@ -93,7 +93,7 @@ def run():
# Creates a padded GraphsTuple from an existing GraphsTuple.
# The padded GraphsTuple will contain 10 nodes, 5 edges, and 4 graphs.
# Three graphs are added for the padding.
# First an dummy graph which contains the padding nodes and edges and secondly
# First a dummy graph which contains the padding nodes and edges and secondly
# two empty graphs without nodes or edges to pad out the graphs.
padded_graph = jraph.pad_with_graphs(
single_graph, n_node=10, n_edge=5, n_graph=4)
Expand All @@ -104,7 +104,7 @@ def run():
single_graph = jraph.unpad_with_graphs(padded_graph)
logging.info("Unpadded graph %r", single_graph)

# Creates a GraphsTuple containing a 2 graphs using an explicit batch
# Creates a GraphsTuple containing 2 graphs using an explicit batch
# dimension.
# An explicit batch dimension requires more memory, but can simplify
# the definition of functions operating on the graph.
Expand All @@ -113,7 +113,7 @@ def run():
# Using an explicit batch requires padding all feature vectors to
# the maximum size of nodes and edges.
# The first graph has 3 nodes and 2 edges.
# The second graph has 1 nodes and 1 edges.
# The second graph has 1 node and 1 edge.
# Each node has a 4-dimensional feature vector.
# Each edge has a 5-dimensional feature vector.
# The graph itself has a 6-dimensional feature vector.
Expand All @@ -125,7 +125,7 @@ def run():
receivers=np.array([[2, 2], [0, -1]]))
logging.info("Explicitly batched graph %r", explicitly_batched_graph)

# Running a graph propagation steps.
# Running a graph propagation step.
# First define the update functions for the edges, nodes and globals.
# In this example we use the identity everywhere.
# For Graph neural networks, each update function is typically a neural
Expand Down Expand Up @@ -156,6 +156,7 @@ def update_globals_fn(
aggregated_node_features,
aggregated_edge_features,
globals_):
"""Returns the global features."""
del aggregated_node_features
del aggregated_edge_features
return globals_
Expand All @@ -166,8 +167,8 @@ def update_globals_fn(
aggregate_nodes_for_globals_fn = jraph.segment_sum
aggregate_edges_for_globals_fn = jraph.segment_sum

# Optionally define attention logit function and attention reduce function.
# This can be used for graph attention.
# Optionally define an attention logit function and an attention reduce
# function. This can be used for graph attention.
# The attention function calculates attention weights, and the apply
# attention function calculates the new edge feature given the weights.
# We don't use graph attention here, and just pass the defaults.
Expand Down