Skip to content

Commit

Permalink
support registry with basicsr suffix: can have repeated names in Basi…
Browse files Browse the repository at this point in the history
…cSR and other repos
  • Loading branch information
xinntao committed Apr 6, 2022
1 parent 6697f41 commit 2f0ad00
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions basicsr/utils/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,15 @@ def __init__(self, name):
self._name = name
self._obj_map = {}

def _do_register(self, name, obj):
def _do_register(self, name, obj, suffix=None):
if isinstance(suffix, str):
name = name + '_' + suffix

assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
f"in '{self._name}' registry!")
self._obj_map[name] = obj

def register(self, obj=None):
def register(self, obj=None, suffix=None):
"""
Register the given object under the the name `obj.__name__`.
Can be used as either a decorator or not.
Expand All @@ -50,17 +53,20 @@ def register(self, obj=None):
# used as a decorator
def deco(func_or_class):
name = func_or_class.__name__
self._do_register(name, func_or_class)
self._do_register(name, func_or_class, suffix)
return func_or_class

return deco

# used as a function call
name = obj.__name__
self._do_register(name, obj)
self._do_register(name, obj, suffix)

def get(self, name):
def get(self, name, suffix='basicsr'):
ret = self._obj_map.get(name)
if ret is None:
ret = self._obj_map.get(name + '_' + suffix)
print(f'Name {name} is not found, use name: {name}_{suffix}!')
if ret is None:
raise KeyError(f"No object named '{name}' found in '{self._name}' registry!")
return ret
Expand Down

0 comments on commit 2f0ad00

Please sign in to comment.