Skip to content

Commit

Permalink
Add cache conflict miss support (backend) (#2596)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #2596

Differential Revision: D55998215
  • Loading branch information
sryap authored and facebook-github-bot committed May 20, 2024
1 parent 37c283c commit 9435b51
Show file tree
Hide file tree
Showing 24 changed files with 1,173 additions and 477 deletions.
50 changes: 42 additions & 8 deletions fbgemm_gpu/FbgemmGpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ set(GWD_OPTIMIZERS
set(DEFUSED_OPTIMIZERS
rowwise_adagrad)

# Optimizers with the SSD support
set(SSD_OPTIMIZERS
rowwise_adagrad
sgd)

set(WEIGHT_OPTIONS
weighted
unweighted_nobag
Expand Down Expand Up @@ -143,6 +148,7 @@ set(gen_gpu_kernel_source_files
"gen_embedding_forward_split_unweighted_codegen_cuda.cu"
"gen_embedding_backward_dense_indice_weights_codegen_cuda.cu"
"gen_embedding_backward_split_indice_weights_codegen_cuda.cu"
"gen_embedding_backward_ssd_indice_weights_codegen_cuda.cu"
"gen_embedding_forward_split_weighted_vbe_codegen_cuda.cu"
"gen_embedding_forward_split_unweighted_vbe_codegen_cuda.cu"
"gen_batch_index_select_dim0_forward_codegen_cuda.cu"
Expand All @@ -153,10 +159,13 @@ set(gen_gpu_kernel_source_files
"gen_batch_index_select_dim0_backward_kernel_warp.cu"
"gen_embedding_backward_split_grad_embedding_ops.cu"
"gen_embedding_backward_split_grad_index_select.cu"
"gen_embedding_backward_common_split_device_kernel.cuh"
"gen_embedding_backward_batch_index_select_split_device_kernel.cuh"
"gen_embedding_backward_split_common_device_kernel.cuh"
"gen_embedding_backward_split_batch_index_select_device_kernel.cuh"
"gen_embedding_forward_split_weighted_gwd_codegen_cuda.cu"
"gen_embedding_forward_split_unweighted_gwd_codegen_cuda.cu"
"gen_embedding_forward_ssd_weighted_codegen_cuda.cu"
"gen_embedding_forward_ssd_unweighted_codegen_cuda.cu"
"gen_embedding_forward_ssd_unweighted_nobag_kernel_small.cu"
)

if(NOT USE_ROCM)
Expand All @@ -179,7 +188,8 @@ foreach(wdesc ${WEIGHT_OPTIONS})
"gen_embedding_backward_dense_split_${wdesc}_kernel_cta.cu"
"gen_embedding_backward_dense_split_${wdesc}_kernel_warp.cu"
"gen_embedding_forward_split_${wdesc}_kernel.cu"
"gen_embedding_backward_${wdesc}_split_device_kernel.cuh")
"gen_embedding_forward_ssd_${wdesc}_kernel.cu"
"gen_embedding_backward_split_${wdesc}_device_kernel.cuh")

foreach(etype fp32 fp16 fp8 int8 int4 int2)
list(APPEND gen_gpu_kernel_source_files
Expand All @@ -191,7 +201,7 @@ endforeach()
foreach(wdesc weighted unweighted)
list(APPEND gen_gpu_kernel_source_files
"gen_embedding_forward_split_${wdesc}_vbe_kernel.cu"
"gen_embedding_backward_${wdesc}_vbe_split_device_kernel.cuh")
"gen_embedding_backward_split_${wdesc}_vbe_device_kernel.cuh")
endforeach()

# Generate GWD files
Expand All @@ -207,22 +217,31 @@ set(gen_cpu_source_files

set(gen_python_source_files
${CMAKE_BINARY_DIR}/__init__.py
${CMAKE_BINARY_DIR}/lookup_args.py)
${CMAKE_BINARY_DIR}/lookup_args.py
${CMAKE_BINARY_DIR}/lookup_args_ssd.py
)

# For each of the optimizers, generate the backward split variant by adding
# the Python, CPU-only, GPU host, and GPU kernel source files

# Generate the Python functions only if there is the backend support
# Generate the Python functions only if there is the backend support (for all
# optimizers)
foreach(optimizer
${COMMON_OPTIMIZERS}
${CPU_ONLY_OPTIMIZERS}
${GPU_ONLY_OPTIMIZERS})
list(APPEND gen_python_source_files
"${CMAKE_BINARY_DIR}/lookup_${optimizer}.py")
list(APPEND gen_python_source_files
"${CMAKE_BINARY_DIR}/lookup_${optimizer}.py"
"${CMAKE_BINARY_DIR}/lookup_${optimizer}_pt2.py")
endforeach()

# Generate the Python functions only if there is the backend support (for SSD
# optimizers)
foreach(optimizer ${SSD_OPTIMIZERS})
list(APPEND gen_python_source_files
"${CMAKE_BINARY_DIR}/lookup_${optimizer}_ssd.py")
endforeach()

# Generate the backend API for all optimizers to preserve the backward
# compatibility
list(APPEND gen_cpu_source_files
Expand Down Expand Up @@ -285,6 +304,21 @@ foreach(optimizer ${DEFUSED_OPTIMIZERS})
"${CMAKE_BINARY_DIR}/split_embedding_optimizer_${optimizer}.py")
endforeach()

foreach(optimizer ${SSD_OPTIMIZERS})
list(APPEND gen_gpu_kernel_source_files
"gen_embedding_optimizer_${optimizer}_ssd_device_kernel.cuh"
"gen_embedding_backward_ssd_${optimizer}.cpp"
)

foreach(wdesc weighted unweighted unweighted_nobag)
list(APPEND gen_gpu_kernel_source_files
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_cuda.cu"
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_kernel_cta.cu"
"gen_embedding_backward_${optimizer}_ssd_${wdesc}_kernel_warp.cu")
endforeach()

endforeach()

list(APPEND gen_defused_optim_py_files
${CMAKE_BINARY_DIR}/optimizer_args.py)

Expand Down
150 changes: 104 additions & 46 deletions fbgemm_gpu/codegen/genscript/generate_backward_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# pyre-strict
# flake8: noqa F401

import itertools
import sys

try:
Expand Down Expand Up @@ -39,28 +40,44 @@ def render_backward_templates(
) -> None:
if not kwargs.get("has_gpu_support"):
return

weighted_options = [True, False]
nobag_options = [True, False] if (not is_gwd) else [False]
vbe_options = (
[True, False] if (kwargs.get("has_vbe_support") and not is_gwd) else [False]
)
ssd_options = [True, False] if kwargs.get("has_ssd_support") else [False]
template = CodeTemplate.load(template_filepath)

for weighted in [True, False]:
for nobag in [True, False] if (not is_gwd) else [False]:
for vbe in vbe_options:
if (not nobag or (not weighted and not vbe)) and (
not kwargs.get("dense") or not vbe
):
wdesc = f"{ 'weighted' if weighted else 'unweighted' }{ '_nobag' if nobag else '' }{ '_vbe' if vbe else '' }"
template.write(
filename_format.format(optimizer, wdesc),
weighted=weighted,
nobag=nobag,
vbe=vbe,
is_index_select=False,
kdesc=wdesc,
**kwargs,
is_gwd=is_gwd,
)
for weighted, nobag, vbe, ssd in itertools.product(
weighted_options, nobag_options, vbe_options, ssd_options
):
if nobag and (weighted or vbe):
continue
if kwargs.get("dense") and (vbe or ssd):
continue
if ssd and (vbe or is_gwd):
continue

kdesc = "".join(
[
f"{ 'weighted' if weighted else 'unweighted' }",
f"{ '_nobag' if nobag else '' }",
f"{ '_vbe' if vbe else '' }",
]
)
desc = "_".join([f"{ 'ssd' if ssd else 'split' }", kdesc])
template.write(
filename_format.format(optimizer, desc),
weighted=weighted,
nobag=nobag,
vbe=vbe,
is_index_select=False,
kdesc=kdesc,
is_gwd=is_gwd,
ssd=ssd,
**kwargs,
)

@staticmethod
def generate_backward_split_gpu(**kwargs: Any) -> None:
Expand All @@ -73,19 +90,19 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
for template_filepath, filename_format in [
(
"training/backward/embedding_backward_split_template.cu",
"gen_embedding_backward_{}_split_{}_cuda.cu",
"gen_embedding_backward_{}_{}_cuda.cu",
),
(
"training/backward/embedding_backward_split_meta_template.cpp",
"gen_embedding_backward_{}_split_{}_meta.cpp",
"gen_embedding_backward_{}_{}_meta.cpp",
),
(
"training/backward/embedding_backward_split_kernel_cta_template.cu",
"gen_embedding_backward_{}_split_{}_kernel_cta.cu",
"gen_embedding_backward_{}_{}_kernel_cta.cu",
),
(
"training/backward/embedding_backward_split_kernel_warp_template.cu",
"gen_embedding_backward_{}_split_{}_kernel_warp.cu",
"gen_embedding_backward_{}_{}_kernel_warp.cu",
),
]:
BackwardSplitGenerator.render_backward_templates(
Expand All @@ -94,20 +111,21 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
filename_format,
kwargs,
)

# Generate the global weight decay CUDA kernels
if kwargs.get("has_global_weight_decay_support"):
for template_filepath, filename_format in [
(
"training/backward/embedding_backward_split_kernel_cta_template.cu",
"gen_embedding_backward_{}_split_{}_gwd_kernel_cta.cu",
"gen_embedding_backward_{}_{}_gwd_kernel_cta.cu",
),
(
"training/backward/embedding_backward_split_kernel_warp_template.cu",
"gen_embedding_backward_{}_split_{}_gwd_kernel_warp.cu",
"gen_embedding_backward_{}_{}_gwd_kernel_warp.cu",
),
(
"training/backward/embedding_backward_split_template.cu",
"gen_embedding_backward_{}_split_{}_gwd_cuda.cu",
"gen_embedding_backward_{}_{}_gwd_cuda.cu",
),
]:
BackwardSplitGenerator.render_backward_templates(
Expand All @@ -118,23 +136,38 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
is_gwd=True,
)

# Generate optimizer kernel
CodeTemplate.load(
"training/optimizer/embedding_optimizer_split_device_kernel_template.cuh"
).write(
f"gen_embedding_optimizer_{optimizer}_split_device_kernel.cuh", **kwargs
)
for ssd in (
[True, False]
if kwargs.get("has_ssd_support") and not kwargs.get("dense")
else [False]
):
desc = f"{ 'ssd' if ssd else 'split' }"
# Generate optimizer kernel
CodeTemplate.load(
"training/optimizer/embedding_optimizer_split_device_kernel_template.cuh"
).write(
f"gen_embedding_optimizer_{optimizer}_{desc}_device_kernel.cuh",
ssd=ssd,
**kwargs,
)

# Generate the backward splits (non-dense)
# We generate only the API to preserve the backward compatibility if
# has_gpu_support=True
if not kwargs.get("dense"):
# Generate CUDA autograd, PT2 unified autograd, and PT2 backward wrapper
# Generate CUDA autograd
template_filepath = (
"training/backward/embedding_backward_split_host_template.cpp"
)
for ssd in [True, False]:
desc = "ssd" if ssd else "split"
filename = f"gen_embedding_backward_{desc}_{optimizer}.cpp"
CodeTemplate.load(template_filepath).write(
filename, is_forward=False, ssd=ssd, **kwargs
)

# Generate PT2 unified autograd, and PT2 backward wrapper
for template_filepath, filename in [
(
"training/backward/embedding_backward_split_host_template.cpp",
f"gen_embedding_backward_split_{optimizer}.cpp",
),
(
"training/pt2/embedding_split_host_pt2_autograd_template.cpp",
f"gen_embedding_split_{optimizer}_pt2_autograd.cpp",
Expand All @@ -153,11 +186,15 @@ def generate_backward_split_gpu(**kwargs: Any) -> None:
template = CodeTemplate.load(
"training/python/split_embedding_codegen_lookup_invoker.template"
)
for filename in [
f"lookup_{optimizer}.py",
f"lookup_{optimizer}_pt2.py",
]:
template.write(filename, is_fbcode=args.is_fbcode, **kwargs)
for ssd in [True, False]:
sdesc = "_ssd" if ssd else ""
for filename in [
f"lookup_{optimizer}{sdesc}.py",
f"lookup_{optimizer}{sdesc}_pt2.py",
]:
template.write(
filename, is_fbcode=args.is_fbcode, ssd=ssd, **kwargs
)

@staticmethod
def generate_backward_split_cpu(**kwargs: Any) -> None:
Expand Down Expand Up @@ -213,18 +250,19 @@ def generate_backward_device() -> None:
BackwardSplitGenerator.render_backward_templates(
template_filepath,
"",
"{}gen_embedding_backward_{}_split_device_kernel.cuh",
"{}gen_embedding_backward_{}_device_kernel.cuh",
{
"has_gpu_support": True,
"has_vbe_support": True,
"has_ssd_support": True,
"dense": False,
"gen_once": False,
},
)

# Generate common backward device kernels (generate only once)
CodeTemplate.load(template_filepath).write(
"gen_embedding_backward_common_split_device_kernel.cuh",
"gen_embedding_backward_split_common_device_kernel.cuh",
gen_once=True,
)

Expand All @@ -242,16 +280,27 @@ def generate_backward_indices() -> None:
template = CodeTemplate.load(
"training/backward/embedding_backward_split_indice_weights_template.cu"
)
for dense in [True, False]:
dense_options = [True, False]
ssd_options = [True, False]
for dense, ssd in itertools.product(dense_options, ssd_options):
if dense and ssd:
continue
desc = "dense" if dense else ("ssd" if ssd else "split")
template.write(
f"gen_embedding_backward_{'dense' if dense else 'split'}_indice_weights_codegen_cuda.cu",
f"gen_embedding_backward_{ desc }_indice_weights_codegen_cuda.cu",
dense=dense,
ssd=ssd,
)

@staticmethod
def generate_python_sources() -> None:
CodeTemplate.load("training/python/__init__.template").write("__init__.py")
CodeTemplate.copy_to_root("training/python/lookup_args.py")

template = CodeTemplate.load("training/python/lookup_args.template")
for ssd in [True, False]:
sdesc = "_ssd" if ssd else ""
filename = f"lookup_args{sdesc}.py"
template.write(filename, ssd=ssd)

@staticmethod
def generate() -> None:
Expand All @@ -276,8 +325,17 @@ def generate() -> None:
none_optimizer(),
]

ssd_tensors = [
"row_addrs",
"inserted_rows",
"post_bwd_evicted_indices",
"actions_count",
]

for optimizer in optimizers:
BackwardSplitGenerator.generate_backward_split(**optimizer)
BackwardSplitGenerator.generate_backward_split(
ssd_tensors=ssd_tensors, **optimizer
)

# Generate common device kernels for backwards
BackwardSplitGenerator.generate_backward_device()
Expand Down
Loading

0 comments on commit 9435b51

Please sign in to comment.