-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'ludwig-ai:master' into improve-ludwig-feature-dict
- Loading branch information
Showing
29 changed files
with
1,541 additions
and
92 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
import logging | ||
import os | ||
import shutil | ||
|
||
import pandas as pd | ||
import torch | ||
import yaml | ||
from torchvision.utils import save_image | ||
|
||
from ludwig.api import LudwigModel | ||
from ludwig.datasets import camseq | ||
|
||
# clean out prior results | ||
shutil.rmtree("./results", ignore_errors=True) | ||
|
||
# set up Python dictionary to hold model training parameters | ||
with open("./config_camseq.yaml") as f: | ||
config = yaml.safe_load(f.read()) | ||
|
||
# Define Ludwig model object that drive model training | ||
model = LudwigModel(config, logging_level=logging.INFO) | ||
|
||
# load Camseq dataset | ||
df = camseq.load(split=False) | ||
|
||
pred_set = df[0:1] # prediction hold-out 1 image | ||
data_set = df[1:] # train,test,validate on remaining images | ||
|
||
# initiate model training | ||
(train_stats, _, output_directory) = model.train( # training statistics # location for training results saved to disk | ||
dataset=data_set, | ||
experiment_name="simple_image_experiment", | ||
model_name="single_model", | ||
skip_save_processed_input=True, | ||
) | ||
|
||
# print("{}".format(model.model)) | ||
|
||
# predict | ||
pred_set.reset_index(inplace=True) | ||
pred_out_df, results = model.predict(pred_set) | ||
|
||
if not isinstance(pred_out_df, pd.DataFrame): | ||
pred_out_df = pred_out_df.compute() | ||
pred_out_df["image_path"] = pred_set["image_path"] | ||
pred_out_df["mask_path"] = pred_set["mask_path"] | ||
|
||
for index, row in pred_out_df.iterrows(): | ||
pred_mask = torch.from_numpy(row["mask_path_predictions"]) | ||
pred_mask_path = os.path.dirname(os.path.realpath(__file__)) + "/predicted_" + os.path.basename(row["mask_path"]) | ||
print(f"\nSaving predicted mask to {pred_mask_path}") | ||
if torch.any(pred_mask.gt(1)): | ||
pred_mask = pred_mask.float() / 255 | ||
save_image(pred_mask, pred_mask_path) | ||
print("Input image_path: {}".format(row["image_path"])) | ||
print("Label mask_path: {}".format(row["mask_path"])) | ||
print(f"Predicted mask_path: {pred_mask_path}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
input_features: | ||
- name: image_path | ||
type: image | ||
preprocessing: | ||
num_processes: 6 | ||
infer_image_max_height: 1024 | ||
infer_image_max_width: 1024 | ||
encoder: unet | ||
|
||
output_features: | ||
- name: mask_path | ||
type: image | ||
preprocessing: | ||
num_processes: 6 | ||
infer_image_max_height: 1024 | ||
infer_image_max_width: 1024 | ||
infer_image_num_classes: true | ||
num_classes: 32 | ||
decoder: | ||
type: unet | ||
num_fc_layers: 0 | ||
loss: | ||
type: softmax_cross_entropy | ||
|
||
combiner: | ||
type: concat | ||
num_fc_layers: 0 | ||
|
||
trainer: | ||
epochs: 100 | ||
early_stop: -1 | ||
batch_size: 1 | ||
max_batch_size: 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
version: 1.0 | ||
name: camseq | ||
kaggle_dataset_id: carlolepelaars/camseq-semantic-segmentation | ||
archive_filenames: camseq-semantic-segmentation.zip | ||
sha256: | ||
camseq-semantic-segmentation.zip: ea3aeba2661d9b3e3ea406668e7d9240cb2ba0c7e374914bb6d866147faff502 | ||
loader: camseq.CamseqLoader | ||
preserve_paths: | ||
- images | ||
- masks | ||
description: | | ||
CamSeq01 Cambridge Labeled Objects in Video | ||
https://www.kaggle.com/datasets/carlolepelaars/camseq-semantic-segmentation | ||
columns: | ||
- name: image_path | ||
type: image | ||
- name: mask_path | ||
type: image | ||
output_features: | ||
- name: mask_path | ||
type: image |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
# Copyright (c) 2023 Aizen Corp. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
import os | ||
from typing import List | ||
|
||
import pandas as pd | ||
|
||
from ludwig.datasets.loaders.dataset_loader import DatasetLoader | ||
from ludwig.utils.fs_utils import makedirs | ||
|
||
|
||
class CamseqLoader(DatasetLoader): | ||
def transform_files(self, file_paths: List[str]) -> List[str]: | ||
if not os.path.exists(self.processed_dataset_dir): | ||
os.makedirs(self.processed_dataset_dir) | ||
|
||
# move images and masks into separate directories | ||
source_dir = self.raw_dataset_dir | ||
images_dir = os.path.join(source_dir, "images") | ||
masks_dir = os.path.join(source_dir, "masks") | ||
makedirs(images_dir, exist_ok=True) | ||
makedirs(masks_dir, exist_ok=True) | ||
|
||
data_files = [] | ||
for f in os.listdir(source_dir): | ||
if f.endswith("_L.png"): # masks | ||
dest_file = os.path.join(masks_dir, f) | ||
elif f.endswith(".png"): # images | ||
dest_file = os.path.join(images_dir, f) | ||
else: | ||
continue | ||
source_file = os.path.join(source_dir, f) | ||
os.replace(source_file, dest_file) | ||
data_files.append(dest_file) | ||
|
||
return super().transform_files(data_files) | ||
|
||
def load_unprocessed_dataframe(self, file_paths: List[str]) -> pd.DataFrame: | ||
"""Creates a dataframe of image paths and mask paths.""" | ||
images_dir = os.path.join(self.processed_dataset_dir, "images") | ||
masks_dir = os.path.join(self.processed_dataset_dir, "masks") | ||
images = [] | ||
masks = [] | ||
for f in os.listdir(images_dir): | ||
images.append(os.path.join(images_dir, f)) | ||
mask_f = f[:-4] + "_L.png" | ||
masks.append(os.path.join(masks_dir, mask_f)) | ||
|
||
return pd.DataFrame({"image_path": images, "mask_path": masks}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,6 @@ | ||
# register all decoders | ||
import ludwig.decoders.generic_decoders # noqa | ||
import ludwig.decoders.image_decoders # noqa | ||
import ludwig.decoders.llm_decoders # noqa | ||
import ludwig.decoders.sequence_decoders # noqa | ||
import ludwig.decoders.sequence_tagger # noqa |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
#! /usr/bin/env python | ||
# Copyright (c) 2023 Aizen Corp. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
import logging | ||
from typing import Dict, Optional, Type | ||
|
||
import torch | ||
|
||
from ludwig.api_annotations import DeveloperAPI | ||
from ludwig.constants import ENCODER_OUTPUT_STATE, HIDDEN, IMAGE, LOGITS, PREDICTIONS | ||
from ludwig.decoders.base import Decoder | ||
from ludwig.decoders.registry import register_decoder | ||
from ludwig.modules.convolutional_modules import UNetUpStack | ||
from ludwig.schema.decoders.image_decoders import ImageDecoderConfig, UNetDecoderConfig | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@DeveloperAPI | ||
@register_decoder("unet", IMAGE) | ||
class UNetDecoder(Decoder): | ||
def __init__( | ||
self, | ||
input_size: int, | ||
height: int, | ||
width: int, | ||
num_channels: int = 1, | ||
num_classes: int = 2, | ||
conv_norm: Optional[str] = None, | ||
decoder_config=None, | ||
**kwargs, | ||
): | ||
super().__init__() | ||
self.config = decoder_config | ||
self.num_classes = num_classes | ||
|
||
logger.debug(f" {self.name}") | ||
if num_classes < 2: | ||
raise ValueError(f"Invalid `num_classes` {num_classes} for unet decoder") | ||
if height % 16 or width % 16: | ||
raise ValueError(f"Invalid `height` {height} or `width` {width} for unet decoder") | ||
|
||
self.unet = UNetUpStack( | ||
img_height=height, | ||
img_width=width, | ||
out_channels=num_classes, | ||
norm=conv_norm, | ||
) | ||
|
||
self.input_reshape = list(self.unet.input_shape) | ||
self.input_reshape.insert(0, -1) | ||
self._output_shape = (height, width) | ||
|
||
def forward(self, combiner_outputs: Dict[str, torch.Tensor], target: torch.Tensor): | ||
hidden = combiner_outputs[HIDDEN] | ||
skips = combiner_outputs[ENCODER_OUTPUT_STATE] | ||
|
||
# unflatten combiner outputs | ||
hidden = hidden.reshape(self.input_reshape) | ||
|
||
logits = self.unet(hidden, skips) | ||
predictions = logits.argmax(dim=1).squeeze(1).byte() | ||
|
||
return {LOGITS: logits, PREDICTIONS: predictions} | ||
|
||
def get_prediction_set(self): | ||
return {LOGITS, PREDICTIONS} | ||
|
||
@staticmethod | ||
def get_schema_cls() -> Type[ImageDecoderConfig]: | ||
return UNetDecoderConfig | ||
|
||
@property | ||
def output_shape(self) -> torch.Size: | ||
return torch.Size(self._output_shape) | ||
|
||
@property | ||
def input_shape(self) -> torch.Size: | ||
return self.unet.input_shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.