diff --git a/WDL/runtime/task.py b/WDL/runtime/task.py index cf7b1b37..cf7ac85b 100644 --- a/WDL/runtime/task.py +++ b/WDL/runtime/task.py @@ -491,17 +491,30 @@ def _parse_boolean(s: str) -> WDL.Value.Boolean: self._override_static("read_boolean", _read_something(_parse_boolean)) - def _read_lines(container_file: WDL.Value.File, lib: _StdLib = self) -> WDL.Value.Array: - host_file = lib.container.host_file(container_file.value, lib.inputs_only) + def parse_lines(s: str) -> WDL.Value.Array: ans = [] - with open(host_file, "r") as infile: - for line in infile: - if line.endswith("\n"): - line = line[:-1] - ans.append(WDL.Value.String(line)) + if s: + ans = [ + WDL.Value.String(line) + for line in (s[:-1] if s.endswith("\n") else s).split("\n") + ] return WDL.Value.Array(WDL.Type.Array(WDL.Type.String()), ans) - self._override_static("read_lines", _read_lines) + self._override_static("read_lines", _read_something(parse_lines)) + + def parse_tsv(s: str) -> WDL.Value.Array: + # TODO: should a blank line parse as [] or ['']? + ans = [ + WDL.Value.Array( + WDL.Type.Array(WDL.Type.String()), + [WDL.Value.String(field) for field in line.value.split("\t")], + ) + for line in parse_lines(s).value + ] + # pyre-ignore + return WDL.Value.Array(WDL.Type.Array(WDL.Type.Array(WDL.Type.String())), ans) + + self._override_static("read_tsv", _read_something(parse_tsv)) def _write_something( serialize: Callable[[WDL.Value.Base, BinaryIO], None], lib: _StdLib = self @@ -538,6 +551,22 @@ def _serialize_lines(array: WDL.Value.Array, outfile: BinaryIO) -> None: _write_something(lambda v, outfile: outfile.write(json.dumps(v.json).encode("utf-8"))), ) + self._override_static( + "write_tsv", + _write_something( + lambda v, outfile: _serialize_lines( + WDL.Value.Array( + WDL.Type.Array(WDL.Type.String()), + [ + WDL.Value.String("\t".join([part.value for part in parts.value])) + for parts in v.value + ], + ), + outfile, + ) + ), + ) + class InputStdLib(_StdLib): # StdLib for evaluation of task inputs and command diff --git a/tests/test_5stdlib.py b/tests/test_5stdlib.py index 6b61e996..75ccc19b 100644 --- a/tests/test_5stdlib.py +++ b/tests/test_5stdlib.py @@ -441,6 +441,7 @@ def test_write(self): version 1.0 task hello { File foo = write_lines(["foo","bar","baz"]) + File tsv = write_tsv([["one", "two", "three"], ["un", "deux", "trois"]]) File json = write_json({"key1": "value1", "key2": "value2"}) command <<< @@ -448,12 +449,18 @@ def test_write(self): if [ "$foo_sha" != "b1b113c6ed8ab3a14779f7c54179eac2b87d39fcebbf65a50556b8d68caaa2fb" ]; then exit 1 fi + tsv_sha=$(sha256sum < ~{tsv} | cut -f1 -d ' ') + if [ "$tsv_sha" != "a7124e688203195cd674cf147bbf965eda49e8df581d01c05944330fab096084" ]; then + exit 1 + fi >>> output { File o_json = json + Array[Array[String]] o_tsv = read_tsv(tsv) } } """) with open(outputs["o_json"]) as infile: self.assertEqual(json.load(infile), {"key1": "value1", "key2": "value2"}) + self.assertEqual(outputs["o_tsv"], [["one", "two", "three"], ["un", "deux", "trois"]])