Skip to content

Commit

Permalink
[vulkan] Detect and set device-capabilities for aot::TargetDevice use…
Browse files Browse the repository at this point in the history
…d in offline cache (#5843)

* Detect and set device-capabilities for aot::TargetDevice in offline cache

* Fix w-error

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
PGZXB and pre-commit-ci[bot] authored Aug 24, 2022
1 parent 4cfae7b commit 814fba5
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 4 deletions.
6 changes: 6 additions & 0 deletions taichi/rhi/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,12 @@ class Device {
caps_[capability_id] = val;
}

void clone_caps(Device &dest) const {
for (const auto &[k, v] : caps_) {
dest.set_cap(k, v);
}
}

void print_all_cap() const;

struct AllocParams {
Expand Down
7 changes: 5 additions & 2 deletions taichi/runtime/gfx/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -103,10 +103,13 @@ class AotDataConverter {
} // namespace
AotModuleBuilderImpl::AotModuleBuilderImpl(
const std::vector<CompiledSNodeStructs> &compiled_structs,
Arch device_api_backend)
Arch device_api_backend,
std::unique_ptr<Device> &&target_device)
: compiled_structs_(compiled_structs),
device_api_backend_(device_api_backend) {
aot_target_device_ = std::make_unique<aot::TargetDevice>(device_api_backend_);
aot_target_device_ =
target_device ? std::move(target_device)
: std::make_unique<aot::TargetDevice>(device_api_backend_);
if (!compiled_structs.empty()) {
ti_aot_data_.root_buffer_size = compiled_structs[0].root_size;
}
Expand Down
3 changes: 2 additions & 1 deletion taichi/runtime/gfx/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
public:
explicit AotModuleBuilderImpl(
const std::vector<CompiledSNodeStructs> &compiled_structs,
Arch device_api_backend);
Arch device_api_backend,
std::unique_ptr<Device> &&target_device = nullptr);

void dump(const std::string &output_dir,
const std::string &filename) const override;
Expand Down
7 changes: 6 additions & 1 deletion taichi/runtime/program_impls/vulkan/vulkan_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,12 @@ void VulkanProgramImpl::dump_cache_data_to_disk() {
const std::unique_ptr<AotModuleBuilder>
&VulkanProgramImpl::get_caching_module_builder() {
if (!caching_module_builder_) {
caching_module_builder_ = make_aot_module_builder();
TI_ASSERT(vulkan_runtime_ && embedded_device_);
auto target_device = std::make_unique<aot::TargetDevice>(config->arch);
embedded_device_->device()->clone_caps(*target_device);
caching_module_builder_ = std::make_unique<gfx::AotModuleBuilderImpl>(
snode_tree_mgr_->get_compiled_structs(), Arch::vulkan,
std::move(target_device));
}
return caching_module_builder_;
}
Expand Down

0 comments on commit 814fba5

Please sign in to comment.