Skip to content

Commit

Permalink
Add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
ashnair1 committed Jun 16, 2022
1 parent 2e280a2 commit 647ef0e
Showing 1 changed file with 20 additions and 2 deletions.
22 changes: 20 additions & 2 deletions predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,13 @@


def write_mask(mask: torch.Tensor, output_dir: str, input_filename: str) -> None:
"""Write mask to specified output directory."""
"""Write mask to specified output directory with same filename as input raster.
Args:
mask (torch.Tensor): mask tensor
output_dir (str): output directory
input_filename (str): path to input raster
"""
output_path = os.path.join(output_dir, os.path.basename(input_filename))
with rio.open(input_filename) as src:
profile = src.profile
Expand All @@ -73,7 +79,19 @@ def write_mask(mask: torch.Tensor, output_dir: str, input_filename: str) -> None


def main(config_dir: str, predict_on: str, output_dir: str, device: str) -> None:
"""Main inference loop."""
"""Main inference loop.
Args:
config_dir (str): Path to config-dir to load config and ckpt
predict_on (str): Directory/Dataset to run inference on
output_dir (str): Path to output_directory to save predicted masks
device (str): Choice of device. Must be in [cuda, cpu]
Raises:
ValueError: Raised if task name is not in TASK_TO_MODULES_MAPPING
FileExistsError: Raised if specified output directory contains
files and overwrite=False.
"""
os.makedirs(output_dir, exist_ok=True)

# Load checkpoint and config
Expand Down

0 comments on commit 647ef0e

Please sign in to comment.