diff --git a/detectree2/models/train.py b/detectree2/models/train.py index 55d57f29..9e393f70 100644 --- a/detectree2/models/train.py +++ b/detectree2/models/train.py @@ -536,7 +536,7 @@ def build_test_loader(cls, cfg, dataset_name): """ return build_detection_test_loader(cfg, dataset_name, mapper=FlexibleDatasetMapper(cfg, is_train=False)) -def get_tree_dicts(directory: str, class_mapping = None) -> List[Dict]: +def get_tree_dicts(directory: str, class_mapping: Dict[str, int] = None) -> List[Dict]: """Get the tree dictionaries. Args: @@ -608,7 +608,7 @@ def get_tree_dicts(directory: str, class_mapping = None) -> List[Dict]: def combine_dicts(root_dir: str, val_dir: int, mode: str = "train", - class_mapping = None) -> List[Dict]: + class_mapping: Dict[str, int] = None) -> List[Dict]: """ Combine dictionaries from different directories based on the specified mode. @@ -976,7 +976,7 @@ def modify_conv1_weights(model, num_input_channels): model.backbone.bottom_up.stem.conv1.weight.copy_(new_weights) -def get_latest_model_path(output_dir) -> str: +def get_latest_model_path(output_dir: str) -> str: """ Find the model file with the highest index in the specified output directory.