Skip to content

Commit

Permalink
Fix Transpose Convolutions
Browse files Browse the repository at this point in the history
  • Loading branch information
raugarcr committed Jul 15, 2024
1 parent cc92a6b commit fa99e51
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
5 changes: 3 additions & 2 deletions src/descriptors/descriptor_convT2D.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ void ConvolDescriptorT2D::build(Tensor *A) {

I = A;

nk = A->shape[1]; //ksize[0];
//nk = A->shape[1];
nk = ksize[0];
kr = ksize[1];
kc = ksize[2];
kz = A->shape[1]/groups;
Expand Down Expand Up @@ -213,7 +214,7 @@ void ConvolDescriptorT2D::build(Tensor *A) {
in,iz,ir,ic);
cudnnCreateFilterDescriptor(&wDesc);
//CONVT we need to swap input channels with output so all other swappings (forward and backward functions) matches
cudnnSetFilter4dDescriptor(wDesc, data_type, tensor_format, nk, kz, kr, kc);
cudnnSetFilter4dDescriptor(wDesc, data_type, tensor_format, kz, nz, kr, kc);

cudnnCreateTensorDescriptor(&yDesc);
cudnnSetTensor4dDescriptor(yDesc, tensor_format, data_type, in, z,r,c);
Expand Down
3 changes: 2 additions & 1 deletion src/serialization/onnx/net/layers/conv/convT_onnx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ Layer* build_convT_layer(onnx::NodeProto *node,
string weights_name = node->input(1); // Get weights and dims
vector<float> *weights = &(map_init_values[weights_name]);
vector<int> dims = map_init_dims[weights_name];
filters = dims[0];
//filters = dims[0];
filters = dims[1];

// Deduce conv dimension from layer input
if (parent_shape.size() == 3)
Expand Down

0 comments on commit fa99e51

Please sign in to comment.