Skip to content

Commit

Permalink
Add structured readout versions of binary and multiclass classificati…
Browse files Browse the repository at this point in the history
…on tasks.

PiperOrigin-RevId: 588115018
  • Loading branch information
dzelle authored and tensorflower-gardener committed Dec 5, 2023
1 parent 68f1e28 commit e5988bb
Show file tree
Hide file tree
Showing 4 changed files with 467 additions and 26 deletions.
2 changes: 2 additions & 0 deletions tensorflow_gnn/api_def/runner-symbols.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ runner.ParameterServerStrategy
runner.PassthruDatasetProvider
runner.PassthruSampleDatasetsProvider
runner.Predictions
runner.NodeBinaryClassification
runner.RootNodeBinaryClassification
runner.RootNodeLabelFn
runner.RootNodeMeanAbsoluteError
Expand All @@ -34,6 +35,7 @@ runner.RootNodeMeanAbsolutePercentageError
runner.RootNodeMeanSquaredError
runner.RootNodeMeanSquaredLogScaledError
runner.RootNodeMeanSquaredLogarithmicError
runner.NodeMulticlassClassification
runner.RootNodeMulticlassClassification
runner.RunResult
runner.SampleTFRecordDatasetsProvider
Expand Down
6 changes: 4 additions & 2 deletions tensorflow_gnn/runner/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,10 +94,12 @@
# in `distribute_test.py`.)
#
# Tasks (Classification)
RootNodeBinaryClassification = classification.RootNodeBinaryClassification
RootNodeMulticlassClassification = classification.RootNodeMulticlassClassification
GraphBinaryClassification = classification.GraphBinaryClassification
GraphMulticlassClassification = classification.GraphMulticlassClassification
NodeBinaryClassification = classification.NodeBinaryClassification
NodeMulticlassClassification = classification.NodeMulticlassClassification
RootNodeBinaryClassification = classification.RootNodeBinaryClassification
RootNodeMulticlassClassification = classification.RootNodeMulticlassClassification
# Tasks (Link Prediction)
DotProductLinkPrediction = link_prediction.DotProductLinkPrediction
HadamardProductLinkPrediction = link_prediction.HadamardProductLinkPrediction
Expand Down
Loading

0 comments on commit e5988bb

Please sign in to comment.