diff --git a/detectree2/models/predict.py b/detectree2/models/predict.py index 247b0c3..18bd01a 100644 --- a/detectree2/models/predict.py +++ b/detectree2/models/predict.py @@ -26,72 +26,71 @@ def predict_on_data( eval=False, save: bool = True, num_predictions=0, - mode="rgb", ): """Make predictions on tiled data. - Predicts crowns for all images present in a directory and outputs masks as JSON files. + Predicts crowns for all images (.png or .tif) present in a directory and outputs masks as JSON files. Args: - directory (str): Directory containing images to predict on. - out_folder (str): Folder to save predictions. - predictor: The predictor object (e.g., DefaultPredictor). + directory (str): Directory containing the images. + out_folder (str): Output folder for predictions. + predictor (DefaultPredictor): The predictor object. eval (bool): Whether to use evaluation mode. save (bool): Whether to save the predictions. - num_predictions (int): Number of predictions to make (0 for all). - mode (str): Image mode, 'rgb' or 'ms' (multispectral). + num_predictions (int): Number of predictions to make. + Returns: + None """ pred_dir = os.path.join(directory, out_folder) Path(pred_dir).mkdir(parents=True, exist_ok=True) if eval: dataset_dicts = get_tree_dicts(directory) + if len(dataset_dicts) > 0: + sample_file = dataset_dicts[0]["file_name"] + _, mode = get_filenames(os.path.dirname(sample_file)) + else: + mode = None else: - dataset_dicts = get_filenames(directory, mode=mode) + dataset_dicts, mode = get_filenames(directory) total_files = len(dataset_dicts) + num_to_pred = len( + dataset_dicts) if num_predictions == 0 else num_predictions - # Decide the number of items to predict on - if num_predictions == 0: - num_to_pred = len(dataset_dicts) - else: - num_to_pred = num_predictions - - print(f"Predicting {num_to_pred} files") + print(f"Predicting {num_to_pred} files in mode {mode}") for i, d in enumerate(dataset_dicts[:num_to_pred], start=1): - if mode == "rgb": - img = cv2.imread(d["file_name"]) + file_name = d["file_name"] + file_ext = os.path.splitext(file_name)[1].lower() + if file_ext == ".png": + # RGB image, read with cv2 + img = cv2.imread(file_name) if img is None: - print(f"Failed to read image {d['file_name']} with cv2.") - continue - elif mode == "ms": - try: - with rasterio.open(d["file_name"]) as src: - img = src.read() # shape is (bands, H, W) - img = np.transpose(img, (1, 2, 0)) # shape (H, W, bands) - except Exception as e: - print( - f"Failed to read image {d['file_name']} with rasterio: {e}") + print(f"Failed to read image {file_name} with cv2.") continue + elif file_ext == ".tif": + # Multispectral image, read with rasterio + with rasterio.open(file_name) as src: + img = src.read() + # Transpose to match expected format (H, W, C) + img = np.transpose(img, (1, 2, 0)) else: - print(f"Unknown mode '{mode}'.") + print(f"Unsupported file extension {file_ext} for file {file_name}") continue outputs = predictor(img) # Create the output file name - file_name_path = d["file_name"] - file_name = os.path.basename(os.path.normpath(file_name_path)) - file_root, file_ext = os.path.splitext(file_name) - file_name = file_root + ".json" - output_file = os.path.join(pred_dir, f"Prediction_{file_name}") + file_name_only = os.path.basename(file_name) + file_name_json = os.path.splitext(file_name_only)[0] + ".json" + output_file = os.path.join(pred_dir, f"Prediction_{file_name_json}") if save: - # Convert predictions to JSON and save them + # Save predictions to JSON file evaluations = instances_to_coco_json(outputs["instances"].to("cpu"), - d["file_name"]) + file_name) with open(output_file, "w") as dest: json.dump(evaluations, dest) diff --git a/detectree2/models/train.py b/detectree2/models/train.py index 7553d9c..b41d049 100644 --- a/detectree2/models/train.py +++ b/detectree2/models/train.py @@ -690,38 +690,45 @@ def combine_dicts( return tree_dicts -def get_filenames(directory: str, mode: str = "rgb"): - """Get the file names based on the mode. - - Allows for predictions where no delineations have been manually produced. +def get_filenames(directory: str): + """Get the file names from the directory, handling both RGB (.png) and multispectral (.tif) images. Args: directory (str): Directory of images to be predicted on. - mode (str): Image mode, 'rgb' or 'ms' (multispectral). Returns: - List of dictionaries containing file names. + tuple: A tuple containing: + - dataset_dicts (list): List of dictionaries with 'file_name' keys. + - mode (str): 'rgb' if .png files are used, 'ms' if .tif files are used. """ dataset_dicts = [] - if mode == "rgb": - # For RGB mode, consider common RGB image extensions - image_extensions = ["*.png", "*.jpg", "*.jpeg"] - elif mode == "ms": - # For multispectral mode, consider common multispectral image extensions - image_extensions = ["*.tif", "*.tiff"] + + # Get list of .png and .tif files + png_files = glob.glob(os.path.join(directory, "*.png")) + tif_files = glob.glob(os.path.join(directory, "*.tif")) + + if png_files and tif_files: + # Both .png and .tif files are present, select only .png files + files = png_files + mode = "rgb" + elif png_files: + # Only .png files are present + files = png_files + mode = "rgb" + elif tif_files: + # Only .tif files are present + files = tif_files + mode = "ms" else: - raise ValueError( - f"Unknown mode '{mode}'. Supported modes are 'rgb' and 'ms'.") + # No image files found + files = [] + mode = None - files = [] - for ext in image_extensions: - files.extend(glob.glob(os.path.join(directory, ext))) for filename in files: file = {} - # filename already has the full path file["file_name"] = filename dataset_dicts.append(file) - return dataset_dicts + return dataset_dicts, mode def register_train_data(train_location, @@ -944,62 +951,66 @@ def predictions_on_data( scale=1, geos_exist=True, num_predictions=0, - mode="rgb", ): - """Generate predictions from a test folder and output them to a predictions folder. + """Make predictions on test data and output them to the predictions folder. Args: - directory: Directory containing test data. - predictor: Predictor object. + directory (str): Directory containing test data. + predictor (DefaultPredictor): The predictor object. trees_metadata: Metadata for trees. - save: Boolean to save predictions. - scale: Scale of image. - geos_exist: Boolean to determine if geojson files exist. - num_predictions: Number of predictions to make. - mode: Image mode, 'rgb' or 'ms' (multispectral). + save (bool): Whether to save the predictions. + scale (float): Scale of the image for visualization. + geos_exist (bool): Determines if geojson files exist. + num_predictions (int): Number of predictions to make. + Returns: + None """ - test_location = os.path.join(directory, "test") - pred_dir = os.path.join(test_location, "predictions") + pred_dir = os.path.join(directory, "predictions") Path(pred_dir).mkdir(parents=True, exist_ok=True) + test_location = os.path.join(directory, "test") + if geos_exist: dataset_dicts = get_tree_dicts(test_location) + if len(dataset_dicts) > 0: + sample_file = dataset_dicts[0]["file_name"] + _, mode = get_filenames(os.path.dirname(sample_file)) + else: + mode = None else: - dataset_dicts = get_filenames(test_location, mode=mode) - - total_files = len(dataset_dicts) - - # Decide the number of items to predict on - if num_predictions == 0: - num_to_pred = len(dataset_dicts) - else: - num_to_pred = num_predictions - - print(f"Predicting {num_to_pred} files") - - for i, d in enumerate(dataset_dicts[:num_to_pred], start=1): - if mode == "rgb": - img = cv2.imread(d["file_name"]) + dataset_dicts, mode = get_filenames(test_location) + + # Decide how many items to predict on + num_to_pred = len( + dataset_dicts) if num_predictions == 0 else num_predictions + + for d in random.sample(dataset_dicts, num_to_pred): + file_name = d["file_name"] + file_ext = os.path.splitext(file_name)[1].lower() + if file_ext == ".png": + # RGB image, read with cv2 + img = cv2.imread(file_name) if img is None: - print(f"Failed to read image {d['file_name']} with cv2.") - continue - elif mode == "ms": - try: - with rasterio.open(d["file_name"]) as src: - img = src.read() # shape is (bands, H, W) - img = np.transpose(img, (1, 2, 0)) # shape (H, W, bands) - except Exception as e: - print( - f"Failed to read image {d['file_name']} with rasterio: {e}") + print(f"Failed to read image {file_name} with cv2.") continue + # Convert BGR to RGB for visualization + img_vis = img[:, :, ::-1] + elif file_ext == ".tif": + # Multispectral image, read with rasterio + with rasterio.open(file_name) as src: + img = src.read() + # Transpose to match expected format (H, W, C) + img = np.transpose(img, (1, 2, 0)) + # For visualization, convert to RGB if possible + img_vis = img[:, :, :3] if img.shape[2] >= 3 else img else: - print(f"Unknown mode '{mode}'.") + print(f"Unsupported file extension {file_ext} for file {file_name}") continue outputs = predictor(img) v = Visualizer( - img[:, :, ::-1], + img_vis, metadata=trees_metadata, scale=scale, instance_mode=ColorMode.SEGMENTATION, @@ -1007,22 +1018,17 @@ def predictions_on_data( v = v.draw_instance_predictions(outputs["instances"].to("cpu")) # Create the output file name - file_name_path = d["file_name"] - file_name = os.path.basename(os.path.normpath(file_name_path)) - file_root, file_ext = os.path.splitext(file_name) - file_name = file_root + ".json" - output_file = os.path.join(pred_dir, f"Prediction_{file_name}") + file_name_only = os.path.basename(file_name) + file_name_json = os.path.splitext(file_name_only)[0] + ".json" + output_file = os.path.join(pred_dir, f"Prediction_{file_name_json}") if save: - # Convert predictions to JSON and save them + # Save predictions to JSON file evaluations = instances_to_coco_json(outputs["instances"].to("cpu"), - d["file_name"]) + file_name) with open(output_file, "w") as dest: json.dump(evaluations, dest) - if i % 50 == 0: - print(f"Predicted {i} files of {total_files}") - def modify_conv1_weights(model, num_input_channels): """