Skip to content

Commit

Permalink
Fix GPU index parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
joeyballentine committed Aug 29, 2023
1 parent 9d890de commit f123321
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 4 deletions.
4 changes: 3 additions & 1 deletion backend/src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -364,8 +364,10 @@ def get_bool(self, key: str, default: bool) -> bool:
return value
raise ValueError(f"Invalid bool value for {key}: {value}")

def get_int(self, key: str, default: int) -> int:
def get_int(self, key: str, default: int, parse_str: bool = False) -> int:
value = self.__settings.get(key, default)
if parse_str and isinstance(value, str):
return int(value)
if isinstance(value, int) and not isinstance(value, bool):
return value
raise ValueError(f"Invalid str value for {key}: {value}")
Expand Down
2 changes: 1 addition & 1 deletion backend/src/packages/chaiNNer_ncnn/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@ def get_settings() -> NcnnSettings:
settings = package.get_settings()

return NcnnSettings(
gpu_index=settings.get_int("gpu_index", 0),
gpu_index=settings.get_int("gpu_index", 0, True),
)
2 changes: 1 addition & 1 deletion backend/src/packages/chaiNNer_onnx/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def get_settings() -> OnnxSettings:
os.makedirs(tensorrt_cache_path)

return OnnxSettings(
gpu_index=settings.get_int("gpu_index", 0),
gpu_index=settings.get_int("gpu_index", 0, True),
execution_provider=settings.get_str("execution_provider", default_provider),
tensorrt_cache_path=tensorrt_cache_path,
tensorrt_fp16_mode=settings.get_bool("tensorrt_fp16_mode", False),
Expand Down
2 changes: 1 addition & 1 deletion backend/src/packages/chaiNNer_pytorch/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,5 +78,5 @@ def get_settings() -> PyTorchSettings:
return PyTorchSettings(
use_cpu=settings.get_bool("use_cpu", False),
use_fp16=settings.get_bool("use_fp16", False),
gpu_index=settings.get_int("gpu_index", 0),
gpu_index=settings.get_int("gpu_index", 0, True),
)

0 comments on commit f123321

Please sign in to comment.