Skip to content

Commit

Permalink
prediction functions with ms
Browse files Browse the repository at this point in the history
  • Loading branch information
PatBall1 committed Dec 4, 2024
1 parent 7dae05f commit 06e1ecb
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 102 deletions.
69 changes: 34 additions & 35 deletions detectree2/models/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,72 +26,71 @@ def predict_on_data(
eval=False,
save: bool = True,
num_predictions=0,
mode="rgb",
):
"""Make predictions on tiled data.
Predicts crowns for all images present in a directory and outputs masks as JSON files.
Predicts crowns for all images (.png or .tif) present in a directory and outputs masks as JSON files.
Args:
directory (str): Directory containing images to predict on.
out_folder (str): Folder to save predictions.
predictor: The predictor object (e.g., DefaultPredictor).
directory (str): Directory containing the images.
out_folder (str): Output folder for predictions.
predictor (DefaultPredictor): The predictor object.
eval (bool): Whether to use evaluation mode.
save (bool): Whether to save the predictions.
num_predictions (int): Number of predictions to make (0 for all).
mode (str): Image mode, 'rgb' or 'ms' (multispectral).
num_predictions (int): Number of predictions to make.
Returns:
None
"""
pred_dir = os.path.join(directory, out_folder)
Path(pred_dir).mkdir(parents=True, exist_ok=True)

if eval:
dataset_dicts = get_tree_dicts(directory)
if len(dataset_dicts) > 0:
sample_file = dataset_dicts[0]["file_name"]
_, mode = get_filenames(os.path.dirname(sample_file))
else:
mode = None
else:
dataset_dicts = get_filenames(directory, mode=mode)
dataset_dicts, mode = get_filenames(directory)

total_files = len(dataset_dicts)
num_to_pred = len(
dataset_dicts) if num_predictions == 0 else num_predictions

# Decide the number of items to predict on
if num_predictions == 0:
num_to_pred = len(dataset_dicts)
else:
num_to_pred = num_predictions

print(f"Predicting {num_to_pred} files")
print(f"Predicting {num_to_pred} files in mode {mode}")

for i, d in enumerate(dataset_dicts[:num_to_pred], start=1):
if mode == "rgb":
img = cv2.imread(d["file_name"])
file_name = d["file_name"]
file_ext = os.path.splitext(file_name)[1].lower()
if file_ext == ".png":
# RGB image, read with cv2
img = cv2.imread(file_name)
if img is None:
print(f"Failed to read image {d['file_name']} with cv2.")
continue
elif mode == "ms":
try:
with rasterio.open(d["file_name"]) as src:
img = src.read() # shape is (bands, H, W)
img = np.transpose(img, (1, 2, 0)) # shape (H, W, bands)
except Exception as e:
print(
f"Failed to read image {d['file_name']} with rasterio: {e}")
print(f"Failed to read image {file_name} with cv2.")
continue
elif file_ext == ".tif":
# Multispectral image, read with rasterio
with rasterio.open(file_name) as src:
img = src.read()
# Transpose to match expected format (H, W, C)
img = np.transpose(img, (1, 2, 0))
else:
print(f"Unknown mode '{mode}'.")
print(f"Unsupported file extension {file_ext} for file {file_name}")
continue

outputs = predictor(img)

# Create the output file name
file_name_path = d["file_name"]
file_name = os.path.basename(os.path.normpath(file_name_path))
file_root, file_ext = os.path.splitext(file_name)
file_name = file_root + ".json"
output_file = os.path.join(pred_dir, f"Prediction_{file_name}")
file_name_only = os.path.basename(file_name)
file_name_json = os.path.splitext(file_name_only)[0] + ".json"
output_file = os.path.join(pred_dir, f"Prediction_{file_name_json}")

if save:
# Convert predictions to JSON and save them
# Save predictions to JSON file
evaluations = instances_to_coco_json(outputs["instances"].to("cpu"),
d["file_name"])
file_name)
with open(output_file, "w") as dest:
json.dump(evaluations, dest)

Expand Down
140 changes: 73 additions & 67 deletions detectree2/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,38 +690,45 @@ def combine_dicts(
return tree_dicts


def get_filenames(directory: str, mode: str = "rgb"):
"""Get the file names based on the mode.
Allows for predictions where no delineations have been manually produced.
def get_filenames(directory: str):
"""Get the file names from the directory, handling both RGB (.png) and multispectral (.tif) images.
Args:
directory (str): Directory of images to be predicted on.
mode (str): Image mode, 'rgb' or 'ms' (multispectral).
Returns:
List of dictionaries containing file names.
tuple: A tuple containing:
- dataset_dicts (list): List of dictionaries with 'file_name' keys.
- mode (str): 'rgb' if .png files are used, 'ms' if .tif files are used.
"""
dataset_dicts = []
if mode == "rgb":
# For RGB mode, consider common RGB image extensions
image_extensions = ["*.png", "*.jpg", "*.jpeg"]
elif mode == "ms":
# For multispectral mode, consider common multispectral image extensions
image_extensions = ["*.tif", "*.tiff"]

# Get list of .png and .tif files
png_files = glob.glob(os.path.join(directory, "*.png"))
tif_files = glob.glob(os.path.join(directory, "*.tif"))

if png_files and tif_files:
# Both .png and .tif files are present, select only .png files
files = png_files
mode = "rgb"
elif png_files:
# Only .png files are present
files = png_files
mode = "rgb"
elif tif_files:
# Only .tif files are present
files = tif_files
mode = "ms"
else:
raise ValueError(
f"Unknown mode '{mode}'. Supported modes are 'rgb' and 'ms'.")
# No image files found
files = []
mode = None

files = []
for ext in image_extensions:
files.extend(glob.glob(os.path.join(directory, ext)))
for filename in files:
file = {}
# filename already has the full path
file["file_name"] = filename
dataset_dicts.append(file)
return dataset_dicts
return dataset_dicts, mode


def register_train_data(train_location,
Expand Down Expand Up @@ -944,85 +951,84 @@ def predictions_on_data(
scale=1,
geos_exist=True,
num_predictions=0,
mode="rgb",
):
"""Generate predictions from a test folder and output them to a predictions folder.
"""Make predictions on test data and output them to the predictions folder.
Args:
directory: Directory containing test data.
predictor: Predictor object.
directory (str): Directory containing test data.
predictor (DefaultPredictor): The predictor object.
trees_metadata: Metadata for trees.
save: Boolean to save predictions.
scale: Scale of image.
geos_exist: Boolean to determine if geojson files exist.
num_predictions: Number of predictions to make.
mode: Image mode, 'rgb' or 'ms' (multispectral).
save (bool): Whether to save the predictions.
scale (float): Scale of the image for visualization.
geos_exist (bool): Determines if geojson files exist.
num_predictions (int): Number of predictions to make.
Returns:
None
"""
test_location = os.path.join(directory, "test")
pred_dir = os.path.join(test_location, "predictions")
pred_dir = os.path.join(directory, "predictions")
Path(pred_dir).mkdir(parents=True, exist_ok=True)

test_location = os.path.join(directory, "test")

if geos_exist:
dataset_dicts = get_tree_dicts(test_location)
if len(dataset_dicts) > 0:
sample_file = dataset_dicts[0]["file_name"]
_, mode = get_filenames(os.path.dirname(sample_file))
else:
mode = None
else:
dataset_dicts = get_filenames(test_location, mode=mode)

total_files = len(dataset_dicts)

# Decide the number of items to predict on
if num_predictions == 0:
num_to_pred = len(dataset_dicts)
else:
num_to_pred = num_predictions

print(f"Predicting {num_to_pred} files")

for i, d in enumerate(dataset_dicts[:num_to_pred], start=1):
if mode == "rgb":
img = cv2.imread(d["file_name"])
dataset_dicts, mode = get_filenames(test_location)

# Decide how many items to predict on
num_to_pred = len(
dataset_dicts) if num_predictions == 0 else num_predictions

for d in random.sample(dataset_dicts, num_to_pred):
file_name = d["file_name"]
file_ext = os.path.splitext(file_name)[1].lower()
if file_ext == ".png":
# RGB image, read with cv2
img = cv2.imread(file_name)
if img is None:
print(f"Failed to read image {d['file_name']} with cv2.")
continue
elif mode == "ms":
try:
with rasterio.open(d["file_name"]) as src:
img = src.read() # shape is (bands, H, W)
img = np.transpose(img, (1, 2, 0)) # shape (H, W, bands)
except Exception as e:
print(
f"Failed to read image {d['file_name']} with rasterio: {e}")
print(f"Failed to read image {file_name} with cv2.")
continue
# Convert BGR to RGB for visualization
img_vis = img[:, :, ::-1]
elif file_ext == ".tif":
# Multispectral image, read with rasterio
with rasterio.open(file_name) as src:
img = src.read()
# Transpose to match expected format (H, W, C)
img = np.transpose(img, (1, 2, 0))
# For visualization, convert to RGB if possible
img_vis = img[:, :, :3] if img.shape[2] >= 3 else img
else:
print(f"Unknown mode '{mode}'.")
print(f"Unsupported file extension {file_ext} for file {file_name}")
continue

outputs = predictor(img)
v = Visualizer(
img[:, :, ::-1],
img_vis,
metadata=trees_metadata,
scale=scale,
instance_mode=ColorMode.SEGMENTATION,
)
v = v.draw_instance_predictions(outputs["instances"].to("cpu"))

# Create the output file name
file_name_path = d["file_name"]
file_name = os.path.basename(os.path.normpath(file_name_path))
file_root, file_ext = os.path.splitext(file_name)
file_name = file_root + ".json"
output_file = os.path.join(pred_dir, f"Prediction_{file_name}")
file_name_only = os.path.basename(file_name)
file_name_json = os.path.splitext(file_name_only)[0] + ".json"
output_file = os.path.join(pred_dir, f"Prediction_{file_name_json}")

if save:
# Convert predictions to JSON and save them
# Save predictions to JSON file
evaluations = instances_to_coco_json(outputs["instances"].to("cpu"),
d["file_name"])
file_name)
with open(output_file, "w") as dest:
json.dump(evaluations, dest)

if i % 50 == 0:
print(f"Predicted {i} files of {total_files}")


def modify_conv1_weights(model, num_input_channels):
"""
Expand Down

0 comments on commit 06e1ecb

Please sign in to comment.