Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix backward_clip num inputs and type of clip params #15688

Merged
merged 6 commits into from
Aug 9, 2019
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/operator/tensor/matrix_op-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -1457,7 +1457,7 @@ struct ClipParam : public dmlc::Parameter<ClipParam> {
struct clip {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* datas,
DType a_min, DType a_max) {
const float a_min, const float a_max) {
DType data = datas[i];
if (data > a_max) {
out[i] = a_max;
Expand All @@ -1473,7 +1473,7 @@ struct clip {
struct clip_grad {
template<typename DType>
MSHADOW_XINLINE static void Map(index_t i, DType* out, const DType* grad, const DType* datas,
DType a_min, DType a_max) {
const float a_min, const float a_max) {
DType data = datas[i];
if (data > a_max) {
out[i] = 0;
Expand All @@ -1500,7 +1500,7 @@ void Clip(const nnvm::NodeAttrs& attrs,
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
mxnet_op::Kernel<mxnet::op::clip, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
DType(param.a_min), DType(param.a_max));
param.a_min, param.a_max);
});
}

Expand Down Expand Up @@ -1529,7 +1529,7 @@ void ClipGrad_(const nnvm::NodeAttrs& attrs,
Stream<xpu> *s = ctx.get_stream<xpu>();
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
Kernel<clip_grad, xpu>::Launch(s, outputs[0].Size(), outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), DType(param.a_min), DType(param.a_max));
inputs[0].dptr<DType>(), inputs[1].dptr<DType>(), param.a_min, param.a_max);
});
}

Expand Down
2 changes: 1 addition & 1 deletion src/operator/tensor/matrix_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,7 @@ parameter values:
.add_arguments(ClipParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_clip)
.set_num_inputs(1)
.set_num_inputs(2)
DickJC123 marked this conversation as resolved.
Show resolved Hide resolved
.set_num_outputs(1)
.set_attr_parser(ParamParser<ClipParam>)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
Expand Down
16 changes: 13 additions & 3 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4174,15 +4174,25 @@ def test_special_functions_using_scipy():


@with_seed()
@unittest.skip("Flaky test, tracked at https://github.com/apache/incubator-mxnet/issues/12901")
def test_clip():
data = mx.symbol.Variable('data')
shape = (30, 30)
data_tmp = np.random.uniform(-1, 1, shape)
data_tmp = np.random.uniform(-1, 1, shape).astype('float32')
test = mx.sym.clip(data, a_max=0.6, a_min=-0.6)
check_symbolic_forward(test, [data_tmp], [np.clip(data_tmp, -0.6, 0.6)])
check_symbolic_backward(test, [data_tmp], [np.ones(shape)],
[np.where(data_tmp < 0.6, [1], [0]) * np.where(data_tmp > -0.6, [1], [0])])
[np.where(data_tmp <= 0.6, [1], [0]) * np.where(data_tmp >= -0.6, [1], [0])])

# Test monitor on symbol using clip

def simple_callback(name, arr):
pass

exe = test.simple_bind(ctx=mx.current_context(), data=shape)
exe.set_monitor_callback(simple_callback, monitor_all=True)
exe.forward(is_train=True)
exe.backward(out_grads=mx.nd.ones(shape))
mx.nd.waitall()


@with_seed()
Expand Down