Skip to content

Commit

Permalink
Add OptimItem in TBE codegen (pytorch#2518)
Browse files Browse the repository at this point in the history
Summary:

This diff abstracts the optimization arguments in TBE code generation
scripts by introducing the `OptimizerArgsSetItem` dataclass. This allows
us to avoid updating function and variable annotations every time we
add or remove arguments from the optimizer args.

Reviewed By: spcyppt

Differential Revision: D56192736
  • Loading branch information
sryap authored and facebook-github-bot committed Apr 19, 2024
1 parent dbae12b commit fe7e06c
Show file tree
Hide file tree
Showing 3 changed files with 206 additions and 202 deletions.
14 changes: 11 additions & 3 deletions fbgemm_gpu/codegen/genscript/generate_index_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,27 @@

try:
from .common import CodeTemplate
from .optimizer_args import FLOAT, OptimizerArgsSet
from .optimizer_args import (
FLOAT,
OptimizerArgsSet,
OptimizerArgsSetItem as OptimItem,
)
except ImportError:
# pyre-ignore[21]
from common import CodeTemplate

# pyre-ignore[21]
from optimizer_args import FLOAT, OptimizerArgsSet
from optimizer_args import (
FLOAT,
OptimizerArgsSet,
OptimizerArgsSetItem as OptimItem,
)


class IndexSelectGenerator:
@staticmethod
def generate() -> None:
optargs = OptimizerArgsSet.create([(FLOAT, "unused")])
optargs = OptimizerArgsSet.create([OptimItem(FLOAT, "unused")])
for template_file, generated_file in [
(
"training/forward/embedding_forward_split_template.cu",
Expand Down
174 changes: 82 additions & 92 deletions fbgemm_gpu/codegen/genscript/optimizer_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import jinja2

Expand Down Expand Up @@ -201,6 +201,23 @@ def make_ivalue_cast(ty: int) -> str:
return {INT: "toInt", FLOAT: "toDouble"}[ty]


######################################################################
# Optimizer Args Set Item
######################################################################


@dataclass
class OptimizerArgsSetItem:
ty: int # type
name: str
default: Union[float, int] = 0 # DEFAULT_ARG_VAL
ph_tys: Optional[List[int]] = None # placeholder types


# Alias b/c the name is too long
OptimItem = OptimizerArgsSetItem


######################################################################
# Optimizer Args
######################################################################
Expand All @@ -226,59 +243,49 @@ class OptimizerArgs:
@staticmethod
# pyre-ignore[3]
def create(
split_arg_spec: List[Tuple[int, str, Union[int, float]]],
augmented_arg_spec: List[Tuple[int, str, Union[int, float]]],
split_arg_spec: List[OptimItem],
arg_spec: List[OptimItem],
):
return OptimizerArgs(
split_kernel_args=[
make_kernel_arg(ty, name, default)
for (ty, name, default) in split_arg_spec
make_kernel_arg(s.ty, s.name, s.default) for s in split_arg_spec
],
split_kernel_args_no_defaults=[
make_kernel_arg(ty, name, None) for (ty, name, _) in split_arg_spec
make_kernel_arg(s.ty, s.name, None) for s in split_arg_spec
],
split_kernel_arg_constructors=[
make_kernel_arg_constructor(ty, name)
for (ty, name, default) in split_arg_spec
make_kernel_arg_constructor(s.ty, s.name) for s in split_arg_spec
],
split_cpu_kernel_args=[
make_cpu_kernel_arg(ty, name, default)
for (ty, name, default) in split_arg_spec
make_cpu_kernel_arg(s.ty, s.name, s.default) for s in split_arg_spec
],
split_cpu_kernel_arg_constructors=[
make_cpu_kernel_arg_constructor(ty, name)
for (ty, name, default) in split_arg_spec
make_cpu_kernel_arg_constructor(s.ty, s.name) for s in split_arg_spec
],
split_function_args=[
make_function_arg(ty, name, default)
for (ty, name, default) in split_arg_spec
make_function_arg(s.ty, s.name, s.default) for s in split_arg_spec
],
split_function_args_no_defaults=[
make_function_arg(ty, name, None)
for (ty, name, default) in split_arg_spec
],
split_tensors=[
name for (ty, name, default) in augmented_arg_spec if ty == TENSOR
make_function_arg(s.ty, s.name, None) for s in split_arg_spec
],
split_tensors=[s.name for s in arg_spec if s.ty == TENSOR],
split_saved_tensors=[
name
for (ty, name, default) in split_arg_spec
if ty in (TENSOR, INT_TENSOR, LONG_TENSOR)
s.name
for s in split_arg_spec
if s.ty in (TENSOR, INT_TENSOR, LONG_TENSOR)
],
saved_data=[
(name, make_ivalue_cast(ty))
for (ty, name, default) in augmented_arg_spec
if ty != TENSOR
(s.name, make_ivalue_cast(s.ty)) for s in arg_spec if s.ty != TENSOR
],
split_function_arg_names=[name for (ty, name, default) in split_arg_spec],
split_function_arg_names=[s.name for s in split_arg_spec],
split_function_schemas=[
make_function_schema_arg(ty, name, default)
for (ty, name, default) in split_arg_spec
make_function_schema_arg(s.ty, s.name, s.default)
for s in split_arg_spec
],
split_variables=["Variable()" for _ in split_arg_spec],
split_ref_kernel_args=[
make_kernel_arg(ty, name, default, pass_by_ref=True)
for (ty, name, default) in split_arg_spec
make_kernel_arg(s.ty, s.name, s.default, pass_by_ref=True)
for s in split_arg_spec
],
)

Expand All @@ -295,78 +302,61 @@ class OptimizerArgsSet:
any: OptimizerArgs

@staticmethod
def create_for_cpu(
augmented_arg_spec: List[Tuple[int, str, Union[float, int]]]
def create_optim_args(
arg_spec: List[OptimItem], ext_fn: Callable[[OptimItem], List[OptimItem]]
) -> OptimizerArgs:
split_arg_spec = []
for ty, arg, default in augmented_arg_spec:
if ty in (FLOAT, INT):
split_arg_spec.append((ty, arg, default))
for s in arg_spec:
if s.ty in (FLOAT, INT):
split_arg_spec.append(OptimItem(s.ty, s.name, s.default))
else:
assert ty == TENSOR
split_arg_spec.extend(
[
(TENSOR, f"{arg}_host", default),
(INT_TENSOR, f"{arg}_placements", default),
(LONG_TENSOR, f"{arg}_offsets", default),
]
)
return OptimizerArgs.create(split_arg_spec, augmented_arg_spec)
assert s.ty == TENSOR
split_arg_spec.extend(ext_fn(s))
return OptimizerArgs.create(split_arg_spec, arg_spec)

@staticmethod
def create_for_cuda(
augmented_arg_spec: List[Tuple[int, str, Union[float, int]]]
) -> OptimizerArgs:
split_arg_spec = []
for ty, arg, default in augmented_arg_spec:
if ty in (FLOAT, INT):
split_arg_spec.append((ty, arg, default))
else:
assert ty == TENSOR
split_arg_spec.extend(
[
(TENSOR, f"{arg}_dev", default),
(TENSOR, f"{arg}_uvm", default),
(INT_TENSOR, f"{arg}_placements", default),
(LONG_TENSOR, f"{arg}_offsets", default),
]
)
return OptimizerArgs.create(split_arg_spec, augmented_arg_spec)
def extend_for_cpu(spec: OptimItem) -> List[OptimItem]:
name = spec.name
default = spec.default
return [
OptimItem(TENSOR, f"{name}_host", default),
OptimItem(INT_TENSOR, f"{name}_placements", default),
OptimItem(LONG_TENSOR, f"{name}_offsets", default),
]

@staticmethod
def create_for_any(
augmented_arg_spec: List[Tuple[int, str, Union[float, int]]]
) -> OptimizerArgs:
split_arg_spec = []
for ty, arg, default in augmented_arg_spec:
if ty in (FLOAT, INT):
split_arg_spec.append((ty, arg, default))
else:
assert ty == TENSOR
split_arg_spec.extend(
[
(TENSOR, f"{arg}_host", default),
(TENSOR, f"{arg}_dev", default),
(TENSOR, f"{arg}_uvm", default),
(INT_TENSOR, f"{arg}_placements", default),
(LONG_TENSOR, f"{arg}_offsets", default),
]
)
return OptimizerArgs.create(split_arg_spec, augmented_arg_spec)
def extend_for_cuda(spec: OptimItem) -> List[OptimItem]:
name = spec.name
default = spec.default
return [
OptimItem(TENSOR, f"{name}_dev", default),
OptimItem(TENSOR, f"{name}_uvm", default),
OptimItem(INT_TENSOR, f"{name}_placements", default),
OptimItem(LONG_TENSOR, f"{name}_offsets", default),
]

@staticmethod
# pyre-ignore[3]
def create(
arg_spec: List[Union[Tuple[int, str], Tuple[int, str, Union[float, int]]]]
):
DEFAULT_ARG_VAL = 0
# pyre-ignore[9]
augmented_arg_spec: List[Tuple[int, str, Union[float, int]]] = [
item if len(item) == 3 else (*item, DEFAULT_ARG_VAL) for item in arg_spec
def extend_for_any(spec: OptimItem) -> List[OptimItem]:
name = spec.name
default = spec.default
return [
OptimItem(TENSOR, f"{name}_host", default),
OptimItem(TENSOR, f"{name}_dev", default),
OptimItem(TENSOR, f"{name}_uvm", default),
OptimItem(INT_TENSOR, f"{name}_placements", default),
OptimItem(LONG_TENSOR, f"{name}_offsets", default),
]

@staticmethod
# pyre-ignore[3]
def create(arg_spec: List[OptimItem]):
return OptimizerArgsSet(
OptimizerArgsSet.create_for_cpu(augmented_arg_spec),
OptimizerArgsSet.create_for_cuda(augmented_arg_spec),
OptimizerArgsSet.create_for_any(augmented_arg_spec),
*(
OptimizerArgsSet.create_optim_args(arg_spec, ext_fn)
for ext_fn in (
OptimizerArgsSet.extend_for_cpu,
OptimizerArgsSet.extend_for_cuda,
OptimizerArgsSet.extend_for_any,
)
)
)
Loading

0 comments on commit fe7e06c

Please sign in to comment.