Skip to content

Commit

Permalink
INC Bench pruning support (#295)
Browse files Browse the repository at this point in the history
Signed-off-by: bmyrcha <[email protected]>
  • Loading branch information
bmyrcha authored Dec 17, 2022
1 parent 5b9be25 commit d24fea6
Show file tree
Hide file tree
Showing 90 changed files with 5,216 additions and 764 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,13 +64,6 @@ def execute_benchmark(data: Dict[str, Any]) -> None:
project_id = benchmark_details["project_id"]
project_details = ProjectAPIInterface.get_project_details({"id": project_id})

BenchmarkAPIInterface.update_benchmark_status(
{
"id": benchmark_id,
"status": ExecutionStatus.WIP,
},
)

response_data = execute_real_benchmark(
request_id=request_id,
project_details=project_details,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Pruning configuration generator class."""
from typing import Any

from neural_compressor.ux.components.config_generator.config_generator import ConfigGenerator
from neural_compressor.ux.utils.workload.config import Config
from neural_compressor.ux.utils.workload.evaluation import Accuracy, Evaluation, Metric
from neural_compressor.ux.utils.workload.pruning import Pruning


class PruningConfigGenerator(ConfigGenerator):
"""PruningConfigGenerator class."""

def __init__(self, *args: Any, **kwargs: Any) -> None:
"""Initialize configuration generator."""
super().__init__(*args, **kwargs)
data = kwargs.get("data", {})
self.pruning_configuration: dict = data["pruning_details"]

def generate(self) -> None:
"""Generate yaml config file."""
config = Config()
config.load(self.predefined_config_path)
config.quantization = None
config.model = self.generate_model_config()
config.evaluation = self.generate_evaluation_config()
config.pruning = self.generate_pruning_config()
config.dump(self.config_path)

def generate_evaluation_config(self) -> Evaluation:
"""Generate evaluation configuration."""
evaluation = Evaluation()
evaluation.accuracy = Accuracy()

if self.metric:
evaluation.accuracy.metric = Metric(self.metric)

evaluation.accuracy.dataloader = self.generate_dataloader_config(batch_size=1)
evaluation.set_accuracy_postprocess_transforms(self.transforms)
return evaluation

def generate_pruning_config(self) -> Pruning:
"""Generate graph optimization configuration."""
pruning = Pruning(self.pruning_configuration)
if pruning.train is not None:
pruning.train.dataloader = self.generate_dataloader_config(batch_size=1)
pruning.train.set_postprocess_transforms(self.transforms)
return pruning
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Configuration type parser."""
from typing import Any


class PruningConfigParser:
"""Pruning configuration parser class."""

def parse(self, input_data: list) -> dict:
"""Parse configuration."""
raise NotImplementedError

def generate_tree(self, input_data: dict) -> list:
"""Generate tree from pruning configuration."""
parsed_tree = self.parse_entry(input_data)
return parsed_tree

def parse_entry(self, input_data: dict) -> Any:
"""Parse configuration entry to tree element."""
config_tree = []
for key, value in input_data.items():
if key in ["train", "approach"] and value is None:
continue
parsed_entry = {"name": key}
if isinstance(value, dict):
children = self.parse_entry(value)
parsed_entry.update({"children": children})
elif isinstance(value, list):
for list_entry in value:
parsed_list_entries = self.parse_entry(list_entry)
parsed_entry.update({"children": parsed_list_entries})
else:
parsed_entry.update({"value": value})
config_tree.append(parsed_entry)
return config_tree
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
# -*- coding: utf-8 -*-
# Copyright (c) 2022 Intel Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# flake8: noqa
# mypy: ignore-errors
"""pruning_support
Revision ID: 644ec953a7dc
Revises: 6ece06672ed3
Create Date: 2022-12-09 17:22:17.310141
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.orm import sessionmaker

from neural_compressor.ux.components.db_manager.db_manager import DBManager
from neural_compressor.ux.components.db_manager.db_models.optimization_type import OptimizationType
from neural_compressor.ux.components.db_manager.db_models.precision import (
Precision,
precision_optimization_type_association,
)
from neural_compressor.ux.utils.consts import OptimizationTypes, Precisions

db_manager = DBManager()
Session = sessionmaker(bind=db_manager.engine)

# revision identifiers, used by Alembic.
revision = "644ec953a7dc"
down_revision = "6ece06672ed3"
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
with Session.begin() as db_session:
pruning_optimization_id = OptimizationType.add(
db_session=db_session,
name=OptimizationTypes.PRUNING.value,
)
fp32_precision_id = Precision.get_precision_by_name(
db_session=db_session,
precision_name=Precisions.FP32.value,
)[0]

query = precision_optimization_type_association.insert().values(
precision_id=fp32_precision_id,
optimization_type_id=pruning_optimization_id,
)
db_session.execute(query)

op.create_table(
"pruning_details",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("train", sa.String(), nullable=True),
sa.Column("approach", sa.String(), nullable=True),
sa.Column(
"created_at",
sa.DateTime(),
server_default=sa.text("(CURRENT_TIMESTAMP)"),
nullable=False,
),
sa.Column("modified_at", sa.DateTime(), nullable=True),
sa.PrimaryKeyConstraint("id", name=op.f("pk_pruning_details")),
)
with op.batch_alter_table("pruning_details", schema=None) as batch_op:
batch_op.create_index(batch_op.f("ix_pruning_details_id"), ["id"], unique=True)

op.create_table(
"example",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("name", sa.String(length=50), nullable=False),
sa.Column("framework", sa.Integer(), nullable=False),
sa.Column("domain", sa.Integer(), nullable=False),
sa.Column("dataset_type", sa.String(length=50), nullable=False),
sa.Column("model_url", sa.String(length=250), nullable=False),
sa.Column("config_url", sa.String(length=250), nullable=False),
sa.Column("created_at", sa.DateTime(), nullable=False),
sa.ForeignKeyConstraint(["domain"], ["domain.id"], name=op.f("fk_example_domain_domain")),
sa.ForeignKeyConstraint(
["framework"], ["framework.id"], name=op.f("fk_example_framework_framework")
),
sa.PrimaryKeyConstraint("id", name=op.f("pk_example")),
)
with op.batch_alter_table("example", schema=None) as batch_op:
batch_op.create_index(batch_op.f("ix_example_id"), ["id"], unique=False)

with op.batch_alter_table("model", schema=None) as batch_op:
batch_op.add_column(
sa.Column(
"supports_pruning",
sa.Boolean(),
default=False,
nullable=True,
),
)
op.execute("UPDATE model SET supports_pruning = false")

with op.batch_alter_table("model", schema=None) as batch_op:
batch_op.alter_column("supports_pruning", nullable=False)

with op.batch_alter_table("optimization", schema=None) as batch_op:
batch_op.add_column(sa.Column("pruning_details_id", sa.Integer(), nullable=True))
batch_op.create_foreign_key(
batch_op.f("fk_optimization_pruning_details_id_pruning_details"),
"pruning_details",
["pruning_details_id"],
["id"],
)

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table("optimization", schema=None) as batch_op:
batch_op.drop_constraint(
batch_op.f("fk_optimization_pruning_details_id_pruning_details"), type_="foreignkey"
)
batch_op.drop_column("pruning_details_id")

with op.batch_alter_table("model", schema=None) as batch_op:
batch_op.drop_column("supports_pruning")

with op.batch_alter_table("example", schema=None) as batch_op:
batch_op.drop_index(batch_op.f("ix_example_id"))

op.drop_table("example")
with op.batch_alter_table("pruning_details", schema=None) as batch_op:
batch_op.drop_index(batch_op.f("ix_pruning_details_id"))

op.drop_table("pruning_details")
# ### end Alembic commands ###
4 changes: 4 additions & 0 deletions neural_compressor/ux/components/db_manager/db_models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class Model(Base):
output_nodes = Column(String(250), nullable=False, default="")
supports_profiling = Column(Boolean, nullable=False, default=False)
supports_graph = Column(Boolean, nullable=False, default=False)
supports_pruning = Column(Boolean, nullable=False, default=False)
created_at = Column(DateTime, nullable=False, default=func.now())

project: Any = relationship("Project", back_populates="models")
Expand Down Expand Up @@ -104,6 +105,7 @@ def add(
domain_flavour_id: int,
supports_profiling: bool,
supports_graph: bool,
supports_pruning: bool,
) -> int:
"""
Add model to database.
Expand All @@ -123,6 +125,7 @@ def add(
domain_flavour_id=domain_flavour_id,
supports_profiling=supports_profiling,
supports_graph=supports_graph,
supports_pruning=supports_pruning,
)
db_session.add(new_model)
db_session.flush()
Expand Down Expand Up @@ -206,5 +209,6 @@ def build_info(model: Any) -> dict:
"output_nodes": json.loads(model.output_nodes),
"supports_profiling": model.supports_profiling,
"supports_graph": model.supports_graph,
"supports_pruning": model.supports_pruning,
"created_at": str(model.created_at),
}
Loading

0 comments on commit d24fea6

Please sign in to comment.