Skip to content

Commit

Permalink
Support multiple Observers
Browse files Browse the repository at this point in the history
  • Loading branch information
pattonw committed Feb 7, 2024
1 parent 8b9d59f commit 1fed55d
Show file tree
Hide file tree
Showing 7 changed files with 225 additions and 57 deletions.
26 changes: 3 additions & 23 deletions gunpowder/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,36 +25,17 @@ def __exit__(self, type, value, traceback):
logger.debug("tear down completed")


import neuroglancer
from .neuroglancer.event import step_next
from .observers import NeuroglancerObserver


class build_neuroglancer(object):
def __init__(self, pipeline):
self.pipeline = pipeline
self.observer = NeuroglancerObserver("neuroglancer", pipeline)

def __enter__(self):
neuroglancer.set_server_bind_address("0.0.0.0")
viewer = neuroglancer.Viewer()

viewer.actions.add("continue", step_next)

with viewer.config_state.txn() as s:
s.input_event_bindings.data_view["keyt"] = "continue"
with viewer.txn() as s:
s.layout = neuroglancer.row_layout(
[
neuroglancer.column_layout(
[
neuroglancer.LayerGroupViewer(layers=[]),
neuroglancer.LayerGroupViewer(layers=[]),
]
),
]
)

try:
self.pipeline.setup(viewer)
self.pipeline.setup([self.observer])
except:
logger.error(
"something went wrong during the setup of the pipeline, calling tear down"
Expand All @@ -63,7 +44,6 @@ def __enter__(self):
logger.debug("tear down completed")
raise

print(viewer)
return self.pipeline

def __exit__(self, type, value, traceback):
Expand Down
8 changes: 0 additions & 8 deletions gunpowder/neuroglancer/add_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,6 @@ def parse_dims(array):
spatial_dims = array.spec.roi.dims
channel_dims = dims - spatial_dims

print("dims :", dims)
print("spatial dims:", spatial_dims)
print("channel dims:", channel_dims)

return dims, spatial_dims, channel_dims


Expand All @@ -73,10 +69,6 @@ def create_coordinate_space(array, spatial_dim_names, channel_dim_names, unit):
units = [""] * channel_dims + [unit] * spatial_dims
scales = [1] * channel_dims + list(array.spec.voxel_size)

print("Names :", names)
print("Units :", units)
print("Scales :", scales)

return neuroglancer.CoordinateSpace(
names=names,
units=units,
Expand Down
34 changes: 21 additions & 13 deletions gunpowder/nodes/batch_provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@
from gunpowder.array_spec import ArraySpec
from gunpowder.graph import GraphKey
from gunpowder.graph_spec import GraphSpec
from gunpowder.neuroglancer.event import wait_for_step
from gunpowder.neuroglancer.visualize import visualize
from gunpowder.observers import Observer

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -54,7 +53,13 @@ class BatchProvider(object):
instead.
"""

viewer = None
_observers = None

@property
def observers(self):
if self._observers is None:
self._observers = []
return self._observers

def add_upstream_provider(self, provider):
self.get_upstream_providers().append(provider)
Expand Down Expand Up @@ -219,20 +224,23 @@ def request_batch(self, request):

return batch

def setup_viewer(self, viewer):
self.viewer = viewer
def register_observer(self, observer: Observer):
self.observers.append(observer)
self.observe_sources(observer)

def observe_sources(self, observer):
"""
to be implemented in subclasses
"""
pass

def observe_request(self, request):
if self.viewer is not None:
print("Waiting for step...")
visualize(self.viewer, request)
wait_for_step()
for observer in self.observers:
observer.update(request, self)

def observe_batch(self, batch):
if self.viewer is not None:
print("Waiting for step...")
visualize(self.viewer, batch)
wait_for_step()
for observer in self.observers:
observer.update(batch, self)

def set_seeds(self, request):
seed = request.random_seed
Expand Down
20 changes: 10 additions & 10 deletions gunpowder/nodes/zarr_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,16 +125,16 @@ def setup(self):

self.provides(array_key, spec)

def setup_viewer(self, viewer):
self.viewer = viewer
with viewer.txn() as s:
with self._open_file(self.store) as data_file:
for array_key, ds_name in self.datasets.items():
if ds_name not in data_file:
raise RuntimeError("%s not in %s" % (ds_name, self.store))

spec = self.__read_spec(array_key, data_file, ds_name)
add_layer(s, Array(data_file[ds_name], spec), f"{array_key}_SOURCE")
def observe_sources(self, observer):
with self._open_file(self.store) as data_file:
for array_key, ds_name in self.datasets.items():
if ds_name not in data_file:
raise RuntimeError("%s not in %s" % (ds_name, self.store))

spec = self.__read_spec(array_key, data_file, ds_name)
observer.add_source(
Array(data_file[ds_name], spec), f"{array_key}_SOURCE"
)

def provide(self, request):
timing = Timing(self)
Expand Down
109 changes: 109 additions & 0 deletions gunpowder/observers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from .batch import Batch
from .batch_request import BatchRequest
from .neuroglancer.event import step_next
from .neuroglancer.event import wait_for_step
from .neuroglancer.visualize import visualize
from .neuroglancer.add_layer import add_layer

# from .nodes import BatchProvider

import neuroglancer

from abc import ABC
from typing import Optional


class Observer(ABC):
def __init__(self, name, pipeline):
self.name = name
self.pipeline = pipeline

def update(self, request_or_batch: BatchRequest or Batch):
"""
Take a BatchRequest or Batch and update the observer's state with
their contents
"""
pass

def add_source(self, *args, **kwargs):
"""
Add a source to the observer. This is a no-op for observers that do not
provide an array source.
"""
pass


class NeuroglancerObserver(Observer):
def __init__(self, name, pipeline, host="0.0.0.0", port=0):
super().__init__(name, pipeline)
self.host = host
self.port = port

neuroglancer.set_server_bind_address(self.host, self.port)
self.viewer = neuroglancer.Viewer()
self.viewer.actions.add("continue", step_next)

with self.viewer.config_state.txn() as s:
s.input_event_bindings.data_view["keyt"] = "continue"

with self.viewer.txn() as s:
s.layout = neuroglancer.row_layout(
[
neuroglancer.column_layout(
[
neuroglancer.LayerGroupViewer(layers=[]),
neuroglancer.LayerGroupViewer(layers=[]),
]
),
]
)

print(self.viewer)
print("Hit T in neuroglancer viewer to step through the pipeline")

def update(self, request_or_batch: BatchRequest or Batch, node: Optional = None):
visualize(self.viewer, request_or_batch)
string = self.pipeline.to_string(bold=node)
print(
"\r"
+ (
"REQUESTING: "
if isinstance(request_or_batch, BatchRequest)
else "PROVIDING: "
)
+ string
+ " " * 2,
end="",
)
# print(self.pipeline.to_string(bold=node))
wait_for_step()

def add_source(
self,
array,
name,
):
spatial_dim_names = ["t", "z", "y", "x"]
channel_dim_names = ["b^", "c^"]
opacity = None
shader = None
rgb_channels = None
color = None
visible = True
value_scale_factor = 1.0
units = "nm"
with self.viewer.txn() as s:
add_layer(
s,
array,
name,
spatial_dim_names,
channel_dim_names,
opacity,
shader,
rgb_channels,
color,
visible,
value_scale_factor,
units,
)
21 changes: 18 additions & 3 deletions gunpowder/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from gunpowder.nodes import BatchProvider
from gunpowder.nodes.batch_provider import BatchRequestError
from .observers import Observer

import logging
import traceback
from typing import Optional

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -77,10 +79,12 @@ def copy(self):

return pipeline

def setup(self, viewer=None):
def setup(self, observers: Optional[list[Observer]] = None):
"""Connect all batch providers in the pipeline and call setup for
each, from source to sink."""

observers = observers if observers is not None else []

def connect(node):
for child in node.children:
node.output.add_upstream_provider(child.output)
Expand All @@ -94,8 +98,8 @@ def connect(node):
def node_setup(node):
try:
node.output.setup()
if viewer is not None:
node.output.setup_viewer(viewer)
for observer in observers:
node.output.register_observer(observer)
except Exception as e:
raise PipelineSetupError(node.output) from e

Expand Down Expand Up @@ -198,6 +202,17 @@ def to_string(node):
reprs = self.traverse(to_string, reverse=True)

return self.__rec_repr__(reprs)

def to_string(self, bold=None):
def to_string(node):
if node.output == bold:
return f"\033[1m{node.output.name()}\033[0m"
else:
return node.output.name()

reprs = self.traverse(to_string, reverse=True)

return self.__rec_repr__(reprs)

def __rec_repr__(self, reprs):
if not isinstance(reprs, list):
Expand Down
64 changes: 64 additions & 0 deletions neuroglancer_fun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import matplotlib.pyplot as plt
import numpy as np
import random
import zarr
import torch
from skimage import data
from skimage import filters

# make sure we all see the same
torch.manual_seed(1961923)
np.random.seed(1961923)
random.seed(1961923)

# open a sample image (channels first)
raw_data = data.astronaut().transpose(2, 0, 1)

# create some dummy "ground-truth" to train on
gt_data = filters.gaussian(raw_data[0], sigma=3.0) > 0.75
gt_data = gt_data[np.newaxis, :].astype(np.float32)

# store image in zarr container
f = zarr.open("sample_data.zarr", "w")
f["raw"] = raw_data
f["raw"].attrs["resolution"] = (1, 1)
f["ground_truth"] = gt_data
f["ground_truth"].attrs["resolution"] = (1, 1)

import gunpowder as gp

# declare arrays to use in the pipeline
raw = gp.ArrayKey("RAW")
gt = gp.ArrayKey("GT")

# create "pipeline" consisting only of a data source
source = gp.ZarrSource(
"sample_data.zarr", # the zarr container
{raw: "raw", gt: "ground_truth"}, # which dataset to associate to the array key
{
raw: gp.ArraySpec(interpolatable=True),
gt: gp.ArraySpec(interpolatable=False),
}, # meta-information
)
pipeline = source
pipeline += gp.Normalize(raw)
pipeline += gp.RandomLocation()
pipeline += gp.DeformAugment(
gp.Coordinate(5, 5),
gp.Coordinate(2, 2),
graph_raster_voxel_size=gp.Coordinate(1, 1),
)

# formulate a request for "raw"
request = gp.BatchRequest()
request.add(raw, gp.Coordinate(64, 64), gp.Coordinate(1, 1))
request.add(gt, gp.Coordinate(32, 32), gp.Coordinate(1, 1))

# build the pipeline...
with gp.build_neuroglancer(pipeline):
for _ in range(10):
# ...and request a batch
batch = pipeline.request_batch(request)

# show the content of the batch
print(f"batch returned: {batch}")

0 comments on commit 1fed55d

Please sign in to comment.