Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use Flops Profiler to test model.generate() #2515

Merged
merged 13 commits into from
Jun 22, 2023
48 changes: 29 additions & 19 deletions deepspeed/profiling/flops_profiler/profiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1180,20 +1180,19 @@ def get_module_duration(module):
return duration


def get_model_profile(
model,
input_shape=None,
args=[],
kwargs={},
print_profile=True,
detailed=True,
module_depth=-1,
top_modules=1,
warm_up=1,
as_string=True,
output_file=None,
ignore_modules=None,
):
def get_model_profile(model,
input_shape=None,
args=[],
kwargs={},
print_profile=True,
detailed=True,
module_depth=-1,
top_modules=1,
warm_up=1,
as_string=True,
output_file=None,
ignore_modules=None,
mode='forward'):
"""Returns the total floating-point operations, MACs, and parameters of a model.

Example:
Expand Down Expand Up @@ -1239,18 +1238,29 @@ def get_model_profile(

args = [input]
assert (len(args) > 0) or (len(kwargs) > 0), "args and/or kwargs must be specified if input_shape is None"

for _ in range(warm_up):
if kwargs:
_ = model(*args, **kwargs)
if mode == 'forward':
_ = model(*args, **kwargs)
if mode == 'generate':
_ = model.generate(*args, **kwargs)
else:
_ = model(*args)
if mode == 'forward':
_ = model(*args)
if mode == 'generate':
_ = model.generate(*args)
prof.start_profile(ignore_list=ignore_modules)

if kwargs:
_ = model(*args, **kwargs)
if mode == 'forward':
_ = model(*args, **kwargs)
if mode == 'generate':
_ = model.generate(*args, **kwargs)
else:
_ = model(*args)
if mode == 'forward':
_ = model(*args)
if mode == 'generate':
_ = model.generate(*args)

flops = prof.get_total_flops()
macs = prof.get_total_macs()
Expand Down