Skip to content

Commit

Permalink
Merge pull request #1397 from pyiron/file_obj_str
Browse files Browse the repository at this point in the history
File: no longer derive from string
  • Loading branch information
pmrv authored Mar 25, 2024
2 parents c1aa430 + 8dc1d4f commit 42d0f43
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 11 deletions.
22 changes: 17 additions & 5 deletions pyiron_base/jobs/job/extension/files.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import posixpath
from typing import List
from pyiron_base.jobs.job.util import (
_working_directory_list_files,
Expand Down Expand Up @@ -96,12 +97,12 @@ def __getitem__(self, item):
working_directory=self._working_directory,
include_archive=False,
):
return File(os.path.join(self._working_directory, item))
return File(posixpath.join(self._working_directory, item))
elif item in _working_directory_list_files(
working_directory=self._working_directory,
include_archive=True,
):
return File(os.path.join(self._working_directory, item))
return File(posixpath.join(self._working_directory, item))
else:
raise FileNotFoundError(item)

Expand All @@ -115,13 +116,24 @@ def __getattr__(self, item):
raise FileNotFoundError(item) from None


class File(str):
class File:
__slots__ = ("_path",)

def __init__(self, path):
self._path = path

def __str__(self):
return self._path

def tail(self, lines: int = 100):
print(
*_working_directory_read_file(
working_directory=os.path.dirname(self),
file_name=os.path.basename(self),
working_directory=os.path.dirname(self._path),
file_name=os.path.basename(self._path),
tail=lines,
),
sep="",
)

def __eq__(self, other):
return self.__str__().__eq__(other)
17 changes: 13 additions & 4 deletions pyiron_base/jobs/job/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
_doc_str_job_core_attr,
)
from pyiron_base.jobs.job.extension.executable import Executable
from pyiron_base.jobs.job.extension.files import File
from pyiron_base.jobs.job.extension.jobstatus import JobStatus
from pyiron_base.jobs.job.runfunction import (
run_job_with_parameter_repair,
Expand Down Expand Up @@ -275,7 +276,7 @@ def restart_file_list(self):
Returns:
list: list of files
"""
return self._restart_file_list
return [str(f) if isinstance(f, File) else f for f in self._restart_file_list]

@restart_file_list.setter
def restart_file_list(self, filenames):
Expand All @@ -286,6 +287,8 @@ def restart_file_list(self, filenames):
filenames (list):
"""
for f in filenames:
if isinstance(f, File):
f = str(f)
if not (os.path.isfile(f)):
raise IOError("File: {} does not exist".format(f))
self.restart_file_list.append(f)
Expand All @@ -295,7 +298,7 @@ def restart_file_dict(self):
"""
A dictionary of the new name of the copied restart files
"""
for actual_name in [os.path.basename(f) for f in self._restart_file_list]:
for actual_name in [os.path.basename(f) for f in self.restart_file_list]:
if actual_name not in self._restart_file_dict.keys():
self._restart_file_dict[actual_name] = actual_name
return self._restart_file_dict
Expand All @@ -305,7 +308,13 @@ def restart_file_dict(self, val):
if not isinstance(val, dict):
raise ValueError("restart_file_dict should be a dictionary!")
else:
self._restart_file_dict = val
self._restart_file_dict = {}
for k, v in val.items():
if isinstance(k, File):
k = str(k)
if isinstance(v, File):
v = str(v)
self._restart_file_dict[k] = v

@property
def exclude_nodes_hdf(self):
Expand Down Expand Up @@ -1043,7 +1052,7 @@ def to_dict(self):
data_dict = self._type_to_dict()
data_dict["status"] = self.status.string
data_dict["input/generic_dict"] = {
"restart_file_list": self._restart_file_list,
"restart_file_list": self.restart_file_list,
"restart_file_dict": self._restart_file_dict,
"exclude_nodes_hdf": self._exclude_nodes_hdf,
"exclude_groups_hdf": self._exclude_groups_hdf,
Expand Down
2 changes: 1 addition & 1 deletion tests/flex/test_executablecontainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_job_files(self):
if os.name != "nt":
self.assertEqual(job.files.error_out, output_file_path)
else:
self.assertEqual(job.files.error_out.replace("/", "\\"), output_file_path)
self.assertEqual(job.files.error_out, output_file_path.replace("\\", "/"))

def test_create_job_factory_typeerror(self):
create_catjob = create_job_factory(
Expand Down
3 changes: 2 additions & 1 deletion tests/flex/test_wrap_executable.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import posixpath
from pyiron_base._tests import TestWithProject


Expand Down Expand Up @@ -28,7 +29,7 @@ def test_python_version(self):
self.assertTrue(python_version_step.status.finished)
self.assertEqual(
python_version_step.files.error_out,
os.path.join(python_version_step.working_directory, "error.out")
posixpath.join(python_version_step.working_directory, "error.out")
)

def test_cat(self):
Expand Down

0 comments on commit 42d0f43

Please sign in to comment.