Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
address more comments
Browse files Browse the repository at this point in the history
lithomas1 committed Jun 13, 2024

Verified

This commit was signed with the committer’s verified signature. The key has expired.
jeff-mccoy Megamind
1 parent 699efd3 commit e242182
Showing 4 changed files with 25 additions and 13 deletions.
10 changes: 8 additions & 2 deletions python/cudf/cudf/_lib/pylibcudf/io/types.pyx
Original file line number Diff line number Diff line change
@@ -200,6 +200,8 @@ cdef class SinkInfo:
if isinstance(sinks[0], io.StringIO):
data_sinks.reserve(len(sinks))
for s in sinks:
if not isinstance(s, io.StringIO):
raise ValueError("All sinks must be of the same type!")
self.sink_storage.push_back(
unique_ptr[data_sink](new iobase_data_sink(s))
)
@@ -219,8 +221,10 @@ cdef class SinkInfo:
}:
raise NotImplementedError(f"Unsupported encoding {s.encoding}")
sink = move(unique_ptr[data_sink](new iobase_data_sink(s.buffer)))
else:
elif isinstance(s, io.BytesIO):
sink = move(unique_ptr[data_sink](new iobase_data_sink(s)))
else:
raise ValueError("All sinks must be of the same type!")

self.sink_storage.push_back(
move(sink)
@@ -230,7 +234,9 @@ cdef class SinkInfo:
elif isinstance(sinks[0], (basestring, os.PathLike)):
paths.reserve(len(sinks))
for s in sinks:
if not isinstance(s, (basestring, os.PathLike)):
raise ValueError("All sinks must be of the same type!")
paths.push_back(<string> os.path.expanduser(s).encode())
self.c_obj = sink_info(move(paths))
else:
raise TypeError("Unrecognized input type: {}".format(type(sinks)))
raise TypeError("Unrecognized input type: {}".format(type(sinks[0])))
Original file line number Diff line number Diff line change
@@ -7,19 +7,24 @@
import cudf._lib.pylibcudf as plc


@pytest.fixture(params=[plc.io.SourceInfo, plc.io.SinkInfo])
def io_class(request):
return request.param


@pytest.mark.parametrize(
"source", ["a.txt", b"hello world", io.BytesIO(b"hello world")]
)
def test_source_info_ctor(source, tmp_path):
def test_source_info_ctor(io_class, source, tmp_path):
if isinstance(source, str):
file = tmp_path / source
file.write_bytes("hello world".encode("utf-8"))
source = str(file)

plc.io.SourceInfo([source])
if io_class is plc.io.SinkInfo and isinstance(source, bytes):
pytest.skip("bytes is not a valid input for SinkInfo")

# TODO: test contents of source_info buffer is correct
# once buffers are exposed on python side
io_class([source])


@pytest.mark.parametrize(
@@ -30,18 +35,17 @@ def test_source_info_ctor(source, tmp_path):
[io.BytesIO(b"hello world"), io.BytesIO(b"hello there")],
],
)
def test_source_info_ctor_multiple(sources, tmp_path):
def test_source_info_ctor_multiple(io_class, sources, tmp_path):
for i in range(len(sources)):
source = sources[i]
if isinstance(source, str):
file = tmp_path / source
file.write_bytes("hello world".encode("utf-8"))
sources[i] = str(file)
elif io_class is plc.io.SinkInfo and isinstance(source, bytes):
pytest.skip("bytes is not a valid input for SinkInfo")

plc.io.SourceInfo(sources)

# TODO: test contents of source_info buffer is correct
# once buffers are exposed on python side
io_class(sources)


@pytest.mark.parametrize(
@@ -56,7 +60,7 @@ def test_source_info_ctor_multiple(sources, tmp_path):
],
],
)
def test_source_info_ctor_mixing_invalid(sources, tmp_path):
def test_source_info_ctor_mixing_invalid(io_class, sources, tmp_path):
# Unlike the previous test
# don't create files so that they are missing
for i in range(len(sources)):
@@ -65,5 +69,7 @@ def test_source_info_ctor_mixing_invalid(sources, tmp_path):
file = tmp_path / source
file.write_bytes("hello world".encode("utf-8"))
sources[i] = str(file)
elif io_class is plc.io.SinkInfo and isinstance(source, bytes):
pytest.skip("bytes is not a valid input for SinkInfo")
with pytest.raises(ValueError):
plc.io.SourceInfo(sources)
io_class(sources)

0 comments on commit e242182

Please sign in to comment.