Skip to content

Commit

Permalink
Add more API entities
Browse files Browse the repository at this point in the history
  • Loading branch information
vshampor committed Apr 17, 2023
1 parent bf55ebc commit 85d5191
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 4 deletions.
2 changes: 2 additions & 0 deletions nncf/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@
from typing import Generic
from typing import TypeVar

from nncf.common.api_marker import api

DataItem = TypeVar('DataItem')
ModelInput = TypeVar('ModelInput')


@api()
class Dataset(Generic[DataItem, ModelInput]):
"""
The `nncf.Dataset` class defines the interface by which compression algorithms
Expand Down
5 changes: 5 additions & 0 deletions nncf/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@

from enum import Enum

from nncf.common.api_marker import api


@api()
class TargetDevice(Enum):
"""
Describes the target device the specificity of which will be taken
Expand All @@ -32,6 +36,7 @@ class TargetDevice(Enum):
CPU_SPR = 'CPU_SPR'


@api()
class ModelType(Enum):
"""
Describes the model type the specificity of which will be taken into
Expand Down
10 changes: 6 additions & 4 deletions nncf/scopes.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@
import re
from typing import List, Optional

from nncf.common.api_marker import api
from nncf.common.logging import nncf_logger
from nncf.common.graph.graph import NNCFGraph


@api()
class IgnoredScope:
"""
Dataclass that contains description of the ignored scope.
Expand Down Expand Up @@ -114,11 +116,11 @@ def get_ignored_node_names_from_ignored_scope(ignored_scope: IgnoredScope,
if ignored_node_name in node_names:
matched_by_names.append(ignored_node_name)
if strict and len(ignored_scope.names) != len(matched_by_names):
skipped_names = set(ignored_scope.names) - set(matched_by_names)
raise RuntimeError(f'Ignored nodes with name {list(skipped_names)}'
' were not found in the NNCFGraph. ' + error_msg)
skipped_names = set(ignored_scope.names) - set(matched_by_names)
raise RuntimeError(f'Ignored nodes with name {list(skipped_names)}'
' were not found in the NNCFGraph. ' + error_msg)
nncf_logger.info(f'{len(matched_by_names)}'
' ignored nodes was found by name in the NNCFGraph')
' ignored nodes was found by name in the NNCFGraph')

matched_by_patterns = []
if ignored_scope.patterns:
Expand Down

0 comments on commit 85d5191

Please sign in to comment.