Skip to content

Commit

Permalink
Add example usage
Browse files Browse the repository at this point in the history
  • Loading branch information
antoniojkim committed Aug 4, 2022
1 parent 67aa2f2 commit 876ec6e
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 0 deletions.
30 changes: 30 additions & 0 deletions python/torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
#pragma once

#include "torch/csrc/lazy/backend/backend_device.h"
#include "torch/csrc/lazy/core/tensor.h"

#include "../ops/device_data.h"


namespace torch {
namespace lazy {

inline torch::lazy::DeviceData* device_data_cast(
const at::Tensor& tensor, c10::optional<torch::lazy::BackendDevice> device = c10::nullopt
) {
if (!device) {
device = torch::lazy::GetBackendDevice(tensor);
}
TORCH_CHECK(device);
torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device);
if (lazy_tensor) {
torch::lazy::Value param_value = lazy_tensor->GetIrValue();
if (param_value && param_value->op() == torch::lazy::DeviceData::ClassOpKind()) {
return dynamic_cast<torch::lazy::DeviceData*>(param_value.node.get());
}
}
return nullptr;
}

} // namespace lazy
} // namespace torch
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <torch_mlir/csrc/base_lazy_backend/mlir_lowering_context.h>
#include <torch_mlir/csrc/base_lazy_backend/utils/sys_utils.h>
#include <torch_mlir/csrc/base_lazy_backend/utils/tensor_utils.h>

#include <exception>
#include <iostream>
Expand Down Expand Up @@ -73,6 +74,15 @@ PYBIND11_MODULE(_REFERENCE_LAZY_BACKEND, m) {
torch::lazy::GetLatestComputation().get());
return py::cast(computation);
});
m.def("set_parameter_name",
[](const at::Tensor& tensor, const std::string& name) -> bool {
torch::lazy::DeviceData* ir_node = torch::lazy::device_data_cast(tensor);
if (ir_node) {
ir_node->SetName(name);
return true;
}
return false;
});
m.def("_initialize", []() {
NoGilSection gil;
Initialize();
Expand Down

0 comments on commit 876ec6e

Please sign in to comment.