Skip to content

Commit

Permalink
Resolve remaining LTC CI failures (#1110)
Browse files Browse the repository at this point in the history
* Replace CHECK_EQ with TORCH_CHECK_EQ

* Check value of TORCH_MLIR_USE_INSTALLED_PYTORCH during LTC build

* Update LTC XFAIL with NewZerosModule ops

* Explicitly blacklist _like ops

* Automatically blacklist new_/_like ops

* Prune away unused Python dependencies from LTC

* Add flag to disable LTC

* Autogen dummy _REFERENCE_LAZY_BACKEND library when LTC is disabled

* Implement compute_shape_var

* Removed Var tests from XFAIL Set

* XFAIL tests using _local_scalar_dense or index.Tensor

* Add StdDim tests to XFAIL set

* Autogen aten::cat
  • Loading branch information
henrytwo committed Jul 30, 2022
1 parent 4253622 commit 2c3b360
Show file tree
Hide file tree
Showing 15 changed files with 155 additions and 100 deletions.
1 change: 1 addition & 0 deletions .github/workflows/buildAndTest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ jobs:
-DMLIR_ENABLE_BINDINGS_PYTHON=OFF \
-DTORCH_MLIR_ENABLE_MHLO=ON \
-DTORCH_MLIR_USE_INSTALLED_PYTORCH=${{ matrix.torch-binary }} \
-DTORCH_MLIR_ENABLE_LTC=OFF \
-DPython3_EXECUTABLE=$(which python) \
.
Expand Down
2 changes: 2 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ if(TORCH_MLIR_ENABLE_MHLO)
endif()
endif()

option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" ON)

torch_mlir_add_llvm_external_project(
torch-mlir-dialects
TORCH_MLIR_DIALECTS
Expand Down
5 changes: 4 additions & 1 deletion build_tools/autogen_ltc_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def lowering_function(self, schema: LazyIrSchema):
{emplace_arguments_str}
{emplace_kwarguments}
torch::lazy::TorchMlirOpVector {schema.aten_name}_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
TORCH_CHECK_EQ({schema.aten_name}_out.size(), {len(schema.returns)});
return {schema.aten_name}_out;
}}
Expand Down Expand Up @@ -236,6 +236,9 @@ def get_opnames(ops):
continue
if base in supported or op in supported:
continue
# Blacklist new_/_like ops since they are non-differentiable.
if any(o.startswith("new_") or o.endswith("_like") for o in (base, op)):
continue

if func.has_composite_implicit_autograd_kernel:
composite_implicit.add(op)
Expand Down
1 change: 0 additions & 1 deletion build_tools/autogen_ltc_backend.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ supported:
# - bernoulli
# - bernoulli_
- _to_copy
- cat
- clone
- empty.memory_format
- empty_strided
Expand Down
25 changes: 15 additions & 10 deletions e2e_testing/torchscript/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@
"FullLikeModuleInt3D_basic",
"GeFloatIntModule_basic",
"GeFloatModule_basic",
"GeIntModule_basic",
"GtFloatIntModule_basic",
"GtIntModule_basic",
"HBC_basic",
Expand Down Expand Up @@ -266,6 +267,11 @@
"IndexPutImpl3DFloatNonAccumulateModule_basic",
"IndexTensorModule3dInput_basic",
"IndexTensorModule_basic",
"IndexTensorMultiInputContiguousCenter_basic",
"IndexTensorMultiInputNonContiguous_basic",
"IndexTensorMultiInputOneDim_basic",
"IndexTensorMultiInputThreeIndexers_basic",
"IndexTensorMultiInput_basic",
"IndexTensorSelectDimModule_basic",
"Matmul_dot",
"Matmul_matvec",
Expand All @@ -288,6 +294,12 @@
"NewOnesModuleFloat3D_basic",
"NewOnesModuleInt2D_basic",
"NewOnesModuleInt3D_basic",
"NewZerosModuleDefaultDtype_basic",
"NewZerosModuleFalsePinMemory_basic",
"NewZerosModuleFloat2D_basic",
"NewZerosModuleFloat3D_basic",
"NewZerosModuleInt2D_basic",
"NewZerosModuleInt3D_basic",
"OnesLikeModule_defaultDtype",
"OnesLikeModule_falsePinMemory",
"OnesLikeModule_float",
Expand All @@ -302,6 +314,9 @@
"SliceStartEqEndModule_basic",
"SqrtIntModule_basic",
"StdBiasedModule_basic",
"StdDimBiasedModule_basic",
"StdDimKeepDimFalseModule_basic",
"StdDimKeepDimTrueModule_basic",
"StdUnbiasedModule_basic",
"SubFloatModule_basic",
"SubIntModule_basic",
Expand All @@ -317,15 +332,5 @@
"UniformModule_basic",
"UniformStaticModule_basic",
"UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic",
"VarBiasedModule_basic",
"VarDimAllDimReduceModule_basic",
"VarDimBiasedModule_basic",
"VarDimKeepDimFalseModule_basic",
"VarDimModule_basic",
"VarDimMultiDimModule_basic",
"VarDimNegativeModule_basic",
"VarDimSingleDimModule_basic",
"VarDimUnbiasedModule_basic",
"VarUnbiasedModule_basic",
"ViewCollapseDynamicWithAtenSizeIntModule_basic",
}
56 changes: 33 additions & 23 deletions python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,30 @@ set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir")
# We vendor our own MLIR instance in the `torch_mlir` namespace.
add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.")

################################################################################
# PyTorch
################################################################################

option(TORCH_MLIR_USE_INSTALLED_PYTORCH "Build from local PyTorch in environment" ON)

if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH)
# Source builds
set(ENV{PYTORCH_REPO} ${PYTORCH_REPO})
set(ENV{PYTORCH_BRANCH} ${PYTORCH_BRANCH})
set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET})
set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES})
set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER})
set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER})
execute_process(
COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/../build_tools/build_libtorch.sh
RESULT_VARIABLE _result
)
if(_result)
message(FATAL_ERROR "Failed to run `build_libtorch.sh`")
endif()
set(TORCH_INSTALL_PREFIX "libtorch")
endif()

################################################################################
# Sources
################################################################################
Expand Down Expand Up @@ -60,33 +84,17 @@ declare_mlir_python_extension(TorchMLIRPythonExtensions.Main
# Lazy Tensor Core
################################################################################

add_subdirectory(torch_mlir/csrc/base_lazy_backend)
if(TORCH_MLIR_ENABLE_LTC)
add_subdirectory(torch_mlir/csrc/base_lazy_backend)
endif()
# Reference backend has a separate check for TORCH_MLIR_ENABLE_LTC.
add_subdirectory(torch_mlir/csrc/reference_lazy_backend)

################################################################################
# Optionally handle JIT IR importer.
################################################################################

option(TORCH_MLIR_USE_INSTALLED_PYTORCH "Build from local PyTorch in environment" ON)

if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER)
if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH)
# Source builds
set(ENV{PYTORCH_REPO} ${PYTORCH_REPO})
set(ENV{PYTORCH_BRANCH} ${PYTORCH_BRANCH})
set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET})
set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES})
set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER})
set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER})
execute_process(
COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/../build_tools/build_libtorch.sh
RESULT_VARIABLE _result
)
if(_result)
message(FATAL_ERROR "Failed to run `build_libtorch.sh`")
endif()
set(TORCH_INSTALL_PREFIX "libtorch")
endif()
add_subdirectory(torch_mlir/dialects/torch/importer/jit_ir)
add_subdirectory(torch_mlir_e2e_test)
endif()
Expand Down Expand Up @@ -154,8 +162,10 @@ endif()
# TODO: Add after macOS builds are fixed
#add_dependencies(TorchMLIRPythonModules torch_mlir_custom_op_example)

# Add Torch-MLIR LTC backend as dependency
add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend)
add_dependencies(TorchMLIRPythonModules reference_lazy_backend)
if(TORCH_MLIR_ENABLE_LTC)
# Add Torch-MLIR LTC backend as dependency
add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend)
add_dependencies(TorchMLIRPythonModules reference_lazy_backend)
endif()

add_subdirectory(test)
10 changes: 7 additions & 3 deletions python/torch_mlir/csrc/base_lazy_backend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@

list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules")
include(TorchMLIRPyTorch)

TorchMLIRProbeForPyTorchInstall()
if(TORCH_MLIR_USE_INSTALLED_PYTORCH)
TorchMLIRConfigurePyTorch()
else()
set(Torch_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libtorch/share/cmake/Torch")
endif()

find_package(Torch 1.11 REQUIRED)

TorchMLIRConfigurePyTorch()
set(TORCHGEN_DIR ${Torch_ROOT}/../../../torchgen)

include_directories(BEFORE
Expand Down Expand Up @@ -76,8 +82,6 @@ target_link_libraries(torch_mlir_ltc_backend
TorchMLIRAggregateCAPI
TorchMLIRJITIRImporter
${TORCH_LIBRARIES}
${Python3_LIBRARIES}
torch_python
)

message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ void TorchMlirLoweringContext::Lower(const Node* node) {
dynamic_cast<const torch::lazy::TorchMlirNode*>(node)) {
TorchMlirOpVector ops = torch_mlir_node->Lower(function_, this);
CHECK(!ops.empty()) << "Failed to lower: " << *node;
CHECK_EQ(node->num_outputs(), ops.size());
TORCH_CHECK_EQ(node->num_outputs(), ops.size());
for (size_t i = 0; i < ops.size(); ++i) {
AssignOutputOp(torch::lazy::Output(node, i), ops[i]);
}
Expand Down
19 changes: 0 additions & 19 deletions python/torch_mlir/csrc/base_lazy_backend/mlir_native_functions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -154,25 +154,6 @@ void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) {
// // return self;
// }

at::Tensor LazyNativeFunctions::cat(at::TensorList tensors, int64_t dim) {
TORCH_LAZY_FN_COUNTER("lazy::");
auto lazy_tensors = torch::lazy::GetLtcTensors(tensors);
std::vector<torch::lazy::Value> values;
values.reserve(lazy_tensors.size());
for (auto& tensor : lazy_tensors) {
values.emplace_back(tensor->GetIrValue());
}

auto shapes = torch::lazy::compute_shape_cat(tensors, dim);
UNIMPLEMENTED_FUNCTION_ERROR();
// auto node =
// torch::lazy::MakeNode<ir::ops::Cat>(values, dim, std::move(shapes));
// auto result = torch::lazy::CreateAtenFromLtcTensor(
// torch::lazy::LazyTensor::Create(torch::lazy::Value(node, 0),
// lazy_tensors[0]->GetDevice()));
// return result;
}

// clone is special in LT because we make it a no-op.
// This should be safe to do, because every operator in the LT is functional.
at::Tensor LazyNativeFunctions::clone(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ GenerateClone(torch::jit::Value* val, TorchMlirFunction function) {
// Type of cloned value should be identical to the original one.
TorchMlirOpVector cloned =
LowerBuiltin(at::aten::clone, {val->type()}, function, clone_arguments);
CHECK_EQ(cloned.size(), 1);
TORCH_CHECK_EQ(cloned.size(), 1);
return cloned.front();
}

Expand Down Expand Up @@ -235,7 +235,7 @@ torch::jit::Value* GenerateSlice(
c10::ArrayRef<Shape>(
compute_shape_slice(base->type(), dim, start, end, step)),
function, arguments);
CHECK_EQ(selected.size(), 1);
TORCH_CHECK_EQ(selected.size(), 1);
return selected.front();
}

Expand Down
2 changes: 1 addition & 1 deletion python/torch_mlir/csrc/base_lazy_backend/ops/to_copy.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class ToCopy : public torch::lazy::TorchMlirNode {
kwarguments.emplace_back("non_blocking", non_blocking);
kwarguments.emplace_back("memory_format", memory_format);
torch::lazy::TorchMlirOpVector _to_copy_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments);
CHECK_EQ(_to_copy_out.size(), 1);
TORCH_CHECK_EQ(_to_copy_out.size(), 1);

return _to_copy_out;

Expand Down
7 changes: 7 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/shape_inference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,12 @@ compute_shape_mul(const at::Tensor& self, const at::Scalar& other) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}

std::vector<torch::lazy::Shape> compute_shape_var(
const at::Tensor& self, at::OptionalIntArrayRef dim,
c10::optional<int64_t> correction, bool keepdim) {
// Result of variance is scalar tensor.
return {Shape(self.scalar_type(), {})};
}

} // namespace lazy
} // namespace torch
89 changes: 51 additions & 38 deletions python/torch_mlir/csrc/reference_lazy_backend/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,15 @@

list(APPEND CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/python/torch_mlir/cmake/modules")
include(TorchMLIRPyTorch)

TorchMLIRProbeForPyTorchInstall()
find_package(Torch 1.11 REQUIRED)
if(TORCH_MLIR_USE_INSTALLED_PYTORCH)
TorchMLIRConfigurePyTorch()
else()
set(Torch_DIR "${CMAKE_CURRENT_SOURCE_DIR}/../../../../libtorch/share/cmake/Torch")
endif()

TorchMLIRConfigurePyTorch()
find_package(Torch 1.11 REQUIRED)

###########################################################################
# Setup Python development
Expand All @@ -21,39 +26,47 @@ mlir_configure_python_dev_packages()
# Library definition
###########################################################################

include_directories(BEFORE
${TORCH_INCLUDE_DIRS}
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}
${Python3_INCLUDE_DIRS}
${PYTHON_H_DIR}
${PROJECT_SOURCE_DIR}/python
)
link_directories("${TORCH_INSTALL_PREFIX}/lib")
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/lib)
add_link_options(-Wl,-rpath,$ORIGIN/lib)

set(REFERENCE_LAZY_BACKEND_CSRC
backend_impl.cpp
reference_lazy_backend_pybind.cpp
)
add_library(reference_lazy_backend SHARED ${REFERENCE_LAZY_BACKEND_CSRC})
add_dependencies(reference_lazy_backend
torch_mlir_ltc_backend
)
target_link_libraries(reference_lazy_backend
${TORCH_LIBRARIES}
${Python3_LIBRARIES}
torch_python
torch_mlir_ltc_backend
)

message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic")
set_target_properties(reference_lazy_backend PROPERTIES
LIBRARY_OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/reference_lazy_backend"
OUTPUT_NAME _REFERENCE_LAZY_BACKEND
PREFIX "${PYTHON_MODULE_PREFIX}"
SUFFIX "${PYTHON_MODULE_EXTENSION}"
CXX_VISIBILITY_PRESET "hidden"
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic"
)
set(LIBRARY_OUTPUT_PATH "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/reference_lazy_backend")
set(OUTPUT_NAME "_REFERENCE_LAZY_BACKEND")

if(TORCH_MLIR_ENABLE_LTC)
include_directories(BEFORE
${TORCH_INCLUDE_DIRS}
${CMAKE_CURRENT_SOURCE_DIR}
${CMAKE_CURRENT_BINARY_DIR}
${Python3_INCLUDE_DIRS}
${PYTHON_H_DIR}
${PROJECT_SOURCE_DIR}/python
)
link_directories("${TORCH_INSTALL_PREFIX}/lib")
link_directories(${CMAKE_CURRENT_SOURCE_DIR}/lib)
add_link_options(-Wl,-rpath,$ORIGIN/lib)

add_library(reference_lazy_backend SHARED
backend_impl.cpp
reference_lazy_backend_pybind.cpp
)
add_dependencies(reference_lazy_backend
torch_mlir_ltc_backend
)
target_link_libraries(reference_lazy_backend
${TORCH_LIBRARIES}
torch_mlir_ltc_backend
)

message(STATUS "TORCH_CXXFLAGS=${TORCH_CXXFLAGS} -Wno-pedantic")
set_target_properties(reference_lazy_backend PROPERTIES
LIBRARY_OUTPUT_DIRECTORY ${LIBRARY_OUTPUT_PATH}
OUTPUT_NAME ${OUTPUT_NAME}
PREFIX "${PYTHON_MODULE_PREFIX}"
SUFFIX "${PYTHON_MODULE_EXTENSION}"
CXX_VISIBILITY_PRESET "hidden"
COMPILE_FLAGS "${TORCH_CXXFLAGS} -Wno-pedantic"
)
else()
# To avoid import errors when LTC is disabled (and a bunch of checks
# associated with that), we will generate a dummy placeholder library.
execute_process(
COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/gen_dummy_lib.py ${LIBRARY_OUTPUT_PATH} ${OUTPUT_NAME}
)
endif()
Loading

0 comments on commit 2c3b360

Please sign in to comment.