-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix concat_op #9337
Fix concat_op #9337
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -20,19 +20,35 @@ | |
class TestConcatOp(OpTest): | ||
def setUp(self): | ||
self.op_type = "concat" | ||
x0 = np.random.random((2, 1, 4, 5)).astype('float32') | ||
x1 = np.random.random((2, 2, 4, 5)).astype('float32') | ||
x2 = np.random.random((2, 3, 4, 5)).astype('float32') | ||
axis = 1 | ||
self.inputs = {'X': [('x0', x0), ('x1', x1), ('x2', x2)]} | ||
self.attrs = {'axis': axis} | ||
self.outputs = {'Out': np.concatenate((x0, x1, x2), axis=axis)} | ||
self.init_test_data() | ||
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]} | ||
self.attrs = {'axis': self.axis} | ||
self.outputs = { | ||
'Out': np.concatenate( | ||
(self.x0, self.x1, self.x2), axis=self.axis) | ||
} | ||
|
||
def test_check_output(self): | ||
self.check_output() | ||
|
||
def test_check_grad(self): | ||
self.check_grad(['x0'], 'Out') | ||
self.check_grad(['x1'], 'Out') | ||
self.check_grad(['x2'], 'Out') | ||
|
||
def init_test_data(self): | ||
self.x0 = np.random.random((2, 1, 4, 5)).astype('float32') | ||
self.x1 = np.random.random((2, 2, 4, 5)).astype('float32') | ||
self.x2 = np.random.random((2, 3, 4, 5)).astype('float32') | ||
self.axis = 1 | ||
|
||
|
||
class TestConcatOp2(OpTest): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Seems that this case is somehow duplicated with the above one, how about change to test axis==0? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, there are two CUDA kernels for |
||
def init_test_data(self): | ||
self.x0 = np.random.random((2, 3, 4, 5)).astype('float32') | ||
self.x1 = np.random.random((2, 3, 4, 5)).astype('float32') | ||
self.x2 = np.random.random((2, 3, 4, 5)).astype('float32') | ||
self.axis = 1 | ||
|
||
|
||
if __name__ == '__main__': | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please double check that whether
static_cast
is more suitable.