diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index de4d700cdf80ee..7d3cfd5d8dd43c 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -191,6 +191,19 @@ backward : as_strided_grad no_need_buffer : input +- op : asgd_ + args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor d, Tensor y, Tensor n, Tensor master_param, bool multi_precision=false) + output : Tensor(param_out), Tensor(d_out), Tensor(y_out), Tensor(master_param_out) + infer_meta : + func : ASGDInferMeta + kernel : + func : asgd + data_type : param + data_transform : + support_trans_dtype : learning_rate, n + optional : master_param, master_param_out + inplace : (param -> param_out), (d -> d_out), (y -> y_out), (master_param -> master_param_out) + - op : asin args : (Tensor x) output : Tensor(out) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index 5b9708b38a17e1..882de5cd512eaa 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -481,6 +481,48 @@ void AddNTensorArrayInferMeta(const std::vector& x, } } +void ASGDInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& learning_rate, + const MetaTensor& d, + const MetaTensor& y, + const MetaTensor& n, + const MetaTensor& master_param, + bool multi_precision, + MetaTensor* param_out, + MetaTensor* d_out, + MetaTensor* y_out, + MetaTensor* master_param_out) { + PADDLE_ENFORCE_NOT_NULL( + param_out, + phi::errors::InvalidArgument( + "Output(ParamOut) of ASGDOp should not be null.")); + + PADDLE_ENFORCE_NOT_NULL(d_out, + phi::errors::InvalidArgument( + "Output(DOut) of ASGDOp should not be null.")); + + PADDLE_ENFORCE_NOT_NULL(y_out, + phi::errors::InvalidArgument( + "Output(YOut) of ASGDOp should not be null.")); + + param_out->set_dims(param.dims()); + param_out->set_dtype(param.dtype()); + d_out->set_dims(d.dims()); + d_out->set_dtype(d.dtype()); + y_out->set_dims(y.dims()); + y_out->set_dtype(y.dtype()); + if (multi_precision) { + master_param_out->set_dims(master_param.dims()); + if (DataType::FLOAT16 == master_param.dtype() || + DataType::BFLOAT16 == master_param.dtype()) { + master_param_out->set_dtype(DataType::FLOAT32); + } else { + master_param_out->set_dtype(master_param.dtype()); + } + } +} + void AucInferMeta(const MetaTensor& input, const MetaTensor& label, const MetaTensor& stat_pos, diff --git a/paddle/phi/infermeta/multiary.h b/paddle/phi/infermeta/multiary.h index f51c3dacb19095..fd1a1bc31cccb8 100644 --- a/paddle/phi/infermeta/multiary.h +++ b/paddle/phi/infermeta/multiary.h @@ -138,6 +138,19 @@ void AddNTensorArrayInferMeta(const std::vector& x, MetaTensor* out, MetaConfig config); +void ASGDInferMeta(const MetaTensor& param, + const MetaTensor& grad, + const MetaTensor& learning_rate, + const MetaTensor& d, + const MetaTensor& y, + const MetaTensor& n, + const MetaTensor& master_param, + bool multi_precision, + MetaTensor* param_out, + MetaTensor* d_out, + MetaTensor* y_out, + MetaTensor* master_param_out); + void AucInferMeta(const MetaTensor& input, const MetaTensor& label, const MetaTensor& stat_pos, diff --git a/paddle/phi/kernels/asgd_kernel.h b/paddle/phi/kernels/asgd_kernel.h new file mode 100644 index 00000000000000..9f5f9761d3b6dc --- /dev/null +++ b/paddle/phi/kernels/asgd_kernel.h @@ -0,0 +1,37 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/core/selected_rows.h" + +namespace phi { + +template +void ASGDKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& learning_rate, + const DenseTensor& d, + const DenseTensor& y, + const DenseTensor& n, + const paddle::optional& master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* d_out, + DenseTensor* y_out, + DenseTensor* master_param_out); + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/asgd_kernel.cc b/paddle/phi/kernels/cpu/asgd_kernel.cc new file mode 100644 index 00000000000000..610806133846e2 --- /dev/null +++ b/paddle/phi/kernels/cpu/asgd_kernel.cc @@ -0,0 +1,73 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/asgd_kernel.h" + +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" +#include "paddle/phi/kernels/funcs/jit/kernels.h" + +namespace phi { + +template +void ASGDKernelCPUImpl(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& learning_rate, + const DenseTensor& d, + const DenseTensor& y, + const DenseTensor& n, + DenseTensor* param_out, + DenseTensor* d_out, + DenseTensor* y_out) { + auto param_eigen = EigenVector::Flatten(param); + auto grad_eigen = EigenVector::Flatten(grad); + auto d_eigen = EigenVector::Flatten(d); + auto y_eigen = EigenVector::Flatten(y); + auto param_out_eigen = EigenVector::Flatten(*param_out); + auto d_out_eigen = EigenVector::Flatten(*d_out); + auto y_out_eigen = EigenVector::Flatten(*y_out); + T learning_rate_T = learning_rate.data()[0]; + T n_T = n.data()[0]; + + d_out_eigen = d_eigen - y_eigen + grad_eigen; + y_out_eigen = grad_eigen; + param_out_eigen = param_eigen - (learning_rate_T / n_T) * d_out_eigen; +} + +template +void ASGDKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& learning_rate, + const DenseTensor& d, + const DenseTensor& y, + const DenseTensor& n, + const paddle::optional& master_param UNUSED, + bool multi_precision UNUSED, + DenseTensor* param_out, + DenseTensor* d_out, + DenseTensor* y_out, + DenseTensor* master_param_out UNUSED) { + dev_ctx.template Alloc(param_out); + dev_ctx.template Alloc(d_out); + dev_ctx.template Alloc(y_out); + ASGDKernelCPUImpl( + dev_ctx, param, grad, learning_rate, d, y, n, param_out, d_out, y_out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(asgd, CPU, ALL_LAYOUT, phi::ASGDKernel, float, double) {} diff --git a/paddle/phi/kernels/gpu/asgd_kernel.cu b/paddle/phi/kernels/gpu/asgd_kernel.cu new file mode 100644 index 00000000000000..11418ec0e2c0bf --- /dev/null +++ b/paddle/phi/kernels/gpu/asgd_kernel.cu @@ -0,0 +1,106 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/asgd_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/backends/gpu/gpu_helper.h" +#include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/common/amp_type_traits.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/mixed_vector.h" + +namespace phi { + +template +__global__ void ASGDKernelGPUImpl(const T* param, + const T* grad, + const T* learning_rate, + const T* d, + const T* y, + const T* n, + const MT* master_param, + int num, + T* param_out, + T* d_out, + T* y_out, + MT* master_param_out) { + MT learning_rate_MT = static_cast(learning_rate[0]); + MT n_MT = static_cast(n[0]); + CUDA_KERNEL_LOOP(i, num) { + MT param_data = master_param ? master_param[i] : static_cast(param[i]); + MT grad_data = static_cast(grad[i]); + MT d_data = static_cast(d[i]); + MT y_data = static_cast(y[i]); + d_data = d_data - y_data + grad_data; + y_data = grad_data; + param_data = param_data - (learning_rate_MT / n_MT) * d_data; + param_out[i] = static_cast(param_data); + d_out[i] = static_cast(d_data); + y_out[i] = static_cast(y_data); + if (master_param_out) { + master_param_out[i] = param_data; + } + } +} + +template +void ASGDKernel(const Context& dev_ctx, + const DenseTensor& param, + const DenseTensor& grad, + const DenseTensor& learning_rate, + const DenseTensor& d, + const DenseTensor& y, + const DenseTensor& n, + const paddle::optional& master_param, + bool multi_precision, + DenseTensor* param_out, + DenseTensor* d_out, + DenseTensor* y_out, + DenseTensor* master_param_out) { + using MPDType = typename phi::dtype::MPTypeTrait::Type; + const MPDType* master_in_data = + multi_precision ? master_param->data() : nullptr; + MPDType* master_out_data = + multi_precision ? dev_ctx.template Alloc(master_param_out) + : nullptr; + + int block = 512; + int grid = (param.numel() + block - 1) / block; + + ASGDKernelGPUImpl<<>>( + param.data(), + grad.data(), + learning_rate.data(), + d.data(), + y.data(), + n.data(), + master_in_data, + param.numel(), + dev_ctx.template Alloc(param_out), + dev_ctx.template Alloc(d_out), + dev_ctx.template Alloc(y_out), + master_out_data); +} + +} // namespace phi + +PD_REGISTER_KERNEL(asgd, + GPU, + ALL_LAYOUT, + phi::ASGDKernel, + phi::dtype::float16, + phi::dtype::bfloat16, + float, + double) {} diff --git a/python/paddle/optimizer/__init__.py b/python/paddle/optimizer/__init__.py index 516779cd924f6a..744aed81a545c3 100644 --- a/python/paddle/optimizer/__init__.py +++ b/python/paddle/optimizer/__init__.py @@ -18,6 +18,7 @@ from .adam import Adam from .adamax import Adamax from .adamw import AdamW +from .asgd import ASGD from .lamb import Lamb from .lbfgs import LBFGS from .momentum import Momentum @@ -32,6 +33,7 @@ 'Adam', 'AdamW', 'Adamax', + 'ASGD', 'RMSProp', 'Adadelta', 'SGD', diff --git a/python/paddle/optimizer/asgd.py b/python/paddle/optimizer/asgd.py new file mode 100644 index 00000000000000..36e08a6b12057f --- /dev/null +++ b/python/paddle/optimizer/asgd.py @@ -0,0 +1,359 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import paddle +from paddle import _C_ops +from paddle.tensor.creation import to_tensor + +from ..base import framework +from ..base.dygraph import no_grad +from ..base.framework import in_dygraph_mode, in_pir_mode +from .optimizer import Optimizer + +__all__ = [] + + +class ASGD(Optimizer): + r""" + Optimizer of the ASGD algorithm.Please refer to this for details: + `Minimizing Finite Sums with the Stochastic Average Gradient `_. + + .. math:: + + \begin{aligned} + &\hspace{0mm} d=0,\ y_i=0\ \textbf{for}\ i=1,2,...,n \\ + &\hspace{0mm} \textbf{for}\ \: m=0,1,...\ \textbf{do} \: \\ + &\hspace{5mm} i=m\ \%\ n \\ + &\hspace{5mm} d=d-y_i+f_i{}'(x) \\ + &\hspace{5mm} y_i=f_i{}'(x) \\ + &\hspace{5mm} x=x-learning\_rate(\frac{d}{\mathrm{min}(m+1,\ n)}+\lambda x) \\ + &\hspace{0mm} \textbf{end for} \\ + \end{aligned} + + Parameters: + learning_rate (float|Tensor|LearningRateDecay, optional): The learning rate used to update ``Parameter``. + It can be a float value, a ``Tensor`` with a float type or a LearningRateDecay. The default value is 0.001. + batch_num (int, optional): The number of batches needed to complete one epoch. + Assuming the total number of samples is ``all``, + it is recommended to set ``batch_num`` to ``all`` / ``batch_size``. + In situations where the graphics memory is tight, + it is possible to reduce the batch_num appropriately. + The default value is 1. + parameters (list|tuple, optional): List/Tuple of ``Tensor`` to update to minimize ``loss``. + This parameter is required in dygraph mode. + The default value is None in static graph mode, at this time all parameters will be updated. + weight_decay (float|WeightDecayRegularizer, optional): The strategy of regularization. + It can be a float value as coeff of L2 regularization or :ref:`api_paddle_regularizer_L1Decay`, :ref:`api_paddle_regularizer_L2Decay`. + If a parameter has set regularizer using :ref:`api_paddle_ParamAttr` already, + the regularization setting here in optimizer will be ignored for this parameter. + Otherwise, the regularization setting here in optimizer will take effect. + Default None, meaning there is no regularization. + grad_clip (GradientClipBase, optional): Gradient clipping strategy, it's an instance of some derived class of ``GradientClipBase`` . + There are three clipping strategies ( :ref:`api_paddle_nn_ClipGradByGlobalNorm` , :ref:`api_paddle_nn_ClipGradByNorm` , :ref:`api_paddle_nn_ClipGradByValue` ). + Default None, meaning there is no gradient clipping. + multi_precision (bool, optional): In mixed precision training scenarios based on GPU, + this parameter is mainly used to ensure the numerical stability of gradient updates. + When it is set to True, the optimizer will save a backup of FP32 type parameters with an equal value for FP16 type parameters. + When updating gradients, first increase the gradient type to FP32, and then assign it to the FP32 type parameter backup. + Finally, the updated FP32 type value will be converted to FP16 type first, + and then assigned to the actual FP16 type parameters participating in the calculation. + The default value is False. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + + Examples: + .. code-block:: python + + >>> import paddle + + >>> inp = paddle.uniform(min=-0.1, max=0.1, shape=[10, 10], dtype='float32') + >>> linear = paddle.nn.Linear(10, 10) + >>> inp = paddle.to_tensor(inp) + >>> out = linear(inp) + >>> loss = paddle.mean(out) + >>> asgd = paddle.optimizer.ASGD(learning_rate=0.001, batch_num=10, parameters=linear.parameters(), weight_decay=0.01) + >>> out.backward() + >>> asgd.step() + >>> asgd.clear_grad() + """ + _d_acc_str = "d" + _y_acc_str = "y" + _m_acc_str = "m" + + def __init__( + self, + learning_rate=0.001, + batch_num=1, + parameters=None, + weight_decay=None, + grad_clip=None, + multi_precision=False, + name=None, + ): + if learning_rate is None: + raise ValueError("learning_rate should not be none") + if batch_num is None: + raise ValueError("batch_num should not be none") + if not 0 < batch_num: + raise ValueError("batch_num should be greater than 0") + super().__init__( + learning_rate=learning_rate, + parameters=parameters, + weight_decay=weight_decay, + grad_clip=grad_clip, + name=name, + ) + self.type = "asgd" + self._multi_precision = multi_precision + self._master_weights = {} + self._n = batch_num + self._n_tensor = None + + def _create_accumulators(self, block, parameters): + assert isinstance(block, framework.Block) + if isinstance(parameters, dict): + parameters = self._update_param_group(parameters) + + for p in parameters: + if p.name in self._already_create_accumulater: + continue + p_new = p + if self._multi_precision and self._is_dtype_fp16_or_bf16(p.dtype): + master_p = self._create_master_weight(p) + p_new = master_p + if ( + self._is_dtype_fp16_or_bf16(p.dtype) + and not self._multi_precision + ): + warnings.warn( + "Accumulating with FP16/BF16 in optimizer can lead to poor accuracy or slow convergence." + "Consider using multi_precision=True option of the Adam optimizer." + ) + + self._add_accumulator( + self._d_acc_str, + p_new, + p.dtype, + 0, + ) + + # Sometimes p.shape is a tuple, so we need to change it to a list + self._add_accumulator( + self._y_acc_str, + p_new, + p.dtype, + 0, + [self._n] + list(p.shape), + ) + + self._add_accumulator( + self._m_acc_str, + p_new, + "int64", + 0, + [1], + ) + + self._already_create_accumulater.add(p.name) + + def _assign_accumulator_master( + self, block, name, param, assign_value, index + ): + if self._name is not None: + name = self._name + "_" + name + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param.dtype + ) + target_param = ( + self._master_weights[param.name] if find_master else param + ) + target_name = target_param.name + if ( + name not in self._accumulators + or target_name not in self._accumulators[name] + ): + raise Exception( + f"Accumulator {name} does not exist for parameter {target_name}" + ) + + if in_pir_mode(): + if index is None: + self._accumulators[name][target_name] = paddle.assign( + assign_value + ) + else: + self._accumulators[name][target_name][index] = paddle.assign( + assign_value + ) + else: + assert isinstance(block, framework.Block) + + assign_inputs = { + "X": assign_value, + } + + assign_outputs = { + "Out": self._accumulators[name][target_name], + } + + block.append_op( + type="assign", + inputs=assign_inputs, + outputs=assign_outputs, + ) + + @no_grad + def _append_optimize_op(self, block, param_and_grad): + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) + + if self._n_tensor is None: + self._n_tensor = to_tensor( + [self._n], + ) + + d = self._get_accumulator_master(self._d_acc_str, param_and_grad[0]) + + m = self._get_accumulator_master(self._m_acc_str, param_and_grad[0]) + + ys = self._get_accumulator_master(self._y_acc_str, param_and_grad[0]) + index = paddle.mod(m, self._n_tensor).item() + y = paddle.assign(ys[index]) + + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16( + param_and_grad[0].dtype + ) + master_weight = ( + self._master_weights[param_and_grad[0].name] + if find_master + else None + ) + + lr = self._create_param_lr(param_and_grad) + + if in_dygraph_mode(): + m.add_(to_tensor([1], dtype=m.dtype)) + + _C_ops.asgd_( + param_and_grad[0], + param_and_grad[1], + lr, + d, + ys[index], + paddle.fmin(m, self._n_tensor), + master_weight, + find_master, + ) + + return None + elif in_pir_mode(): + m = paddle.assign(paddle.add(m, to_tensor([1], dtype=m.dtype))) + self._assign_accumulator_master( + block, self._m_acc_str, param_and_grad[0], m, None + ) + + # The y in the static graph has one more dimension than the y in the dynamic graph. + # So we should unify the shape of y in both dynamic and static graph. + # eg: + # dynamic graph: y.shape is [2, 2] + # static graph: y.shape is [1, 2, 2] + # so we should do + # static graph: y = y[0] + y = y[0] + + _C_ops.asgd_( + param_and_grad[0], + param_and_grad[1], + lr, + d, + y, + paddle.fmin(m, self._n_tensor), + master_weight, + find_master, + ) + + self._assign_accumulator_master( + block, self._y_acc_str, param_and_grad[0], y, index + ) + + return None + else: + assert isinstance(block, framework.Block) + # create the optimize op + add_inputs = { + "X": m, + "Y": to_tensor([1], dtype=m.dtype), + } + + add_outputs = { + "Out": m, + } + + block.append_op( + type="elementwise_add", + inputs=add_inputs, + outputs=add_outputs, + ) + + # The y in the static graph has one more dimension than the y in the dynamic graph. + # So we should unify the shape of y in both dynamic and static graph. + # eg: + # dynamic graph: y.shape is [2, 2] + # static graph: y.shape is [1, 2, 2] + # so we should do + # static graph: y = y[0] + y = y[0] + + asgd_inputs = { + "param": param_and_grad[0], + "grad": param_and_grad[1], + "learning_rate": lr, + "d": d, + "y": y, + "n": paddle.fmin(m, self._n_tensor), + } + + asgd_outputs = { + "param_out": param_and_grad[0], + "d_out": d, + "y_out": y, + } + + asgd_attrs = {"multi_precision": find_master} + + if find_master: + asgd_inputs["master_param"] = master_weight + asgd_outputs["master_param_out"] = master_weight + + asgd_op = block.append_op( + type=self.type, + inputs=asgd_inputs, + outputs=asgd_outputs, + attrs=asgd_attrs, + stop_gradient=True, + ) + + ys = paddle.static.setitem(ys, index, y) + + self._assign_accumulator_master( + block, self._y_acc_str, param_and_grad[0], ys, None + ) + + return asgd_op + + def _update_param_group(self, parameters): + parameters = parameters.get('params') + return parameters diff --git a/test/legacy_test/test_asgd_op.py b/test/legacy_test/test_asgd_op.py new file mode 100644 index 00000000000000..b3518f30974e2a --- /dev/null +++ b/test/legacy_test/test_asgd_op.py @@ -0,0 +1,492 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from op_test import ( + OpTest, + convert_float_to_uint16, +) +from utils import dygraph_guard + +import paddle +from paddle.base import core + +paddle.enable_static() + + +def asgd_wrapper( + param, + grad, + learning_rate, + d, + y, + n, + master_param=None, + multi_precision=False, +): + paddle._C_ops.asgd_( + param, + grad, + learning_rate, + d, + y, + n, + None, + False, + ) + + +class TestASGDOpMixin: + def setUp(self): + self.init_basic_info() + self.init_input() + self.update_input_dtype() + self.init_output() + self.update_output_dtype() + + self.inputs = { + "param": self.params, + "grad": self.grads, + "learning_rate": self.learning_rate, + "d": self.ds, + "y": self.ys, + "n": self.n, + } + + self.outputs = { + "param_out": self.params_out, + "d_out": self.ds_out, + "y_out": self.ys_out, + } + + def init_basic_info(self): + self.op_type = "asgd" + self.python_api = asgd_wrapper + self.python_out_sig = ['Out'] + self.h = 102 + self.w = 105 + + def init_input(self): + self.params = np.random.random((self.h, self.w)) + self.learning_rate = np.array([0.001]) + self.n = np.array([1000]) + self.grads = np.random.random((self.h, self.w)) + self.ds = np.random.random((self.h, self.w)) + self.ys = np.random.random((self.h, self.w)) + + def init_output(self): + self.ds_out = self.ds - self.ys + self.grads + self.ys_out = self.grads.copy() + self.params_out = ( + self.params - (self.learning_rate / self.n) * self.ds_out + ) + + def update_input_dtype(self): + pass + + def update_output_dtype(self): + pass + + def test_check_output(self): + self.check_output(check_pir=True) + + +class TestASGDOp(TestASGDOpMixin, OpTest): + pass + + +class TestCase1(TestASGDOp): + def update_input_dtype(self): + self.params = self.params.astype("float32") + self.learning_rate = self.learning_rate.astype("float32") + self.n = self.n.astype("float32") + self.grads = self.grads.astype("float32") + self.ds = self.ds.astype("float32") + self.ys = self.ys.astype("float32") + + +class TestCase2(TestASGDOp): + def update_input_dtype(self): + self.params = self.params.astype("float16") + self.learning_rate = self.learning_rate.astype("float16") + self.n = self.n.astype("float16") + self.grads = self.grads.astype("float16") + self.ds = self.ds.astype("float16") + self.ys = self.ys.astype("float16") + + def test_check_output(self): + if core.is_compiled_with_cuda(): + self.check_output_with_place(core.CUDAPlace(0), check_pir=True) + + +class TestCase3(TestASGDOp): + def update_input_dtype(self): + self.params = convert_float_to_uint16(self.params) + self.learning_rate = convert_float_to_uint16(self.learning_rate) + self.n = convert_float_to_uint16(self.n) + self.grads = convert_float_to_uint16(self.grads) + self.ds = convert_float_to_uint16(self.ds) + self.ys = convert_float_to_uint16(self.ys) + + def update_output_dtype(self): + self.ds_out = convert_float_to_uint16(self.ds_out) + self.ys_out = convert_float_to_uint16(self.ys_out) + self.params_out = convert_float_to_uint16(self.params_out) + + def test_check_output(self): + if core.is_compiled_with_cuda(): + self.check_output_with_place(core.CUDAPlace(0), check_pir=True) + + +class TestCase4(TestASGDOp): + def init_input(self): + self.params = np.random.random((self.h, self.w)) + self.learning_rate = np.array([0.001]) + self.n = np.array([1]) + self.grads = np.random.random((self.h, self.w)) + self.ds = np.random.random((self.h, self.w)) + self.ys = np.random.random((self.h, self.w)) + + +class TestASGDV2(unittest.TestCase): + def test_asgd_dygraph(self): + paddle.disable_static() + value = np.arange(26).reshape(2, 13).astype("float32") + a = paddle.to_tensor(value) + linear = paddle.nn.Linear(13, 5) + + asgd = paddle.optimizer.ASGD( + learning_rate=0.001, + batch_num=2, + parameters=linear.parameters(), + ) + out = linear(a) + out.backward() + asgd.step() + asgd.clear_gradients() + + def test_raise_error(self): + self.assertRaises( + ValueError, + paddle.optimizer.ASGD, + batch_num=2, + learning_rate=None, + ) + self.assertRaises( + ValueError, + paddle.optimizer.ASGD, + batch_num=None, + ) + self.assertRaises( + ValueError, + paddle.optimizer.ASGD, + batch_num=-2, + ) + + def test_asgd_group_dygraph(self): + paddle.disable_static() + value = np.arange(26).reshape(2, 13).astype("float32") + a = paddle.to_tensor(value) + linear_1 = paddle.nn.Linear(13, 5) + linear_2 = paddle.nn.Linear(5, 3) + asgd = paddle.optimizer.ASGD( + learning_rate=0.001, + batch_num=2, + parameters=[ + {'params': linear_1.parameters()}, + { + 'params': linear_2.parameters(), + 'learning_rate': 0.0001, + }, + ], + ) + out = linear_1(a) + out = linear_2(out) + out.backward() + asgd.step() + asgd.clear_gradients() + + +class TestASGDMultiPrecision(unittest.TestCase): + def dygraph_asgd_mp(self, mp): + paddle.disable_static() + paddle.seed(10) + paddle.set_device('gpu') + input = paddle.randn((2, 2)) + model = paddle.nn.Linear(2, 2) + optimizer = paddle.optimizer.ASGD( + batch_num=2, parameters=model.parameters(), multi_precision=mp + ) + if mp: + model = paddle.amp.decorate(models=model, level='O2') + scaler = paddle.amp.GradScaler(init_loss_scaling=1024) + for idx in range(5): + if mp: + with paddle.amp.auto_cast(level='O2'): + output = model(input) + loss = paddle.mean(output) + scaled = scaler.scale(loss) + scaled.backward() + scaler.minimize(optimizer, scaled) + optimizer.clear_grad() + else: + output = model(input) + loss = paddle.mean(output) + optimizer.step() + optimizer.clear_grad() + + return output, model.parameters() + + def static_asgd_mp(self, mp): + paddle.enable_static() + paddle.seed(10) + np.random.seed(10) + exe = paddle.static.Executor('gpu') + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + optimizer = paddle.optimizer.ASGD(batch_num=2, multi_precision=mp) + + if mp: + optimizer = paddle.static.amp.decorate( + optimizer, + init_loss_scaling=128.0, + use_dynamic_loss_scaling=True, + use_pure_fp16=True, + use_fp16_guard=False, + ) + with paddle.static.program_guard(train_program, startup_program): + if mp: + data = paddle.static.data( + shape=[2, 2], name='X', dtype='float16' + ) + else: + data = paddle.static.data( + shape=[2, 2], name='X', dtype='float32' + ) + hidden = paddle.static.nn.fc(x=data, size=10) + loss = paddle.mean(hidden) + optimizer.minimize(loss) + exe.run(startup_program) + + if mp: + optimizer.amp_init( + place=paddle.CUDAPlace(0), scope=paddle.static.global_scope() + ) + x = np.random.random(size=(2, 2)).astype('float16') + else: + x = np.random.random(size=(2, 2)).astype('float32') + out = [] + for idx in range(5): + (loss_data,) = exe.run( + train_program, feed={"X": x}, fetch_list=[loss.name] + ) + out.append(loss_data) + return out + + def test_main(self): + if not paddle.is_compiled_with_cuda(): + return + "Test dygraph mode" + output1_dy, params1_dy = self.dygraph_asgd_mp(mp=True) + output2_dy, params2_dy = self.dygraph_asgd_mp(mp=False) + np.testing.assert_allclose( + output1_dy.astype('float32').numpy(), + output2_dy.astype('float32').numpy(), + rtol=1e-05, + atol=0.1, + ) + for idx in range(len(params1_dy)): + np.testing.assert_allclose( + params1_dy[idx].astype('float32').numpy(), + params2_dy[idx].astype('float32').numpy(), + rtol=1e-05, + atol=0.1, + ) + "Test static graph mode" + output1_st = self.static_asgd_mp(mp=True) + output2_st = self.static_asgd_mp(mp=False) + for idx in range(len(output1_st)): + np.testing.assert_allclose( + output1_st[idx].astype('float32'), + output2_st[idx].astype('float32'), + rtol=1e-05, + atol=0.1, + ) + + +class TestASGDSimple(unittest.TestCase): + def setUp(self) -> None: + self.data = np.random.random(size=(2, 2)).astype('float32') + + def run_static(self): + with paddle.pir_utils.IrGuard(): + paddle.seed(10) + np.random.seed(10) + + exe = paddle.static.Executor('gpu') + train_program = paddle.static.Program() + startup_program = paddle.static.Program() + + with paddle.static.program_guard(train_program, startup_program): + input = paddle.static.data( + shape=[2, 2], name='input', dtype='float32' + ) + model = paddle.nn.Linear(2, 2) + output = model(input) + loss = paddle.mean(output) + + optimizer = paddle.optimizer.ASGD( + batch_num=3, + ) + + optimizer.minimize(loss) + + exe.run(startup_program) + out = [] + for _ in range(10): + (loss_data,) = exe.run( + train_program, feed={"input": self.data}, fetch_list=[loss] + ) + out.append(loss_data) + return out + + def run_dygraph(self): + with dygraph_guard(): + paddle.seed(10) + np.random.seed(10) + + out = [] + model = paddle.nn.Linear(2, 2) + optimizer = paddle.optimizer.ASGD( + batch_num=3, parameters=model.parameters() + ) + for _ in range(10): + output = model(paddle.to_tensor(self.data)) + loss = paddle.mean(output) + out.append(loss.numpy()) + loss.backward() + optimizer.step() + optimizer.clear_grad() + return out + + def test_main(self): + if not paddle.is_compiled_with_cuda(): + return + out1 = self.run_dygraph() + out2 = self.run_static() + np.testing.assert_allclose(out1, out2) + + +class TestASGDValidation: + def setUp(self) -> None: + self.init_all_size() + self.init_batch_size() + self.init_batch_num() + self.data = np.random.random(size=(self.all_size, 2)).astype('float32') + + def init_all_size(self): + self.all_size = 64 + + def init_batch_size(self): + self.batch_size = 8 + + def init_batch_num(self): + self.batch_num = (int)(self.all_size / self.batch_size) + + def run_validation(self) -> None: + with dygraph_guard(): + paddle.seed(10) + np.random.seed(10) + + param_validation = {} + grad_validation = {} + lr_validation = {} + d_validation = {} + ys_validation = {} + y_validation = {} + n_validation = {} + + model = paddle.nn.Linear(2, 2) + optimizer = paddle.optimizer.ASGD( + batch_num=self.batch_num, parameters=model.parameters() + ) + + for param in model.parameters(): + d_validation[param.name] = np.zeros(param.shape) + ys_validation[param.name] = np.zeros( + [self.batch_num] + param.shape + ) + + for i in range(5): + data_start = i * self.batch_size % self.all_size + data_end = data_start + self.batch_size + cur_data = self.data[data_start:data_end] + output = model(paddle.to_tensor(cur_data)) + loss = paddle.mean(output) + loss = output + loss.backward() + + for param in model.parameters(): + param_validation[param.name] = param.numpy() + + optimizer.step() + + for param in model.parameters(): + grad_validation[param.name] = param.grad.numpy() + lr_validation[param.name] = optimizer.get_lr() + y_validation[param.name] = ys_validation[param.name][ + i % self.batch_num + ] + d_validation[param.name] = ( + d_validation[param.name] + - y_validation[param.name] + + grad_validation[param.name] + ) + ys_validation[param.name][ + i % self.batch_num + ] = grad_validation[param.name] + n_validation[param.name] = min(i + 1, self.batch_num) + param_validation[param.name] = ( + param_validation[param.name] + - lr_validation[param.name] + * d_validation[param.name] + / n_validation[param.name] + ) + + np.testing.assert_allclose( + param.numpy(), + param_validation[param.name], + ) + + optimizer.clear_grad() + + def test_main(self): + if not paddle.is_compiled_with_cuda(): + return + self.run_validation() + + +class TestASGDValidationCase1(TestASGDValidation, unittest.TestCase): + pass + + +class TestASGDValidationCase2(TestASGDValidationCase1): + def init_batch_num(self): + self.batch_num = 2 + + +if __name__ == "__main__": + unittest.main()