Skip to content

Commit

Permalink
Merge branch 'main' into new-tests-example-1
Browse files Browse the repository at this point in the history
  • Loading branch information
ofirgo committed Jan 28, 2025
2 parents a29ca35 + c3ec981 commit 9463b0d
Show file tree
Hide file tree
Showing 27 changed files with 1,139 additions and 377 deletions.
31 changes: 0 additions & 31 deletions .github/workflows/run_keras_sony_custom_layers.yml

This file was deleted.

5 changes: 3 additions & 2 deletions .github/workflows/run_keras_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,9 @@ jobs:
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install tensorflow==${{ inputs.tf-version }} sony-custom-layers pytest
pip install tensorflow==${{ inputs.tf-version }} sony-custom-layers
pip install pytest pytest-mock
pip check
- name: Run unittests
run: |
python -m unittest discover tests/keras_tests -v
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/run_pytorch_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install torch==${{ inputs.torch-version }} torchvision onnx onnxruntime onnxruntime-extensions
pip install pytest
pip install pytest pytest-mock
pip check
- name: Run unittests
run: |
python -m unittest discover tests/pytorch_tests -v
Expand Down
31 changes: 18 additions & 13 deletions .github/workflows/run_tests_suite_coverage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -30,43 +30,48 @@ jobs:
with:
python-version: '3.10'

- name: Set up Coverage
- name: Set up environment for common tests
run: |
python -m pip install --upgrade pip
pip install coverage
pip install -r requirements.txt coverage pytest pytest-mock
- name: Run common tests (unittest)
run: coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" unittest discover tests/common_tests -v

- name: Run common tests (pytest)
run: coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/common

- name: Set up TensorFlow environment
run: |
python -m venv tf_env
source tf_env/bin/activate
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install tensorflow==2.13.* coverage pytest
pip install -r requirements.txt tensorflow==2.13.* sony-custom-layers coverage pytest pytest-mock
- name: Run TensorFlow testsuite
- name: Run TensorFlow tests (unittest)
run: |
source tf_env/bin/activate
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" unittest tests/test_suite.py -v
- name: Run TensorFlow pytest
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" unittest discover tests/keras_tests -v
- name: Run TensorFlow tests (pytest)
run: |
source tf_env/bin/activate
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/keras
- name: Set up Pytorch environment
- name: Set up PyTorch environment
run: |
python -m venv torch_env
source torch_env/bin/activate
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install torch==2.0.* torchvision onnx onnxruntime onnxruntime-extensions coverage pytest
pip install torch==2.0.* torchvision onnx onnxruntime onnxruntime-extensions sony-custom-layers coverage pytest pytest-mock
- name: Run torch testsuite
- name: Run PyTorch tests (unittest)
run: |
source torch_env/bin/activate
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" unittest tests/test_suite.py -v
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" unittest discover tests/pytorch_tests -v
- name: Run torch pytest
- name: Run PyTorch tests (pytest)
run: |
source torch_env/bin/activate
coverage run --parallel-mode -m --omit "*__init__.py" --include "model_compression_toolkit/**/*.py" pytest tests_pytest/pytorch
Expand Down
9 changes: 8 additions & 1 deletion .github/workflows/tests_common.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,13 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -r requirements.txt
pip install -r requirements.txt
pip install pytest pytest-mock
pip check
- name: Run unittests
run: python -m unittest discover tests/common_tests -v

- name: Run pytest
run: pytest tests_pytest/common

9 changes: 6 additions & 3 deletions model_compression_toolkit/core/common/graph/base_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@
FrameworkQuantizationCapabilities


WeightAttrT = Union[str, int]


class BaseNode:
"""
Class to represent a node in a graph that represents the model.
Expand All @@ -40,7 +43,7 @@ def __init__(self,
framework_attr: Dict[str, Any],
input_shape: Tuple[Any],
output_shape: Tuple[Any],
weights: Dict[Union[str, int], np.ndarray],
weights: Dict[WeightAttrT, np.ndarray],
layer_class: type,
reuse: bool = False,
reuse_group: str = None,
Expand Down Expand Up @@ -189,7 +192,7 @@ def is_reused(self) -> bool:
"""
return self.reuse or self.reuse_group is not None

def _get_weight_name(self, name: Union[str, int]) -> List[Union[str, int]]:
def _get_weight_name(self, name: WeightAttrT) -> List[WeightAttrT]:
"""
Get weight names that match argument name (either string weights or integer for
positional weights).
Expand All @@ -203,7 +206,7 @@ def _get_weight_name(self, name: Union[str, int]) -> List[Union[str, int]]:
return [k for k in self.weights.keys()
if (isinstance(k, int) and name == k) or (isinstance(k, str) and name in k)]

def get_weights_by_keys(self, name: Union[str, int]) -> np.ndarray:
def get_weights_by_keys(self, name: WeightAttrT) -> np.ndarray:
"""
Get a node's weight by its name.
Args:
Expand Down
Loading

0 comments on commit 9463b0d

Please sign in to comment.