Skip to content

Commit

Permalink
xe: jit: gemm: fix type update logic
Browse files Browse the repository at this point in the history
  • Loading branch information
petercad committed Jul 16, 2024
1 parent b3162a9 commit e8e05b4
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 28 deletions.
42 changes: 14 additions & 28 deletions src/gpu/intel/jit/gemm/gen_gemm_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -493,34 +493,20 @@ status_t gen_gemm_nocopy_kernel_desc_t::select_kernel(compute::gpu_arch_t arch,

if (!entry_) return status::unimplemented;

if (mode & mode_tf32) {
if (entry_->selector.precisions[0][0] == 'T')
problem_.Ta = problem_.Ta_ext = Type::tf32;
if (entry_->selector.precisions[1][0] == 'T')
problem_.Tb = problem_.Tb_ext = Type::tf32;
}

if (mode & mode_bf16x1) {
if (utils::one_of(entry_->selector.precisions[0][0], 'B', '['))
problem_.Ta = Type::bf16;
if (utils::one_of(entry_->selector.precisions[1][0], 'B', '['))
problem_.Tb = Type::bf16;
}

if (mode & mode_f16x1) {
if (utils::one_of(entry_->selector.precisions[0][0], 'H', '['))
problem_.Ta = Type::f16;
if (utils::one_of(entry_->selector.precisions[1][0], 'H', '['))
problem_.Tb = Type::f16;
}

if (problem_.Ta.isInt4()) {
if (entry_->selector.precisions[0][0] == '[') problem_.Ta = Type::s8;
}

if (problem_.Tb.isInt4()) {
if (entry_->selector.precisions[1][0] == '[') problem_.Tb = Type::s8;
}
// Update A/B types from entry.
Type Ta_new, Ta_ext_new, Tb_new, Tb_ext_new;
parsePrecisions(entry_->selector.precisions[0], Ta_ext_new, Ta_new);
parsePrecisions(entry_->selector.precisions[1], Tb_ext_new, Tb_new);

auto update_type = [](Type &T, Type T_new, bool sz_change = false) {
if ((T.bits() != T_new.bits()) && !sz_change) return;
if (T.isF8() && T_new.isF8()) return;
T = T.isSigned() ? T_new.asSigned() : T_new.asUnsigned();
};
update_type(problem_.Ta, Ta_new, true);
update_type(problem_.Tb, Tb_new, true);
update_type(problem_.Ta_ext, Ta_ext_new);
update_type(problem_.Tb_ext, Tb_ext_new);

auto block_k = entry_->driverInfo.blocking[LoopK];
if (block_k > 0 && k > block_k && beta != 1.0f) problem_.beta = Scalar();
Expand Down
3 changes: 3 additions & 0 deletions src/gpu/intel/jit/gemm/gen_gemm_kernel_generator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ class Type {
constexpr bool isSigned() const {
return (uint32_t(val) & 0x110000) != 0x100000;
}
constexpr Type asUnsigned() const {
return static_cast<_Type>(uint32_t(val) & ~(isInteger() ? 0x10000 : 0));
}
constexpr Type asSigned() const {
return static_cast<_Type>(uint32_t(val) | (isInteger() ? 0x10000 : 0));
}
Expand Down
20 changes: 20 additions & 0 deletions src/gpu/intel/jit/gemm/strategy_parser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,26 @@ void adjustStrategy(HW hw, const GEMMProblem &problem, GEMMStrategy &strategy,
strategy.remHandling[LoopN] = RemainderHandling::Ignore;
}

const char *parsePrecision(const char *s, Type &precision) {
if (*s) { precision = charPrecision(*s++); }
return s;
}

const char *parsePrecisions(const char *s, Type &precision1, Type &precision2) {
if (*s == '[') {
s++;
s = parsePrecision(s, precision1);
s = parsePrecision(s, precision2);
if (*s++ != ']')
throw std::runtime_error("Syntax error in precisions; expected ]");
} else {
s = parsePrecision(s, precision1);
precision2 = precision1;
}

return s;
}

} // namespace jit
} // namespace intel
} // namespace gpu
Expand Down
3 changes: 3 additions & 0 deletions src/gpu/intel/jit/gemm/strategy_parser.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ void parseStrategy(const char *str, ngen::HW hw, const GEMMProblem &problem,
void adjustStrategy(ngen::HW hw, const GEMMProblem &problem,
GEMMStrategy &strategy, const char *tags = nullptr);

const char *parsePrecision(const char *s, Type &precision);
const char *parsePrecisions(const char *s, Type &precision1, Type &precision2);

} // namespace jit
} // namespace intel
} // namespace gpu
Expand Down

0 comments on commit e8e05b4

Please sign in to comment.