Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add patch_tensorboard #204

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion dvclive/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,35 @@
from dvclive.error import DataAlreadyLoggedError


def _is_np(val):
return val.__class__.__module__ == "numpy"


def _is_tf(val):
return val.__class__.__module__.split(".")[0] == "tensorflow"


class Data(abc.ABC):
def __init__(self, name: str, output_folder: str) -> None:
self.name = name
self.output_folder: Path = Path(output_folder) / self.subfolder
self._step: Optional[int] = None
self.val = None
self._val = None
self._step_none_logged: bool = False
self._dump_kwargs = None

@property
def step(self) -> int:
return self._step

@property
def val(self):
return self._val

@val.setter
def val(self, x):
self._val = x

@step.setter
def step(self, val: int) -> None:
if self._step_none_logged and val == self._step:
Expand Down
30 changes: 20 additions & 10 deletions dvclive/data/image.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from pathlib import Path

from .base import Data
from .base import Data, _is_np, _is_tf


class Image(Data):
Expand All @@ -24,10 +24,27 @@ def output_path(self) -> Path:
def could_log(val: object) -> bool:
if val.__class__.__module__ == "PIL.Image":
return True
if val.__class__.__module__ == "numpy":
if _is_np(val):
return True
if _is_tf(val) and val.ndim in [2, 3]:
return True
return False

@property
def val(self):
return self._val

@val.setter
def val(self, x):
if _is_np(x) or _is_tf(x):
from PIL import Image as ImagePIL

if _is_tf(x):
x = x.numpy().squeeze()
self._val = ImagePIL.fromarray(x)
else:
self._val = x

def first_step_dump(self) -> None:
if self.no_step_output_path.exists():
self.no_step_output_path.rename(self.output_path)
Expand All @@ -36,14 +53,7 @@ def no_step_dump(self) -> None:
self.step_dump()

def step_dump(self) -> None:
if self.val.__class__.__module__ == "numpy":
from PIL import Image as ImagePIL

_val = ImagePIL.fromarray(self.val)
else:
_val = self.val

_val.save(self.output_path)
self.val.save(self.output_path)

@property
def summary(self):
Expand Down
20 changes: 19 additions & 1 deletion dvclive/data/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from dvclive.utils import nested_set

from .base import Data
from .base import Data, _is_tf


class Scalar(Data):
Expand All @@ -17,8 +17,26 @@ class Scalar(Data):
def could_log(val: object) -> bool:
if isinstance(val, (int, float)):
return True
if _is_tf(val):
return True
return False

@property
def val(self):
return self._val

@val.setter
def val(self, x):
if _is_tf(x):
import numpy as np

x = x.numpy().squeeze()
if isinstance(x, np.integer):
x = int(x)
else:
x = float(x)
self._val = x

@property
def output_path(self) -> Path:
_path = self.output_folder / self.name
Expand Down
53 changes: 53 additions & 0 deletions dvclive/tensorboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import gorilla
import tensorflow as tf

from dvclive import Live

# pylint: disable=unused-argument, no-member


def patch_tensorboard(override: bool = True, **kwargs):
live = Live(**kwargs)
Copy link
Contributor Author

@daavoo daavoo Feb 10, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if patch_tensorboard should instead optionally accept a Live instance

settings = gorilla.Settings(allow_hit=True, store_hit=True)

original_scalar = gorilla.Patch(
tf.summary, "original_scalar", tf.summary.scalar, settings=settings
)
gorilla.apply(original_scalar)

def log_scalar(name, data, step=None, description=None):
if step is not None:
if step != live.get_step():
live.set_step(step)
live.log(name, data)
if not override:
tf.summary.original_scalar(name, data, step=None, description=None)

original_image = gorilla.Patch(
tf.summary, "original_image", tf.summary.image, settings=settings
)
gorilla.apply(original_image)

def log_image(name, data, step=None, max_outputs=3, description=None):
name += ".png"
if step is not None:
if step != live.get_step():
live.set_step(step)
if len(data) > 1:
for n, image in enumerate(data):
if n > max_outputs:
break
live.log_image(f"sample-{n}-{name}", image)
else:
live.log_image(name, data[0])

if not override:
tf.summary.original_image(name, data, step=None, description=None)

scalar_patch = gorilla.Patch(tf.summary, "scalar", log_scalar, settings)
gorilla.apply(scalar_patch)

image_patch = gorilla.Patch(tf.summary, "image", log_image, settings)
gorilla.apply(image_patch)

return original_scalar, original_image, scalar_patch, image_patch
17 changes: 16 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,22 @@ def run(self):
catalyst = ["catalyst<=21.12"]
fastai = ["fastai"]
pl = ["pytorch_lightning"]
image = ["pillow"]
tensorboard = tf + ["gorilla"]

all_libs = mmcv + tf + xgb + lgbm + hugginface + catalyst + fastai + pl + plots
all_libs = (
mmcv
+ tf
+ xgb
+ lgbm
+ hugginface
+ catalyst
+ fastai
+ pl
+ image
+ plots
+ tensorboard
)

tests_requires = [
"pylint==2.5.3",
Expand Down Expand Up @@ -83,6 +97,7 @@ def run(self):
"sklearn": plots,
"image": image,
"plots": plots,
"tensorboard": tensorboard,
},
keywords="data-science metrics machine-learning developer-tools ai",
python_requires=">=3.6",
Expand Down
11 changes: 10 additions & 1 deletion tests/test_data/test_image.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import numpy as np
import pytest
import tensorflow as tf
from PIL import Image

# pylint: disable=unused-argument
# pylint: disable=unused-argument,no-value-for-parameter
from dvclive import Live
from dvclive.data import Image as LiveImage

Expand Down Expand Up @@ -31,6 +32,14 @@ def test_numpy(tmp_dir, shape):
assert (tmp_dir / dvclive.dir / LiveImage.subfolder / "image.png").exists()


@pytest.mark.parametrize("shape", [(500, 500), (500, 500, 3), (500, 500, 4)])
def test_tensorflow(tmp_dir, shape):
dvclive = Live()
dvclive.log_image("image.png", tf.zeros(shape=shape, dtype=tf.uint8))

assert (tmp_dir / dvclive.dir / LiveImage.subfolder / "image.png").exists()


def test_step_formatting(tmp_dir):
dvclive = Live()
img = np.ones((500, 500, 3), np.uint8)
Expand Down
19 changes: 19 additions & 0 deletions tests/test_data/test_scalar.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import tensorflow as tf

from dvclive import Live
from tests.test_main import _parse_json

# pylint: disable=unused-argument


def test_tensorflow(tmp_dir):
dvclive = Live()
dvclive.log("int", tf.constant(1))
dvclive.log("float", tf.constant(1.5))

summary = _parse_json("dvclive.json")

assert isinstance(summary["int"], int)
assert summary["int"] == 1
assert isinstance(summary["float"], float)
assert summary["float"] == 1.5
34 changes: 34 additions & 0 deletions tests/test_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,3 +179,37 @@ def test_keras_None_model_file_skip_load(
)

assert load_weights.call_count == 0


def test_patch_tensorboard_keras_callback(tmp_dir, xor_model, mocker):
import gorilla
import tensorflow as tf

from dvclive.tensorboard import patch_tensorboard

scalar = mocker.spy(tf.summary, "scalar")
image = mocker.spy(tf.summary, "image")

patches = patch_tensorboard(path="logs")

model, x, y = xor_model()

model.fit(
x,
y,
epochs=2,
batch_size=1,
callbacks=[tf.keras.callbacks.TensorBoard()],
)

assert not scalar.call_count
assert not image.call_count

assert os.path.exists("logs")
logs, _ = read_logs("logs/scalars")

assert "epoch_accuracy" in logs
assert len(logs["epoch_accuracy"]) == 2

for patch in patches:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So in case of using dvclive with tensorboard we would need to revert the patches after the execution in order to not use modified tensorboard later in our project?

Shouldn't patch_tensorboard be a context manager that would revert the patch after exiting the context?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to test it but I think that the patch doesn't persist across python executions.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about the same process? I think safest way would be to let the user control when the dvclive is patching the tensorboard.

gorilla.revert(patch)
72 changes: 72 additions & 0 deletions tests/test_tensorboard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import os

import gorilla
import tensorflow as tf

from dvclive.tensorboard import patch_tensorboard
from tests.test_main import _parse_json

# pylint: disable=unused-argument, no-value-for-parameter


def test_patch_tensorboard(tmp_dir, mocker):
scalar = mocker.spy(tf.summary, "scalar")
image = mocker.spy(tf.summary, "image")

patches = patch_tensorboard()

tf.summary.scalar("m", 0.5)
tf.summary.image("image", [tf.zeros(shape=[8, 8, 1], dtype=tf.uint8)])

assert not scalar.call_count
assert not image.call_count

summary = _parse_json("dvclive.json")
image_path = os.path.join("dvclive", "images", "image.png")
assert summary["m"] == 0.5
assert os.path.exists(image_path)

for patch in patches:
gorilla.revert(patch)


def test_patch_tensorboard_no_override(tmp_dir, mocker):
scalar = mocker.spy(tf.summary, "scalar")
image = mocker.spy(tf.summary, "image")

patches = patch_tensorboard(override=False)

tf.summary.scalar("m", 0.5)
tf.summary.image("image", [tf.zeros(shape=[8, 8, 1], dtype=tf.uint8)])

assert scalar.call_count
assert image.call_count

summary = _parse_json("dvclive.json")
image_path = os.path.join("dvclive", "images", "image.png")
assert summary["m"] == 0.5
assert os.path.exists(image_path)

for patch in patches:
gorilla.revert(patch)


def test_patch_tensorboard_live_args(tmp_dir, mocker):
scalar = mocker.spy(tf.summary, "scalar")
image = mocker.spy(tf.summary, "image")

patches = patch_tensorboard(path="logs")

tf.summary.scalar("m", 0.5)
tf.summary.image("image", [tf.zeros(shape=[8, 8, 1], dtype=tf.uint8)])

assert not scalar.call_count
assert not image.call_count

summary = _parse_json("logs.json")
image_path = os.path.join("logs", "images", "image.png")
assert summary["m"] == 0.5
assert os.path.exists(image_path)

for patch in patches:
gorilla.revert(patch)