Skip to content

Commit

Permalink
[Prim][PIR] Set prim gflag for pure cpp (#60505)
Browse files Browse the repository at this point in the history
* inference support decomp

* polish code

* add decomp base define

* add decomp base define2

* change decomp infer

* fix symbol overload

* fix test case

* debug

* debug

* decomp add debug info

* add cpp flag

* revert

* remove unused flag
  • Loading branch information
cyber-pioneer authored Jan 3, 2024
1 parent 54b95ae commit 99af9f7
Showing 1 changed file with 8 additions and 2 deletions.
10 changes: 8 additions & 2 deletions paddle/fluid/prim/utils/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,15 @@
#include "paddle/fluid/prim/utils/static/static_global_utils.h"

PADDLE_DEFINE_EXPORTED_bool(prim_enabled, false, "enable_prim or not");
PADDLE_DEFINE_EXPORTED_bool(prim_all, false, "enable prim_all or not");
PADDLE_DEFINE_EXPORTED_bool(prim_forward, false, "enable prim_forward or not");
PADDLE_DEFINE_EXPORTED_bool(prim_backward, false, "enable prim_backward not");

namespace paddle {
namespace prim {
bool PrimCommonUtils::IsBwdPrimEnabled() {
return StaticCompositeContext::Instance().IsBwdPrimEnabled();
bool res = StaticCompositeContext::Instance().IsBwdPrimEnabled();
return res || FLAGS_prim_all || FLAGS_prim_backward;
}

void PrimCommonUtils::SetBwdPrimEnabled(bool enable_prim) {
Expand All @@ -36,7 +41,8 @@ void PrimCommonUtils::SetEagerPrimEnabled(bool enable_prim) {
}

bool PrimCommonUtils::IsFwdPrimEnabled() {
return StaticCompositeContext::Instance().IsFwdPrimEnabled();
bool res = StaticCompositeContext::Instance().IsFwdPrimEnabled();
return res || FLAGS_prim_all || FLAGS_prim_forward;
}

void PrimCommonUtils::SetFwdPrimEnabled(bool enable_prim) {
Expand Down

0 comments on commit 99af9f7

Please sign in to comment.