Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhance reshape #9008

Merged
merged 15 commits into from
Apr 2, 2018
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 64 additions & 36 deletions paddle/fluid/operators/reshape_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,53 +25,82 @@ class ReshapeOp : public framework::OperatorWithKernel {
: OperatorWithKernel(type, inputs, outputs, attrs) {}

void InferShape(framework::InferShapeContext *ctx) const override {
// input check
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of ReshapeOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of ReshapeOp should not be null.");

auto shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(shape.size() > 0, "Attr(shape) shouldn't be empty.");
const std::vector<int> &shape = ctx->Attrs().Get<std::vector<int>>("shape");
PADDLE_ENFORCE(!shape.empty(),
"The shape information must be set by Attr(shape).");

std::vector<int64_t> output_shape;
auto x_dims = ctx->GetInputDim("X");
bool need_copy_dim = ValidateShape(shape, x_dims, output_shape);

if (need_copy_dim) {
// Some dimensions can only be determined during runtime. Here temporarily
// set output tensor's shape the same as that of the input tensor.
ctx->SetOutputDim("Out", x_dims);
} else {
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));

// FIXME(caoying): When shape of the output tensor is determined during
// runtime, LoD information of X will not passed to the output.
if (shape[0] == x_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
}

private:
bool ValidateShape(const std::vector<int> &shape,
const framework::DDim &input_dim,
std::vector<int64_t> &output_shape) const {
// only one dimension canbe set to -1, whose size will be automatically
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

canbe -> can be

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

// infered.
const int64_t unknown_index = -1;
const auto in_size = framework::product(input_dim);
const auto x_rank = input_dim.size();

bool need_dim_copy = false;
std::vector<size_t> neg_dims_idx;
// set some dimension to -1 if it is unknown
const int unknown_size = -1;
for (size_t i = 0; i < shape.size(); ++i) {
PADDLE_ENFORCE(shape[i] > 0 || shape[i] == unknown_size,
"Each dimension of Attr(shape) must be positive or %d.",
unknown_size);
if (shape[i] == unknown_size) {
PADDLE_ENFORCE(shape[i] >= 0 || shape[i] == unknown_index,
"Each input dimension of Attr(shape) must be positive, or "
"only one input dimension can be -1.");
if (shape[i] == unknown_index) {
neg_dims_idx.push_back(i);
PADDLE_ENFORCE(neg_dims_idx.size() <= 1,
"Only one dimension of Attr(shape) can be unknown.");
} else if (shape[i] == 0) {
PADDLE_ENFORCE_LT(
i, x_rank,
"Only dimension less than rank of Input(X) can be set to 0.");
need_dim_copy = true;
}
}
PADDLE_ENFORCE_LE(
neg_dims_idx.size(), 1,
"Only one input dimension of Attr(shape) may be unknown.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

may -> can

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


int64_t capacity =
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
int64_t in_size = framework::product(x_dims);
if (neg_dims_idx.size() == 1) {
// dim infer
shape[neg_dims_idx[0]] = in_size / (-capacity);
// recalculate capacity
capacity = shape[neg_dims_idx[0]] * (-capacity);
}
// capacity check
PADDLE_ENFORCE(capacity == in_size,
"The size of Input(X) mismatches with Attr(shape).");
// resize output
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(),
output_shape.resize(shape.size(), 0);
std::transform(shape.begin(), shape.end(), output_shape.begin(),
[](int a) { return static_cast<int64_t>(a); });
auto out_dims = framework::make_ddim(shape_int64);
ctx->SetOutputDim("Out", out_dims);
if (shape[0] == x_dims[0]) {
// Only pass LoD when the first dimension is equal between
// output and input.
ctx->ShareLoD("X", /*->*/ "Out");

// some dimension can only be determinted during runtime.
if (need_dim_copy) return need_dim_copy;

int64_t inferred_dim = 0;
if (neg_dims_idx.size()) {
int64_t capacity = std::accumulate(shape.begin(), shape.end(), 1,
std::multiplies<int>());
inferred_dim = in_size / (-capacity);
PADDLE_ENFORCE_EQ(inferred_dim * (-capacity), in_size,
"Invalid shape is given.");
output_shape[neg_dims_idx[0]] = inferred_dim;
}
return false;
}
};

Expand All @@ -81,9 +110,8 @@ class ReshapeOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "The input tensor of reshape operator.");
AddOutput("Out", "The output tensor of reshape operator.");
AddAttr<std::vector<int>>("shape",
"(vector<int>) "
"Target shape of reshape operator.");
AddAttr<std::vector<int>>(
"shape", "(std::vector<int>) Target shape of reshape operator.");
AddAttr<bool>("inplace",
"Change the source tensor's shape without copy memory.")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

copy -> copying

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

.SetDefault(true);
Expand All @@ -99,7 +127,7 @@ and target shape = [1, 4], the reshape operator will transform
the tensor X into a 2-D tensor: [[1, 2, 3, 4]]

One dimension in the target shape can be set -1, representing that its
size is unknown. In this case, the real dimension will be infered from
size is unknown. In this case, the real dimension will be infered from
the original shape of Input(X) and other dimensions in the target shape.
)DOC");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to explain when the dimension can be set to 0 (necessary) and in-place reshape (optional) in the doc.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

}
Expand Down
33 changes: 32 additions & 1 deletion paddle/fluid/operators/reshape_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ class ReshapeKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const {
auto* out = ctx.Output<framework::Tensor>("Out");
auto* in = ctx.Input<framework::Tensor>("X");

auto out_dims =
ValidateShape(ctx.Attr<std::vector<int>>("shape"), in->dims());
bool inplace = ctx.Attr<bool>("inplace");
auto out_dims = out->dims();
if (!inplace) {
out->mutable_data<T>(ctx.GetPlace());
framework::TensorCopy(*in, ctx.GetPlace(), ctx.device_context(), out);
Expand All @@ -37,6 +39,34 @@ class ReshapeKernel : public framework::OpKernel<T> {
out->Resize(out_dims);
}
}

private:
framework::DDim ValidateShape(const std::vector<int> shape_attr,
const framework::DDim& in_dims) const {
const int64_t in_size = framework::product(in_dims);
// only one dimension canbe set to -1, whose size will be automatically
// infered.
const int64_t unknown_index = -1;

std::vector<int64_t> output_shape(shape_attr.size(), 0);
int64_t capacity = 1;
int neg_dim_idx = -1;
for (size_t i = 0; i < shape_attr.size(); ++i) {
if (shape_attr[i] == unknown_index) neg_dim_idx = i;
capacity *= (shape_attr[i] ? shape_attr[i] : in_dims[i]);
output_shape[i] =
(shape_attr[i] ? static_cast<int64_t>(shape_attr[i]) : in_dims[i]);
}

if (neg_dim_idx != -1) {
output_shape[neg_dim_idx] = -in_size / capacity;
PADDLE_ENFORCE_EQ(output_shape[neg_dim_idx] * capacity, -in_size,
"Invalid shape is given.");
} else {
PADDLE_ENFORCE_EQ(capacity, in_size, "Invalid shape is given.");
}
return framework::make_ddim(output_shape);
}
};

template <typename DeviceContext, typename T>
Expand All @@ -45,6 +75,7 @@ class ReshapeGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const {
auto* d_out = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* d_x = ctx.Output<framework::Tensor>(framework::GradVarName("X"));

d_x->mutable_data<T>(ctx.GetPlace());
bool inplace = ctx.Attr<bool>("inplace");

Expand Down
17 changes: 8 additions & 9 deletions python/paddle/fluid/layers/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from layer_function_generator import autodoc
from ..layer_helper import LayerHelper
import tensor
import ops
import nn
import math

Expand Down Expand Up @@ -58,7 +57,7 @@ def detection_output(loc,

This operation is to get the detection results by performing following
two steps:

1. Decode input bounding box predictions according to the prior boxes.
2. Get the final detection results by applying multi-class non maximum
suppression (NMS).
Expand Down Expand Up @@ -458,7 +457,7 @@ def ssd_loss(location,
num, num_prior, num_class = confidence.shape

def __reshape_to_2d(var):
return ops.reshape(x=var, shape=[-1, var.shape[-1]])
return nn.reshape(x=var, shape=[-1, var.shape[-1]])

# 1. Find matched boundding box by prior box.
# 1.1 Compute IOU similarity between ground-truth boxes and prior boxes.
Expand All @@ -469,7 +468,7 @@ def __reshape_to_2d(var):

# 2. Compute confidence for mining hard examples
# 2.1. Get the target label based on matched indices
gt_label = ops.reshape(x=gt_label, shape=gt_label.shape + (1, ))
gt_label = nn.reshape(x=gt_label, shape=gt_label.shape + (1, ))
target_label, _ = target_assign(
gt_label, matched_indices, mismatch_value=background_label)
# 2.2. Compute confidence loss.
Expand All @@ -480,7 +479,7 @@ def __reshape_to_2d(var):
conf_loss = nn.softmax_with_cross_entropy(confidence, target_label)

# 3. Mining hard examples
conf_loss = ops.reshape(x=conf_loss, shape=(num, num_prior))
conf_loss = nn.reshape(x=conf_loss, shape=(num, num_prior))
neg_indices = helper.create_tmp_variable(dtype='int32')
dtype = matched_indices.dtype
updated_matched_indices = helper.create_tmp_variable(dtype=dtype)
Expand Down Expand Up @@ -548,7 +547,7 @@ def __reshape_to_2d(var):
# 5.3 Compute overall weighted loss.
loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss
# reshape to [N, Np], N is the batch size and Np is the prior box number.
loss = ops.reshape(x=loss, shape=[-1, num_prior])
loss = nn.reshape(x=loss, shape=[-1, num_prior])
loss = nn.reduce_sum(loss, dim=1, keep_dim=True)
if normalize:
normalizer = nn.reduce_sum(target_loc_weight)
Expand Down Expand Up @@ -696,7 +695,7 @@ def _reshape_with_axis_(input, axis=1):
new_shape = [
-1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)])
]
out = ops.reshape(x=input, shape=new_shape)
out = nn.reshape(x=input, shape=new_shape)
return out

def _is_list_or_tuple_(data):
Expand Down Expand Up @@ -793,7 +792,7 @@ def _is_list_or_tuple_and_equal(data, length, err_info):
mbox_loc.shape[0],
mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3] / 4, 4
]
mbox_loc_flatten = ops.reshape(mbox_loc, shape=new_shape)
mbox_loc_flatten = nn.reshape(mbox_loc, shape=new_shape)
mbox_locs.append(mbox_loc_flatten)

# get conf_loc
Expand All @@ -809,7 +808,7 @@ def _is_list_or_tuple_and_equal(data, length, err_info):
conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] *
conf_loc.shape[3] / num_classes, num_classes
]
conf_loc_flatten = ops.reshape(conf_loc, shape=new_shape)
conf_loc_flatten = nn.reshape(conf_loc, shape=new_shape)
mbox_confs.append(conf_loc_flatten)

if len(box_results) == 1:
Expand Down
56 changes: 56 additions & 0 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@
'smooth_l1',
'one_hot',
'autoincreased_step_counter',
'reshape',
]


Expand Down Expand Up @@ -3184,6 +3185,8 @@ def one_hot(input, depth):
The one-hot tensor or LodTensor, same as input.

Examples:
.. code-block:: python

X is a LoDTensor:
X.lod = [[0, 1, 4]]
X.shape = [4, 1]
Expand Down Expand Up @@ -3236,3 +3239,56 @@ def autoincreased_step_counter(counter_name=None, begin=1, step=1):
counter.stop_gradient = True

return counter


def reshape(x, shape, act=None, inplace=True, name=None):
"""
Gives a new shape to Tensor without changing its data.
This layer takes a tensor as input and the attribute shape specifying the
new shape. The shape attribute must be specified. At most one dimension of
the new shape can be -1. In this case, the value is inferred from the size
of the tensor and the remaining dimensions. A dimension could also be 0,
in which case the actual dimension value is going to be copied from the
input tensor.

Args:
input(variable): The input tensor.
shape(list): The new shape. At most one dimension of the new shape can
be -1.
act (str): The non-linear activation to be applied to output variable.
inplace(bool): If this flag is set true, a new output tensor is created
whose data is copied from input x, otherwise the output
shares data with input without copying.

Returns(variable): The output tensor.

Examples:
.. code-block:: python

Given a 2-D tensor X with shape [2 x 2], and the new shape: [1, 4].
The reshape layer will change tensor X into a 2-D tensor with
shape [1 x 4] with its data unchanged.

Given a 3-D tensor x with shape [2, 3, 4] and the new shape: [3, -1].
The reshape layer will change tensor X into a 2-D tensor with shape:
[3 x 8] with its data unchanged.

Given a 3-D tensor x with shape [2, 3, 8] and the new shape:
[-1, 0, 2, 2]. The reshape layer will change tensor X into a 4-D tensor
with shape [4, 3, 2, 2] with its data unchanged.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also need to refine this doc

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.


"""

if not (isinstance(shape, list) or isinstance(shape, tuple)):
raise ValueError("Input shape must be a python lsit or tuple.")

helper = LayerHelper("reshape", **locals())
reshaped = helper.create_tmp_variable(dtype=x.dtype)
helper.append_op(
type="reshape",
inputs={"X": x},
attrs={"shape": shape,
"inplace": inplace},
outputs={"Out": reshaped})

return helper.append_activation(reshaped)
1 change: 0 additions & 1 deletion python/paddle/fluid/layers/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
__all__ = [
'mean',
'mul',
'reshape',
'scale',
'sigmoid_cross_entropy_with_logits',
'elementwise_add',
Expand Down
8 changes: 4 additions & 4 deletions python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,7 @@ def find_actual(target_name, fetch_list):
np.allclose(
actual_t, expect_t, atol=atol),
"Output (" + out_name + ") has diff at " + str(place) +
str(actual_t) + str(expect_t))
str(actual_t) + "\n" + str(expect_t))
if isinstance(expect, tuple):
self.assertListEqual(actual.lod(), expect[1],
"Output (" + out_name +
Expand Down Expand Up @@ -546,6 +546,6 @@ def _get_gradient(self, input_to_check, place, output_names, no_grad_set):

fetch_list = [g for p, g in param_grad_list]
executor = Executor(place)
return map(
np.array,
executor.run(prog, feed_dict, fetch_list, return_numpy=False))
return map(np.array,
executor.run(prog, feed_dict, fetch_list,
return_numpy=False))
Empty file.
Loading