Skip to content

Commit

Permalink
Add output_dtype in v0 version of training TBE op (#18)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: pytorch/torchrec#18

Addressing s251694

Reviewed By: xw285cornell

Differential Revision: D32399931

fbshipit-source-id: 7a3fc8b12ce2093173b5f21945850efbd5254737
  • Loading branch information
jianyuh authored and facebook-github-bot committed Nov 16, 2021
1 parent f3d3b95 commit 7e1183c
Showing 1 changed file with 3 additions and 56 deletions.
59 changes: 3 additions & 56 deletions fbgemm_gpu/codegen/embedding_backward_split_host_template.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -333,59 +333,8 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function(
bool gradient_clipping,
double max_gradient,
bool stochastic_rounding,
{{ args.split_function_args | join(", ") }}) {
return SplitLookupFunction_{{ optimizer }}_Op::apply(
{% if not dense %}
placeholder_autograd_tensor,
0 /* hardcode output dtype to float*/,
{% endif %}
dev_weights,
uvm_weights,
lxu_cache_weights,
weights_placements,
weights_offsets,
D_offsets,
total_D,
max_D,
hash_size_cumsum,
total_hash_size_bits,
indices,
offsets,
pooling_mode,
indice_weights,
feature_requires_grad,
lxu_cache_locations,
gradient_clipping,
max_gradient,
stochastic_rounding,
{{ args.split_function_arg_names | join(", ") }})[0];
}

Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_v2(
{% if not dense %}
Tensor placeholder_autograd_tensor,
int64_t output_dtype,
{% endif %}
Tensor dev_weights,
Tensor uvm_weights,
Tensor lxu_cache_weights,
Tensor weights_placements,
Tensor weights_offsets,
Tensor D_offsets,
int64_t total_D,
int64_t max_D,
Tensor hash_size_cumsum,
int64_t total_hash_size_bits,
Tensor indices,
Tensor offsets,
int64_t pooling_mode,
c10::optional<Tensor> indice_weights,
c10::optional<Tensor> feature_requires_grad,
Tensor lxu_cache_locations,
bool gradient_clipping,
double max_gradient,
bool stochastic_rounding,
{{ args.split_function_args | join(", ") }}) {
{{ args.split_function_args | join(", ") }},
int64_t output_dtype) {
return SplitLookupFunction_{{ optimizer }}_Op::apply(
{% if not dense %}
placeholder_autograd_tensor,
Expand Down Expand Up @@ -414,8 +363,6 @@ Tensor split_embedding_codegen_lookup_{{ optimizer }}_function_v2(
}

TORCH_LIBRARY_FRAGMENT(fb, m) {
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function({% if not dense %} Tensor placeholder_autograd_tensor, {% endif %}Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}) -> Tensor");
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function({% if not dense %} Tensor placeholder_autograd_tensor, {% endif %}Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}, int output_dtype=0) -> Tensor");
m.impl("split_embedding_codegen_lookup_{{ optimizer }}_function", torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function)));
m.def("split_embedding_codegen_lookup_{{ optimizer }}_function_v2({% if not dense %} Tensor placeholder_autograd_tensor, int output_dtype, {% endif %}Tensor dev_weights, Tensor uvm_weights, Tensor lxu_cache_weights, Tensor weights_placements, Tensor weights_offsets, Tensor D_offsets, int total_D, int max_D, Tensor hash_size_cumsum, int total_hash_size_bits, Tensor indices, Tensor offsets, int pooling_mode, Tensor? indice_weights, Tensor? feature_requires_grad, Tensor lxu_cache_locations, bool gradient_clipping, float max_gradient, bool stochastic_rounding, {{ args.split_function_schemas | join(", ") }}) -> Tensor");
m.impl("split_embedding_codegen_lookup_{{ optimizer }}_function_v2", torch::dispatch(c10::DispatchKey::CUDA, TORCH_FN(split_embedding_codegen_lookup_{{ optimizer }}_function_v2)));
}

0 comments on commit 7e1183c

Please sign in to comment.