Skip to content

Commit

Permalink
Merge branch 'cderb/set_f8_beta_api' of https://github.com/ROCmSoftwa…
Browse files Browse the repository at this point in the history
…rePlatform/MIOpen into cderb/set_f8_beta_api
  • Loading branch information
cderb committed Oct 3, 2023
2 parents 7bb707e + 39fdfcb commit 937197a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 8 deletions.
9 changes: 5 additions & 4 deletions src/gemm_v2.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
#include <miopen/env.hpp>
#include <miopen/tensor.hpp>
#include <miopen/handle.hpp>
#include <miopen/datatype.hpp>

#if MIOPEN_BACKEND_HIP
#include <miopen/hipoc_kernel.hpp>
Expand Down Expand Up @@ -173,6 +174,7 @@ rocblas_status miopen_rocblas_gemm_ex3(const miopen::Handle& handle,
rocblas_gemm_algo::rocblas_gemm_algo_standard,
0,
flags); // gfx90a_alt_impl));
return rb_status;
#pragma clang diagnostic pop
#endif
MIOPEN_THROW(miopenStatusBadParm, "An appropriate version of rocBLAS is required for this op");
Expand Down Expand Up @@ -258,10 +260,9 @@ std::ostream& operator<<(std::ostream& stream, const GemmDescriptor& gemm_desc)
<< "strideC " << gemm_desc.strideC << ", "
<< "alpha " << gemm_desc.alpha << ", "
<< "beta " << gemm_desc.beta << ", "
<< "dataType " << gemm_desc.dataType << "a_cast_type" << gemm_desc.a_cast_type
<< ", "
<< "b_cast_type" << gemm_desc.b_cast_type << ", "
<< "} ";
<< "dataType " << GetDataType(gemm_desc.dataType) << ", "
<< "a_cast_type " << GetDataType(gemm_desc.a_cast_type) << ", "
<< "b_cast_type " << GetDataType(gemm_desc.b_cast_type) << "} ";
}

#if MIOPEN_USE_ROCBLAS
Expand Down
7 changes: 5 additions & 2 deletions src/invoker_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,8 +104,11 @@ void InvokerCache::Register(const Key& key, const Invoker& invoker)
auto it = invokers.find(key.first);
if(it != invokers.end())
it->second.invokers.insert({key.second, invoker});
auto& item = invokers.insert({key.first, Item{}}).first->second;
item.invokers.insert({key.second, invoker});
else
{
auto& item = invokers.insert({key.first, Item{}}).first->second;
item.invokers.insert({key.second, invoker});
}
MIOPEN_LOG_I2("Invoker registered for algorithm " << key.first << " and solver " << key.second);
}

Expand Down
4 changes: 2 additions & 2 deletions src/solver/conv_direct_naive_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,12 @@ std::string ConvDirectNaiveConvCompileOption(const ExecutionContext& ctx,
ss << " -DWEIGHTS_TYPE=" << miopen::GetDataType(problem.GetWeightsDataType());
ss << " -DOUTPUT_TYPE="
<< miopen::GetDataType(ProblemInterpreter::GetOutputDataType(problem));
const auto in_cast_type = problem.GetInCastType();
const auto in_cast_type = ProblemInterpreter::GetInputCastType(problem);
if(in_cast_type)
ss << " -DINPUT_CAST_TYPE=" << miopen::GetDataType(*in_cast_type);
const auto wei_cast_type = problem.GetWeightsCastType();
if(wei_cast_type)
ss << " -DWEIGHTS_CAST_TYPE=" << miopen::GetDataType(*(wei_cast_type));
ss << " -DWEIGHTS_CAST_TYPE=" << miopen::GetDataType(*wei_cast_type);
const auto out_cast_type = ProblemInterpreter::GetOutputCastType(problem);
if(out_cast_type)
ss << " -DOUTPUT_CAST_TYPE=" << miopen::GetDataType(*out_cast_type);
Expand Down

0 comments on commit 937197a

Please sign in to comment.