Skip to content

Commit

Permalink
Merge pull request #243 from slaclab/emittance_mae
Browse files Browse the repository at this point in the history
Emittance mean absolute error
  • Loading branch information
roussel-ryan authored Feb 24, 2025
2 parents 134b152 + 9710783 commit ce3cae4
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 6 deletions.
4 changes: 2 additions & 2 deletions lcls_tools/common/data/emittance.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def loss_torch(params):
params = torch.reshape(params, [*beamsize_squared.shape[:-2], 3])
sig = torch.stack(beam_matrix_tuple(params), dim=-1).unsqueeze(-1)
# sig should now be shape batchshape x 3 x 1 (column vectors)
total_squared_error = (amat @ sig - beamsize_squared).pow(2).sum()
total_squared_error = (amat @ sig - beamsize_squared).abs().sum()
return total_squared_error

def loss_jacobian(params):
Expand All @@ -130,7 +130,7 @@ def loss(params):
params = np.reshape(params, [*beamsize_squared.shape[:-2], 3])
sig = np.expand_dims(np.stack(beam_matrix_tuple(params), axis=-1), axis=-1)
# sig should now be shape batchshape x 3 x 1 (column vectors)
total_squared_error = np.sum((amat @ sig - beamsize_squared) ** 2)
total_squared_error = np.sum(np.abs(amat @ sig - beamsize_squared))
return total_squared_error

loss_jacobian = None
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/lcls_tools/common/image/test_image_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def test_fit_and_visualization(self):

result = ImageProjectionFit().fit_image(test_image)

assert np.allclose(result.centroid, [5, 5])
assert np.allclose(result.centroid, [4.5, 4.5])
assert np.allclose(result.rms_size, [1.16, 1.16])
assert np.allclose(result.total_intensity, 1020.0)
assert np.allclose(result.image, test_image)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ def mock_function(scan_settings, function):

# check outputs based on nans in the data
assert np.equal(
result.quadrupole_pv_values[0], np.concat((k[:6], k[7:]))
result.quadrupole_pv_values[0], np.concatenate((k[:6], k[7:]))
).all()
assert np.equal(
result.quadrupole_pv_values[1], np.concat((k[:1], k[3:]))
result.quadrupole_pv_values[1], np.concatenate((k[:1], k[3:]))
).all()

assert np.allclose(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_measure(self):
assert result.rms_sizes.shape == (1, 2)
assert result.total_intensities.shape == (1,)
assert np.allclose(result.rms_sizes, np.array([8.0347, 8.0347]))
assert np.allclose(result.centroids, np.array([50, 50]))
assert np.allclose(result.centroids.flatten(), np.array([49.5, 49.5]))
assert np.allclose(result.total_intensities, np.array([102000.0]))

assert result.metadata == self.measurement.model_dump()
Expand Down

0 comments on commit ce3cae4

Please sign in to comment.