Skip to content

Commit

Permalink
Avoid convert to compType from dstDataType before writting the output…
Browse files Browse the repository at this point in the history
… value
  • Loading branch information
qianfengz committed Sep 15, 2021
1 parent f098bfb commit ba91b99
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);

StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;

dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);

if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
Expand All @@ -200,11 +204,11 @@ struct GridwiseReduction_xy_to_x_blockwise
threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);

accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}

auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Expand All @@ -218,7 +222,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index(block_global_1d_id));

threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
}
};

Expand Down Expand Up @@ -345,6 +349,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);

StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;

dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);

if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
Expand All @@ -368,11 +376,11 @@ struct GridwiseReduction_xy_to_x_blockwise
make_tuple(I0),
priorDstValue_buf);

accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}

auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Expand Down Expand Up @@ -400,7 +408,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index(block_global_1d_id));

threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}
Expand Down Expand Up @@ -547,6 +555,10 @@ struct GridwiseReduction_xy_to_x_blockwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);

StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;

dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);

if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
Expand All @@ -570,11 +582,11 @@ struct GridwiseReduction_xy_to_x_blockwise
make_tuple(I0),
priorDstValue_buf);

accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}

auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Expand Down Expand Up @@ -602,7 +614,7 @@ struct GridwiseReduction_xy_to_x_blockwise
make_multi_index(block_global_1d_id));

threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);

StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;

dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);

if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
Expand All @@ -166,11 +170,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);

accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}

auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Expand All @@ -184,7 +188,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index(thread_global_1d_id));

threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
};

template <>
Expand Down Expand Up @@ -271,6 +275,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);

StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;

dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);

if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
Expand All @@ -290,11 +298,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load.Run(
dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);

accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}

auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Expand Down Expand Up @@ -322,7 +330,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index(thread_global_1d_id));

threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
};
Expand Down Expand Up @@ -430,6 +438,10 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);

StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;

dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);

if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load = ThreadwiseTensorSliceTransfer_v2<dstDataType,
Expand All @@ -449,11 +461,11 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
threadwise_dst_load.Run(
dst1dDesc, dst_global_val_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);

accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}

auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Expand Down Expand Up @@ -481,7 +493,7 @@ struct GridwiseReduction_xy_to_x_direct_threadwise
make_multi_index(thread_global_1d_id));

threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);

StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;

dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);

if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
Expand All @@ -176,11 +180,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
threadwise_dst_load.Run(
dst1dDesc, dst_global_buf, ReducedDataDesc, make_tuple(I0), priorDstValue_buf);

accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf(I0) * beta);
dstValue_buf(I0) += priorDstValue_buf(I0) * beta;
}

auto threadwise_dst_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Expand All @@ -194,7 +198,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index(warp_global_1d_id));

threadwise_dst_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_buf);
}
};

Expand Down Expand Up @@ -291,6 +295,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);

StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;

dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);

if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
Expand All @@ -314,11 +322,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_tuple(I0),
priorDstValue_buf);

accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}

auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Expand Down Expand Up @@ -346,7 +354,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index(warp_global_1d_id));

threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}
Expand Down Expand Up @@ -466,6 +474,10 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
if(!float_equal_one{}(alpha))
accuValue_buf(I0) *= type_convert<compType>{}(alpha);

StaticBuffer<AddressSpaceEnum_t::Vgpr, dstDataType, 1, true> dstValue_buf;

dstValue_buf(I0) = type_convert<dstDataType>{}(accuValue_buf[I0]);

if(!float_equal_zero{}(beta))
{
auto threadwise_dst_load =
Expand All @@ -489,11 +501,11 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_tuple(I0),
priorDstValue_buf);

accuValue_buf(I0) += type_convert<compType>{}(priorDstValue_buf[I0] * beta);
dstValue_buf(I0) += priorDstValue_buf[I0] * beta;
}

auto threadwise_dst_val_store =
ThreadwiseTensorSliceTransfer_v1r3<compType,
ThreadwiseTensorSliceTransfer_v1r3<dstDataType,
dstDataType,
decltype(ReducedDataDesc),
dst1dDescType,
Expand Down Expand Up @@ -521,7 +533,7 @@ struct GridwiseReduction_xy_to_x_direct_warpwise
make_multi_index(warp_global_1d_id));

threadwise_dst_val_store.Run(
ReducedDataDesc, make_tuple(I0), accuValue_buf, dst1dDesc, dst_global_val_buf);
ReducedDataDesc, make_tuple(I0), dstValue_buf, dst1dDesc, dst_global_val_buf);
threadwise_dst_idx_store.Run(
ReducedDataDesc, make_tuple(I0), accuIndex_buf, dst1dDesc, dst_global_idx_buf);
}
Expand Down

0 comments on commit ba91b99

Please sign in to comment.