Skip to content

Commit

Permalink
support ipex xpu (#1348)
Browse files Browse the repository at this point in the history
Signed-off-by: Xin He <[email protected]>
Signed-off-by: Cheng, Zixuan <[email protected]>
  • Loading branch information
xin3he authored Nov 22, 2023
1 parent c3214c9 commit af0b50f
Show file tree
Hide file tree
Showing 14 changed files with 555 additions and 249 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@
help='run benchmark')
parser.add_argument('--ipex', dest='ipex', action='store_true',
help='tuning or benchmark with Intel PyTorch Extension')
parser.add_argument('--xpu', action='store_true',
help='whether use xpu')

best_acc1 = 0

Expand Down Expand Up @@ -225,7 +227,8 @@ def main_worker(gpu, ngpus_per_node, args):
model.cuda()
else:
model = torch.nn.DataParallel(model)

if args.xpu:
model = model.to("xpu")
# define loss function (criterion) and optimizer
criterion = nn.CrossEntropyLoss()
#criterion = nn.CrossEntropyLoss().cuda(args.gpu)
Expand Down Expand Up @@ -297,7 +300,10 @@ def eval_func(model):
if args.tune:
from neural_compressor import PostTrainingQuantConfig
from neural_compressor import quantization
conf = PostTrainingQuantConfig(backend='ipex')
if args.xpu:
conf = PostTrainingQuantConfig(backend='ipex', device="xpu")
else:
conf = PostTrainingQuantConfig(backend='ipex')
q_model = quantization.fit(model,
conf,
calib_dataloader=val_loader,
Expand Down Expand Up @@ -417,6 +423,9 @@ def validate(val_loader, model, criterion, args):
if args.gpu is not None:
input = input.cuda(args.gpu, non_blocking=True)
target = target.cuda(args.gpu, non_blocking=True)
if args.xpu:
input = input.to("xpu")
target = target.to("xpu")

# compute output
output = model(input)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
accelerate
datasets>=1.8.0
transformers==4.30.0
transformers>=4.34.1
tensorboard
tqdm
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ class ModelArguments:
"help": "The inference iterations to run for benchmark."
},
)

xpu: bool = field(
default=False, metadata={"help": "whether to use xpu"}
)

@dataclass
class DataTrainingArguments:
Expand Down Expand Up @@ -650,6 +652,9 @@ def take_eval_steps(model, trainer, metric_name, save_metrics=False):
def eval_func(model):
return take_eval_steps(model, trainer, metric_name)

if model_args.xpu:
model = model.to("xpu")

if model_args.tune:
ipex.nn.utils._model_convert.replace_dropout_with_identity(model)
from neural_compressor.config import PostTrainingQuantConfig
Expand All @@ -664,6 +669,8 @@ def eval_func(model):
else:
example_inputs = None # please provide correct example_inputs if necessary.
conf = PostTrainingQuantConfig(backend="ipex", calibration_sampling_size=800, example_inputs=example_inputs)
if model_args.xpu:
conf.device = "xpu"
q_model = quantization.fit(model,
conf,
calib_dataloader=eval_dataloader,
Expand All @@ -680,7 +687,7 @@ def eval_func(model):
example_inputs = get_example_inputs(model, eval_dataloader)
model = ipex.optimize(model)
with torch.no_grad():
model = torch.jit.trace(model, example_inputs, strict=False)
model = torch.jit.trace(model, example_inputs=example_inputs, strict=False)
model = torch.jit.freeze(model)

if model_args.benchmark or model_args.accuracy_only:
Expand All @@ -692,6 +699,8 @@ def eval_func(model):
iteration=model_args.iters,
cores_per_instance=4,
num_of_instance=1)
if model_args.xpu:
b_conf.device = "xpu"
benchmark.fit(model, b_conf, b_dataloader=eval_dataloader)
else:
eval_func(model)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function run_tuning {
--dataset_name squad \
--do_eval \
--max_seq_length 384 \
--no_cuda \
--no_cuda \ # remove if using xpu
--tune \
--output_dir $tuned_checkpoint
fi
Expand All @@ -55,7 +55,7 @@ function run_tuning {
--dataset_name squad \
--do_eval \
--max_seq_length 384 \
--no_cuda \
--no_cuda \ # remove if using xpu
--tune \
--output_dir $tuned_checkpoint
fi
Expand Down
121 changes: 54 additions & 67 deletions neural_compressor/adaptor/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,73 +78,37 @@ def get_torch_white_list(approach):
return white_list


def pytorch_forward_wrapper(model, input, device="cpu", conf=None, running_mode="inference"):
def pytorch_forward_wrapper(
model,
input,
conf=None,
backend="default",
running_mode="inference",
):
version = get_torch_version()
if isinstance(input, dict) or isinstance(input, UserDict):
if device == "cpu":
output = model(**input)
elif device == "ipex":
# have to split the case to avoid exposing ipex.DEVICE outside
# which require intel extension installed
if version.release < Version("1.12.0").release: # pragma: no cover
if running_mode == "calibration":
with ipex.quantization.calibrate(conf, default_recipe=True): # pylint: disable=E1101
output = model(**input)
else:
output = model(**input)
else:
output = model(**input)
else: # pragma: no cover
for inp in input.keys():
input[inp] = (
input[inp].to("dpcpp" if device == "gpu" else device)
if isinstance(input[inp], torch.Tensor)
else input[inp]
)
output = model(**input)
elif isinstance(input, list) or isinstance(input, tuple):
if device == "cpu":
output = model(*input)
elif device == "ipex":
if version.release < Version("1.12.0").release: # pragma: no cover
if running_mode == "calibration":
with ipex.quantization.calibrate(conf, default_recipe=True): # pylint: disable=E1101
output = model(*input)
else:
output = model(*input)
else:
output = model(*input)
else: # pragma: no cover
tmp_device = "dpcpp" if device == "gpu" else device
input = [
inp.to(tmp_device) if isinstance(inp, torch.Tensor) else inp for inp in input
] # pylint: disable=E1133
output = model(*input)
from .torch_utils.util import forward_wrapper

if (
version.release < Version("1.12.0").release and backend == "ipex" and running_mode == "calibration"
): # pragma: no cover
with ipex.quantization.calibrate(conf, default_recipe=True): # pylint: disable=E1101
output = forward_wrapper(model, input)
else:
if device == "cpu" or not isinstance(input, torch.Tensor):
output = model(input)
elif device == "ipex":
if version.release < Version("1.12.0").release: # pragma: no cover
if running_mode == "calibration":
with ipex.quantization.calibrate(conf, default_recipe=True): # pylint: disable=E1101
output = model(input)
else:
output = model(input)
else:
output = model(input)
else: # pragma: no cover
input = input.to("dpcpp" if device == "gpu" else device) # pylint: disable=no-member
output = model(input)
output = forward_wrapper(model, input)
return output


def get_example_inputs(model, dataloader):
version = get_torch_version()
from .torch_utils.util import move_input_device

# Suggest set dataloader like calib_dataloader
if dataloader is None:
return None
device = next(model.parameters()).device
try:
for idx, (input, label) in enumerate(dataloader):
input = move_input_device(input, device)
output = pytorch_forward_wrapper(model, input)
if isinstance(input, (dict, UserDict)): # pragma: no cover
assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0"
Expand All @@ -162,6 +126,7 @@ def get_example_inputs(model, dataloader):
break
except Exception as e: # pragma: no cover
for idx, input in enumerate(dataloader):
input = move_input_device(input, device)
output = pytorch_forward_wrapper(model, input)
if isinstance(input, (dict, UserDict)): # pragma: no cover
assert version.release >= Version("1.12.0").release, "INC support IPEX version >= 1.12.0"
Expand Down Expand Up @@ -814,6 +779,7 @@ def __init__(self, framework_specific_info):
self.bf16_ops = []
self.use_bf16 = framework_specific_info.get("use_bf16", True)
self.device = framework_specific_info["device"]
self.backend = framework_specific_info.get("backend", "default")
self.q_dataloader = framework_specific_info["q_dataloader"]
self.q_func = framework_specific_info.get("q_func", None)
self.benchmark = GLOBAL_STATE.STATE == MODE.BENCHMARK
Expand Down Expand Up @@ -881,14 +847,14 @@ def calib_func(self, model, dataloader, tmp_iterations, conf=None):
try:
for idx, (input, label) in enumerate(dataloader):
output = pytorch_forward_wrapper(
model, input, device=self.device, conf=conf, running_mode="calibration"
model, input, backend=self.backend, conf=conf, running_mode="calibration"
)
if idx >= tmp_iterations - 1:
break
except Exception as e:
for idx, input in enumerate(dataloader):
output = pytorch_forward_wrapper(
model, input, device=self.device, conf=conf, running_mode="calibration"
model, input, backend=self.backend, conf=conf, running_mode="calibration"
)
if idx >= tmp_iterations - 1:
break
Expand Down Expand Up @@ -936,7 +902,7 @@ def eval_func(self, model, dataloader, postprocess, metrics, measurer, iteration
if measurer is not None:
measurer.start()

output = pytorch_forward_wrapper(model, input, device=self.device, conf=conf)
output = pytorch_forward_wrapper(model, input, backend=self.backend, conf=conf)
if self.device != "cpu": # pragma: no cover
output = output.to("cpu")
label = label.to("cpu")
Expand Down Expand Up @@ -978,7 +944,7 @@ def eval_func(self, model, dataloader, postprocess, metrics, measurer, iteration
if measurer is not None:
measurer.start()

output = pytorch_forward_wrapper(model, input, device=self.device, conf=conf)
output = pytorch_forward_wrapper(model, input, backend=self.backend, conf=conf)

if measurer is not None:
measurer.end()
Expand Down Expand Up @@ -2272,7 +2238,7 @@ def train(self, model, dataloader, optimizer_tuple, criterion_tuple, hooks, **kw
on_step_begin(cnt)
print(".", end="", flush=True)
cnt += 1
output = pytorch_forward_wrapper(model_, image, device=device)
output = pytorch_forward_wrapper(model_, image)
loss = criterion(output, target)
if hooks is not None:
loss = on_after_compute_loss(image, output, loss)
Expand Down Expand Up @@ -2639,7 +2605,9 @@ def __init__(self, framework_specific_info):
super(PyTorch_IPEXAdaptor, self).__init__(framework_specific_info)
self.version = get_torch_version()
query_config_file = "pytorch_ipex.yaml"
self.query_handler = PyTorchQuery(local_config_file=os.path.join(os.path.dirname(__file__), query_config_file))
self.query_handler = PyTorchQuery(
device=self.device, local_config_file=os.path.join(os.path.dirname(__file__), query_config_file)
)
self.cfgs = None
self.fuse_ops = None
self.op_infos_from_cfgs = None
Expand All @@ -2651,7 +2619,6 @@ def __init__(self, framework_specific_info):
os.remove(self.ipex_config_path)
except:
logger.warning("Fail to remove {}.".format(self.ipex_config_path))
self.device = "ipex"

@dump_elapsed_time("Pass quantize model")
def quantize(self, tune_cfg, model, dataloader, q_func=None):
Expand All @@ -2669,6 +2636,8 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
# IPEX bug #1: deepcopied prepared model cannot do calibration, need model._model
# q_model._model is useless, but we need to copy other attributes, and pass the converted
# model to q_model. Also, sq will collect state_dict to origin_stat for recover
if self.device == "xpu":
model.to(self.device)
if self.performance_only:
q_model = model
else:
Expand Down Expand Up @@ -2722,7 +2691,12 @@ def quantize(self, tune_cfg, model, dataloader, q_func=None):
if not hasattr(model._model, "save_qconf_summary") or not hasattr(model._model, "load_qconf_summary"):
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig

if self.version.release >= Version("2.1").release:
if self.device == "xpu":
static_qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric),
)
elif self.version.release >= Version("2.1").release:
static_qconfig = ipex.quantization.default_static_qconfig_mapping
else:
static_qconfig = QConfig(
Expand Down Expand Up @@ -3107,7 +3081,12 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
), "IPEX need q_dataloader or example_inputs to prepare the model"
from torch.ao.quantization import MinMaxObserver, PerChannelMinMaxObserver, QConfig

if self.version.release >= Version("2.1").release:
if self.device == "xpu":
static_qconfig = QConfig(
activation=MinMaxObserver.with_args(qscheme=torch.per_tensor_affine, dtype=torch.quint8),
weight=MinMaxObserver.with_args(dtype=torch.qint8, qscheme=torch.per_tensor_symmetric),
)
elif self.version.release >= Version("2.1").release:
# HistogramObserver will cause a performance issue.
# static_qconfig = ipex.quantization.default_static_qconfig_mapping
qconfig = QConfig(
Expand Down Expand Up @@ -3145,6 +3124,9 @@ def _get_quantizable_ops_recursively(self, model, prefix, quantizable_ops):
static_qconfig = ipex.quantization.get_smooth_quant_qconfig_mapping(alpha=0.5)
if self.example_inputs is None:
self.example_inputs = get_example_inputs(model, self.q_dataloader)
from neural_compressor.adaptor.torch_utils.util import move_input_device

self.example_inputs = move_input_device(self.example_inputs, device=self.device)
if isinstance(self.example_inputs, dict):
model = ipex.quantization.prepare(
model, static_qconfig, example_kwarg_inputs=self.example_inputs, inplace=True
Expand Down Expand Up @@ -3400,7 +3382,7 @@ def _simple_inference(self, q_model, dataloader, iterations=1):
"""The function is used for ipex warm-up inference."""
if self.example_inputs is not None:
for _ in range(iterations):
if isinstance(self.example_inputs, tuple):
if isinstance(self.example_inputs, tuple) or isinstance(self.example_inputs, list):
q_model(*self.example_inputs)
elif isinstance(self.example_inputs, dict):
q_model(**self.example_inputs)
Expand Down Expand Up @@ -3919,7 +3901,7 @@ def train(self, model, dataloader, optimizer_tuple, criterion_tuple, hooks, **kw
on_step_begin(cnt)
print(".", end="", flush=True)
cnt += 1
output = pytorch_forward_wrapper(model._model, input, device=device)
output = pytorch_forward_wrapper(model._model, input)
loss = criterion(output, target)
if hooks is not None:
loss = on_after_compute_loss(input, output, loss)
Expand Down Expand Up @@ -4936,10 +4918,11 @@ def query_fw_capability(self, model):


class PyTorchQuery(QueryBackendCapability):
def __init__(self, local_config_file=None):
def __init__(self, device="cpu", local_config_file=None):
super().__init__()
self.version = get_torch_version()
self.cfg = local_config_file
self.device = device
self.cur_config = None
self._one_shot_query()

Expand Down Expand Up @@ -4973,6 +4956,10 @@ def _one_shot_query(self):
raise ValueError(
"Please check if the format of {} follows " "Neural Compressor yaml scheme.".format(self.cfg)
)
if self.device == "xpu":
self.cur_config = self.cur_config[self.device]
elif "cpu" in self.cur_config:
self.cur_config = self.cur_config["cpu"]
self._update_cfg_with_usr_definition()

def _update_cfg_with_usr_definition(self):
Expand Down
Loading

0 comments on commit af0b50f

Please sign in to comment.