diff --git a/tardis/montecarlo/enum.py b/tardis/montecarlo/enum.py deleted file mode 100644 index e8f9e767ad3..00000000000 --- a/tardis/montecarlo/enum.py +++ /dev/null @@ -1,64 +0,0 @@ -from ctypes import c_int - - -class EnumerationType(type(c_int)): - """ - From http://code.activestate.com/recipes/576415-ctype-enumeration-class - Metaclass for CEnumeration class. - """ - def __new__(metacls, name, bases, dictionary): - if "_members_" not in dictionary: - _members_ = {} - for key, value in dictionary.items(): - if not key.startswith("_"): - _members_[key] = value - - dictionary["_members_"] = _members_ - else: - _members_ = dictionary["_members_"] - - dictionary["_reverse_map_"] = {value: key for key, value in _members_.items()} - cls = type(c_int).__new__(metacls, name, bases, dictionary) - - for key, value in cls._members_.items(): - globals()[key] = value - return cls - - def __repr__(self): - return "" % self.__name__ - - -class CEnumeration(c_int): - """ - From http://code.activestate.com/recipes/576415-ctype-enumeration-class - Python implementation about `enum` datatype of C. - """ - __metaclass__ = EnumerationType - _members_ = {} - - def __eq__(self, other): - if isinstance(other, int): - return self.value == other - return type(self) == type(other) and self.value == other.value - - def __repr__(self): - return "<%s.%s: %d>" % (self.__class__.__name__, - self._reverse_map_.get(self.value, '(unknown)'), - self.value) - - -class TardisError(CEnumeration): - OK = 0 - BOUNDS_ERROR = 1 - COMOV_NU_LESS_THAN_NU_LINE = 2 - - -class RPacketStatus(CEnumeration): - IN_PROCESS = 0 - EMITTED = 1 - REABSORBED = 2 - - -class ContinuumProcessesStatus(CEnumeration): - OFF = 0 - ON = 1 diff --git a/tardis/montecarlo/struct.py b/tardis/montecarlo/struct.py index f6b4cfc428c..976d84be834 100644 --- a/tardis/montecarlo/struct.py +++ b/tardis/montecarlo/struct.py @@ -1,5 +1,8 @@ from ctypes import Structure, POINTER, c_int, c_int64, c_double, c_ulong -from enum import RPacketStatus, ContinuumProcessesStatus + +c_tardis_error_t = c_int +c_rpacket_status_t = c_int +c_cont_status_t = c_int class RPacket(Structure): @@ -22,7 +25,7 @@ class RPacket(Structure): ('d_boundary', c_double), ('d_cont', c_double), ('next_shell_id', c_int64), - ('status', RPacketStatus), + ('status', c_rpacket_status_t), ('id', c_int64), ('chi_th', c_double), ('chi_cont', c_double), @@ -85,7 +88,7 @@ class StorageModel(Structure): ('t_electrons', POINTER(c_double)), ('l_pop', POINTER(c_double)), ('l_pop_r', POINTER(c_double)), - ('cont_status', ContinuumProcessesStatus), + ('cont_status', c_cont_status_t), ('virt_packet_nus', POINTER(c_double)), ('virt_packet_energies', POINTER(c_double)), ('virt_packet_last_interaction_in_nu', POINTER(c_double)), @@ -104,3 +107,17 @@ class RKState(Structure): ('has_gauss', c_int), ('gauss', c_double) ] + +# Variables corresponding to `tardis_error_t` enum. +TARDIS_ERROR_OK = 0 +TARDIS_ERROR_BOUNDS_ERROR = 1 +TARDIS_ERROR_COMOV_NU_LESS_THAN_NU_LINE = 2 + +# Variables corresponding to `rpacket_status_t` enum. +TARDIS_PACKET_STATUS_IN_PROCESS = 0 +TARDIS_PACKET_STATUS_EMITTED = 1 +TARDIS_PACKET_STATUS_REABSORBED = 2 + +# Variables corresponding to `ContinuumProcessesStatus` enum. +CONTINUUM_OFF = 0 +CONTINUUM_ON = 1 diff --git a/tardis/montecarlo/tests/test_cmontecarlo.py b/tardis/montecarlo/tests/test_cmontecarlo.py index 9d380953308..1c6e4c8a5ac 100644 --- a/tardis/montecarlo/tests/test_cmontecarlo.py +++ b/tardis/montecarlo/tests/test_cmontecarlo.py @@ -49,8 +49,17 @@ from numpy.testing import assert_almost_equal from tardis import __path__ as path -from tardis.montecarlo.struct import RPacket, StorageModel, RKState -from tardis.montecarlo.enum import TardisError, RPacketStatus, ContinuumProcessesStatus +from tardis.montecarlo.struct import ( + RPacket, StorageModel, RKState, + TARDIS_ERROR_OK, + TARDIS_ERROR_BOUNDS_ERROR, + TARDIS_ERROR_COMOV_NU_LESS_THAN_NU_LINE, + TARDIS_PACKET_STATUS_IN_PROCESS, + TARDIS_PACKET_STATUS_EMITTED, + TARDIS_PACKET_STATUS_REABSORBED, + CONTINUUM_OFF, + CONTINUUM_ON +) # Wrap the shared object containing tests for C methods, written in C. # TODO: Shift all tests here in Python and completely remove this test design. @@ -79,7 +88,7 @@ def packet(): current_continuum_id=1, virtual_packet_flag=1, virtual_packet=0, - status=RPacketStatus.IN_PROCESS, + status=TARDIS_PACKET_STATUS_IN_PROCESS, id=0 ) @@ -141,7 +150,7 @@ def model(): l_pop=(c_double * 20000)(*([2.0] * 20000)), l_pop_r=(c_double * 20000)(*([3.0] * 20000)), - cont_status=ContinuumProcessesStatus.OFF + cont_status=CONTINUUM_OFF ) @@ -203,16 +212,16 @@ def test_compute_distance2boundary(packet_params, expected_params, packet, model @pytest.mark.parametrize( ['packet_params', 'expected_params'], [({'nu_line': 0.1, 'next_line_id': 0, 'last_line': 1}, - {'tardis_error': TardisError.OK, 'd_line': 1e+99}), + {'tardis_error': TARDIS_ERROR_OK, 'd_line': 1e+99}), ({'nu_line': 0.2, 'next_line_id': 1, 'last_line': 0}, - {'tardis_error': TardisError.OK, 'd_line': 7.792353908000001e+17}), + {'tardis_error': TARDIS_ERROR_OK, 'd_line': 7.792353908000001e+17}), ({'nu_line': 0.5, 'next_line_id': 1, 'last_line': 0}, - {'tardis_error': TardisError.COMOV_NU_LESS_THAN_NU_LINE, 'd_line': 0.0}), + {'tardis_error': TARDIS_ERROR_COMOV_NU_LESS_THAN_NU_LINE, 'd_line': 0.0}), ({'nu_line': 0.6, 'next_line_id': 0, 'last_line': 0}, - {'tardis_error': TardisError.COMOV_NU_LESS_THAN_NU_LINE, 'd_line': 0.0})] + {'tardis_error': TARDIS_ERROR_COMOV_NU_LESS_THAN_NU_LINE, 'd_line': 0.0})] ) def test_compute_distance2line(packet_params, expected_params, packet, model): packet.nu_line = packet_params['nu_line']