Skip to content

Commit

Permalink
[SW-191415] update fp8 maxAbs observer using torch.copy_
Browse files Browse the repository at this point in the history
Change-Id: I3923c832f9a8a2b14e392f3f4719d233a457702f
  • Loading branch information
dudilester committed Jul 24, 2024
1 parent 7f62871 commit 5e3a679
Showing 1 changed file with 4 additions and 18 deletions.
22 changes: 4 additions & 18 deletions neural_compressor/torch/algorithms/fp8_quant/_core/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
)
patched_types.add(type(mod))

set_hqt_config(mod, top_level_config)
mod_extra_config = init_measure_object(
mod,
name,
Expand All @@ -104,7 +105,6 @@ def register_patched_measure_modules(model, mod_list, observer_class, d_shapes=N
(d_shapes[name] if ((d_shapes is not None) and (name in d_shapes)) else None),
params,
)
set_hqt_config(mod, top_level_config) # set config in the module, as it consumed by the patched module
pmod = patch_module_measure(mod, mod_extra_config, mod_default_dict)
for param_name in pmod._mod_extra_config.params:
param = getattr(pmod, param_name)
Expand Down Expand Up @@ -247,27 +247,13 @@ def __init__(self, name, mod, d_shape=None, params=None):
self.mod = mod
self.first = True
self.used = False
self.state = self.init_state_from_shape(d_shape)

def init_state(self, x):
device = x.device
state = torch.zeros((1, 1), device=device, dtype=torch.float32)
self.shape = list(x.shape)
return state

def init_state_from_shape(self, x_shape, device="hpu"):
state = torch.zeros((1, 1), device=device, dtype=torch.float32)
self.first = False
return state
config = get_hqt_config(mod).cfg
self.state = torch.zeros((1, 1), device="hpu", dtype=config["hp_dtype"])

def update_state(self, x):
# TODO: [SW-189690] Find better way to update self.state in MaxAbsObserver class in HQT
self.state = torch.maximum(torch.max(torch.abs(x)), self.state)
self.state.copy_(torch.maximum(torch.max(torch.abs(x)), self.state))

def measure(self, x):
if self.first:
self.state = self.init_state(x)
self.first = False
self.update_state(x)
self.used = True

Expand Down

0 comments on commit 5e3a679

Please sign in to comment.