diff --git a/tools/profiler.py b/tools/profiler.py index 822bb83255..81dcc0b996 100644 --- a/tools/profiler.py +++ b/tools/profiler.py @@ -80,8 +80,8 @@ def __init__(self, model): self.model = model @TimeCounter.count_time(Backend.PYTORCH.value) - def forward(self, *args, **kwargs): - return self.model(*args, **kwargs) + def test_step(self, *args, **kwargs): + return self.model.test_step(*args, **kwargs) def main():