Skip to content

Commit

Permalink
fix adding bias multiple times in layer output.
Browse files Browse the repository at this point in the history
  • Loading branch information
minhthuc2502 committed Mar 4, 2024
1 parent 05a2702 commit bcefb34
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
4 changes: 2 additions & 2 deletions include/ctranslate2/layers/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ namespace ctranslate2 {
Dense(const models::Model& model,
const std::string& scope,
const ops::ActivationType* activation_type = nullptr,
const bool affected_by_tp = false);
const bool is_layer_out = false);
DataType output_type() const override;
dim_t output_size() const override;
void operator()(const StorageView& input, StorageView& output) const;
Expand All @@ -148,7 +148,7 @@ namespace ctranslate2 {
const ops::Gemm _gemm_op;
const ops::Quantize _quantize_op;
const ops::Dequantize _dequantize_op;
const bool _affected_by_tp;
const bool _is_layer_out;
};

class LayerNorm : public Layer
Expand Down
10 changes: 7 additions & 3 deletions src/layers/common.cc
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ namespace ctranslate2 {
Dense::Dense(const models::Model& model,
const std::string& scope,
const ops::ActivationType* activation_type,
const bool affected_by_tp)
const bool is_layer_out)
: _packed_weight(false)
, _weight(get_linear_weight(model, scope, &_packed_weight))
, _bias(model.get_variable_if_exists(scope + "/bias"))
Expand Down Expand Up @@ -295,7 +295,7 @@ namespace ctranslate2 {
/*shift_to_uint8=*/bool(_u8_shift_compensation),
/*round_before_cast=*/model.round_before_cast_in_quantization())
, _dequantize_op(activation_type)
, _affected_by_tp(affected_by_tp)
, _is_layer_out(is_layer_out)
{
}

Expand Down Expand Up @@ -348,7 +348,7 @@ namespace ctranslate2 {
StorageView qoutput(DataType::INT32, device);
const StorageView* pinput = &input;

if (ScopedMPISetter::getNRanks() > 1 && _affected_by_tp) {
if (ScopedMPISetter::getNRanks() > 1 && _is_layer_out) {
StorageView input_reshaped(input.shape(), input.dtype(), input.device());
Shape shape = input.shape();
dim_t batch_size = shape[0];
Expand Down Expand Up @@ -381,6 +381,8 @@ namespace ctranslate2 {
}

_gemm_op(qinput, *weight, qoutput, compensation);
if (ScopedMPISetter::getNRanks() >= 1 && ScopedMPISetter::getCurRank() == 0 && _is_layer_out)
bias = nullptr;
_dequantize_op(qoutput,
qinput_scale,
*qscale,
Expand All @@ -389,6 +391,8 @@ namespace ctranslate2 {
output,
bias);
} else {
if (ScopedMPISetter::getNRanks() >= 1 && ScopedMPISetter::getCurRank() == 0 && _is_layer_out)
bias = nullptr;
_gemm_op(input, *weight, output, nullptr, bias);
}
}
Expand Down

0 comments on commit bcefb34

Please sign in to comment.