Skip to content

Commit

Permalink
Update runner.md.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 519228096
  • Loading branch information
dzelle authored and tensorflower-gardener committed Mar 24, 2023
1 parent 2517114 commit 255d264
Showing 1 changed file with 24 additions and 28 deletions.
52 changes: 24 additions & 28 deletions tensorflow_gnn/docs/guide/runner.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,11 +44,10 @@ initial_node_states = lambda node_set, node_set_name: node_set["embedding"]
# This `tf.keras.layers.Layer` matches the `GraphTensorProcessorFn` protocol.
map_features = tfgnn.keras.layers.MapFeatures(node_sets_fn=initial_node_states)

# Extract labels from the graph context.
extract_labels = lambda inputs: inputs.context["label"]

# Binary classification by the root node.
task = runner.RootNodeBinaryClassification("nodes", label_fn=extract_labels)
task = runner.RootNodeBinaryClassification(
"nodes",
label_fn=runner.ContextLabelFn("label"))

trainer = runner.KerasTrainer(
strategy=tf.distribute.TPUStrategy(...),
Expand All @@ -70,8 +69,7 @@ runner.run(
task=task,
gtspec=gtspec,
global_batch_size=128,
# Extract any labels before lossing them via `MapFeatures`!
feature_processors=[extract_labels, map_features],
feature_processors=[map_features],
valid_ds_provider=valid_ds_provider)
```

Expand Down Expand Up @@ -122,17 +120,17 @@ Contributors have free rein in their implementation of `get_dataset`, e.g.:
in-memory generation of synthetic graphs or real time conversion of different
graph persistence formats.

### Graph Task Adaptation, Preprocessing and Objectives
### Task Preprocessing, Prediction and Objectives

```python
class Task(abc.ABC):

@abc.abstractmethod
def adapt(self, model: tf.keras.Model) -> tf.keras.Model:
def preprocess(self, inputs: GraphTensor) -> tuple[Union[GraphTensor, Sequence[GraphTensor]], Field]:
raise NotImplementedError()

@abc.abstractmethod
def preprocessors(self) -> Sequence[Callable[..., tf.data.Dataset]]:
def predict(self, *args: GraphTensor) -> Field:
raise NotImplementedError()

@abc.abstractmethod
Expand All @@ -144,13 +142,13 @@ class Task(abc.ABC):
raise NotImplementedError()
```

A `Task` collects the ancillary pieces for training a Keras model with a graph
learning objective. A `Task` is expected to return preprocessors for a
`tf.data.Dataset` (of `GraphTensor`) and adapt a Keras model for a graph
learning objective. Model adaptation refers to the addition of the readout and
prediction heads (see step 3 of
A `Task` represents a learning objective for a GNN model and defines all the
non-GNN pieces around the base GNN. A `Task` is expected to define preprocessing
for a `tf.data.Dataset` (of `GraphTensor`) and produce prediction outputs (via
`predict(...)`). `predict(...)` typically performs the addition of the readout
and prediction heads (see step 3 of
[The big picture](gnn_modeling.md#the-big-picture-initialization-graph-updates-and-readout)).
It also provides any losses and metrics for that objective. Common
The `Task` also provides any losses and metrics for that objective. Common
implementations for classification and regression (by graph or root node) are
provided:

Expand Down Expand Up @@ -178,7 +176,8 @@ example, an imagined `RadiaInfomax` that:
* Masks arbitrary nodes,
* Creates psuedo labels;

* For an arbitrary model,
* For an arbitrary input (where that input is the base GNN output for those
`GraphTensor` returned by `preprocess(...)`),

* Adds a head to `R^4` from the root node hidden state;

Expand All @@ -189,22 +188,19 @@ example, an imagined `RadiaInfomax` that:
```python
class RadiaInfomax(runner.Task):

def adapt(self, model: tf.keras.Model) -> tf.keras.Model:
tfgnn.check_scalar_graph_tensor(model.output, name="RadiaInfomax")
def preprocess(self, inputs: GraphTensor) -> tuple[GraphTensor, Field]:
return mask_some_nodes(gt), create_psuedolabels()

def predict(self, inputs: GraphTensor) -> Field:
# A single `GraphTensor` input corresponding to the base GNN output given
# the `GraphTensor` returned by `preprocess(...)`.
tfgnn.check_scalar_graph_tensor(inputs, name="RadiaInfomax")
activations = tfgnn.keras.layers.ReadoutFirstNode(
node_set_name="nodes",
feature_name=tfgnn.HIDDEN_STATE)(model.output)
logits = tf.keras.layers.Dense(
feature_name=tfgnn.HIDDEN_STATE)(inputs)
return tf.keras.layers.Dense(
4, # Apply RadiaInfomax in R^4.
name="logits")(activations)
return tf.keras.Model(model.inputs, logits)

def preprocessors(self) -> Sequence[Callable[..., tf.data.Dataset]]:
def preprocessor(gt: tfgnn.GraphTensor):
gt = mask_some_nodes(gt)
labels = create_psuedolabels()
return gt, labels
return [lambda ds: ds.map(preprocessor),]

def losses(self) -> Sequence[Callable[[tf.Tensor, tf.Tensor], tf.Tensor]]:
return [tf.keras.losses.CosineSimilarity(),]
Expand Down

0 comments on commit 255d264

Please sign in to comment.