diff --git a/taichi/backends/vulkan/vulkan_api.cpp b/taichi/backends/vulkan/vulkan_api.cpp index 9e837edcf9213..895c5eb1b4bb9 100644 --- a/taichi/backends/vulkan/vulkan_api.cpp +++ b/taichi/backends/vulkan/vulkan_api.cpp @@ -273,9 +273,6 @@ void EmbeddedVulkanDevice::create_instance() { } auto extensions = get_required_extensions(); - for (auto ext : params_.additional_instance_extensions) { - extensions.push_back(ext); - } #ifdef TI_VULKAN_DEBUG glfwInit(); @@ -302,6 +299,11 @@ void EmbeddedVulkanDevice::create_instance() { extensions.push_back(ext.extensionName); } } + if (std::find(params_.additional_instance_extensions.begin(), + params_.additional_instance_extensions.end(), + name) != params_.additional_instance_extensions.end()) { + extensions.push_back(ext.extensionName); + } } create_info.enabledExtensionCount = (uint32_t)extensions.size(); @@ -401,10 +403,6 @@ void EmbeddedVulkanDevice::create_logical_device() { // Detect extensions std::vector enabled_extensions; - for (auto ext : params_.additional_device_extensions) { - enabled_extensions.push_back(ext); - } - uint32_t extension_count = 0; vkEnumerateDeviceExtensionProperties(physical_device_, nullptr, &extension_count, nullptr); @@ -458,6 +456,10 @@ void EmbeddedVulkanDevice::create_logical_device() { capability_.has_float16 = true; capability_.has_int8 = true; enabled_extensions.push_back(ext.extensionName); + } else if (std::find(params_.additional_device_extensions.begin(), + params_.additional_device_extensions.end(), + name) != params_.additional_device_extensions.end()) { + enabled_extensions.push_back(ext.extensionName); } } diff --git a/taichi/backends/vulkan/vulkan_api.h b/taichi/backends/vulkan/vulkan_api.h index 5a8ed62ab7e24..487d58dc499b9 100644 --- a/taichi/backends/vulkan/vulkan_api.h +++ b/taichi/backends/vulkan/vulkan_api.h @@ -146,8 +146,8 @@ class EmbeddedVulkanDevice { struct Params { std::optional api_version; bool is_for_ui{false}; - std::vector additional_instance_extensions; - std::vector additional_device_extensions; + std::vector additional_instance_extensions; + std::vector additional_device_extensions; // the VkSurfaceKHR needs to be created after creating the VkInstance, but // before creating the VkPhysicalDevice thus, we allow the user to pass in a // custom surface creator