Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Include files outside the archive in job.files.list() #1323

Merged
merged 9 commits into from
Feb 13, 2024
21 changes: 15 additions & 6 deletions pyiron_base/jobs/job/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,13 +364,20 @@ def _working_directory_list_files(working_directory):
list of str: file names
"""
if os.path.isdir(working_directory):
uncompressed_files_lst = os.listdir(working_directory)
if _working_directory_is_compressed(working_directory=working_directory):
with tarfile.open(
_get_compressed_job_name(working_directory=working_directory), "r"
) as tar:
return [member.name for member in tar.getmembers() if member.isfile()]
compressed_job_name = _get_compressed_job_name(
working_directory=working_directory
)
with tarfile.open(compressed_job_name, "r") as tar:
job_archive_name = os.path.basename(compressed_job_name)
compressed_files_lst = [
member.name for member in tar.getmembers() if member.isfile()
]
uncompressed_files_lst.remove(job_archive_name)
return uncompressed_files_lst + compressed_files_lst
else:
return os.listdir(working_directory)
return uncompressed_files_lst
return []


Expand Down Expand Up @@ -411,7 +418,9 @@ def _working_directory_read_file(working_directory, file_name, tail=None):
):
raise FileNotFoundError(file_name)

if _working_directory_is_compressed(working_directory=working_directory):
if _working_directory_is_compressed(
working_directory=working_directory
) and file_name not in os.listdir(working_directory):
with tarfile.open(
_get_compressed_job_name(working_directory=working_directory),
encoding="utf8",
Expand Down
24 changes: 24 additions & 0 deletions tests/job/test_genericJob.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,30 @@ def test_child_ids_finished(self):
def test_index(self):
pass

def test_compress_file_list(self):
file_lst = ["file_not_to_compress", "file_to_compress"]
ham = self.project.create.job.ScriptJob("job_script_compress")
os.makedirs(ham.working_directory, exist_ok=True)
for file in file_lst:
with open(os.path.join(ham.working_directory, file), "w") as f:
f.writelines(["content: " + file])
for file in file_lst:
self.assertTrue(file in ham.files.list())
for file in file_lst:
self.assertTrue(file in os.listdir(ham.working_directory))
ham.compress(files_to_compress=["file_to_compress"])
for file in ["job_script_compress.tar.bz2", "file_not_to_compress"]:
self.assertTrue(file in os.listdir(ham.working_directory))
for file in file_lst:
self.assertTrue(file in ham.files.list())
with contextlib.redirect_stdout(io.StringIO(newline=os.linesep)) as f:
ham.files.file_not_to_compress.tail()
self.assertEqual(f.getvalue().replace('\r', ''), "content: file_not_to_compress\n")
with contextlib.redirect_stdout(io.StringIO(newline=os.linesep)) as f:
ham.files.file_to_compress.tail()
self.assertEqual(f.getvalue().replace('\r', ''), "content: file_to_compress\n")
ham.remove()

def test_job_name(self):
cwd = self.file_location
with self.subTest("ensure create is working"):
Expand Down
Loading