TPU Support, Remote Data Loading, Video Classification from tensors: Feature Rich Release
We are elated to announce the release of Lightning Flash v0.8, a feature-rich release with improved testing to ensure better user experience for all our lovely users! The team at Lightning AI and our community contributors have been working hard for this release, and nothing makes us happier to share all their lovely contributions with you.
We discuss major features and changes below. For a curated list, scroll to the bottom to see all the pull requests included for this release.
TPU Support 🦸🏻
Before this release, Lightning Flash worked well on a single-core TPU (training, validation, and prediction), but failed comprehensively on multiple cores. This release has enabled training and validation support for multi-core TPUs, allowing users to try out their models on TPUs using Lightning Flash. Prediction of multi-core TPUs is an ongoing effort, and we hope to bring it to you in the near future.
Before v0.8 | After v0.8 | |
---|---|---|
Single core | Training, Validation, Prediction | Training, Validation, Prediction |
Multiple cores | Not supported | Training, Validation |
As we move ahead, and we see more users trying the TPUs with Lightning Flash, we expect that there might be unseen errors or issues, and we will be looking forward to addressing them as we get a chance. So please don't hesitate to let us know your experience!
Remote Data Loading: fsspec
arrives into Lightning Flash ☁️
Before this release, users had to download a dataset or a file from the URL and pass it to our data loader classes. This was a pain point that we are happy to let go of in this release. Starting v0.8, you'll not have to download any of those files locally, and you can just pass the file URL - and expect it to work!
Before v0.8 | After v0.8 | |
---|---|---|
Example |
Download titanic.csv from the URL and pass the path to the train_file argument:
from flash.tabular import TabularClassificationData
datamodule = TabularClassificationData.from_csv(
categorical_fields=["Age", "Cabin"],
numerical_fields="Fare",
target_fields="Survived",
train_file="titanic.csv",
val_split=0.1,
batch_size=8,
) |
Just pass the URL to train_file argument: from flash.tabular import TabularClassificationData
datamodule = TabularClassificationData.from_csv(
categorical_fields=["Age", "Cabin"],
numerical_fields="Fare",
target_fields="Survived",
train_file="https://pl-flash-data.s3.amazonaws.com/titanic.csv",
val_split=0.1,
batch_size=8,
) |
For more details, feel free to check out the documentation here.
Video Classification from Tensors 📹
At times, it's required to load raw data, or pre-process videos before progressing to loading data and training the model. These raw data for Video Classification, are mostly available as tensors, and before this release - one had to save them again in video files, and pass the paths to the data loading classes in Flash. Starting this release, we now support loading data from tensors for Video Classification.
import torch
from flash.video import VideoClassifier, VideoClassificationData
import flash
# 5 number of frames, 3 channels, height = 10 and width = 10
mock_tensors = torch.randint(size=(3, 5, 10, 10), low=0, high=255)
datamodule = VideoClassificationData.from_tensors(
train_data=[mock_tensors, mock_tensors], # can also stack: torch.stack((mock_tensors, mock_tensors))
train_targets=["patient", "doctor"],
predict_data=[mock_tensors],
batch_size=1,
)
model = VideoClassifier(num_classes=datamodule.num_classes, pretrained=False, backbone="slow_r50", labels=datamodule.labels)
trainer = flash.Trainer(max_epochs=1)
trainer.finetune(model, datamodule=datamodule)
This will also come in handy for those having multi-modal pipelines who don't want to save the output of a model to files and instead pass the raw data to the next model, saving you quite a lot of time wasted in the conversion process.
Refactored Transforms in Lightning Flash ⚙️
One of the community-driven contributions that we are proud to share. Before this release, a user had to pass an input transform class for each stage, which was cumbersome. With this release, you can just pass transform=<YourTransformClass>
to the required method. This is a breaking change, and if you are not sure how to resolve this, please create an issue and we'll be happy to help!
Before v0.8 | After v0.8 | |
---|---|---|
Example |
dm = XYZTask_DataModule.from_xyz(
train_file=train_file,
val_file=val_file,
test_file=test_file,
predict_file=predict_file,
train_transform=InputTransform,
val_transform=InputTransform,
test_transform=InputTransform,
predict_transform=InputTransform,
transform_kwargs=transform_kwargs,
) |
dm = XYZTask_DataModule.from_xyz(
train_file=train_file,
val_file=val_file,
test_file=test_file,
predict_file=predict_file,
transform=InputTransform(**transform_kwargs),
) |
Note that, within your InputTransform
class, you can have <stage>_per_batch_transform_on_device
methods to support various stages.
class SampleInputTransform(InputTransform):
def per_sample_transform(self):
def fn(x):
return x
return fn
def train_per_batch_transform_on_device(self) -> Callable:
return ...
def val_per_batch_transform_on_device(self) -> Callable:
return ...
def test_per_batch_transform_on_device(self) -> Callable:
return ...
def predict_per_batch_transform_on_device(self) -> Callable:
return ...
Object Detection in Flash is now servable 💁
If you aren't aware yet, Lightning Flash supports serving models. Starting this release, Object Detection is added to the beautiful category of tasks that can be served using Lightning Flash. Below is an example of how the inference server code for object detection will look like:
# Inference Server
from flash.image import ObjectDetector
model = ObjectDetector.load_from_checkpoint("https://flash-weights.s3.amazonaws.com/0.8.0/object_detection_model.pt")
model.serve()
For more details, check out the documentation here.
Added
- Added support for
from_tensors
forVideoClassification
(#1389) - Added fine tuning strategies for DeepSpeed (with parameter loading and storing omitted) (#1377)
- Added
torchvision
as a requirement todatatype_audio.txt
as it's used for Audio Classification (#1425) - Added
figsize
andlimit_nb_samples
for showing batch images (#1381) - Added support for
from_lists
for Tabular Classification and Regression (#1337) - Added support for
from_dicts
for Tabular Classification and Regression (#1331) - Added support for using the
ImageEmbedder
SSL training for all image classifier backbones (#1264) - Added support for audio file formats to
AudioClassificationData
(#1085) - Added support for Flash serve to the
ObjectDetector
(#1370) - Added support for loading
ImageClassificationData
from PIL images withfrom_images
(#1372) - Added support for loading
ObjectDetectionData
withfrom_numpy
,from_images
, andfrom_tensors
(#1372) - Added support for remote data loading with fsspec (#1387)
- Added support for TSV files to
from_csv
methods (#1387) - Added support for more formats when loading audio files (#1387)
- Added support to use any task as an embedder by calling
as_embedder
(#1396) - Added support for normalization of images in
SemanticSegmentationData
(#1399)
Changed
- Changed the
ImageEmbedder
dependency on VISSL to optional (#1276) - Changed the transforms in
SemanticSegmentationData
to use albumentations instead of Kornia (#1313)
Removed
- Removed support for audio files with
sd2
extension, because SoundFile (for sd2 extension) doesn't accept fsspec objects. (#1409)
Fixed
- Fixed when suitable error not being raised for image segmentation (kornia) (#1425).
- Fixed the script of integrating
lightning-flash
withlearn2learn
(#1376) - Fixed JIT tracing tests where the model class was not attached to the
Trainer
class (#1410) - Fixed examples for BaaL integration by removing usage of
on_<stage>_dataloader
hooks (removed in PL 1.7.0) (#1410) - Fixed examples for BaaL integration for the case when
probabilities
list is empty (#1410) - Fixed a bug where collate functions were not being attached successfully after the
DataLoader
is initialized (in PL 1.7.0 changing attributes after initialization doesn't do anything) (#1410) - Fixed a bug where grayscale images were not properly converted to RGB when loaded. (#1394)
- Fixed a bug where size of mask for instance segmentation doesn't match to size of original image. (#1353)
- Fixed image classification data
show_train_batch
for subplots with rows > 1. (#1339) - Fixed support for all the versions (including the latest and older) of
baal
. (#1315) - Fixed a bug where a loaded
TabularClassifier
orTabularRegressor
checkpoint could not be served (#1324) - Fixed a bug where the
freeze_unfreeze
andunfreeze_milestones
finetuning strategies could not be used in tandem with aonecyclelr
LR scheduler (#1329) - Fixed a bug where the backbone learning rate would be divided by 10 when unfrozen if using the
freeze_unfreeze
orunfreeze_milestones
strategies (#1329) - Fixed naming of optimizer and scheduler registries which did not allow manual optimization. (#1342)
- Fixed a bug where the
processor_backbone
argument toSpeechRecognition
was not used for decoding outputs (#1362) - Fixed a bug where
.npy
files could not be used withSemanticSegmentationData
(#1369)
Contributors
@akihironitta @aniketmaurya @Borda @carmocca @ethanwharris @JustinGoheen @krshrimali @ligaz @Nico995 @uakarsh
If we forgot someone let us know 😃