diff --git a/tensordict/utils.py b/tensordict/utils.py index 0c8585d39..cdc0756f8 100644 --- a/tensordict/utils.py +++ b/tensordict/utils.py @@ -2140,7 +2140,7 @@ def _is_json_serializable(item): return isinstance(item, (str, int, float, bool)) or item is None -def print_directory_tree(path, indent="", display_metadata=True): +def print_directory_tree(path, indent="", display_metadata=True) -> str: """Prints the directory tree starting from the specified path. Args: @@ -2149,7 +2149,11 @@ def print_directory_tree(path, indent="", display_metadata=True): display_metadata (bool): if ``True``, metadata of the dir will be displayed too. + Returns: + the string printed with the logger. + """ + string = [] if display_metadata: def get_directory_size(path="."): @@ -2171,17 +2175,23 @@ def format_size(size): total_size_bytes = get_directory_size(path) formatted_size = format_size(total_size_bytes) - logger.info(f"Directory size: {formatted_size}") + string.append(f"Directory size: {formatted_size}") + logger.info(string[-1]) if os.path.isdir(path): - logger.info(indent + os.path.basename(path) + "/") + string.append(indent + os.path.basename(path) + "/") + logger.info(string[-1]) indent += " " for item in os.listdir(path): - print_directory_tree( - os.path.join(path, item), indent=indent, display_metadata=False + string.append( + print_directory_tree( + os.path.join(path, item), indent=indent, display_metadata=False + ) ) else: - logger.info(indent + os.path.basename(path)) + string.append(indent + os.path.basename(path)) + logger.info(string[-1]) + return "\n".join(string) def isin(