Skip to content

Commit

Permalink
fix test loops
Browse files Browse the repository at this point in the history
  • Loading branch information
ofirgo committed Jan 14, 2025
1 parent 06fefb0 commit a77e905
Showing 1 changed file with 50 additions and 38 deletions.
88 changes: 50 additions & 38 deletions tests_pytest/base_test_classes/base_tpc_attach2fw_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

from typing import NamedTuple

import pytest

Expand All @@ -36,6 +36,8 @@
signedness=schema.Signedness.AUTO
)

OpSet = NamedTuple("OpSet", [('op_name', str), ('op_list', list)])


class BaseTpcAttach2FrameworkTest:

Expand Down Expand Up @@ -63,26 +65,31 @@ def test_attach2fw_attach_without_attributes(self):
default_qc_options = schema.QuantizationConfigOptions(quantization_configurations=(default_op_cfg,))
tested_qc_options = schema.QuantizationConfigOptions(quantization_configurations=(tested_op_cfg,))

for op_name, op_list in self.attach2fw._opset2layer.items():
if op_name not in self.attach2fw._opset2attr_mapping.keys():
tpc = schema.TargetPlatformCapabilities(
default_qco=default_qc_options,
operator_set=tuple([schema.OperatorsSet(name=op_name, qc_options=tested_qc_options)]))
opsets_without_attrs = [OpSet(op_name=op_name, op_list=op_list)
for op_name, op_list in self.attach2fw._opset2layer.items()
if op_name not in self.attach2fw._opset2attr_mapping.keys()]

assert len(opsets_without_attrs) > 0

fw_quant_capabilities = self.attach2fw.attach(tpc) # Run 'attach' to test operator attach to framework
for opset in opsets_without_attrs:
tpc = schema.TargetPlatformCapabilities(
default_qco=default_qc_options,
operator_set=tuple([schema.OperatorsSet(name=opset.op_name, qc_options=tested_qc_options)]))

assert isinstance(fw_quant_capabilities, FrameworkQuantizationCapabilities)
fw_quant_capabilities = self.attach2fw.attach(tpc) # Run 'attach' to test operator attach to framework

all_mapped_ops = fw_quant_capabilities.layer2qco.copy()
all_mapped_ops.update(fw_quant_capabilities.filterlayer2qco)
if len(op_list) == 0:
assert len(all_mapped_ops) == 0
else:
assert len(all_mapped_ops) == len(op_list)
assert isinstance(fw_quant_capabilities, FrameworkQuantizationCapabilities)

for qco in all_mapped_ops.values():
assert len(qco.quantization_configurations) == 1
assert qco.base_config.activation_n_bits == 42
all_mapped_ops = fw_quant_capabilities.layer2qco.copy()
all_mapped_ops.update(fw_quant_capabilities.filterlayer2qco)
if len(opset.op_list) == 0:
assert len(all_mapped_ops) == 0
else:
assert len(all_mapped_ops) == len(opset.op_list)

for qco in all_mapped_ops.values():
assert len(qco.quantization_configurations) == 1
assert qco.base_config.activation_n_bits == 42


def test_attach2fw_attach_linear_op_with_attributes(self):
Expand All @@ -103,32 +110,37 @@ def test_attach2fw_attach_linear_op_with_attributes(self):
default_qc_options = schema.QuantizationConfigOptions(quantization_configurations=(default_op_cfg,))
tested_qc_options = schema.QuantizationConfigOptions(quantization_configurations=(tested_op_cfg,))

for op_name, op_list in self.attach2fw._opset2layer.items():
if op_name in self.attach2fw._opset2attr_mapping.keys():
tpc = schema.TargetPlatformCapabilities(
default_qco=default_qc_options,
operator_set=tuple([schema.OperatorsSet(name=op_name, qc_options=tested_qc_options)]))
opsets_with_attrs = [OpSet(op_name=op_name, op_list=op_list)
for op_name, op_list in self.attach2fw._opset2layer.items()
if op_name in self.attach2fw._opset2attr_mapping.keys()]

assert len(opsets_with_attrs) > 0

for opset in opsets_with_attrs:
tpc = schema.TargetPlatformCapabilities(
default_qco=default_qc_options,
operator_set=tuple([schema.OperatorsSet(name=opset.op_name, qc_options=tested_qc_options)]))

fw_quant_capabilities = self.attach2fw.attach(tpc) # Run 'attach' to test operator attach to framework
fw_linear_attr_names = self.attach2fw._opset2attr_mapping[op_name]
fw_quant_capabilities = self.attach2fw.attach(tpc) # Run 'attach' to test operator attach to framework
fw_linear_attr_names = self.attach2fw._opset2attr_mapping[opset.op_name]

assert isinstance(fw_quant_capabilities, FrameworkQuantizationCapabilities)
assert isinstance(fw_quant_capabilities, FrameworkQuantizationCapabilities)

all_mapped_ops = fw_quant_capabilities.layer2qco.copy()
all_mapped_ops.update(fw_quant_capabilities.filterlayer2qco)
if len(op_list) == 0:
assert len(all_mapped_ops) == 0
else:
assert len(all_mapped_ops) == len(op_list)
all_mapped_ops = fw_quant_capabilities.layer2qco.copy()
all_mapped_ops.update(fw_quant_capabilities.filterlayer2qco)
if len(opset.op_list) == 0:
assert len(all_mapped_ops) == 0
else:
assert len(all_mapped_ops) == len(opset.op_list)

for qco in all_mapped_ops.values():
assert len(qco.quantization_configurations) == 1
assert qco.base_config.default_weight_attr_config == default_attr_config
for qco in all_mapped_ops.values():
assert len(qco.quantization_configurations) == 1
assert qco.base_config.default_weight_attr_config == default_attr_config

for attr_name, fw_layer2attr_mapping in fw_linear_attr_names.items():
assert isinstance(fw_layer2attr_mapping, DefaultDict)
layer_attr_mapping = fw_layer2attr_mapping.get(op_list[0])
assert qco.base_config.attr_weights_configs_mapping.get(layer_attr_mapping) == tested_attr_cfg
for attr_name, fw_layer2attr_mapping in fw_linear_attr_names.items():
assert isinstance(fw_layer2attr_mapping, DefaultDict)
layer_attr_mapping = fw_layer2attr_mapping.get(opset.op_list[0])
assert qco.base_config.attr_weights_configs_mapping.get(layer_attr_mapping) == tested_attr_cfg


def test_attach2fw_attach_to_default_config(self):
Expand Down

0 comments on commit a77e905

Please sign in to comment.