diff --git a/docs/images/cellprofiler.jpg b/docs/images/cellprofiler.jpg new file mode 100644 index 000000000..05a89b93e Binary files /dev/null and b/docs/images/cellprofiler.jpg differ diff --git a/docs/images/cellprofiler_plugin.png b/docs/images/cellprofiler_plugin.png new file mode 100644 index 000000000..397e6d795 Binary files /dev/null and b/docs/images/cellprofiler_plugin.png differ diff --git a/monailabel/tasks/infer/basic_infer.py b/monailabel/tasks/infer/basic_infer.py index e0b12eddc..930bf6b85 100644 --- a/monailabel/tasks/infer/basic_infer.py +++ b/monailabel/tasks/infer/basic_infer.py @@ -467,11 +467,10 @@ def _get_network(self, device, data): if path: checkpoint = torch.load(path, map_location=torch.device(device)) model_state_dict = checkpoint.get(self.model_state_dict, checkpoint) - - if set(self.network.state_dict().keys()) != set(checkpoint.keys()): + if set(self.network.state_dict().keys()) != set(model_state_dict.keys()): logger.warning( f"Checkpoint keys don't match network.state_dict()! Items that exist in only one dict" - f" but not in the other: {set(self.network.state_dict().keys()) ^ set(checkpoint.keys())}" + f" but not in the other: {set(self.network.state_dict().keys()) ^ set(model_state_dict.keys())}" ) logger.warning( "The run will now continue unless load_strict is set to True. " diff --git a/monailabel/utils/others/modelzoo_list.py b/monailabel/utils/others/modelzoo_list.py index cce2daf29..ebdc9afe0 100644 --- a/monailabel/utils/others/modelzoo_list.py +++ b/monailabel/utils/others/modelzoo_list.py @@ -32,4 +32,5 @@ "wholeBrainSeg_Large_UNEST_segmentation", # whole brain segmentation for T1 MRI brain images. Added Oct 2022 "lung_nodule_ct_detection", # The first lung nodule detection task can be used for MONAI Label. Added Dec 2022 "wholeBody_ct_segmentation", # The SegResNet trained TotalSegmentator dataset with 104 tissues. Added Feb 2023 + "vista2d", # The VISTA segmentation trained on a collection of 15K public microscopy images. Added Aug 2024 ] diff --git a/plugins/cellprofiler/README.md b/plugins/cellprofiler/README.md new file mode 100644 index 000000000..308683e83 --- /dev/null +++ b/plugins/cellprofiler/README.md @@ -0,0 +1,51 @@ + + +# MONAI Label Plugin for CellProfiler +CellProfiler is a free open-source software designed to enable biologists without training in computer vision or programming to quantitatively measure phenotypes from thousands of images automatically. + +CellProfiler is designed to serve biologists as well as bioimage analysts who want a flexible system that is easy to deploy to collaborators who lack computational skills. It is commonly used for small-scale experiments involving a few images but is also routinely run on millions of images using cluster or cloud computing resources at some of the largest pharmaceutical companies and academic screening centers in the world. + + + +### Table of Contents +- [Supported Applications](#supported-applications) +- [Installing CellProfiler](#installing-cellprofiler) +- [Install MONAI Label Extension](#install-monai-label-extension) +- [Using the Plugin](#using-the-plugin) + +### Supported Applications +Users can find supported applications in the [sample-apps](../../sample-apps/vista2d/) folder and use the vista2d bundle. Currently, only this bundle can be used to create and refine labels for medical imaging tasks in CellProfiler. + +### Installing CellProfiler + +To use MONAILabel with CellProfiler, you first need to install CellProfiler from source code following the installation part of [CellProfiler WiKi](https://github.com/CellProfiler/CellProfiler/wiki). Once you have CellProfiler installed, you can install the MONAILabel plugin as shown in the `Install MONAI Label Extension` part. + +Please note these tips when installing the software: +1. After cloned the source code, you must switch to a specific version tag of the repo to start installation. +1. For Macbooks with Apple Silicon, please note [this issue](https://github.com/CellProfiler/CellProfiler/issues/4932) before you start to install. +1. Before actually running the command `pip install -e .`, you may need to install numpy if it doesn't exist in your environment. + +### Install MONAI Label Extension +1. Save the runvista2d.py file to a local path. +1. Start CellProfiler from CLI, open the `File-->preferences` option and fill in the `CellProfiler plugin directory` parameter with your local path. + +## Using the Plugin + +1. Start the MONAI Label server with vista2d bundle. +1. Add the `RunVISTA2D` module to your pipeline for object processing. +1. Make sure the MONAILabel Server URL is correctly set in the plugin input. +1. Click the `Analyse Images` or `Start Test Mode` button to execute the pipeline. +1. Update masks and perform the sequential modules to get measurement results. + + diff --git a/plugins/cellprofiler/resources/vista2d_test.tiff b/plugins/cellprofiler/resources/vista2d_test.tiff new file mode 100644 index 000000000..bb1aa9b08 Binary files /dev/null and b/plugins/cellprofiler/resources/vista2d_test.tiff differ diff --git a/plugins/cellprofiler/runvista2d.py b/plugins/cellprofiler/runvista2d.py new file mode 100644 index 000000000..9aa33ccbe --- /dev/null +++ b/plugins/cellprofiler/runvista2d.py @@ -0,0 +1,813 @@ +################################# +# +# Imports from useful Python libraries +# +################################# + +import cgi +import http.client +import json +import logging +import mimetypes +import os +import re +import ssl +import tempfile +from pathlib import Path +from urllib.parse import quote_plus, unquote, urlencode, urlparse + +import requests +import skimage +from cellprofiler_core.module.image_segmentation import ImageSegmentation +from cellprofiler_core.object import Objects +from cellprofiler_core.setting.choice import Choice +from cellprofiler_core.setting.text import Text + +################################# +# +# Imports from CellProfiler +# +################################## + + +VISTA_link = "https://doi.org/10.48550/arXiv.2406.05285" +LOGGER = logging.getLogger(__name__) + +__doc__ = f"""\ +RunVISTA2D +=========== + +**RunVISTA2D** uses a pre-trained VISTA2D model to detect cells in an image. + +This module is useful for automating simple segmentation tasks in CellProfiler. +The module accepts tiff input images and produces an object set. + +This module is a client/frontend of a MONAI Label server. A VISTA2D based MONAI Label server needs to be set up and the address of the server needs to be passed to +this module, before running. + +Installation: + +This module has no external dependencies other than the python(>3.8) build-in dependencies. + +You'll need to set up the VISTA2D based MONAI Label server based on the tutorial https://github.com/Project-MONAI/MONAILabel/tree/main/plugins/cellprofiler. After setting up the server, please +provide the server address to this plugin.s + +Yufan He, Pengfei Guo, Yucheng Tang, Andriy Myronenko, Vishwesh Nath, Ziyue Xu, Dong Yang, Can Zhao, Benjamin Simon, Mason Belue, Stephanie Harmon, Baris Turkbey, Daguang Xu, & Wenqi Li. (2024). VISTA3D: Versatile Imaging SegmenTation and Annotation model for 3D Computed Tomography.{VISTA_link} +============ ============ =============== +Supports 2D? Supports 3D? Respects masks? +============ ============ =============== +YES No NO +============ ============ =============== + +""" + + +def bytes_to_str(b): + return b.decode("utf-8") if isinstance(b, bytes) else b + + +class MONAILabelClient: + """ + Basic MONAILabel Client to invoke infer/train APIs over http/https + """ + + def __init__(self, server_url, tmpdir=None, client_id=None): + """ + :param server_url: Server URL for MONAILabel (e.g. http://127.0.0.1:8000) + :param tmpdir: Temp directory to save temporary files. If None then it uses tempfile.tempdir + :param client_id: Client ID that will be added for all basic requests + """ + + self._server_url = server_url.rstrip("/").strip() + self._tmpdir = tmpdir if tmpdir else tempfile.tempdir if tempfile.tempdir else "/tmp" + self._client_id = client_id + self._headers = {} + + def _update_client_id(self, params): + if params: + params["client_id"] = self._client_id + else: + params = {"client_id": self._client_id} + return params + + def update_auth(self, token): + if token: + self._headers["Authorization"] = f"{token['token_type']} {token['access_token']}" + + def get_server_url(self): + """ + Return server url + + :return: the url for monailabel server + """ + return self._server_url + + def set_server_url(self, server_url): + """ + Set url for monailabel server + + :param server_url: server url for monailabel + """ + self._server_url = server_url.rstrip("/").strip() + + def auth_enabled(self) -> bool: + """ + Check if Auth is enabled + + """ + selector = "/auth/" + status, response, _, _ = MONAILabelUtils.http_method("GET", self._server_url, selector) + if status != 200: + return False + + response = bytes_to_str(response) + LOGGER.debug(f"Response: {response}") + enabled = json.loads(response).get("enabled", False) + return True if enabled else False + + def auth_token(self, username, password): + """ + Fetch Auth Token. Currently only basic authentication is supported. + + :param username: UserName for basic authentication + :param password: Password for basic authentication + """ + selector = "/auth/token" + data = urlencode({"username": username, "password": password, "grant_type": "password"}) + status, response, _, _ = MONAILabelUtils.http_method( + "POST", self._server_url, selector, data, None, "application/x-www-form-urlencoded" + ) + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response + ) + + response = bytes_to_str(response) + LOGGER.debug(f"Response: {response}") + return json.loads(response) + + def auth_valid_token(self) -> bool: + selector = "/auth/token/valid" + status, _, _, _ = MONAILabelUtils.http_method("GET", self._server_url, selector, headers=self._headers) + return True if status == 200 else False + + def info(self): + """ + Invoke /info/ request over MONAILabel Server + + :return: json response + """ + selector = "/info/" + status, response, _, _ = MONAILabelUtils.http_method("GET", self._server_url, selector, headers=self._headers) + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response + ) + + response = bytes_to_str(response) + logging.debug(f"Response: {response}") + return json.loads(response) + + def next_sample(self, strategy, params): + """ + Get Next sample + + :param strategy: Name of strategy to be used for fetching next sample + :param params: Additional JSON params as part of strategy request + :return: json response which contains information about next image selected for annotation + """ + params = self._update_client_id(params) + selector = f"/activelearning/{MONAILabelUtils.urllib_quote_plus(strategy)}" + status, response, _, _ = MONAILabelUtils.http_method( + "POST", self._server_url, selector, params, headers=self._headers + ) + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response + ) + + response = bytes_to_str(response) + logging.debug(f"Response: {response}") + return json.loads(response) + + def create_session(self, image_in, params=None): + """ + Create New Session + + :param image_in: filepath for image to be sent to server as part of session creation + :param params: additional JSON params as part of session reqeust + :return: json response which contains session id and other details + """ + selector = "/session/" + params = self._update_client_id(params) + + status, response, _ = MONAILabelUtils.http_upload( + "PUT", self._server_url, selector, params, [image_in], headers=self._headers + ) + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response + ) + + response = bytes_to_str(response) + logging.debug(f"Response: {response}") + return json.loads(response) + + def get_session(self, session_id): + """ + Get Session + + :param session_id: Session Id + :return: json response which contains more details about the session + """ + selector = f"/session/{MONAILabelUtils.urllib_quote_plus(session_id)}" + status, response, _, _ = MONAILabelUtils.http_method("GET", self._server_url, selector, headers=self._headers) + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response + ) + + response = bytes_to_str(response) + logging.debug(f"Response: {response}") + return json.loads(response) + + def remove_session(self, session_id): + """ + Remove any existing Session + + :param session_id: Session Id + :return: json response + """ + selector = f"/session/{MONAILabelUtils.urllib_quote_plus(session_id)}" + status, response, _, _ = MONAILabelUtils.http_method( + "DELETE", self._server_url, selector, headers=self._headers + ) + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response + ) + + response = bytes_to_str(response) + logging.debug(f"Response: {response}") + return json.loads(response) + + def upload_image(self, image_in, image_id=None, params=None): + """ + Upload New Image to MONAILabel Datastore + + :param image_in: Image File Path + :param image_id: Force Image ID; If not provided then Server it auto generate new Image ID + :param params: Additional JSON params + :return: json response which contains image id and other details + """ + selector = f"/datastore/?image={MONAILabelUtils.urllib_quote_plus(image_id)}" + + files = {"file": image_in} + params = self._update_client_id(params) + fields = {"params": json.dumps(params) if params else "{}"} + + status, response, _, _ = MONAILabelUtils.http_multipart( + "PUT", self._server_url, selector, fields, files, headers=self._headers + ) + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, + f"Status: {status}; Response: {bytes_to_str(response)}", + ) + + response = bytes_to_str(response) + logging.debug(f"Response: {response}") + return json.loads(response) + + def save_label(self, image_id, label_in, tag="", params=None): + """ + Save/Submit Label + + :param image_id: Image Id for which label needs to saved/submitted + :param label_in: Label File path which shall be saved/submitted + :param tag: Save label against tag in datastore + :param params: Additional JSON params for the request + :return: json response + """ + selector = f"/datastore/label?image={MONAILabelUtils.urllib_quote_plus(image_id)}" + if tag: + selector += f"&tag={MONAILabelUtils.urllib_quote_plus(tag)}" + + params = self._update_client_id(params) + fields = { + "params": json.dumps(params), + } + files = {"label": label_in} + + status, response, _, _ = MONAILabelUtils.http_multipart( + "PUT", self._server_url, selector, fields, files, headers=self._headers + ) + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, + f"Status: {status}; Response: {bytes_to_str(response)}", + ) + + response = bytes_to_str(response) + logging.debug(f"Response: {response}") + return json.loads(response) + + def datastore(self): + selector = "/datastore/?output=all" + status, response, _, _ = MONAILabelUtils.http_method("GET", self._server_url, selector, headers=self._headers) + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response + ) + + response = bytes_to_str(response) + logging.debug(f"Response: {response}") + return json.loads(response) + + def download_label(self, label_id, tag): + selector = "/datastore/label?label={}&tag={}".format( + MONAILabelUtils.urllib_quote_plus(label_id), MONAILabelUtils.urllib_quote_plus(tag) + ) + status, response, _, headers = MONAILabelUtils.http_method( + "GET", self._server_url, selector, headers=self._headers + ) + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, f"Status: {status}; Response: {bytes_to_str(response)}", status, response + ) + + content_disposition = headers.get("content-disposition") + + if not content_disposition: + logging.warning("Filename not found. Fall back to no loaded labels") + file_name = MONAILabelUtils.get_filename(content_disposition) + + file_ext = "".join(Path(file_name).suffixes) + local_filename = tempfile.NamedTemporaryFile(dir=self._tmpdir, suffix=file_ext).name + with open(local_filename, "wb") as f: + f.write(response) + + return local_filename + + def infer(self, model, image_id, params, label_in=None, file=None, session_id=None): + """ + Run Infer + + :param model: Name of Model + :param image_id: Image Id + :param params: Additional configs/json params as part of Infer request + :param label_in: File path for label mask which is needed to run Inference (e.g. In case of Scribbles) + :param file: File path for Image (use raw image instead of image_id) + :param session_id: Session ID (use existing session id instead of image_id) + :return: response_file (label mask), response_body (json result/output params) + """ + selector = "/infer/{}?image={}".format( + MONAILabelUtils.urllib_quote_plus(model), + MONAILabelUtils.urllib_quote_plus(image_id), + ) + if session_id: + selector += f"&session_id={MONAILabelUtils.urllib_quote_plus(session_id)}" + + params = self._update_client_id(params) + fields = {"params": json.dumps(params) if params else "{}"} + files = {"label": label_in} if label_in else {} + files.update({"file": file} if file and not session_id else {}) + + status, form, files, _ = MONAILabelUtils.http_multipart( + "POST", self._server_url, selector, fields, files, headers=self._headers + ) + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, + f"Status: {status}; Response: {bytes_to_str(form)}", + ) + + form = json.loads(form) if isinstance(form, str) else form + params = form.get("params") if files else form + params = json.loads(params) if isinstance(params, str) else params + + image_out = MONAILabelUtils.save_result(files, self._tmpdir) + return image_out, params + + def wsi_infer(self, model, image_id, body=None, output="dsa", session_id=None): + """ + Run WSI Infer in case of Pathology App + + :param model: Name of Model + :param image_id: Image Id + :param body: Additional configs/json params as part of Infer request + :param output: Output File format (dsa|asap|json) + :param session_id: Session ID (use existing session id instead of image_id) + :return: response_file (None), response_body + """ + selector = "/infer/wsi/{}?image={}".format( + MONAILabelUtils.urllib_quote_plus(model), + MONAILabelUtils.urllib_quote_plus(image_id), + ) + if session_id: + selector += f"&session_id={MONAILabelUtils.urllib_quote_plus(session_id)}" + if output: + selector += f"&output={MONAILabelUtils.urllib_quote_plus(output)}" + + body = self._update_client_id(body if body else {}) + status, form, _, _ = MONAILabelUtils.http_method("POST", self._server_url, selector, body) + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, + f"Status: {status}; Response: {bytes_to_str(form)}", + ) + + return None, form + + def train_start(self, model, params): + """ + Run Train Task + + :param model: Name of Model + :param params: Additional configs/json params as part of Train request + :return: json response + """ + params = self._update_client_id(params) + + selector = "/train/" + if model: + selector += MONAILabelUtils.urllib_quote_plus(model) + + status, response, _, _ = MONAILabelUtils.http_method( + "POST", self._server_url, selector, params, headers=self._headers + ) + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, + f"Status: {status}; Response: {bytes_to_str(response)}", + ) + + response = bytes_to_str(response) + logging.debug(f"Response: {response}") + return json.loads(response) + + def train_stop(self): + """ + Stop any running Train Task(s) + + :return: json response + """ + selector = "/train/" + status, response, _, _ = MONAILabelUtils.http_method( + "DELETE", self._server_url, selector, headers=self._headers + ) + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, + f"Status: {status}; Response: {bytes_to_str(response)}", + ) + + response = bytes_to_str(response) + logging.debug(f"Response: {response}") + return json.loads(response) + + def train_status(self, check_if_running=False): + """ + Check Train Task Status + + :param check_if_running: Fast mode. Only check if training is Running + :return: boolean if check_if_running is enabled; else json response that contains of full details + """ + selector = "/train/" + if check_if_running: + selector += "?check_if_running=true" + status, response, _, _ = MONAILabelUtils.http_method("GET", self._server_url, selector, headers=self._headers) + if check_if_running: + return status == 200 + + if status != 200: + raise MONAILabelClientException( + MONAILabelError.SERVER_ERROR, + f"Status: {status}; Response: {bytes_to_str(response)}", + ) + + response = bytes_to_str(response) + logging.debug(f"Response: {response}") + return json.loads(response) + + +class MONAILabelError: + """ + Type of Inference Model + + Attributes: + SERVER_ERROR - Server Error + SESSION_EXPIRED - Session Expired + UNKNOWN - Unknown Error + """ + + SERVER_ERROR = 1 + SESSION_EXPIRED = 2 + UNKNOWN = 3 + + +class MONAILabelClientException(Exception): + """ + MONAILabel Client Exception + """ + + __slots__ = ["error", "msg"] + + def __init__(self, error, msg, status_code=None, response=None): + """ + :param error: Error code represented by MONAILabelError + :param msg: Error message + :param status_code: HTTP Response code + :param response: HTTP Response + """ + self.error = error + self.msg = msg + self.status_code = status_code + self.response = response + + +class MONAILabelUtils: + @staticmethod + def http_method(method, server_url, selector, body=None, headers=None, content_type=None): + logging.debug(f"{method} {server_url}{selector}") + + parsed = urlparse(server_url) + path = parsed.path.rstrip("/") + selector = path + "/" + selector.lstrip("/") + logging.debug(f"URI Path: {selector}") + + parsed = urlparse(server_url) + if parsed.scheme == "https": + LOGGER.debug("Using HTTPS mode") + # noinspection PyProtectedMember + conn = http.client.HTTPSConnection(parsed.hostname, parsed.port, context=ssl._create_unverified_context()) + else: + conn = http.client.HTTPConnection(parsed.hostname, parsed.port) + + headers = headers if headers else {} + if body: + if not content_type: + if isinstance(body, dict): + body = json.dumps(body) + content_type = "application/json" + else: + content_type = "text/plain" + headers.update({"content-type": content_type, "content-length": str(len(body))}) + + conn.request(method, selector, body=body, headers=headers) + return MONAILabelUtils.send_response(conn) + + @staticmethod + def http_upload(method, server_url, selector, fields, files, headers=None): + logging.debug(f"{method} {server_url}{selector}") + + url = server_url.rstrip("/") + "/" + selector.lstrip("/") + logging.debug(f"URL: {url}") + + files = [("files", (os.path.basename(f), open(f, "rb"))) for f in files] + headers = headers if headers else {} + response = ( + requests.post(url, files=files, headers=headers) + if method == "POST" + else requests.put(url, files=files, data=fields, headers=headers) + ) + return response.status_code, response.text, None + + @staticmethod + def http_multipart(method, server_url, selector, fields, files, headers={}): + logging.debug(f"{method} {server_url}{selector}") + + content_type, body = MONAILabelUtils.encode_multipart_formdata(fields, files) + headers = headers if headers else {} + headers.update({"content-type": content_type, "content-length": str(len(body))}) + + parsed = urlparse(server_url) + path = parsed.path.rstrip("/") + selector = path + "/" + selector.lstrip("/") + logging.debug(f"URI Path: {selector}") + + if parsed.scheme == "https": + LOGGER.debug("Using HTTPS mode") + # noinspection PyProtectedMember + conn = http.client.HTTPSConnection(parsed.hostname, parsed.port, context=ssl._create_unverified_context()) + else: + conn = http.client.HTTPConnection(parsed.hostname, parsed.port) + + conn.request(method, selector, body, headers) + return MONAILabelUtils.send_response(conn, content_type) + + @staticmethod + def send_response(conn, content_type="application/json"): + response = conn.getresponse() + logging.debug(f"HTTP Response Code: {response.status}") + logging.debug(f"HTTP Response Message: {response.reason}") + logging.debug(f"HTTP Response Headers: {response.getheaders()}") + + response_content_type = response.getheader("content-type", content_type) + logging.debug(f"HTTP Response Content-Type: {response_content_type}") + + if "multipart" in response_content_type: + if response.status == 200: + form, files = MONAILabelUtils.parse_multipart(response.fp if response.fp else response, response.msg) + logging.debug(f"Response FORM: {form}") + logging.debug(f"Response FILES: {files.keys()}") + return response.status, form, files, response.headers + else: + return response.status, response.read(), None, response.headers + + logging.debug("Reading status/content from simple response!") + return response.status, response.read(), None, response.headers + + @staticmethod + def save_result(files, tmpdir): + for name in files: + data = files[name] + result_file = os.path.join(tmpdir, name) + + logging.debug(f"Saving {name} to {result_file}; Size: {len(data)}") + dir_path = os.path.dirname(os.path.realpath(result_file)) + if not os.path.exists(dir_path): + os.makedirs(dir_path) + + with open(result_file, "wb") as f: + if isinstance(data, bytes): + f.write(data) + else: + f.write(data.encode("utf-8")) + + # Currently only one file per response supported + return result_file + + @staticmethod + def encode_multipart_formdata(fields, files): + limit = "----------lImIt_of_THE_fIle_eW_$" + lines = [] + for key, value in fields.items(): + lines.append("--" + limit) + lines.append('Content-Disposition: form-data; name="%s"' % key) + lines.append("") + lines.append(value) + for key, filename in files.items(): + lines.append("--" + limit) + lines.append(f'Content-Disposition: form-data; name="{key}"; filename="{filename}"') + lines.append("Content-Type: %s" % MONAILabelUtils.get_content_type(filename)) + lines.append("") + with open(filename, mode="rb") as f: + data = f.read() + lines.append(data) + lines.append("--" + limit + "--") + lines.append("") + + body = bytearray() + for line in lines: + body.extend(line if isinstance(line, bytes) else line.encode("utf-8")) + body.extend(b"\r\n") + + content_type = "multipart/form-data; boundary=%s" % limit + return content_type, body + + @staticmethod + def get_content_type(filename): + return mimetypes.guess_type(filename)[0] or "application/octet-stream" + + @staticmethod + def parse_multipart(fp, headers): + fs = cgi.FieldStorage( + fp=fp, + environ={"REQUEST_METHOD": "POST"}, + headers=headers, + keep_blank_values=True, + ) + form = {} + files = {} + if hasattr(fs, "list") and isinstance(fs.list, list): + for f in fs.list: + LOGGER.debug(f"FILE-NAME: {f.filename}; NAME: {f.name}; SIZE: {len(f.value)}") + if f.filename: + files[f.filename] = f.value + else: + form[f.name] = f.value + return form, files + + @staticmethod + def urllib_quote_plus(s): + return quote_plus(s) + + @staticmethod + def get_filename(content_disposition): + file_name = re.findall(r"filename\*=([^;]+)", content_disposition, flags=re.IGNORECASE) + if not file_name: + file_name = re.findall('filename="(.+)"', content_disposition, flags=re.IGNORECASE) + if "utf-8''" in file_name[0].lower(): + file_name = re.sub("utf-8''", "", file_name[0], flags=re.IGNORECASE) + file_name = unquote(file_name) + else: + file_name = file_name[0] + return file_name + + +class RunVISTA2D(ImageSegmentation): + category = "Object Processing" + + module_name = "RunVISTA2D" + + variable_revision_number = 1 + + doi = { + "Please cite the following when using RunVISTA2D:": "https://doi.org/10.48550/arXiv.2406.05285", + } + + def create_settings(self): + super().create_settings() + + self.server_address = Text( + text="MONAI label server address", + value="http://127.0.0.1:8000", + doc="""\ +Please set up the MONAI label server in local/cloud environment and fill the server address here. +""", + ) + + self.model_name = Choice( + text="The model for running the inference", + choices=["vista2d"], + value="vista2d", + doc=""" +Pick the model for running infernce. Now only VISTA2D is available. +""", + ) + + def settings(self): + return [ + self.x_name, + self.y_name, + self.server_address, + self.model_name, + ] + + def visible_settings(self): + return [ + self.x_name, + self.y_name, + self.server_address, + self.model_name, + ] + + def run(self, workspace): + x_name = self.x_name.value + y_name = self.y_name.value + images = workspace.image_set + x = images.get_image(x_name) + dimensions = x.dimensions + x_data = x.pixel_data + + with tempfile.TemporaryDirectory() as temp_dir: + temp_img_dir = os.path.join(temp_dir, "img") + os.makedirs(temp_img_dir, exist_ok=True) + temp_img_path = os.path.join(temp_img_dir, x_name + ".tiff") + temp_mask_dir = os.path.join(temp_dir, "mask") + os.makedirs(temp_mask_dir, exist_ok=True) + skimage.io.imsave(temp_img_path, x_data) + monailabel_client = MONAILabelClient(server_url=self.server_address.value, tmpdir=temp_mask_dir) + image_out, params = monailabel_client.infer( + model=self.model_name.value, image_id="", params={}, file=temp_img_path + ) + print(f"Image out:\n{image_out}") + print(f"Params:\n{params}") + y_data = skimage.io.imread(image_out) + + y = Objects() + y.segmented = y_data + y.parent_image = x.parent_image + objects = workspace.object_set + objects.add_objects(y, y_name) + + self.add_measurements(workspace) + + if self.show_window: + workspace.display_data.x_data = x_data + workspace.display_data.y_data = y_data + workspace.display_data.dimensions = dimensions + + def display(self, workspace, figure): + layout = (2, 1) + figure.set_subplots(dimensions=workspace.display_data.dimensions, subplots=layout) + + figure.subplot_imshow( + colormap="gray", + image=workspace.display_data.x_data, + title="Input Image", + x=0, + y=0, + ) + + figure.subplot_imshow_labels( + image=workspace.display_data.y_data, + sharexy=figure.subplot(0, 0), + title=self.y_name.value, + x=1, + y=0, + ) + + # def upgrade_settings(self, setting_values, variable_revision_number, module_name): + # ... diff --git a/plugins/cellprofiler/test_runvista2d.py b/plugins/cellprofiler/test_runvista2d.py new file mode 100644 index 000000000..47ac1f929 --- /dev/null +++ b/plugins/cellprofiler/test_runvista2d.py @@ -0,0 +1,91 @@ +import os + +import cellprofiler_core.image +import cellprofiler_core.measurement +import cellprofiler_core.object +import cellprofiler_core.pipeline +import cellprofiler_core.setting +import cellprofiler_core.workspace +import numpy +import pytest +from runvista2d import MONAILabelClient, MONAILabelClientException, MONAILabelUtils, RunVISTA2D + +IMAGE_NAME = "my_image" +OBJECTS_NAME = "my_objects" +MODEL_NAME = "vista2d" +SERVER_ADDRESS = "http://127.0.0.1:8000" + + +class MockResponse: + @staticmethod + def infer(*args, **kwargs): + filepath = os.path.abspath(__file__) + dir = os.path.dirname(filepath) + image = os.path.join(dir, "resources", "vista2d_test.tiff") + return image, {} + + +class MockErrResponse: + @staticmethod + def http_multipart(*args, **kwargs): + return 400, {}, {}, {} + + +def test_mock_failed(): + x = RunVISTA2D() + x.y_name.value = OBJECTS_NAME + x.x_name.value = IMAGE_NAME + x.server_address.value = SERVER_ADDRESS + x.model_name.value = MODEL_NAME + + img = numpy.zeros((128, 128, 3)) + image = cellprofiler_core.image.Image(img) + image_set_list = cellprofiler_core.image.ImageSetList() + image_set = image_set_list.get_image_set(0) + image_set.providers.append(cellprofiler_core.image.VanillaImage(IMAGE_NAME, image)) + object_set = cellprofiler_core.object.ObjectSet() + measurements = cellprofiler_core.measurement.Measurements() + pipeline = cellprofiler_core.pipeline.Pipeline() + + pytest.MonkeyPatch().setattr(MONAILabelUtils, "http_multipart", MockErrResponse.http_multipart) + with pytest.raises(MONAILabelClientException): + x.run(cellprofiler_core.workspace.Workspace(pipeline, x, image_set, object_set, measurements, None)) + + +def test_mock_successful(): + x = RunVISTA2D() + x.y_name.value = OBJECTS_NAME + x.x_name.value = IMAGE_NAME + x.server_address.value = SERVER_ADDRESS + x.model_name.value = MODEL_NAME + + img = numpy.zeros((128, 128, 3)) + image = cellprofiler_core.image.Image(img) + image_set_list = cellprofiler_core.image.ImageSetList() + image_set = image_set_list.get_image_set(0) + image_set.providers.append(cellprofiler_core.image.VanillaImage(IMAGE_NAME, image)) + object_set = cellprofiler_core.object.ObjectSet() + measurements = cellprofiler_core.measurement.Measurements() + pipeline = cellprofiler_core.pipeline.Pipeline() + + pytest.MonkeyPatch().setattr(MONAILabelClient, "infer", MockResponse.infer) + x.run(cellprofiler_core.workspace.Workspace(pipeline, x, image_set, object_set, measurements, None)) + assert len(object_set.object_names) == 1 + assert OBJECTS_NAME in object_set.object_names + objects = object_set.get_objects(OBJECTS_NAME) + segmented = objects.segmented + assert numpy.all(segmented == 0) + assert "Image" in measurements.get_object_names() + assert OBJECTS_NAME in measurements.get_object_names() + + assert f"Count_{OBJECTS_NAME}" in measurements.get_feature_names("Image") + count = measurements.get_current_measurement("Image", f"Count_{OBJECTS_NAME}") + assert count == 0 + assert "Location_Center_X" in measurements.get_feature_names(OBJECTS_NAME) + location_center_x = measurements.get_current_measurement(OBJECTS_NAME, "Location_Center_X") + assert isinstance(location_center_x, numpy.ndarray) + assert numpy.product(location_center_x.shape) == 0 + assert "Location_Center_Y" in measurements.get_feature_names(OBJECTS_NAME) + location_center_y = measurements.get_current_measurement(OBJECTS_NAME, "Location_Center_Y") + assert isinstance(location_center_y, numpy.ndarray) + assert numpy.product(location_center_y.shape) == 0 diff --git a/sample-apps/vista2d/README.md b/sample-apps/vista2d/README.md new file mode 100644 index 000000000..4aa8f54f9 --- /dev/null +++ b/sample-apps/vista2d/README.md @@ -0,0 +1,68 @@ + + +# VISTA2D Application +A reference app to run the inference task to segment cells. This app works in CellProfiler for now. All samples in the CellProfiler are provided from local path. + +### Table of Contents +- [Supported Viewers](#supported-viewers) +- [Pretrained Models](#pretrained-models) +- [How To Use the App](#how-to-use-the-app) + +### Supported Viewers + +The VISTA2D Application supports the following viewer: + +- [CellProfiler](../../plugins/cellprofiler/) + +### Pretrained Models + +The following are the models which are currently added into Pathology App: + +| Name | Description | +|------|-------------| +| VISTA2D | An example of instance segmentation for the cell segmentation. | + +
+ Model Details (dataset, input, outputs) + +#### Dataset + +You can use the [cellpose dataset](https://www.cellpose.org/dataset) for inference. + +#### Inputs + +TIFF Images + +#### Output + +Segmentation Masks + +
+ +### How To Use the App + +```bash +# skip this if you have already downloaded the app or using github repository (dev mode) +monailabel apps --download --name vista2d --output apps + +# Start server with vista2d model +monailabel start_server --app apps/vista2d --studies datasets --conf models vista2d --conf preload true --conf skip_trainers true +``` + +**Specify bundle version** (Optional) +Above command will download the latest bundles from Model-Zoo by default. If a specific or older bundle version is used, users can add version `_v` followed by the bundle name. Example: + +```bash +monailabel start_server --app apps/vista2d --studies datasets --conf models vista2d_v0.2.1 --conf preload true --conf skip_trainers true +``` diff --git a/sample-apps/vista2d/__init__.py b/sample-apps/vista2d/__init__.py new file mode 100644 index 000000000..1e97f8940 --- /dev/null +++ b/sample-apps/vista2d/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sample-apps/vista2d/lib/__init__.py b/sample-apps/vista2d/lib/__init__.py new file mode 100644 index 000000000..1e97f8940 --- /dev/null +++ b/sample-apps/vista2d/lib/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sample-apps/vista2d/lib/activelearning/__init__.py b/sample-apps/vista2d/lib/activelearning/__init__.py new file mode 100644 index 000000000..1e97f8940 --- /dev/null +++ b/sample-apps/vista2d/lib/activelearning/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sample-apps/vista2d/lib/infers/__init__.py b/sample-apps/vista2d/lib/infers/__init__.py new file mode 100644 index 000000000..57ae33bec --- /dev/null +++ b/sample-apps/vista2d/lib/infers/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .vista2d import VISTA2DInfer diff --git a/sample-apps/vista2d/lib/infers/vista2d.py b/sample-apps/vista2d/lib/infers/vista2d.py new file mode 100644 index 000000000..e7be48143 --- /dev/null +++ b/sample-apps/vista2d/lib/infers/vista2d.py @@ -0,0 +1,33 @@ +import logging +import os +from typing import Any, Dict, Tuple + +from monai.utils import ImageMetaKey + +from monailabel.tasks.infer.bundle import BundleInferTask + +logger = logging.getLogger(__name__) + + +class VISTA2DInfer(BundleInferTask): + """ + This provides Inference Engine for pre-trained VISTA segmentation model. + """ + + def writer(self, data: Dict[str, Any], extension=None, dtype=None) -> Tuple[Any, Any]: + d = dict(data) + output_dir = self.bundle_config.get_parsed_content("output_dir", instantiate=True) + output_ext = self.bundle_config.get_parsed_content("output_ext", instantiate=True) + image_key = self.bundle_config.get_parsed_content("image_key", instantiate=True) + output_postfix = self.bundle_config.get_parsed_content("output_postfix", instantiate=True) + + img = d.get(image_key, None) + filename = img.meta.get(ImageMetaKey.FILENAME_OR_OBJ) if img is not None else None + basename = os.path.splitext(os.path.basename(filename))[0] if filename else "mask" + output_filename = f"{basename}{'_' + output_postfix if output_postfix else ''}{output_ext}" + output_filepath = os.path.join(output_dir, output_filename) + if os.path.exists(output_filepath): + logger.info(f"Reusing the bundle output {output_filepath}.") + return output_filepath, {} + else: + return super().writer(data=data, extension=extension, dtype=dtype) diff --git a/sample-apps/vista2d/lib/trainers/__init__.py b/sample-apps/vista2d/lib/trainers/__init__.py new file mode 100644 index 000000000..1e97f8940 --- /dev/null +++ b/sample-apps/vista2d/lib/trainers/__init__.py @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/sample-apps/vista2d/main.py b/sample-apps/vista2d/main.py new file mode 100644 index 000000000..2380f1ff8 --- /dev/null +++ b/sample-apps/vista2d/main.py @@ -0,0 +1,214 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +from typing import Dict + +from lib.infers import VISTA2DInfer +from monai.transforms import Invertd, SaveImaged + +import monailabel +from monailabel.interfaces.app import MONAILabelApp +from monailabel.interfaces.tasks.infer_v2 import InferTask +from monailabel.interfaces.tasks.scoring import ScoringMethod +from monailabel.interfaces.tasks.strategy import Strategy +from monailabel.interfaces.tasks.train import TrainTask +from monailabel.tasks.activelearning.first import First +from monailabel.tasks.activelearning.random import Random +from monailabel.tasks.infer.bundle import BundleInferTask +from monailabel.tasks.scoring.epistemic_v2 import EpistemicScoring +from monailabel.tasks.train.bundle import BundleTrainTask +from monailabel.utils.others.generic import get_bundle_models, strtobool + +logger = logging.getLogger(__name__) + + +class VISTA2DApp(MONAILabelApp): + def __init__(self, app_dir, studies, conf): + self.models = get_bundle_models(app_dir, conf) + # Add Epistemic model for scoring + self.epistemic_models = ( + get_bundle_models(app_dir, conf, conf_key="epistemic_model") if conf.get("epistemic_model") else None + ) + if self.epistemic_models: + # Get epistemic parameters + self.epistemic_max_samples = int(conf.get("epistemic_max_samples", "0")) + self.epistemic_simulation_size = int(conf.get("epistemic_simulation_size", "5")) + self.epistemic_dropout = float(conf.get("epistemic_dropout", "0.2")) + + super().__init__( + app_dir=app_dir, + studies=studies, + conf=conf, + name=f"MONAILabel - Zoo/Bundle ({monailabel.__version__})", + description="DeepLearning models provided via MONAI Zoo/Bundle", + version=monailabel.__version__, + ) + + def init_infers(self) -> Dict[str, InferTask]: + infers: Dict[str, InferTask] = {} + ################################################# + # Models + ################################################# + + for n, b in self.models.items(): + i = VISTA2DInfer(b, self.conf, model_state_dict="state_dict") + logger.info(f"+++ Adding Inferer:: {n} => {i}") + infers[n] = i + + return infers + + def init_trainers(self) -> Dict[str, TrainTask]: + trainers: Dict[str, TrainTask] = {} + if strtobool(self.conf.get("skip_trainers", "false")): + return trainers + + for n, b in self.models.items(): + t = BundleTrainTask(b, self.conf) + if not t or not t.is_valid(): + continue + + logger.info(f"+++ Adding Trainer:: {n} => {t}") + trainers[n] = t + return trainers + + def init_strategies(self) -> Dict[str, Strategy]: + strategies: Dict[str, Strategy] = { + "random": Random(), + "first": First(), + } + + logger.info(f"Active Learning Strategies:: {list(strategies.keys())}") + return strategies + + def init_scoring_methods(self) -> Dict[str, ScoringMethod]: + methods: Dict[str, ScoringMethod] = {} + if not self.conf.get("epistemic_model"): + return methods + + for n, b in self.epistemic_models.items(): + # Create BundleInferTask task with dropout instantiation for scoring inference + i = BundleInferTask( + b, + self.conf, + train_mode=True, + skip_writer=True, + dropout=self.epistemic_dropout, + post_filter=[SaveImaged, Invertd], + ) + methods[n] = EpistemicScoring( + i, max_samples=self.epistemic_max_samples, simulation_size=self.epistemic_simulation_size + ) + if not methods: + continue + methods = methods if isinstance(methods, dict) else {n: methods[n]} + logger.info(f"+++ Adding Scoring Method:: {n} => {b}") + + logger.info(f"Active Learning Scoring Methods:: {list(methods.keys())}") + return methods + + +""" +Example to run train/infer/scoring task(s) locally without actually running MONAI Label Server +""" + + +def main(): + import argparse + import shutil + from pathlib import Path + + from monailabel.utils.others.generic import device_list, file_ext + + os.putenv("MASTER_ADDR", "127.0.0.1") + os.putenv("MASTER_PORT", "1234") + + logging.basicConfig( + level=logging.INFO, + format="[%(asctime)s] [%(process)s] [%(threadName)s] [%(levelname)s] (%(name)s:%(lineno)d) - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + force=True, + ) + + home = str(Path.home()) + studies = f"{home}/Datasets/cellpose" + + parser = argparse.ArgumentParser() + parser.add_argument("-s", "--studies", default=studies) + parser.add_argument("-m", "--model", default="vista2d") + parser.add_argument("-t", "--test", default="infer", choices=("train", "infer", "batch_infer")) + args = parser.parse_args() + + app_dir = os.path.dirname(__file__) + studies = args.studies + conf = { + "models": args.model, + "preload": "false", + } + + app = VISTA2DApp(app_dir, studies, conf) + + # Infer + if args.test == "infer": + sample = app.next_sample(request={"strategy": "first"}) + image_id = sample["id"] + image_path = sample["path"] + + # Run on all devices + for device in device_list(): + res = app.infer(request={"model": args.model, "image": image_id, "device": device}) + label = res["file"] + label_json = res["params"] + test_dir = os.path.join(args.studies, "test_labels") + os.makedirs(test_dir, exist_ok=True) + + label_file = os.path.join(test_dir, image_id + file_ext(image_path)) + shutil.move(label, label_file) + + print(label_json) + print(f"++++ Image File: {image_path}") + print(f"++++ Label File: {label_file}") + break + return + + # Batch Infer + if args.test == "batch_infer": + app.batch_infer( + request={ + "model": args.model, + "multi_gpu": False, + "save_label": True, + "label_tag": "original", + "max_workers": 1, + "max_batch_size": 0, + } + ) + return + + # Train + app.train( + request={ + "model": args.model, + "max_epochs": 10, + "dataset": "Dataset", # PersistentDataset, CacheDataset + "train_batch_size": 1, + "val_batch_size": 1, + "multi_gpu": False, + "val_split": 0.1, + }, + ) + + +if __name__ == "__main__": + # export PYTHONPATH=~/Projects/MONAILabel:`pwd` + # python main.py + main() diff --git a/sample-apps/vista2d/requirements.txt b/sample-apps/vista2d/requirements.txt new file mode 100644 index 000000000..1e97f8940 --- /dev/null +++ b/sample-apps/vista2d/requirements.txt @@ -0,0 +1,10 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License.