Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor performance for montecarlo #2792

4 changes: 2 additions & 2 deletions tardis/transport/montecarlo/packet_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class BlackBodySimpleSource(BasePacketSource, HDFWriterMixin):

hdf_properties = ["radius", "temperature", "base_seed"]
hdf_name = "black_body_simple_source"
l_coef = np.pi**4 / 90.0

@classmethod
def from_simulation_state(cls, simulation_state, *args, **kwargs):
Expand Down Expand Up @@ -210,15 +211,14 @@ def create_packet_nus(self, no_of_packets, l_samples=1000):
numpy.ndarray
"""
l_array = np.cumsum(np.arange(1, l_samples, dtype=np.float64) ** -4)
l_coef = np.pi**4 / 90.0

# For testing purposes
if self.legacy_mode_enabled:
xis = np.random.random((5, no_of_packets))
else:
xis = self.rng.random((5, no_of_packets))

l = l_array.searchsorted(xis[0] * l_coef) + 1.0
l = l_array.searchsorted(xis[0] * self.l_coef) + 1.0
xis_prod = np.prod(xis[1:], 0)
x = ne.evaluate("-log(xis_prod)/l")

Expand Down
51 changes: 27 additions & 24 deletions tardis/transport/montecarlo/packet_trackers.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,31 +177,34 @@
Dataframe containing properties of RPackets as columns like status, seed, r, nu, mu, energy, shell_id, interaction_type

"""
len_df = sum([len(tracker.r) for tracker in rpacket_trackers])
index_array = np.empty([2, len_df], dtype="int")
df_dtypes = np.dtype(
[
("status", np.int64),
("seed", np.int64),
("r", np.float64),
("nu", np.float64),
("mu", np.float64),
("energy", np.float64),
("shell_id", np.int64),
("interaction_type", np.int64),
]
)
len_df = sum(len(tracker.r) for tracker in rpacket_trackers)

index_array = np.empty((2, len_df), dtype="int")
df_dtypes = np.dtype([
("status", np.int64),
("seed", np.int64),
("r", np.float64),
("nu", np.float64),
("mu", np.float64),
("energy", np.float64),
("shell_id", np.int64),
("interaction_type", np.int64),
])
rpacket_tracker_ndarray = np.empty(len_df, df_dtypes)

cur_index = 0
for rpacket_tracker in rpacket_trackers:
prev_index = cur_index
cur_index = prev_index + len(rpacket_tracker.r)
for j, column_name in enumerate(df_dtypes.fields.keys()):
rpacket_tracker_ndarray[column_name][
prev_index:cur_index
] = getattr(rpacket_tracker, column_name)
index_array[0][prev_index:cur_index] = getattr(rpacket_tracker, "index")
index_array[1][prev_index:cur_index] = range(cur_index - prev_index)
length = len(rpacket_tracker.r)
next_index = cur_index + length

for column_name in df_dtypes.names:
rpacket_tracker_ndarray[column_name][cur_index:next_index] = getattr(rpacket_tracker, column_name)

index_array[0][cur_index:next_index] = rpacket_tracker.index
index_array[1][cur_index:next_index] = np.arange(length)

cur_index = next_index

return pd.DataFrame(
rpacket_tracker_ndarray,
index=pd.MultiIndex.from_arrays(index_array, names=["index", "step"]),
Expand Down Expand Up @@ -288,7 +291,7 @@
A list containing RPacketTracker for each RPacket
"""
rpacket_trackers = List()
for i in range(no_of_packets):
for _ in range(no_of_packets):

Check warning on line 294 in tardis/transport/montecarlo/packet_trackers.py

View check run for this annotation

Codecov / codecov/patch

tardis/transport/montecarlo/packet_trackers.py#L294

Added line #L294 was not covered by tests
rpacket_trackers.append(RPacketTracker(length))
return rpacket_trackers

Expand All @@ -305,6 +308,6 @@
A list containing RPacketLastInteractionTracker for each RPacket
"""
rpacket_trackers = List()
for i in range(no_of_packets):
for _ in range(no_of_packets):

Check warning on line 311 in tardis/transport/montecarlo/packet_trackers.py

View check run for this annotation

Codecov / codecov/patch

tardis/transport/montecarlo/packet_trackers.py#L311

Added line #L311 was not covered by tests
rpacket_trackers.append(RPacketLastInteractionTracker())
return rpacket_trackers
Loading