From e1d54a3788ac3aeac693b50d809269d241d7fab9 Mon Sep 17 00:00:00 2001 From: Wang Zhou Date: Sun, 7 Jan 2024 18:55:33 -0800 Subject: [PATCH] Fix test for WeightDecayMode.NONE Summary: Fix a bug in test for `WeightDecayMode.NONE`: `weights_ref` is not defined Differential Revision: D52590024 --- fbgemm_gpu/test/split_table_batched_embeddings_test.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fbgemm_gpu/test/split_table_batched_embeddings_test.py b/fbgemm_gpu/test/split_table_batched_embeddings_test.py index 7418958b8f..52a3b555c6 100644 --- a/fbgemm_gpu/test/split_table_batched_embeddings_test.py +++ b/fbgemm_gpu/test/split_table_batched_embeddings_test.py @@ -3690,10 +3690,6 @@ def execute_backward_optimizers_( # noqa C901 weights_ref = bs[t].weight.cpu() - lr * ( dense_cpu_grad / denom + weight_decay * bs[t].weight.cpu() ) - elif weight_decay_mode == WeightDecayMode.L2: - # pyre-fixme[58]: `/` is not supported for operand types `float` - # and `Tensor`. - weights_ref = bs[t].weight.cpu() - lr * dense_cpu_grad / denom elif weight_decay_mode == WeightDecayMode.COUNTER: max_counter = cc.max_counter.item() weights_ref = self._get_wts_from_counter_adagrad_using_counter( @@ -3720,6 +3716,10 @@ def execute_backward_optimizers_( # noqa C901 lr, weight_decay, ) + else: # WeightDecayMode.L2 or WeightDecayMode.NONE + # pyre-fixme[58]: `/` is not supported for operand types `float` + # and `Tensor`. + weights_ref = bs[t].weight.cpu() - lr * dense_cpu_grad / denom else: # pyre-fixme[58]: `/` is not supported for operand types `float` # and `Tensor`.