Skip to content

Commit

Permalink
Fix for inconsistent ragged row partitions of edges and indices to se…
Browse files Browse the repository at this point in the history
…lect unique edges.

PiperOrigin-RevId: 549425021
  • Loading branch information
aferludin authored and tensorflower-gardener committed Jul 19, 2023
1 parent 2757a70 commit 9e560b2
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 14 deletions.
6 changes: 5 additions & 1 deletion tensorflow_gnn/experimental/sampler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1438,12 +1438,16 @@ def _get_unique_parallel_edges_indices(
original_edges_idx = tf.ragged.range(sizes).values
# assignes to each edge its graph index within the original graph tensor.
original_graph_idx = tf.repeat(tf.range(num_graphs), sizes)
# TODO(b/285269757): replace with `graph_tensor.row_splits_dtype`
row_splits_dtype = graph_tensor.edge_sets[
edge_set_name
].adjacency.source.dtype
result[edge_set_name] = tf.RaggedTensor.from_value_rowids(
map_to_unique_edge(original_edges_idx),
map_to_unique_edge(original_graph_idx),
nrows=num_graphs,
validate=False,
)
).with_row_splits_dtype(row_splits_dtype)

return result

Expand Down
37 changes: 24 additions & 13 deletions tensorflow_gnn/experimental/sampler/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1201,7 +1201,7 @@ def testHeterogeneous2(self):
)


class ParallelEdgesRemovalTest(tf.test.TestCase):
class ParallelEdgesRemovalTest(tf.test.TestCase, parameterized.TestCase):

def testNoEdges(self):
graph = core.build_graph_tensor(
Expand Down Expand Up @@ -1280,29 +1280,40 @@ def testHeterogeneous(self):
graph.edge_sets['B->A'].adjacency.target, rt([[1, 1], [], [], [1], []])
)

def testHomogeneous(self):
@parameterized.product(
indices_dtype=[tf.int32, tf.int64], row_splits_dtype=[tf.int32, tf.int64]
)
def testHomogeneous(
self, indices_dtype: tf.DType, row_splits_dtype: tf.DType
):
graph = core.build_graph_tensor(
edge_sets={
'A,A->A,A': {
'#source': rt(
[['a'], ['c'] * 5, [], ['e'] * 10, ['g'] * 15 + ['k'] * 10]
[[1], [3] * 5, [], [5] * 10, [7] * 15 + [9] * 10],
dtype=indices_dtype,
row_splits_dtype=row_splits_dtype,
),
'#target': rt(
[['b'], ['d'] * 5, [], ['e'] * 10, ['h'] * 15 + ['k'] * 10]
[[2], [4] * 5, [], [5] * 10, [8] * 15 + [9] * 10],
dtype=indices_dtype,
row_splits_dtype=row_splits_dtype,
),
'f': rt(
[[1], [2] * 5, [], [3] * 10, [4] * 15 + [5] * 10],
row_splits_dtype=row_splits_dtype,
),
'f': rt([[1], [2] * 5, [], [3] * 10, [4] * 15 + [5] * 10]),
},
},
remove_parallel_edges=True,
)
self.assertAllEqual(
graph.edge_sets['A->A'].adjacency.source,
rt([[0], [0], [], [0], [0, 1]]),
)
self.assertAllEqual(
graph.edge_sets['A->A'].adjacency.target,
rt([[1], [1], [], [0], [2, 1]]),
)
source = graph.edge_sets['A->A'].adjacency.source
target = graph.edge_sets['A->A'].adjacency.target
self.assertAllEqual(source, rt([[0], [0], [], [0], [0, 1]]))
self.assertAllEqual(source.row_splits.dtype, row_splits_dtype)
self.assertAllEqual(target, rt([[1], [1], [], [0], [2, 1]]))
self.assertAllEqual(target.row_splits.dtype, row_splits_dtype)

self.assertAllEqual(
graph.edge_sets['A->A']['f'], rt([[1], [2], [], [3], [4, 5]])
)
Expand Down

0 comments on commit 9e560b2

Please sign in to comment.