Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#12 from youth123/lilong/moe
Browse files Browse the repository at this point in the history
upload assign pos op
  • Loading branch information
lilong12 authored Sep 26, 2021
2 parents 10e5304 + 98ebe2a commit 0c459db
Show file tree
Hide file tree
Showing 6 changed files with 362 additions and 0 deletions.
78 changes: 78 additions & 0 deletions paddle/fluid/operators/collective/assign_pos_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/collective/assign_pos_op.h"

namespace paddle {
namespace operators {

class AssignPosOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("cum_count"), "Input", "cum_count", "AssignPos");
OP_INOUT_CHECK(ctx->HasInput("eff_gates_len"), "Input", "eff_gates_len", "AssignPos");
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "AssignPos");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "AssignPos");
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto data_type = OperatorWithKernel::IndicateVarDataType(ctx, "cum_count");
return framework::OpKernelType(data_type, ctx.device_context());
}
};

class AssignPosOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"The tensor which indicates the tokens belong to which topk experts.");
AddInput("cum_count",
"The cumulative sum tokens of experts.");
AddInput("eff_gates_len",
"The effective numbers of tokens should be sent.");
AddOutput("Out", "Assemble tokens in the order of experts.");

AddComment(R"DOC(
assign_pos_op Operator.
Assign pos decides which tokens should be fetched belong to
specially expert orderingly.
)DOC");
}
};



} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;

REGISTER_OP_WITHOUT_GRADIENT(assign_pos, ops::AssignPosOp,
ops::AssignPosOpMaker);


// REGISTER_OPERATOR(assign_pos, ops::AssignPosOp, ops::AssignPosOpMaker)

REGISTER_OP_CPU_KERNEL(assign_pos,
ops::AssignPosOpCPUKernel<int>,
ops::AssignPosOpCPUKernel<int64_t>);

90 changes: 90 additions & 0 deletions paddle/fluid/operators/collective/assign_pos_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/collective/assign_pos_op.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/float16.h"

namespace paddle {
namespace operators {

static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;

static inline int NumBlocks(const int N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}

template <typename T>
__global__ void AssignPos(T* cum_count, const int* gate, int64_t* out, int64_t limit) {
CUDA_KERNEL_LOOP(i, limit) {

int gate_idx = gate[i];
if (gate_idx > -1){
int p = platform::CudaAtomicAdd(cum_count + gate_idx, -1);
out[p - 1] = i;
}
}
}

template <typename T>
class AssignPosCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
// assign pos decides which tokens should be fetched belong to specially expert orderingly.
auto cum_count = context.Input<LoDTensor>("cum_count"); // (num_expert * world_size) int32 | int64
auto gate = context.Input<LoDTensor>("X"); // (batch_size * seq_len, topk) int32
auto eff_gates_len = context.Input<LoDTensor>("eff_gates_len"); // (sum(cum_count))
auto out = context.Output<LoDTensor>("Out"); // (cum_count) value ranges from 0 to batch_size * seq_len * topk
auto place = context.GetPlace();
auto numel = gate->numel();
T* cum_data = const_cast<T*> (cum_count->data<T>());
auto cum_size = cum_count->numel();

framework::Tensor cpu_eff_gates_len;
int64_t cpu_eff_gates_len_data = 0;
if (platform::is_cpu_place(eff_gates_len->place())) {
cpu_eff_gates_len_data = eff_gates_len->data<int64_t>()[0];
} else {
framework::TensorCopySync(*eff_gates_len, platform::CPUPlace(),
&cpu_eff_gates_len);
cpu_eff_gates_len_data = cpu_eff_gates_len.data<int64_t>()[0];
}
const auto &dev_ctx =
context.template device_context<platform::CUDADeviceContext>();

framework::DDim out_dims = framework::make_ddim({cpu_eff_gates_len_data});
auto out_data = out->mutable_data<int64_t>(out_dims, place);

const int* gate_data = gate->data<int>();

int blocks = NumBlocks(numel);
int threads = kNumCUDAThreads;
AssignPos<T><<<blocks, threads, 0, dev_ctx.stream()>>>(
cum_data, gate_data, out_data, numel);

}
};


} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(assign_pos, ops::AssignPosCUDAKernel<int>,
ops::AssignPosCUDAKernel<int64_t>);

35 changes: 35 additions & 0 deletions paddle/fluid/operators/collective/assign_pos_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;

template <typename T>
class AssignPosOpCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unavailable(
"Do not support assign pos op for cpu kernel now."));
}
};

} // namespace operators
} // namespace paddle
59 changes: 59 additions & 0 deletions python/paddle/distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,68 @@
'pull_worker_log',
'global_scatter',
'global_gather',
'assign_pos',
]


def assign_pos(x,
cum_count):
"""
Assign pos decides which tokens should be fetched belong to
specially expert orderingly.
Args:
x (Tensor): Tensor. Every element in the list must be a Tensor whose data type
should be float16, float32, float64, int32 or int64.
cum_count (Tensor): The cumulative sum tokens of experts. Every element in the list must be a Tensor whose
data type should be int64.
Returns:
out (Tensor): Assemble tokens in the order of experts.
Examples:
.. code-block:: python
# required: distributed
import paddle
local_expert_count = [2, 0, 2, 0]
gate_idx = [
[0, 2],
[0, 2]
]
local_expert_count = paddle.to_tensor(local_expert_count)
gate_idx = paddle.to_tensor(gate_idx, dtype="int32")
lec_cum = paddle.cumsum(local_expert_count)
pos = paddle.distributed.utils.assign_pos(x=gate_idx, cum_count=lec_cum)
print(pos) # the result: (2, 0, 3, 1)
"""
if in_dygraph_mode():
return core.ops.assign_pos(x, cum_count, cum_count[-1])
else:
op_type = 'assign_pos'
# check_variable_and_dtype(
# x, 'x', ['float16', 'float32', 'float64', 'int32', 'int64'],
# 'global_scatter')
# check_variable_and_dtype(local_count, 'local_count', ['int64'],
# 'global_scatter')
# check_variable_and_dtype(global_count, 'global_count', ['int64'],
# 'global_scatter')

helper = LayerHelper(op_type, **locals())
out = helper.create_variable_for_type_inference(dtype=cum_count.dtype)

helper.append_op(
type=op_type,
inputs={
'X': [x],
'cum_count': [cum_count],
"eff_gates_len": [cum_count[-1]]
},
outputs={'Out': [out]})
return out



def global_scatter(x,
local_count,
global_count,
Expand Down
100 changes: 100 additions & 0 deletions python/paddle/fluid/tests/unittests/test_assign_pos_op.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import print_function
import unittest

import numpy as np
from scipy.special import expit, erf

from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci
import paddle
import paddle.nn as nn
import paddle.nn.functional as F
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid import compiler, Program, program_guard

class TestAssignPosAPI(unittest.TestCase):
def init(self):
self.dtype = 'int64'
self.shape = [10, 10] # (batch_size * seq_len, d_model)
self.topK = 2
self.num_expert = 2
self.world_size = 2
self.tot_expert = self.num_expert * self.world_size

def setUp(self):
self.init()
self.gate_idx = np.random.randint(low=0, high=self.tot_expert-1, \
size=(self.shape[0], self.topK))
local_expert_count = np.zeros(self.tot_expert).astype(self.dtype)
self.gate = self.gate_idx.flatten()
nums = len(self.gate)
for i in range(nums):
local_expert_count[self.gate[i]] += 1
self.lec_cum = np.zeros(len(local_expert_count), dtype=np.int64)
self.lec_cum[0] = local_expert_count[0]
for i in range(1, len(local_expert_count)):
self.lec_cum[i] = local_expert_count[i] + self.lec_cum[i-1]
self.lec_cum_np = self.lec_cum.copy()
self.pos_np = np.zeros((self.lec_cum[-1], ))
for i in range(0, len(self.gate)):
idx = self.gate[i]
p = self.lec_cum_np[idx]
self.lec_cum_np[idx] -= 1
self.pos_np[p-1] = i
self.place = [paddle.CUDAPlace(0)]

# def test_static_api(self):
# paddle.enable_static()

# def run(place):
# with paddle.static.program_guard(paddle.static.Program()):
# X = paddle.fluid.data('X', self.shape, dtype=self.dtype)
# out = paddle.expm1(X)
# exe = paddle.static.Executor(place)
# res = exe.run(feed={'X': self.x})
# for r in res:
# self.assertEqual(np.allclose(self.out_ref, r), True)

# for place in self.place:
# run(place)

def test_dygraph_api(self):
def run(place):
paddle.disable_static(place)
gate_idx = paddle.to_tensor(self.gate_idx, dtype="int32")
lec_cum = paddle.to_tensor(self.lec_cum, dtype="int64")
pos = paddle.distributed.utils.assign_pos(x=gate_idx, cum_count=lec_cum)
self.assertEqual(np.allclose(self.pos_np, pos), True)
paddle.enable_static()

# print("gate_idx: ", self.gate_idx)
# print("lec_cum: ", self.lec_cum)
# print("pos np: ", self.pos_np)
# print("pos my: ", pos)

for place in self.place:
run(place)

# def test_errors(self):
# paddle.enable_static()
# with paddle.static.program_guard(paddle.static.Program()):
# X = paddle.fluid.data('X', self.shape, dtype='int32')
# self.assertRaises(TypeError, paddle.expm1, X)
# # The input dtype must be float16, float32, float64.

if __name__ == "__main__":
unittest.main()
Binary file not shown.

0 comments on commit 0c459db

Please sign in to comment.