Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Improved type hinting #6

Merged
merged 2 commits into from
Feb 3, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -55,5 +55,13 @@ localtest.py
test.dcm
old/

# Environments
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# VSCode
.vscode/settings.json
2 changes: 1 addition & 1 deletion rt_utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .rtstruct import RTStruct
from .rtstruct_builder import RTStructBuilder
from .rtstruct_builder import RTStructBuilder
43 changes: 28 additions & 15 deletions rt_utils/ds_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
"""
File contains helper methods that handles DICOM header creation/formatting
"""
def create_rtstruct_dataset(series_data):
def create_rtstruct_dataset(series_data) -> FileDataset:
ds = generate_base_dataset()
add_study_and_series_information(ds, series_data)
add_patient_information(ds, series_data)
Expand Down Expand Up @@ -71,10 +71,9 @@ def add_study_and_series_information(ds: FileDataset, series_data):
ds.StudyDescription = reference_ds.StudyDescription
ds.SeriesDescription = reference_ds.SeriesDescription
ds.StudyInstanceUID = reference_ds.StudyInstanceUID
ds.SeriesInstanceUID = generate_uid() # TODO find out if random generation is ok
ds.SeriesInstanceUID = generate_uid() # TODO: find out if random generation is ok
ds.StudyID = reference_ds.StudyID
ds.SeriesNumber = "1" # TODO find out if we can just use 1 (Should be fine since its a new series)
pass
ds.SeriesNumber = "1" # TODO: find out if we can just use 1 (Should be fine since its a new series)

def add_patient_information(ds: FileDataset, series_data):
reference_ds = series_data[0] # All elements in series should have the same data
Expand All @@ -86,6 +85,7 @@ def add_patient_information(ds: FileDataset, series_data):
ds.PatientSize = reference_ds.PatientSize
ds.PatientWeight = reference_ds.PatientWeight


def add_refd_frame_of_ref_sequence(ds: FileDataset, series_data):
refd_frame_of_ref = Dataset()
refd_frame_of_ref.FrameOfReferenceUID = generate_uid() # TODO Find out if random generation is ok
Expand All @@ -95,7 +95,8 @@ def add_refd_frame_of_ref_sequence(ds: FileDataset, series_data):
ds.ReferencedFrameOfReferenceSequence = Sequence()
ds.ReferencedFrameOfReferenceSequence.append(refd_frame_of_ref)

def create_frame_of_ref_study_sequence(series_data):

def create_frame_of_ref_study_sequence(series_data) -> Sequence:
reference_ds = series_data[0] # All elements in series should have the same data
rt_refd_series = Dataset()
rt_refd_series.SeriesInstanceUID = reference_ds.SeriesInstanceUID
Expand All @@ -113,7 +114,8 @@ def create_frame_of_ref_study_sequence(series_data):
rt_refd_study_sequence.append(rt_refd_study)
return rt_refd_study_sequence

def create_contour_image_sequence(series_data):

def create_contour_image_sequence(series_data) -> Sequence:
contour_image_sequence = Sequence()

# Add each referenced image
Expand All @@ -124,7 +126,8 @@ def create_contour_image_sequence(series_data):
contour_image_sequence.append(contour_image)
return contour_image_sequence

def create_structure_set_roi(roi_data: ROIData):

def create_structure_set_roi(roi_data: ROIData) -> Dataset:
# Structure Set ROI Sequence: Structure Set ROI 1
structure_set_roi = Dataset()
structure_set_roi.ROINumber = roi_data.number
Expand All @@ -134,18 +137,21 @@ def create_structure_set_roi(roi_data: ROIData):
structure_set_roi.ROIGenerationAlgorithm = 'MANUAL'
return structure_set_roi

def create_roi_contour(roi_data: ROIData, series_data):

def create_roi_contour(roi_data: ROIData, series_data) -> Dataset:
roi_contour = Dataset()
roi_contour.ROIDisplayColor = roi_data.color
roi_contour.ContourSequence = create_contour_sequence(roi_data, series_data)
roi_contour.ReferencedROINumber = str(roi_data.number)
return roi_contour

"""
Iterate through each slice of the mask
For each connected segment within a slice, create a contour
"""
def create_contour_sequence(roi_data: ROIData, series_data):

def create_contour_sequence(roi_data: ROIData, series_data) -> Sequence:
"""
Iterate through each slice of the mask
For each connected segment within a slice, create a contour
"""

contour_sequence = Sequence()
for i, series_slice in enumerate(series_data):
mask_slice = roi_data.mask[:,:,i]
Expand All @@ -158,9 +164,11 @@ def create_contour_sequence(roi_data: ROIData, series_data):
for contour_data in contour_coords:
contour = create_contour(series_slice, contour_data)
contour_sequence.append(contour)

return contour_sequence

def create_contour(series_slice: Dataset, contour_data: np.ndarray):

def create_contour(series_slice: Dataset, contour_data: np.ndarray) -> Dataset:
contour_image = Dataset()
contour_image.ReferencedSOPClassUID = series_slice.file_meta.MediaStorageSOPClassUID
contour_image.ReferencedSOPInstanceUID = series_slice.file_meta.MediaStorageSOPInstanceUID
Expand All @@ -174,6 +182,7 @@ def create_contour(series_slice: Dataset, contour_data: np.ndarray):
contour.ContourGeometricType = 'CLOSED_PLANAR' # TODO figure out how to get this value
contour.NumberOfContourPoints = len(contour_data) / 3 # Each point has an x, y, and z value
contour.ContourData = contour_data

return contour


Expand All @@ -188,8 +197,12 @@ def create_rtroi_observation(roi_data: ROIData) -> Dataset:
rtroi_observation.ROIInterpreter = ''
return rtroi_observation


def get_contour_sequence_by_roi_number(ds, roi_number):
for roi_contour in ds.ROIContourSequence:
if str(roi_contour.ReferencedROINumber) == str(roi_number): # Ensure same type

# Ensure same type
if str(roi_contour.ReferencedROINumber) == str(roi_number):
return roi_contour.ContourSequence

raise Exception(f"Referenced ROI number '{roi_number}' not found")
87 changes: 49 additions & 38 deletions rt_utils/rtstruct.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,36 @@
import numpy as np
from typing import List

import numpy as np
from pydicom.dataset import FileDataset
from . import ds_helper, image_helper

from rt_utils.utils import ROIData
from . import ds_helper, image_helper


"""
Wrapper class to facilitate appending and extracting ROI's within an RTStruct
"""
class RTStruct:
"""
Wrapper class to facilitate appending and extracting ROI's within an RTStruct
"""

def __init__(self, series_data, ds: FileDataset):
self.series_data = series_data
self.ds = ds
self.frame_of_reference_uid = ds.ReferencedFrameOfReferenceSequence[-1].FrameOfReferenceUID # Use last strucitured set ROI
self.frame_of_reference_uid = ds.ReferencedFrameOfReferenceSequence[-1].FrameOfReferenceUID # Use last strucitured set ROI

"""
Set the series description for the RTStruct dataset
"""
def set_series_description(self, description: str):
"""
Set the series description for the RTStruct dataset
"""

self.ds.SeriesDescription = description

"""
Add a ROI to the rtstruct given a 3D binary mask for the ROI's at each slice
Optionally input a color or name for the ROI
If pin_hole is set to true, will cut a pinhole through ROI's with holes in them so that they are represented with one contour
"""

def add_roi(self, mask: np.ndarray, color=None, name=None, description='', use_pin_hole=False):
"""
Add a ROI to the rtstruct given a 3D binary mask for the ROI's at each slice
Optionally input a color or name for the ROI
If pin_hole is set to true, will cut a pinhole through ROI's with holes in them so that they are represented with one contour
"""

# TODO test if name already exists
self.validate_mask(mask)
roi_number = len(self.ds.StructureSetROISequence) + 1
Expand All @@ -34,48 +40,52 @@ def add_roi(self, mask: np.ndarray, color=None, name=None, description='', use_p
self.ds.StructureSetROISequence.append(ds_helper.create_structure_set_roi(roi_data))
self.ds.RTROIObservationsSequence.append(ds_helper.create_rtroi_observation(roi_data))

def validate_mask(self, mask: np.ndarray):
def validate_mask(self, mask: np.ndarray) -> bool:
if mask.dtype != bool:
raise RTStruct.ROIException(f"Mask data type must be boolean. Got {mask.dtype}")

if mask.ndim != 3:
raise RTStruct.ROIException(f"Mask must be 3 dimensional. Got {mask.ndim}")

if len(self.series_data) != np.shape(mask)[2]:
raise RTStruct.ROIException("Mask must have the save number of layers as input series. " +
f"Expected {len(self.series_data)}, got {np.shape(mask)[2]}")

raise RTStruct.ROIException(
"Mask must have the save number of layers as input series. " +
f"Expected {len(self.series_data)}, got {np.shape(mask)[2]}"
)

if np.sum(mask) == 0:
raise RTStruct.ROIException("Mask cannot be empty")

"""
Returns a list of the names of all ROI within the RTStruct
"""
def get_roi_names(self):
return True

def get_roi_names(self) -> List[str]:
"""
Returns a list of the names of all ROI within the RTStruct
"""

if not self.ds.StructureSetROISequence:
return []

roi_names = []
for structure_roi in self.ds.StructureSetROISequence:
roi_names.append(structure_roi.ROIName)
return roi_names

"""
Returns the 3D binary mask of the ROI with the given input name
"""
def get_roi_mask_by_name(self, name):
return [structure_roi for structure_roi in self.ds.StructureSetROISequence]
asim-shrestha marked this conversation as resolved.
Show resolved Hide resolved

def get_roi_mask_by_name(self, name) -> np.ndarray:
"""
Returns the 3D binary mask of the ROI with the given input name
"""

for structure_roi in self.ds.StructureSetROISequence:
if structure_roi.ROIName == name:
contour_sequence = ds_helper.get_contour_sequence_by_roi_number(self.ds, structure_roi.ROINumber)
return image_helper.create_series_mask_from_contour_sequence(self.series_data, contour_sequence)

raise RTStruct.ROIException(f"ROI of name `{name}` does not exist in RTStruct")

"""
Saves the RTStruct with the specified name / location
Automatically adds '.dcm' as a suffix
"""
def save(self, file_path: str):
"""
Saves the RTStruct with the specified name / location
Automatically adds '.dcm' as a suffix
"""

# Add .dcm if needed
file_path = file_path if file_path.endswith('.dcm') else file_path + '.dcm'

Expand All @@ -92,4 +102,5 @@ class ROIException(Exception):
"""
Exception class for invalid ROI masks
"""
pass

pass
38 changes: 22 additions & 16 deletions rt_utils/rtstruct_builder.py
Original file line number Diff line number Diff line change
@@ -1,36 +1,42 @@
from rt_utils.utils import SOPClassUID
from pydicom.filereader import dcmread
from .rtstruct import RTStruct

from rt_utils.utils import SOPClassUID
from . import ds_helper, image_helper
from .rtstruct import RTStruct


"""
Class to help facilitate the two ways in one can instantiate the RTStruct wrapper
"""
class RTStructBuilder():
class RTStructBuilder:
"""
Method to generate a new rt struct from a DICOM series
Class to help facilitate the two ways in one can instantiate the RTStruct wrapper
"""

@staticmethod
def create_new(dicom_series_path: str):
def create_new(dicom_series_path: str) -> RTStruct:
"""
Method to generate a new rt struct from a DICOM series
"""

series_data = image_helper.load_sorted_image_series(dicom_series_path)
ds = ds_helper.create_rtstruct_dataset(series_data)
return RTStruct(series_data, ds)

"""
Method to load an existing rt struct, given related DICOM series and existing rt struct
"""
@staticmethod
def create_from(dicom_series_path: str, rt_struct_path: str):
def create_from(dicom_series_path: str, rt_struct_path: str) -> RTStruct:
"""
Method to load an existing rt struct, given related DICOM series and existing rt struct
"""

series_data = image_helper.load_sorted_image_series(dicom_series_path)
ds = dcmread(rt_struct_path)
RTStructBuilder.validate_rtstruct(ds)

# TODO create new frame of reference?
return RTStruct(series_data, ds)

@staticmethod
def validate_rtstruct(ds):
if ds.SOPClassUID != SOPClassUID.RTSTRUCT or \
not hasattr(ds, 'ROIContourSequence') or \
not hasattr(ds, 'StructureSetROISequence') or \
not hasattr(ds, 'RTROIObservationsSequence'):
raise Exception("Please check that the existing RTStruct is valid")
not hasattr(ds, 'ROIContourSequence') or \
not hasattr(ds, 'StructureSetROISequence') or \
not hasattr(ds, 'RTROIObservationsSequence'):
raise Exception("Please check that the existing RTStruct is valid")