Skip to content

Commit

Permalink
Merge branch 'develop' into scatter
Browse files Browse the repository at this point in the history
  • Loading branch information
lisamhy authored Nov 8, 2023
2 parents f33b83c + e5ef5d3 commit 25b7d8f
Show file tree
Hide file tree
Showing 484 changed files with 11,071 additions and 4,536 deletions.
4 changes: 2 additions & 2 deletions .clang-tidy
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ clang-analyzer-optin.portability.UnixAPI,
clang-analyzer-security.insecureAPI.vfork,
-clang-analyzer-unix.API,
-clang-analyzer-unix.DynamicMemoryModeling,
-clang-analyzer-unix.Malloc,
clang-analyzer-unix.Malloc,
-clang-analyzer-unix.MallocSizeof,
-clang-analyzer-unix.MismatchedDeallocator,
clang-analyzer-unix.Vfork,
Expand All @@ -158,7 +158,7 @@ cppcoreguidelines-explicit-virtual-functions,
cppcoreguidelines-init-variables,
cppcoreguidelines-narrowing-conversions,
cppcoreguidelines-no-malloc,
-cppcoreguidelines-pro-type-const-cast,
cppcoreguidelines-pro-type-const-cast,
-cppcoreguidelines-pro-type-member-init,
-cppcoreguidelines-slicing,
-hicpp-avoid-goto,
Expand Down
20 changes: 10 additions & 10 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -160,13 +160,13 @@ repos:
hooks:
- id: cmakelint
args: [--config=./tools/codestyle/.cmakelintrc]

- repo: local
hooks:
- id: sort-txt-file
name: sort-txt-file
description: Sorts each line string in a text file
entry: python ./tools/codestyle/sort_txt_file.py
language: python
files: test/white_list/new_ir_op_test_white_list
args: []
# Others
- repo: local
hooks:
- id: sort-txt-file
name: sort-txt-file
description: Sorts each line string in a text file
entry: python ./tools/codestyle/sort_txt_file.py
language: python
files: test/white_list/pir_op_test_white_list
args: []
8 changes: 4 additions & 4 deletions cmake/cinn.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,8 @@ cinn_cc_library(
add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB)
add_dependencies(cinnapi GEN_LLVM_RUNTIME_IR_HEADER ${core_deps})
if(NOT CINN_ONLY)
target_link_libraries(cinnapi pd_op_dialect phi)
add_dependencies(cinnapi pd_op_dialect phi)
target_link_libraries(cinnapi op_dialect_vjp phi)
add_dependencies(cinnapi op_dialect_vjp phi)
endif()

target_link_libraries(cinnapi ${PYTHON_LIBRARIES})
Expand Down Expand Up @@ -222,8 +222,8 @@ function(gen_cinncore LINKTYPE)
add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ZLIB::ZLIB)
add_dependencies(${CINNCORE_TARGET} GEN_LLVM_RUNTIME_IR_HEADER ${core_deps})
if(NOT CINN_ONLY)
target_link_libraries(${CINNCORE_TARGET} pd_op_dialect phi)
add_dependencies(${CINNCORE_TARGET} pd_op_dialect phi)
target_link_libraries(${CINNCORE_TARGET} op_dialect_vjp phi)
add_dependencies(${CINNCORE_TARGET} op_dialect_vjp phi)
endif()

add_dependencies(${CINNCORE_TARGET} pybind)
Expand Down
2 changes: 1 addition & 1 deletion cmake/external/xpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ set(XPU_XFT_LIB_NAME "libxft.so")
set(XPU_XPTI_LIB_NAME "libxpti.so")

if(NOT DEFINED XPU_BASE_DATE)
set(XPU_BASE_DATE "20231025")
set(XPU_BASE_DATE "20231103")
endif()
set(XPU_XCCL_BASE_VERSION "1.0.53.6")
if(NOT DEFINED XPU_XFT_BASE_VERSION)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,12 @@ SearchState EvolutionarySearch::Mutate(
// ir_schedule
const auto& task_key = tune_task_.serialized_key;
InitialTaskRegistry* task_registry = InitialTaskRegistry::Global();
ir::IRSchedule new_ir_sch(
ir::IRSchedule pir_sch(
ir::ir_utils::IRCopy(task_registry->Get(task_key)->module_expr),
utils::ForkRandomState(rand_seed));
new_trace.Replay(&new_ir_sch, true);
ApplyPostScheduleRules(&new_ir_sch, post_schedule_rules_);
auto res = SearchState(std::move(new_ir_sch));
new_trace.Replay(&pir_sch, true);
ApplyPostScheduleRules(&pir_sch, post_schedule_rules_);
auto res = SearchState(std::move(pir_sch));

VLOG(5) << JoinStatesDebugString(
"EvolutionarySearch::Mutate", {state, res}, /*verbose=*/VLOG_IS_ON(6));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ TEST(MutateTileSize, Basic) {
// repeated.
utils::LinearRandomEngine::StateType rand_seed = 123;
ir::IRSchedule ir_schedule(module_expr, rand_seed);
ir::IRSchedule new_ir_schedule(ir_schedule);
ir::IRSchedule pir_schedule(ir_schedule);

// apply schedule
auto loops = ir_schedule.GetLoops("C");
Expand All @@ -74,13 +74,13 @@ TEST(MutateTileSize, Basic) {
MutateTileSize mutator;
ir::ScheduleDesc sch_desc =
mutator.Apply(ir_schedule.GetTraceDesc(), &rand_seed);
sch_desc.Replay(&new_ir_schedule, true);
sch_desc.Replay(&pir_schedule, true);
VLOG(6) << "Expr before mutate tile size: \n"
<< ir_schedule.GetModule().GetExprs()[0];
VLOG(6) << "Expr after mutate tile size: \n"
<< new_ir_schedule.GetModule().GetExprs()[0];
<< pir_schedule.GetModule().GetExprs()[0];

std::string target_new_ir = R"ROC({
std::string target_pir = R"ROC({
ScheduleBlock(root)
{
serial for (i_1, 0, 2)
Expand Down Expand Up @@ -115,7 +115,7 @@ TEST(MutateTileSize, Basic) {
ss << exprs[0];
return ss.str();
};
ASSERT_EQ(get_ir_str(&new_ir_schedule), target_new_ir);
ASSERT_EQ(get_ir_str(&pir_schedule), target_pir);

std::vector<int> last_tile_factors = {2, 16};
for (int i = 0; i < 10; ++i) {
Expand Down
4 changes: 4 additions & 0 deletions paddle/cinn/backends/nvrtc/nvrtc_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@

PD_DECLARE_string(cinn_nvcc_cmd_path);
PD_DECLARE_bool(nvrtc_compile_to_cubin);
PD_DECLARE_bool(cinn_nvrtc_cubin_with_fmad);

namespace cinn {
namespace backends {
Expand Down Expand Up @@ -106,6 +107,9 @@ std::string Compiler::CompileCudaSource(const std::string& code,
}
if (compile_to_cubin_) {
compile_options.push_back("-arch=sm_" + cc);
std::string enable_fmad =
FLAGS_cinn_nvrtc_cubin_with_fmad ? "true" : "false";
compile_options.push_back("--fmad=" + enable_fmad);
} else {
compile_options.push_back("-arch=compute_" + cc);
}
Expand Down
27 changes: 24 additions & 3 deletions paddle/cinn/common/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,22 @@ struct NameGenerator {
mutable std::mutex mutex_;
};

struct PrettyNamer {
const std::string& GetOrNew(const size_t hash_key,
const std::string& name_hint) {
if (pretty_names_.find(hash_key) == pretty_names_.end()) {
pretty_names_[hash_key] = name_generator_.New(name_hint);
}
return pretty_names_.at(hash_key);
}

NameGenerator& GetNameGenerator() { return name_generator_; }

private:
absl::flat_hash_map<size_t, std::string> pretty_names_;
NameGenerator name_generator_;
};

class Context {
public:
static Context& Global();
Expand All @@ -61,10 +77,15 @@ class Context {
* @param name_hint The prefix.
*/
std::string NewName(const std::string& name_hint) {
return name_generator_.New(name_hint);
return pretty_namer_.GetNameGenerator().New(name_hint);
}

void ResetNameId() { name_generator_.ResetID(); }
std::string PrettyUniqName(const size_t hash_key,
const std::string& name_hint) {
return pretty_namer_.GetOrNew(hash_key, name_hint);
}

void ResetNameId() { pretty_namer_.GetNameGenerator().ResetID(); }

const std::vector<std::string>& runtime_include_dir();

Expand All @@ -82,7 +103,7 @@ class Context {
private:
Context() = default;

NameGenerator name_generator_;
PrettyNamer pretty_namer_;
std::vector<std::string> runtime_include_dir_;
mutable std::mutex mutex_;

Expand Down
14 changes: 7 additions & 7 deletions paddle/cinn/common/macros.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@
__test_global_namespace_##uniq_name##__>::value, \
msg)

#define USE_FUSION_PASS(pass_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_fusion_pass_##pass_name, \
"USE_OP_ITSELF must be called in global namespace"); \
extern int TouchFusionPassRegistrar_##pass_name(); \
[[maybe_unused]] static int __use_fusion_pass_##pass_name##_ = \
TouchFusionPassRegistrar_##pass_name()
#define USE_FUSION_PASS(pass_name) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \
__use_cinn_fusion_pass_##pass_name, \
"USE_FUSION_PASS must be called in global namespace"); \
extern int TouchCinnFusionPassRegistrar_##pass_name(); \
[[maybe_unused]] static int __use_cinn_fusion_pass_##pass_name##_ = \
TouchCinnFusionPassRegistrar_##pass_name()
2 changes: 1 addition & 1 deletion paddle/cinn/hlir/dialect/operator/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ if(NOT CINN_ONLY)
manual_op.cc
op_attribute.cc
DEPS
pd_op_dialect)
op_dialect_vjp)

target_include_directories(cinn_op_dialect PRIVATE ${CINN_DIALECT_BINARY_DIR})
endif()
6 changes: 3 additions & 3 deletions paddle/cinn/hlir/dialect/operator/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ if(NOT CINN_ONLY)
cinn_group_lowering_pass.cc
tensor_node.cc
DEPS
pd_op_dialect
op_dialect_vjp
pir_compiler
cinn_runtime_dialect)

Expand All @@ -18,7 +18,7 @@ if(NOT CINN_ONLY)
DEPS
drr
cinn_op_dialect
pd_op_dialect)
op_dialect_vjp)

cinn_cc_library(
add_broadcast_to_elementwise_pass
Expand All @@ -27,5 +27,5 @@ if(NOT CINN_ONLY)
DEPS
pir
cinn_op_dialect
pd_op_dialect)
op_dialect_vjp)
endif()
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,19 @@ bool IsSameDim(const phi::DDim& first, const std::vector<int64_t>& second) {
return false;
}

std::vector<int64_t> GetBroadcastAxis(const phi::DDim& in_shape,
const std::vector<int64_t>& out_shape) {
std::vector<int64_t> broadcast_axes(in_shape.size(), 0);
auto in_shape_size = in_shape.size();
if (in_shape_size >= 1) {
for (int i = 1; i <= in_shape_size; ++i) {
broadcast_axes[in_shape_size - i] = out_shape.size() - i;
}
}

return broadcast_axes;
}

bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
auto x_dims = op->operand_source(0)
.type()
Expand All @@ -93,21 +106,21 @@ bool ProcessOp(pir::Operation* op, pir::PatternRewriter* rewriter) {

if (x_dims != y_dims) {
auto output_shape = GetOutputShape(x_dims, y_dims);
std::vector<int64_t> vec_dims;
for (int64_t i = 0; i < output_shape.size(); ++i) {
vec_dims.push_back(i);
}
if (!IsSameDim(x_dims, output_shape)) {
// add broadcast to input 0
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
op->operand_source(0), vec_dims, output_shape);
op->operand_source(0),
GetBroadcastAxis(x_dims, output_shape),
output_shape);

op->operand(0).set_source(new_transpose_op->result(0));
}

if (!IsSameDim(y_dims, output_shape)) {
auto new_transpose_op = rewriter->Build<cinn::dialect::BroadcastOp>(
op->operand_source(1), vec_dims, output_shape);
op->operand_source(1),
GetBroadcastAxis(y_dims, output_shape),
output_shape);

op->operand(1).set_source(new_transpose_op->result(0));
}
Expand Down
Loading

0 comments on commit 25b7d8f

Please sign in to comment.