From d69eaabddc697f6ab7e7ebfedb4974181c61d575 Mon Sep 17 00:00:00 2001 From: "Brandon T. Willard" Date: Mon, 11 Jul 2022 19:46:31 -0500 Subject: [PATCH] Use context managers with open --- aesara/compile/compiledir.py | 22 ++-- aesara/compile/profiling.py | 196 +++++++++++++++------------- aesara/configdefaults.py | 3 +- aesara/link/c/cmodule.py | 14 +- aesara/link/c/cutils.py | 3 +- aesara/link/c/lazylinker_c.py | 12 +- aesara/misc/pkl_utils.py | 16 +-- doc/tutorial/loading_and_saving.rst | 3 +- tests/misc/test_pkl_utils.py | 6 +- 9 files changed, 146 insertions(+), 129 deletions(-) diff --git a/aesara/compile/compiledir.py b/aesara/compile/compiledir.py index 3af884469d..6ecb6e0eda 100644 --- a/aesara/compile/compiledir.py +++ b/aesara/compile/compiledir.py @@ -33,14 +33,13 @@ def cleanup(): """ compiledir = config.compiledir for directory in os.listdir(compiledir): - file = None try: - try: - filename = os.path.join(compiledir, directory, "key.pkl") - file = open(filename, "rb") - # print file + filename = os.path.join(compiledir, directory, "key.pkl") + # print file + with open(filename, "rb") as file: try: keydata = pickle.load(file) + for key in list(keydata.keys): have_npy_abi_version = False have_c_compiler = False @@ -86,14 +85,11 @@ def cleanup(): "the clean-up, please remove manually " "the directory containing it." ) - except OSError: - _logger.error( - f"Could not clean up this directory: '{directory}'. To complete " - "the clean-up, please remove it manually." - ) - finally: - if file is not None: - file.close() + except OSError: + _logger.error( + f"Could not clean up this directory: '{directory}'. To complete " + "the clean-up, please remove it manually." + ) def print_title(title, overline="", underline=""): diff --git a/aesara/compile/profiling.py b/aesara/compile/profiling.py index c352ada42b..3a6c37d4d5 100644 --- a/aesara/compile/profiling.py +++ b/aesara/compile/profiling.py @@ -15,6 +15,7 @@ import sys import time from collections import defaultdict +from contextlib import contextmanager from typing import Dict, List import numpy as np @@ -25,6 +26,17 @@ from aesara.link.utils import get_destroy_dependencies +@contextmanager +def extended_open(filename, mode="r"): + if filename == "": + yield sys.stdout + elif filename == "": + yield sys.stderr + else: + with open(filename, mode=mode) as f: + yield f + + logger = logging.getLogger("aesara.compile.profiling") aesara_imported_time = time.time() @@ -37,93 +49,92 @@ def _atexit_print_fn(): - """ - Print ProfileStat objects in _atexit_print_list to _atexit_print_file. - - """ + """Print `ProfileStat` objects in `_atexit_print_list` to `_atexit_print_file`.""" if config.profile: to_sum = [] if config.profiling__destination == "stderr": - destination_file = sys.stderr + destination_file = "" elif config.profiling__destination == "stdout": - destination_file = sys.stdout + destination_file = "" else: - destination_file = open(config.profiling__destination, "w") - - # Reverse sort in the order of compile+exec time - for ps in sorted( - _atexit_print_list, key=lambda a: a.compile_time + a.fct_call_time - )[::-1]: - if ( - ps.fct_callcount >= 1 - or ps.compile_time > 1 - or getattr(ps, "callcount", 0) > 1 - ): - ps.summary( + destination_file = config.profiling__destination + + with extended_open(destination_file, mode="w"): + + # Reverse sort in the order of compile+exec time + for ps in sorted( + _atexit_print_list, key=lambda a: a.compile_time + a.fct_call_time + )[::-1]: + if ( + ps.fct_callcount >= 1 + or ps.compile_time > 1 + or getattr(ps, "callcount", 0) > 1 + ): + ps.summary( + file=destination_file, + n_ops_to_print=config.profiling__n_ops, + n_apply_to_print=config.profiling__n_apply, + ) + + if ps.show_sum: + to_sum.append(ps) + else: + # TODO print the name if there is one! + print("Skipping empty Profile") + if len(to_sum) > 1: + # Make a global profile + cum = copy.copy(to_sum[0]) + msg = f"Sum of all({len(to_sum)}) printed profiles at exit." + cum.message = msg + for ps in to_sum[1:]: + for attr in [ + "compile_time", + "fct_call_time", + "fct_callcount", + "vm_call_time", + "optimizer_time", + "linker_time", + "validate_time", + "import_time", + "linker_node_make_thunks", + ]: + setattr(cum, attr, getattr(cum, attr) + getattr(ps, attr)) + + # merge dictionary + for attr in [ + "apply_time", + "apply_callcount", + "apply_cimpl", + "variable_shape", + "variable_strides", + "variable_offset", + "linker_make_thunk_time", + ]: + cum_attr = getattr(cum, attr) + for key, val in getattr(ps, attr.items()): + assert key not in cum_attr, (key, cum_attr) + cum_attr[key] = val + + if cum.optimizer_profile and ps.optimizer_profile: + try: + merge = cum.optimizer_profile[0].merge_profile( + cum.optimizer_profile[1], ps.optimizer_profile[1] + ) + assert len(merge) == len(cum.optimizer_profile[1]) + cum.optimizer_profile = (cum.optimizer_profile[0], merge) + except Exception as e: + print(e) + cum.optimizer_profile = None + else: + cum.optimizer_profile = None + + cum.summary( file=destination_file, n_ops_to_print=config.profiling__n_ops, n_apply_to_print=config.profiling__n_apply, ) - if ps.show_sum: - to_sum.append(ps) - else: - # TODO print the name if there is one! - print("Skipping empty Profile") - if len(to_sum) > 1: - # Make a global profile - cum = copy.copy(to_sum[0]) - msg = f"Sum of all({len(to_sum)}) printed profiles at exit." - cum.message = msg - for ps in to_sum[1:]: - for attr in [ - "compile_time", - "fct_call_time", - "fct_callcount", - "vm_call_time", - "optimizer_time", - "linker_time", - "validate_time", - "import_time", - "linker_node_make_thunks", - ]: - setattr(cum, attr, getattr(cum, attr) + getattr(ps, attr)) - - # merge dictionary - for attr in [ - "apply_time", - "apply_callcount", - "apply_cimpl", - "variable_shape", - "variable_strides", - "variable_offset", - "linker_make_thunk_time", - ]: - cum_attr = getattr(cum, attr) - for key, val in getattr(ps, attr.items()): - assert key not in cum_attr, (key, cum_attr) - cum_attr[key] = val - - if cum.optimizer_profile and ps.optimizer_profile: - try: - merge = cum.optimizer_profile[0].merge_profile( - cum.optimizer_profile[1], ps.optimizer_profile[1] - ) - assert len(merge) == len(cum.optimizer_profile[1]) - cum.optimizer_profile = (cum.optimizer_profile[0], merge) - except Exception as e: - print(e) - cum.optimizer_profile = None - else: - cum.optimizer_profile = None - - cum.summary( - file=destination_file, - n_ops_to_print=config.profiling__n_ops, - n_apply_to_print=config.profiling__n_apply, - ) - if config.print_global_stats: print_global_stats() @@ -139,24 +150,25 @@ def print_global_stats(): """ if config.profiling__destination == "stderr": - destination_file = sys.stderr + destination_file = "" elif config.profiling__destination == "stdout": - destination_file = sys.stdout + destination_file = "" else: - destination_file = open(config.profiling__destination, "w") - - print("=" * 50, file=destination_file) - print( - ( - "Global stats: ", - f"Time elasped since Aesara import = {time.time() - aesara_imported_time:6.3f}s, " - f"Time spent in Aesara functions = {total_fct_exec_time:6.3f}s, " - "Time spent compiling Aesara functions: " - f" optimization = {total_graph_opt_time:6.3f}s, linker = {total_time_linker:6.3f}s ", - ), - file=destination_file, - ) - print("=" * 50, file=destination_file) + destination_file = config.profiling__destination + + with extended_open(destination_file, mode="w"): + print("=" * 50, file=destination_file) + print( + ( + "Global stats: ", + f"Time elasped since Aesara import = {time.time() - aesara_imported_time:6.3f}s, " + f"Time spent in Aesara functions = {total_fct_exec_time:6.3f}s, " + "Time spent compiling Aesara functions: " + f" optimization = {total_graph_opt_time:6.3f}s, linker = {total_time_linker:6.3f}s ", + ), + file=destination_file, + ) + print("=" * 50, file=destination_file) _profiler_printers = [] diff --git a/aesara/configdefaults.py b/aesara/configdefaults.py index 6ec9c45778..4635a8601d 100644 --- a/aesara/configdefaults.py +++ b/aesara/configdefaults.py @@ -1300,7 +1300,8 @@ def _filter_compiledir(path): init_file = os.path.join(path, "__init__.py") if not os.path.exists(init_file): try: - open(init_file, "w").close() + with open(init_file, "w"): + pass except OSError as e: if os.path.exists(init_file): pass # has already been created diff --git a/aesara/link/c/cmodule.py b/aesara/link/c/cmodule.py index ecae473f9f..0d71fedeb6 100644 --- a/aesara/link/c/cmodule.py +++ b/aesara/link/c/cmodule.py @@ -1008,8 +1008,8 @@ def unpickle_failure(): entry = key_data.get_entry() try: # Test to see that the file is [present and] readable. - open(entry).close() - gone = False + with open(entry): + gone = False except OSError: gone = True @@ -1505,8 +1505,8 @@ def clear_unversioned(self, min_age=None): if filename.startswith("tmp"): try: fname = os.path.join(self.dirname, filename, "key.pkl") - open(fname).close() - has_key = True + with open(fname): + has_key = True except OSError: has_key = False if not has_key: @@ -1599,7 +1599,8 @@ def _rmtree( if os.path.exists(parent): try: _logger.info(f'placing "delete.me" in {parent}') - open(os.path.join(parent, "delete.me"), "w").close() + with open(os.path.join(parent, "delete.me"), "w"): + pass except Exception as ee: _logger.warning( f"Failed to remove or mark cache directory {parent} for removal {ee}" @@ -2641,7 +2642,8 @@ def print_command_line_error(): if py_module: # touch the __init__ file - open(os.path.join(location, "__init__.py"), "w").close() + with open(os.path.join(location, "__init__.py"), "w"): + pass assert os.path.isfile(lib_filename) return dlimport(lib_filename) diff --git a/aesara/link/c/cutils.py b/aesara/link/c/cutils.py index 311a8eb33c..a610c28b45 100644 --- a/aesara/link/c/cutils.py +++ b/aesara/link/c/cutils.py @@ -96,7 +96,8 @@ def compile_cutils(): assert e.errno == errno.EEXIST assert os.path.exists(location), location if not os.path.exists(os.path.join(location, "__init__.py")): - open(os.path.join(location, "__init__.py"), "w").close() + with open(os.path.join(location, "__init__.py"), "w"): + pass try: from cutils_ext.cutils_ext import * # noqa diff --git a/aesara/link/c/lazylinker_c.py b/aesara/link/c/lazylinker_c.py index bed42b4f45..4cf30ce653 100644 --- a/aesara/link/c/lazylinker_c.py +++ b/aesara/link/c/lazylinker_c.py @@ -59,7 +59,8 @@ def try_reload(): init_file = os.path.join(location, "__init__.py") if not os.path.exists(init_file): try: - open(init_file, "w").close() + with open(init_file, "w"): + pass except OSError as e: if os.path.exists(init_file): pass # has already been created @@ -126,10 +127,12 @@ def try_reload(): "code generation." ) raise ImportError("The file lazylinker_c.c is not available.") - code = open(cfile).read() + + with open(cfile) as f: + code = f.read() + loc = os.path.join(config.compiledir, dirname) if not os.path.exists(loc): - try: os.mkdir(loc) except OSError as e: @@ -140,14 +143,17 @@ def try_reload(): GCC_compiler.compile_str(dirname, code, location=loc, preargs=args) # Save version into the __init__.py file. init_py = os.path.join(loc, "__init__.py") + with open(init_py, "w") as f: f.write(f"_version = {version}\n") + # If we just compiled the module for the first time, then it was # imported at the same time: we need to make sure we do not # reload the now outdated __init__.pyc below. init_pyc = os.path.join(loc, "__init__.pyc") if os.path.isfile(init_pyc): os.remove(init_pyc) + try_import() try_reload() from lazylinker_ext import lazylinker_ext as lazy_c diff --git a/aesara/misc/pkl_utils.py b/aesara/misc/pkl_utils.py index 253d602ff6..80a0ba824f 100644 --- a/aesara/misc/pkl_utils.py +++ b/aesara/misc/pkl_utils.py @@ -42,21 +42,19 @@ class StripPickler(Pickler): - """ - Subclass of Pickler that strips unnecessary attributes from Aesara objects. - - .. versionadded:: 0.8 + """Subclass of `Pickler` that strips unnecessary attributes from Aesara objects. - Example of use:: + Example + ------- fn_args = dict(inputs=inputs, outputs=outputs, updates=updates) dest_pkl = 'my_test.pkl' - f = open(dest_pkl, 'wb') - strip_pickler = StripPickler(f, protocol=-1) - strip_pickler.dump(fn_args) - f.close() + with open(dest_pkl, 'wb') as f: + strip_pickler = StripPickler(f, protocol=-1) + strip_pickler.dump(fn_args) + """ def __init__(self, file, protocol=0, extra_tag_to_remove=None): diff --git a/doc/tutorial/loading_and_saving.rst b/doc/tutorial/loading_and_saving.rst index 86e075f42a..da743a462a 100644 --- a/doc/tutorial/loading_and_saving.rst +++ b/doc/tutorial/loading_and_saving.rst @@ -118,7 +118,8 @@ For instance, you can define functions along the lines of: def __setstate__(self, d): self.__dict__.update(d) - self.training_set = cPickle.load(open(self.training_set_file, 'rb')) + with open(self.training_set_file, 'rb') as f: + self.training_set = cPickle.load(f) Robust Serialization diff --git a/tests/misc/test_pkl_utils.py b/tests/misc/test_pkl_utils.py index 6413202c95..acb0f0949e 100644 --- a/tests/misc/test_pkl_utils.py +++ b/tests/misc/test_pkl_utils.py @@ -66,6 +66,6 @@ def test_basic(self): with open("test.pkl", "wb") as f: m = matrix() dest_pkl = "my_test.pkl" - f = open(dest_pkl, "wb") - strip_pickler = StripPickler(f, protocol=-1) - strip_pickler.dump(m) + with open(dest_pkl, "wb") as f: + strip_pickler = StripPickler(f, protocol=-1) + strip_pickler.dump(m)