Skip to content

Commit

Permalink
Add log image to mlflow (#2416)
Browse files Browse the repository at this point in the history
  • Loading branch information
eracah authored Sep 15, 2023
1 parent 57f29bd commit 1997062
Show file tree
Hide file tree
Showing 2 changed files with 126 additions and 2 deletions.
73 changes: 72 additions & 1 deletion composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,13 @@

import os
import pathlib
import textwrap
import time
from typing import Any, Dict, List, Optional, Union
import warnings
from typing import Any, Dict, List, Optional, Sequence, Union

import numpy as np
import torch

from composer.core.state import State
from composer.loggers.logger import Logger
Expand Down Expand Up @@ -145,6 +150,30 @@ def log_hyperparameters(self, hyperparameters: Dict[str, Any]):
)
self._optimized_mlflow_client.flush(synchronous=False)

def log_images(
self,
images: Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]],
name: str = 'image',
channels_last: bool = False,
step: Optional[int] = None,
masks: Optional[Dict[str, Union[np.ndarray, torch.Tensor, Sequence[Union[np.ndarray, torch.Tensor]]]]] = None,
mask_class_labels: Optional[Dict[int, str]] = None,
use_table: bool = True,
):
unused_args = (masks, mask_class_labels) # Unused (only for wandb)
if any(unused_args):
warnings.warn(
textwrap.dedent(f"""MLFlowLogger does not support masks, class labels, or tables of images,
but got masks={masks}, mask_class_labels={mask_class_labels}"""))
if self._enabled:
if not isinstance(images, Sequence) and images.ndim <= 3:
images = [images]
for im_ind, image in enumerate(images):
image = _convert_to_mlflow_image(image, channels_last)
self._mlflow_client.log_image(image=image,
artifact_file=f'{name}_{step}_{im_ind}.png',
run_id=self._run_id)

def post_close(self):
if self._enabled:
# We use MlflowClient for run termination because MlflowAutologgingQueueingClient's
Expand All @@ -155,3 +184,45 @@ def post_close(self):
def _flush(self):
"""Test-only method to synchronously flush all queued metrics."""
return self._optimized_mlflow_client.flush(synchronous=True)


def _convert_to_mlflow_image(image: Union[np.ndarray, torch.Tensor], channels_last: bool) -> np.ndarray:
if isinstance(image, torch.Tensor):
image = image.data.cpu().numpy()

# Error out for empty arrays or weird arrays of dimension 0.
if np.any(np.equal(image.shape, 0)):
raise ValueError(f'Got an image (shape {image.shape}) with at least one dimension being 0! ')

# Squeeze any singleton dimensions and then add them back in if image dimension
# less than 3.
image = image.squeeze()

# Add in length-one dimensions to get back up to 3
# putting channels last.
if image.ndim == 1:
image = np.expand_dims(image, (1, 2))
channels_last = True
if image.ndim == 2:
image = np.expand_dims(image, 2)
channels_last = True

if image.ndim != 3:
raise ValueError(
textwrap.dedent(f'''Input image must be 3 dimensions, but instead
got {image.ndim} dims at shape: {image.shape}
Your input image was interpreted as a batch of {image.ndim}
-dimensional images because you either specified a
{image.ndim + 1}D image or a list of {image.ndim}D images.
Please specify either a 4D image of a list of 3D images'''))

assert isinstance(image, np.ndarray)
if not channels_last:
image = image.transpose(1, 2, 0)
if image.shape[-1] not in [1, 3, 4]:
raise ValueError(
textwrap.dedent(f'''Input image must have 1, 3, or 4 channels, but instead
got {image.shape[-1]} channels at shape: {image.shape}
Please specify either a 1-, 3-, or 4-channel image or a list of
1-, 3-, or 4-channel images'''))
return image
55 changes: 54 additions & 1 deletion tests/loggers/test_mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,17 @@

import csv
import json
import os
from pathlib import Path
from unittest.mock import MagicMock

import numpy as np
import pytest
import yaml
from torch.utils.data import DataLoader

from composer.loggers import MLFlowLogger
from composer.core import Callback, State
from composer.loggers import Logger, MLFlowLogger
from composer.trainer import Trainer
from tests.common.datasets import RandomImageDataset
from tests.common.markers import device
Expand Down Expand Up @@ -293,3 +296,53 @@ def test_mlflow_logging_works(tmp_path, device):

expected_params_list = ['num_cpus_per_node', 'node_name', 'num_nodes', 'rank_zero_seed']
assert set(expected_params_list) == set(actual_params_list)


@device('cpu')
def test_mlflow_log_image_works(tmp_path, device):

class ImageLogger(Callback):

def before_forward(self, state: State, logger: Logger):
inputs = state.batch_get_item(key=0)
images = inputs.data.cpu().numpy()
logger.log_images(images, step=state.timestamp.batch.value)
with pytest.warns(UserWarning):
logger.log_images(images,
step=state.timestamp.batch.value,
masks={'a': np.ones((2, 2))},
mask_class_labels={1: 'a'})

mlflow_uri = tmp_path / Path('my-test-mlflow-uri')
experiment_name = 'mlflow_logging_test'
test_mlflow_logger = MLFlowLogger(tracking_uri=mlflow_uri, experiment_name=experiment_name)

dataset_size = 64
batch_size = 4
num_batches = 4
eval_interval = '1ba'

expected_num_ims = num_batches * batch_size

trainer = Trainer(model=SimpleConvModel(),
loggers=test_mlflow_logger,
train_dataloader=DataLoader(RandomImageDataset(size=dataset_size), batch_size),
eval_dataloader=DataLoader(RandomImageDataset(size=dataset_size), batch_size),
max_duration=f'{num_batches}ba',
eval_interval=eval_interval,
callbacks=ImageLogger(),
device=device)

trainer.fit()
test_mlflow_logger._flush()

run = _get_latest_mlflow_run(
experiment_name=experiment_name,
tracking_uri=mlflow_uri,
)
run_id = run.info.run_id
experiment_id = run.info.experiment_id

run_file_path = mlflow_uri / Path(experiment_id) / Path(run_id)
im_dir = run_file_path / Path('artifacts')
assert len(os.listdir(im_dir)) == expected_num_ims

0 comments on commit 1997062

Please sign in to comment.