diff --git a/runtime/bindings/python/hal.cc b/runtime/bindings/python/hal.cc index d102c00bb07e..61bd8a3180a8 100644 --- a/runtime/bindings/python/hal.cc +++ b/runtime/bindings/python/hal.cc @@ -6,6 +6,11 @@ #include "./hal.h" +#include +#include + +#include + #include "./local_dlpack.h" #include "./numpy_interop.h" #include "./vm.h" @@ -1066,12 +1071,34 @@ HalDevice HalDriver::CreateDeviceByURI(std::string& device_uri, // HAL module //------------------------------------------------------------------------------ -// TODO(multi-device): allow for multiple devices to be passed in. -VmModule CreateHalModule(VmInstance* instance, HalDevice* device) { - iree_hal_device_t* device_ptr = device->raw_ptr(); +VmModule CreateHalModule(VmInstance* instance, std::optional device, + std::optional devices) { + if (device && devices) { + PyErr_SetString( + PyExc_ValueError, + "\"device\" and \"devices\" are mutually exclusive arguments."); + } + std::vector devices_vector; + iree_hal_device_t* device_ptr; + iree_hal_device_t** devices_ptr; + iree_host_size_t device_count; iree_vm_module_t* module = NULL; - CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), /*device_count=*/1, - &device_ptr, IREE_HAL_MODULE_FLAG_NONE, + if (device) { + device_ptr = device.value()->raw_ptr(); + devices_ptr = &device_ptr; + device_count = 1; + } else { + // Set device related arguments in the case of multiple devices. + devices_vector.reserve(devices->size()); + for (auto devicesIt = devices->begin(); devicesIt != devices->end(); + ++devicesIt) { + devices_vector.push_back(py::cast(*devicesIt)->raw_ptr()); + } + devices_ptr = devices_vector.data(); + device_count = devices_vector.size(); + } + CheckApiStatus(iree_hal_module_create(instance->raw_ptr(), device_count, + devices_ptr, IREE_HAL_MODULE_FLAG_NONE, iree_allocator_system(), &module), "Error creating hal module"); return VmModule::StealFromRawPtr(module); @@ -1085,7 +1112,8 @@ void SetupHalBindings(nanobind::module_ m) { py::dict driver_cache; // Built-in module creation. - m.def("create_hal_module", &CreateHalModule); + m.def("create_hal_module", &CreateHalModule, py::arg("instance"), + py::arg("device") = py::none(), py::arg("devices") = py::none()); // Enums. py::enum_(m, "MemoryType") diff --git a/runtime/bindings/python/iree/runtime/_binding.pyi b/runtime/bindings/python/iree/runtime/_binding.pyi index 803e10f39c6c..040b92f81373 100644 --- a/runtime/bindings/python/iree/runtime/_binding.pyi +++ b/runtime/bindings/python/iree/runtime/_binding.pyi @@ -4,7 +4,11 @@ from typing import overload import asyncio -def create_hal_module(instance: VmInstance, device: HalDevice) -> VmModule: ... +def create_hal_module( + instance: VmInstance, + device: Optional[HalDevice] = None, + devices: Optional[List[HalDevice]] = None, +) -> VmModule: ... def create_io_parameters_module( instance: VmInstance, *providers: ParameterProvider ) -> VmModule: ... diff --git a/runtime/bindings/python/tests/vm_test.py b/runtime/bindings/python/tests/vm_test.py index 8aae0b7561c6..bc659262d57a 100644 --- a/runtime/bindings/python/tests/vm_test.py +++ b/runtime/bindings/python/tests/vm_test.py @@ -219,6 +219,15 @@ def test_synchronous_invoke_function_new_abi(self): logging.info("result: %s", result) np.testing.assert_allclose(result, [4.0, 10.0, 18.0, 28.0]) + def test_create_vm_module_with_multiple_devices(self): + """Sanity test that we can create a VM module with 2 devices.""" + devices = [ + iree.runtime.get_device("local-task"), + iree.runtime.get_device("local-sync"), + ] + module = iree.runtime.create_hal_module(self.instance, devices=devices) + assert isinstance(module, iree.runtime.VmModule) + if __name__ == "__main__": logging.basicConfig(level=logging.DEBUG)