diff --git a/diffdrr/drr.py b/diffdrr/drr.py index a88fecaa0..12a0ccddf 100644 --- a/diffdrr/drr.py +++ b/diffdrr/drr.py @@ -65,7 +65,7 @@ def __init__( self.air = torch.where(self.volume <= -800) self.soft_tissue = torch.where((-800 < self.volume) & (self.volume <= 350)) self.bone = torch.where(350 < self.volume) - self.set_bone_attenuation_multiplier(bone_attenuation_multiplier) + self.bone_attenuation_multiplier = bone_attenuation_multiplier def reshape_transform(self, img, batch_size): if self.reshape: @@ -104,6 +104,8 @@ def forward( bone_attenuation_multiplier: float = None, # Contrast ratio of bone to soft tissue ): """Generate DRR with rotational and translational parameters.""" + if not hasattr(self, "density"): + self.set_bone_attenuation_multiplier(self.bone_attenuation_multiplier) if bone_attenuation_multiplier is not None: self.set_bone_attenuation_multiplier(bone_attenuation_multiplier) diff --git a/notebooks/api/00_drr.ipynb b/notebooks/api/00_drr.ipynb index 87f8ff816..5e36ae690 100644 --- a/notebooks/api/00_drr.ipynb +++ b/notebooks/api/00_drr.ipynb @@ -157,7 +157,7 @@ " self.air = torch.where(self.volume <= -800)\n", " self.soft_tissue = torch.where((-800 < self.volume) & (self.volume <= 350))\n", " self.bone = torch.where(350 < self.volume)\n", - " self.set_bone_attenuation_multiplier(bone_attenuation_multiplier) \n", + " self.bone_attenuation_multiplier = bone_attenuation_multiplier \n", "\n", " def reshape_transform(self, img, batch_size):\n", " if self.reshape:\n", @@ -220,6 +220,8 @@ " bone_attenuation_multiplier: float = None, # Contrast ratio of bone to soft tissue\n", "):\n", " \"\"\"Generate DRR with rotational and translational parameters.\"\"\"\n", + " if not hasattr(self, \"density\"):\n", + " self.set_bone_attenuation_multiplier(self.bone_attenuation_multiplier)\n", " if bone_attenuation_multiplier is not None:\n", " self.set_bone_attenuation_multiplier(bone_attenuation_multiplier)\n", " \n",