Skip to content

Commit

Permalink
Add a test dataset for reachability dataflow.
Browse files Browse the repository at this point in the history
This adds a target //programl/test/data:reachability_dataflow_dataset
which defines a tarball of test data for the dataflow task.

github.com//issues/119
  • Loading branch information
ChrisCummins committed Aug 30, 2020
1 parent 86e703f commit 14cb9c9
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 2 deletions.
1 change: 1 addition & 0 deletions programl/task/dataflow/dataset/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ cc_binary(
"@labm8//labm8/cpp:logging",
"@labm8//labm8/cpp:strutil",
],
visibility = ["//visibility:public"],
)

py_binary(
Expand Down
3 changes: 1 addition & 2 deletions programl/task/dataflow/dataset/create_vocab.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ of total node texts that are described by the current and prior lines.
those without a text representation. <count> is the number of matching node
texts, and <node_text> is the unique text value.)";

DEFINE_string(path, (labm8::fsutil::GetHomeDirectoryOrDie() / "programl/dataflow").string(),
"The directory to write generated files to.");
DEFINE_string(path, "/tmp/programl/dataflow", "The directory to write generated files to.");
DEFINE_int32(limit, 0,
"If --limit > 0, limit the number of input graphs processed to "
"this number.");
Expand Down
27 changes: 27 additions & 0 deletions programl/test/data/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3088,3 +3088,30 @@ filegroup(
testonly = 1,
srcs = ["module_with_unreachable_instructions.ll"],
)

genrule(
name = "reachability_dataflow_dataset",
testonly = 1,
outs = ["reachability_dataflow_dataset.tar.bz2"],
cmd = (
"$(location :make_reachability_dataflow_dataset) --path=$(@D)/dtmp && " +
"tar cjf $(@D)/reachability_dataflow_dataset.tar.bz2 -C $(@D)/dtmp . && " +
"rm -rf $(@D)/dtmp"
),
tools = [":make_reachability_dataflow_dataset"],
)

py_binary(
name = "make_reachability_dataflow_dataset",
testonly = 1,
srcs = ["make_reachability_dataflow_dataset.py"],
data = [
":llvm_ir_graphs",
":llvm_ir",
":llvm_ir_reachability_features",
"//programl/task/dataflow/dataset:create_vocab",
],
deps = [
"//third_party/py/labm8",
],
)
96 changes: 96 additions & 0 deletions programl/test/data/make_reachability_dataflow_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Copyright 2019-2020 the ProGraML authors.
#
# Contact Chris Cummins <[email protected]>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Create a mini reachability dataflow dataset using test data.
Usage:
$ bazel run //programl/test/data:make_reachability_dataflow_dataset \
--path /path/to/generated/dataset
"""
import os
import shutil
import subprocess
from pathlib import Path

from labm8.py import app, bazelutil

app.DEFINE_string("path", None, "The path of to write the generated dataset to.")
FLAGS = app.FLAGS


LLVM_IR = bazelutil.DataPath("programl/programl/test/data/llvm_ir")

LLVM_IR_GRAPHS = bazelutil.DataPath("programl/programl/test/data/llvm_ir_graphs")

LLVM_IR_GRAPH_REACHABILITY_FEATURES = bazelutil.DataPath(
"programl/programl/test/data/llvm_ir_reachability"
)

CREATE_VOCAB = bazelutil.DataPath(
"programl/programl/task/dataflow/dataset/create_vocab"
)


def make_reachability_dataflow_dataset(root: Path) -> Path:
"""Make a miniature dataset for reachability dataflow.
Args:
root: The root of the dataset.
Returns:
The root of the dataset.
"""
(root / "train").mkdir(parents=True)
(root / "val").mkdir()
(root / "test").mkdir()
(root / "labels").mkdir()

shutil.copytree(LLVM_IR_GRAPHS, root / "graphs")
shutil.copytree(LLVM_IR, root / "ir")
shutil.copytree(
LLVM_IR_GRAPH_REACHABILITY_FEATURES, root / "labels" / "reachability"
)

ngraphs = len(list(LLVM_IR_GRAPHS.iterdir()))
ntrain = int(ngraphs * 0.6)
nval = int(ngraphs * 0.8)

for i, graph in enumerate(LLVM_IR_GRAPHS.iterdir()):
if i < ntrain:
dst = "train"
elif i < nval:
dst = "val"
else:
dst = "test"
name = graph.name[: -len(".ProgramGraph.pb")]
os.symlink(
f"../graphs/{name}.ProgramGraph.pb",
root / dst / f"{name}.ProgramGraph.pb",
)

subprocess.check_call([str(CREATE_VOCAB), "--path", str(root)])

return root


def main():
"""Main entry point."""
assert FLAGS.path
make_reachability_dataflow_dataset(Path(FLAGS.path))


if __name__ == "__main__":
app.Run(main)
10 changes: 10 additions & 0 deletions programl/test/py/plugins/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,13 @@ py_library(
"//third_party/py/labm8",
],
)

py_library(
name = "reachability_dataflow_dataset",
testonly = 1,
srcs = ["reachability_dataflow_dataset.py"],
data = ["//programl/test/data:reachability_dataflow_dataset"],
deps = [
"//third_party/py/labm8",
],
)
29 changes: 29 additions & 0 deletions programl/test/py/plugins/reachability_dataflow_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright 2019-2020 the ProGraML authors.
#
# Contact Chris Cummins <[email protected]>.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path

from labm8.py import bazelutil

REACHABILITY_DATAFLOW_DATASET = bazelutil.DataArchive(
"programl/test/data/reachability_dataflow_dataset.tar.bz2"
)


@test.Fixture(scope="function")
def reachability_dataflow_dataset() -> Path:
"""A test fixture which yields the root of a dataflow dataset."""
with REACHABILITY_DATAFLOW_DATASET as d:
yield d

0 comments on commit 14cb9c9

Please sign in to comment.