Skip to content

Commit

Permalink
[Refactor] Refactor implement_for (#556)
Browse files Browse the repository at this point in the history
  • Loading branch information
vmoens authored Nov 13, 2023
1 parent 924a46a commit 95bd846
Showing 1 changed file with 92 additions and 28 deletions.
120 changes: 92 additions & 28 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from __future__ import annotations

import collections
import dataclasses
import inspect
import math

import sys
import time

import warnings
Expand All @@ -16,7 +18,7 @@
from copy import copy
from functools import wraps
from importlib import import_module
from typing import Any, Callable, List, Sequence, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, Dict, List, Sequence, Tuple, TYPE_CHECKING, Union

import numpy as np
import torch
Expand Down Expand Up @@ -984,6 +986,7 @@ class implement_for:
# Stores pointers to fitting implementations: dict[func_name] = func_pointer
_implementations = {}
_setters = []
_cache_modules = {}

def __init__(
self,
Expand All @@ -1005,35 +1008,87 @@ def check_version(version, from_version, to_version):
@staticmethod
def get_class_that_defined_method(f):
"""Returns the class of a method, if it is defined, and None otherwise."""
return f.__globals__.get(f.__qualname__.split(".")[0], None)
out = f.__globals__.get(f.__qualname__.split(".")[0], None)
return out

@property
def func_name(self):
return self.fn.__name__
@classmethod
def get_func_name(cls, fn):
# produces a name like torchrl.module.Class.method or torchrl.module.function
first = str(fn).split(".")[0][len("<function ") :]
last = str(fn).split(".")[1:]
if last:
first = [first]
last[-1] = last[-1].split(" ")[0]
else:
last = [first.split(" ")[0]]
first = []
return ".".join([fn.__module__] + first + last)

def _get_cls(self, fn):
cls = self.get_class_that_defined_method(fn)
if cls is None:
# class not yet defined
return
if cls.__class__.__name__ == "function":
cls = inspect.getmodule(fn)
return cls

def module_set(self):
"""Sets the function in its module, if it exists already."""
prev_setter = type(self)._implementations.get(self.get_func_name(self.fn), None)
if prev_setter is not None:
prev_setter.do_set = False
type(self)._implementations[self.get_func_name(self.fn)] = self
cls = self.get_class_that_defined_method(self.fn)
if cls is None:
if cls is not None:
if cls.__class__.__name__ == "function":
cls = inspect.getmodule(self.fn)
else:
# class not yet defined
return
if cls.__class__.__name__ == "function":
cls = inspect.getmodule(self.fn)
setattr(cls, self.fn.__name__, self.fn)

@staticmethod
def import_module(module_name: Union[Callable, str]) -> str:
@classmethod
def import_module(cls, module_name: Union[Callable, str]) -> str:
"""Imports module and returns its version."""
if not callable(module_name):
module = import_module(module_name)
module = cls._cache_modules.get(module_name, None)
if module is None:
if module_name in sys.modules:
sys.modules[module_name] = module = import_module(module_name)
else:
cls._cache_modules[module_name] = module = import_module(
module_name
)
else:
module = module_name()
return module.__version__

_lazy_impl = collections.defaultdict(list)

def _delazify(self, func_name):
for local_call in implement_for._lazy_impl[func_name]:
out = local_call()
return out

def __call__(self, fn):
# function names are unique
self.func_name = self.get_func_name(fn)
self.fn = fn
implement_for._lazy_impl[self.func_name].append(self._call)

@wraps(fn)
def _lazy_call_fn(*args, **kwargs):
# first time we call the function, we also do the replacement.
# This will cause the imports to occur only during the first call to fn
return self._delazify(self.func_name)(*args, **kwargs)

return _lazy_call_fn

def _call(self):

# If the module is missing replace the function with the mock.
fn = self.fn
func_name = self.func_name
implementations = implement_for._implementations

Expand All @@ -1043,41 +1098,50 @@ def unsupported(*args, **kwargs):
f"Supported version of '{func_name}' has not been found."
)

do_set = False
self.do_set = False
# Return fitting implementation if it was encountered before.
if func_name in implementations:
try:
# check that backends don't conflict
version = self.import_module(self.module_name)
if self.check_version(version, self.from_version, self.to_version):
do_set = True
if not do_set:
return implementations[func_name]
self.do_set = True
if not self.do_set:
return implementations[func_name].fn
except ModuleNotFoundError:
# then it's ok, there is no conflict
return implementations[func_name]
return implementations[func_name].fn
else:
try:
version = self.import_module(self.module_name)
if self.check_version(version, self.from_version, self.to_version):
do_set = True
self.do_set = True
except ModuleNotFoundError:
return unsupported
if do_set:
implementations[func_name] = fn
if self.do_set:
self.module_set()
return fn
return unsupported

@classmethod
def reset(cls, setters=None):
if setters is None:
setters = copy(cls._setters)
cls._setters = []
cls._implementations = {}
for setter in setters:
setter(setter.fn)
cls._setters.append(setter)
def reset(cls, setters_dict: Dict[str, implement_for] = None):
"""Resets the setters in setter_dict.
``setter_dict`` is a copy of implementations. We just need to iterate through its
values and call :meth:`~.module_set` for each.
"""
if setters_dict is None:
setters_dict = copy(cls._implementations)
for setter in setters_dict.values():
setter.module_set()

def __repr__(self):
return (
f"{self.__class__.__name__}("
f"module_name={self.module_name}({self.from_version, self.to_version}), "
f"fn_name={self.fn.__name__}, cls={self._get_cls(self.fn)}, is_set={self.do_set})"
)


def _unfold_sequence(seq):
Expand Down

0 comments on commit 95bd846

Please sign in to comment.