Skip to content

Commit

Permalink
TensorRT EP - timing cache [patch] (#15113)
Browse files Browse the repository at this point in the history
### Description
Patch #14767 in order to
make two provider options `force_timing_cache` and `detailed_build_log`
can be updated. Otherwise, they only use default value.
`timing_cache_enable` is good.
  • Loading branch information
chilo-ms authored Mar 21, 2023
1 parent 0ace27f commit abb2418
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ TensorrtExecutionProviderInfo TensorrtExecutionProviderInfo::FromProviderOptions
.AddAssignmentToReference(tensorrt::provider_option_names::kLayerNormFP32Fallback, info.layer_norm_fp32_fallback)
.AddAssignmentToReference(tensorrt::provider_option_names::kTimingCacheEnable, info.timing_cache_enable)
.AddAssignmentToReference(tensorrt::provider_option_names::kForceTimingCacheMatch, info.force_timing_cache)
.AddAssignmentToReference(tensorrt::provider_option_names::kDetailedBuildLog, info.detailed_build_log)
.Parse(options)); // add new provider option here.

return info;
Expand Down Expand Up @@ -99,6 +100,8 @@ ProviderOptions TensorrtExecutionProviderInfo::ToProviderOptions(const TensorrtE
{tensorrt::provider_option_names::kContextMemorySharingEnable, MakeStringWithClassicLocale(info.context_memory_sharing_enable)},
{tensorrt::provider_option_names::kLayerNormFP32Fallback, MakeStringWithClassicLocale(info.layer_norm_fp32_fallback)},
{tensorrt::provider_option_names::kTimingCacheEnable, MakeStringWithClassicLocale(info.timing_cache_enable)},
{tensorrt::provider_option_names::kForceTimingCacheMatch, MakeStringWithClassicLocale(info.force_timing_cache)},
{tensorrt::provider_option_names::kDetailedBuildLog, MakeStringWithClassicLocale(info.detailed_build_log)},
};
return options;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ struct Tensorrt_Provider : Provider {
info.context_memory_sharing_enable = options.trt_context_memory_sharing_enable != 0;
info.layer_norm_fp32_fallback = options.trt_layer_norm_fp32_fallback != 0;
info.timing_cache_enable = options.trt_timing_cache_enable != 0;
info.force_timing_cache = options.trt_force_timing_cache != 0;
info.detailed_build_log = options.trt_detailed_build_log != 0;
return std::make_shared<TensorrtProviderFactory>(info);
}
Expand Down Expand Up @@ -141,6 +142,7 @@ struct Tensorrt_Provider : Provider {
trt_options.trt_layer_norm_fp32_fallback = internal_options.layer_norm_fp32_fallback;
trt_options.trt_timing_cache_enable = internal_options.timing_cache_enable;
trt_options.trt_force_timing_cache = internal_options.force_timing_cache;
trt_options.trt_detailed_build_log = internal_options.detailed_build_log;
}

ProviderOptions GetProviderOptions(const void* provider_options) override {
Expand Down

0 comments on commit abb2418

Please sign in to comment.