Skip to content

Commit

Permalink
refactored src and obs loops
Browse files Browse the repository at this point in the history
  • Loading branch information
pmelchior committed May 1, 2024
1 parent 52ba78b commit fcd1b29
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 45 deletions.
6 changes: 3 additions & 3 deletions docs/0-quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"Since we use a trivial `wcs` in this `Observation`, all coordinates are already in image pixels, otherwise RA/Dec pairs are expected as sky coordinates. Also:"
"Since we use a trivial `wcs` in this `Observation`, all coordinates are already in image pixels, otherwise RA/Dec pairs are expected as sky coordinates."
]
},
{
Expand Down Expand Up @@ -253,7 +253,7 @@
" else:\n",
" new_source = scarlet.ExtendedSource(model_frame, center, observation, compact=True)\n",
" sources.append(new_source)\n",
" \n",
"\n",
"for k, src in enumerate(sources):\n",
" print (f\"{k}: {src.__class__.__name__}\")"
]
Expand All @@ -280,7 +280,7 @@
"source": [
"## Create and Fit Model\n",
"\n",
"The `Blend` class holds the list of sources and has the machinery to fit them to the given images. In this example the code is set to run for a maximum of 100 iterations, but will end early if the likelihood and all of the constraints converge."
"The `Blend` class holds the list of sources and has the machinery to fit them to the given images. In this example the code is set to run for a maximum of 100 iterations, but will end early if the likelihood and all the constraints converge."
]
},
{
Expand Down
84 changes: 42 additions & 42 deletions scarlet/initialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,50 +508,53 @@ def set_spectra_to_match(sources, observations):
observations = (observations,)
model_frame = observations[0].model_frame

for obs in observations:
# extract model for every component
morphs = []
parameters = []
update_of = []
for i, src in enumerate(sources):
if isinstance(src, CombinedComponent):
components = src.children
else:
components = (src,)
for j, c in enumerate(components):
p = c.get_parameter(
"spectrum"
) # returns None of c doesn't have parameter "spectrum"
parameters.append(p)
model = obs.render(c.get_model(frame=model_frame))
# correct for different flux in channels to have flat-spectrum component
if p is not None:
model /= obs.renderer.map_channels(p)[:, None, None]
# check for models with identical initializations, see #282
# if duplicate: remove morph[k] from linear fit, but keep track of parameters[k]
# to set spectrum later: update_of: component index -> updated spectrum index
K_ = len(morphs)
update_of.append(K_)
for l in range(K_):
if np.allclose(model, morphs[l]):
update_of[-1] = l
message = f"Source {i}, Component {j} has a model identical to another component.\n"
message += "This is likely not intended, and the source/component should be deleted. "
message += "Spectra will be identical."
logger.warning(message)
if update_of[-1] == K_:
morphs.append(model)
morphs = np.array(morphs)
# extract multi-channel model for every non-degenerate component
parameters = []
update_of = []
models = []
for i, src in enumerate(sources):
if isinstance(src, CombinedComponent):
components = src.children
else:
components = (src,)

for j, c in enumerate(components):
p = c.get_parameter(
"spectrum"
) # returns None of c doesn't have parameter "spectrum"
parameters.append(p)
# correct for different flux in channels to have flat-spectrum component
if p is not None and not p.fixed:
p[:] = 1
model = c.get_model(frame=model_frame)

# check for models with identical initializations, see #282
# if duplicate: remove morph[k] from linear fit, but keep track of parameters[k]
# to set spectrum later: update_of: component index -> updated spectrum index
K_ = len(models)
update_of.append(K_)
for l in range(K_):
if np.allclose(model, models[l]):
update_of[-1] = l
message = f"Source {i}, Component {j} has a model identical to another component.\n"
message += "This is likely not intended, and the source/component should be deleted. "
message += "Spectra will be identical."
logger.warning(message)
if update_of[-1] == K_:
models.append(model)
models = np.array(models)
K = len(parameters)
K_ = len(models)

for obs in observations:
# independent channels, no mixing
# solve the linear inverse problem of the amplitudes in every channel
# given all the rendered morphologies
# spectrum = (M^T Sigma^-1 M)^-1 M^T Sigma^-1 * im
K = len(parameters)
K_ = len(morphs)
C = obs.C
images = obs.data
weights = obs.weights
morphs = np.stack([obs.render(model) for model in models], axis=0)
spectra = np.zeros((K_, C))
for c in range(C):
im = images[c].reshape(-1)
Expand All @@ -574,12 +577,9 @@ def set_spectra_to_match(sources, observations):

# update the parameters with the best-fit spectrum solution
for k, p in enumerate(parameters):
if p is None:
continue
if p.fixed:
continue
l = update_of[k]
obs.renderer.map_channels(p)[:] = spectra[l]
if p is not None and not p.fixed:
l = update_of[k]
obs.renderer.map_channels(p)[:] = spectra[l]

# enforce constraints
for p in parameters:
Expand Down

0 comments on commit fcd1b29

Please sign in to comment.