From bf55ebc8c4f3df89a2d82499d6263eedbd373d9c Mon Sep 17 00:00:00 2001 From: Vasily Shamporov Date: Fri, 14 Apr 2023 12:30:25 +0200 Subject: [PATCH] Fix derived class handling --- docs/api/source/conf.py | 9 +++++++-- nncf/common/api_marker.py | 4 +++- nncf/common/quantization/structs.py | 1 - nncf/torch/automl/environment/__init__.py | 2 +- 4 files changed, 11 insertions(+), 5 deletions(-) diff --git a/docs/api/source/conf.py b/docs/api/source/conf.py index d353c2da529..4e1ae63bbf6 100644 --- a/docs/api/source/conf.py +++ b/docs/api/source/conf.py @@ -58,8 +58,13 @@ def collect_api_entities() -> List[str]: if objects_module == modname: if inspect.isclass(obj) or inspect.isfunction(obj): if hasattr(obj, "_nncf_api_marker"): - print(f"\t{obj_name}") - api_fqns.append(f"{modname}.{obj_name}") + marked_object_name = obj._nncf_api_marker + # Check the actual name of the originally marked object + # so that the classes derived from base API classes don't + # all automatically end up in API + if marked_object_name == obj.__name__: + print(f"\t{obj_name}") + api_fqns.append(f"{modname}.{obj_name}") print() skipped_str = '\n'.join([f"{k}: {v}" for k, v in skipped_modules.items()]) diff --git a/nncf/common/api_marker.py b/nncf/common/api_marker.py index 912ee8f3ef3..5961a061476 100644 --- a/nncf/common/api_marker.py +++ b/nncf/common/api_marker.py @@ -5,5 +5,7 @@ def __init__(self): pass def __call__(self, obj): - setattr(obj, api.API_MARKER_ATTR, True) + # The value of the marker will be useful in determining + # whether we are handling a base class or a derived one. + setattr(obj, api.API_MARKER_ATTR, obj.__name__) return obj diff --git a/nncf/common/quantization/structs.py b/nncf/common/quantization/structs.py index 251991b2d46..61f284b21e1 100644 --- a/nncf/common/quantization/structs.py +++ b/nncf/common/quantization/structs.py @@ -15,7 +15,6 @@ from enum import Enum from typing import Dict, List, Optional, Any -from nncf.common.api_marker import api from nncf.common.graph import NNCFNode from nncf.common.graph import NNCFNodeName from nncf.config.schemata.defaults import QUANTIZATION_BITS diff --git a/nncf/torch/automl/environment/__init__.py b/nncf/torch/automl/environment/__init__.py index de0da022181..8727b935935 100644 --- a/nncf/torch/automl/environment/__init__.py +++ b/nncf/torch/automl/environment/__init__.py @@ -9,4 +9,4 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. -""" \ No newline at end of file +"""