Skip to content

Commit

Permalink
Add more API entities
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Apr 19, 2023
1 parent 30ab20a commit 79d8f1c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
2 changes: 2 additions & 0 deletions nncf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
from typing import Generic
from typing import TypeVar

from nncf.common.api_marker import api

DataItem = TypeVar('DataItem')
ModelInput = TypeVar('ModelInput')


@api()
class Dataset(Generic[DataItem, ModelInput]):
"""
The `nncf.Dataset` class defines the interface by which compression algorithms
Expand Down
5 changes: 5 additions & 0 deletions nncf/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

from enum import Enum

from nncf.common.api_marker import api


@api()
class TargetDevice(Enum):
"""
Describes the target device the specificity of which will be taken
Expand All @@ -32,6 +36,7 @@ class TargetDevice(Enum):
CPU_SPR = 'CPU_SPR'


@api()
class ModelType(Enum):
"""
Describes the model type the specificity of which will be taken into
Expand Down
12 changes: 7 additions & 5 deletions nncf/scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
import re
from typing import List, Optional

from nncf.common.graph.graph import NNCFGraph
from nncf.common.api_marker import api
from nncf.common.logging import nncf_logger
from nncf.common.graph.graph import NNCFGraph


@api()
@dataclass
class IgnoredScope:
"""
Expand Down Expand Up @@ -111,11 +113,11 @@ def get_ignored_node_names_from_ignored_scope(ignored_scope: IgnoredScope,
if ignored_node_name in node_names:
matched_by_names.append(ignored_node_name)
if strict and len(ignored_scope.names) != len(matched_by_names):
skipped_names = set(ignored_scope.names) - set(matched_by_names)
raise RuntimeError(f'Ignored nodes with name {list(skipped_names)}'
' were not found in the NNCFGraph. ' + error_msg)
skipped_names = set(ignored_scope.names) - set(matched_by_names)
raise RuntimeError(f'Ignored nodes with name {list(skipped_names)}'
' were not found in the NNCFGraph. ' + error_msg)
nncf_logger.info(f'{len(matched_by_names)}'
' ignored nodes was found by name in the NNCFGraph')
' ignored nodes was found by name in the NNCFGraph')

matched_by_patterns = []
if ignored_scope.patterns:
Expand Down

0 comments on commit 79d8f1c

Please sign in to comment.