Skip to content

Commit

Permalink
black reformatted
Browse files Browse the repository at this point in the history
  • Loading branch information
Mohinta2892 authored Nov 2, 2023
1 parent 6f6ff2d commit 0fd3bff
Showing 1 changed file with 11 additions and 9 deletions.
20 changes: 11 additions & 9 deletions gunpowder/torch/nodes/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,14 @@ class Predict(GenericPredict):
"""

def __init__(
self,
model,
inputs: Dict[str, ArrayKey],
outputs: Dict[Union[str, int], ArrayKey],
array_specs: Dict[ArrayKey, ArraySpec] = None,
checkpoint: str = None,
device="cuda",
spawn_subprocess=False,
self,
model,
inputs: Dict[str, ArrayKey],
outputs: Dict[Union[str, int], ArrayKey],
array_specs: Dict[ArrayKey, ArraySpec] = None,
checkpoint: str = None,
device="cuda",
spawn_subprocess=False,
):
self.array_specs = array_specs if array_specs is not None else {}

Expand All @@ -87,7 +87,9 @@ def __init__(

def start(self):
# Issue #188
self.use_cuda = torch.cuda.is_available() and self.device_string.__contains__("cuda")
self.use_cuda = torch.cuda.is_available() and self.device_string.__contains__(
"cuda"
)

logger.info(f"Predicting on {'gpu' if self.use_cuda else 'cpu'}")
self.device = torch.device(self.device_string if self.use_cuda else "cpu")
Expand Down

0 comments on commit 0fd3bff

Please sign in to comment.