Skip to content

Commit

Permalink
feat(//core/conversion/converter/Arg): Add typechecking to the unwrap
Browse files Browse the repository at this point in the history
functions

Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Mar 30, 2020
1 parent 8be79e1 commit 73bfd4c
Showing 1 changed file with 32 additions and 17 deletions.
49 changes: 32 additions & 17 deletions core/conversion/converters/Arg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ std::string Arg::type_name() const {
default:
return "None";
}

}

const torch::jit::IValue* Arg::IValue() const {
if (type_ == Type::kIValue) {
return ptr_.ivalue;
Expand Down Expand Up @@ -150,7 +150,7 @@ double Arg::unwrapToDouble(double default_val) {

double Arg::unwrapToDouble() {
return this->unwrapTo<double>();
}
}

bool Arg::unwrapToBool(bool default_val) {
return this->unwrapTo<bool>(default_val);
Expand Down Expand Up @@ -194,26 +194,41 @@ c10::List<bool> Arg::unwrapToBoolList() {

template<typename T>
T Arg::unwrapTo(T default_val) {
if (isIValue()) {
// TODO: implement Tag Checking
return ptr_.ivalue->to<T>();
try {
return this->unwrapTo<T>();
} catch(trtorch::Error& e) {
LOG_DEBUG("In arg unwrapping, returning default value provided (" << e.what() << ")");
return default_val;
}
LOG_DEBUG("In arg unwrapping, returning default value provided");
return default_val;
}


template<typename T>
T Arg::unwrapTo() {
if (isIValue()) {
//TODO: Implement Tag checking
return ptr_.ivalue->to<T>();
//TODO: Exception
//LOG_INTERNAL_ERROR("Requested unwrapping of arg IValue assuming it was " << typeid(T).name() << " however type is " << ptr_.ivalue->type());

TRTORCH_CHECK(isIValue(), "Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name());
auto ivalue = ptr_.ivalue;
bool correct_type = false;
if (typeid(T) == typeid(double)) {
correct_type = ivalue->isDouble();
} else if (typeid(T) == typeid(bool)) {
correct_type = ivalue->isBool();
} else if (typeid(T) == typeid(int64_t)) {
correct_type = ivalue->isInt();
} else if (typeid(T) == typeid(at::Tensor)) {
correct_type = ivalue->isTensor();
} else if (typeid(T) == typeid(c10::Scalar)) {
correct_type = ivalue->isScalar();
} else if (typeid(T) == typeid(c10::List<int64_t>)) {
correct_type = ivalue->isIntList();
} else if (typeid(T) == typeid(c10::List<double>)) {
correct_type = ivalue->isDoubleList();
} else if (typeid(T) == typeid(c10::List<bool>)) {
correct_type = ivalue->isBoolList();
} else {
TRTORCH_THROW_ERROR("Requested unwrapping of arg to an unsupported type: " << typeid(T).name());
}
TRTORCH_THROW_ERROR("Requested unwrapping of arg assuming it was an IValue, however arg type is " << type_name());
return T();

TRTORCH_CHECK(correct_type, "Requested unwrapping of arg IValue assuming it was " << typeid(T).name() << " however type is " << *(ptr_.ivalue->type()));
return ptr_.ivalue->to<T>();
}


Expand Down

0 comments on commit 73bfd4c

Please sign in to comment.