From 79d8f1c6920e12068b7d1f8af5311727dbfbe1a4 Mon Sep 17 00:00:00 2001 From: Vasily Shamporov Date: Mon, 17 Apr 2023 14:14:53 +0200 Subject: [PATCH] Add more API entities --- nncf/data/dataset.py | 2 ++ nncf/parameters.py | 5 +++++ nncf/scopes.py | 12 +++++++----- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/nncf/data/dataset.py b/nncf/data/dataset.py index 8b7ea3f3960..3b974991bce 100644 --- a/nncf/data/dataset.py +++ b/nncf/data/dataset.py @@ -18,11 +18,13 @@ from typing import Generic from typing import TypeVar +from nncf.common.api_marker import api DataItem = TypeVar('DataItem') ModelInput = TypeVar('ModelInput') +@api() class Dataset(Generic[DataItem, ModelInput]): """ The `nncf.Dataset` class defines the interface by which compression algorithms diff --git a/nncf/parameters.py b/nncf/parameters.py index cb4a90cb9b2..d035c7a2b15 100644 --- a/nncf/parameters.py +++ b/nncf/parameters.py @@ -13,6 +13,10 @@ from enum import Enum +from nncf.common.api_marker import api + + +@api() class TargetDevice(Enum): """ Describes the target device the specificity of which will be taken @@ -32,6 +36,7 @@ class TargetDevice(Enum): CPU_SPR = 'CPU_SPR' +@api() class ModelType(Enum): """ Describes the model type the specificity of which will be taken into diff --git a/nncf/scopes.py b/nncf/scopes.py index feff6869c35..6ded64fd3a4 100644 --- a/nncf/scopes.py +++ b/nncf/scopes.py @@ -16,10 +16,12 @@ import re from typing import List, Optional -from nncf.common.graph.graph import NNCFGraph +from nncf.common.api_marker import api from nncf.common.logging import nncf_logger +from nncf.common.graph.graph import NNCFGraph +@api() @dataclass class IgnoredScope: """ @@ -111,11 +113,11 @@ def get_ignored_node_names_from_ignored_scope(ignored_scope: IgnoredScope, if ignored_node_name in node_names: matched_by_names.append(ignored_node_name) if strict and len(ignored_scope.names) != len(matched_by_names): - skipped_names = set(ignored_scope.names) - set(matched_by_names) - raise RuntimeError(f'Ignored nodes with name {list(skipped_names)}' - ' were not found in the NNCFGraph. ' + error_msg) + skipped_names = set(ignored_scope.names) - set(matched_by_names) + raise RuntimeError(f'Ignored nodes with name {list(skipped_names)}' + ' were not found in the NNCFGraph. ' + error_msg) nncf_logger.info(f'{len(matched_by_names)}' - ' ignored nodes was found by name in the NNCFGraph') + ' ignored nodes was found by name in the NNCFGraph') matched_by_patterns = [] if ignored_scope.patterns: