Skip to content

Commit

Permalink
Merge pull request #540 from karandesai-96/enums-for-cmontecarlo
Browse files Browse the repository at this point in the history
Mirroring C enums in Python for CMontecarlo tests.
  • Loading branch information
yeganer committed Apr 18, 2016
2 parents 1f600fb + 5260af7 commit a442cfc
Show file tree
Hide file tree
Showing 3 changed files with 77 additions and 14 deletions.
64 changes: 64 additions & 0 deletions tardis/montecarlo/enum.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from ctypes import c_int


class EnumerationType(type(c_int)):
"""
From http://code.activestate.com/recipes/576415-ctype-enumeration-class
Metaclass for CEnumeration class.
"""
def __new__(metacls, name, bases, dictionary):
if "_members_" not in dictionary:
_members_ = {}
for key, value in dictionary.items():
if not key.startswith("_"):
_members_[key] = value

dictionary["_members_"] = _members_
else:
_members_ = dictionary["_members_"]

dictionary["_reverse_map_"] = {value: key for key, value in _members_.items()}
cls = type(c_int).__new__(metacls, name, bases, dictionary)

for key, value in cls._members_.items():
globals()[key] = value
return cls

def __repr__(self):
return "<Enumeration %s>" % self.__name__


class CEnumeration(c_int):
"""
From http://code.activestate.com/recipes/576415-ctype-enumeration-class
Python implementation about `enum` datatype of C.
"""
__metaclass__ = EnumerationType
_members_ = {}

def __eq__(self, other):
if isinstance(other, int):
return self.value == other
return type(self) == type(other) and self.value == other.value

def __repr__(self):
return "<%s.%s: %d>" % (self.__class__.__name__,
self._reverse_map_.get(self.value, '(unknown)'),
self.value)


class TardisError(CEnumeration):
OK = 0
BOUNDS_ERROR = 1
COMOV_NU_LESS_THAN_NU_LINE = 2


class RPacketStatus(CEnumeration):
IN_PROCESS = 0
EMITTED = 1
REABSORBED = 2


class ContinuumProcessesStatus(CEnumeration):
OFF = 0
ON = 1
5 changes: 3 additions & 2 deletions tardis/montecarlo/struct.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ctypes import Structure, POINTER, c_int, c_int64, c_double, c_ulong
from enum import RPacketStatus, ContinuumProcessesStatus


class RPacket(Structure):
Expand All @@ -21,7 +22,7 @@ class RPacket(Structure):
('d_boundary', c_double),
('d_cont', c_double),
('next_shell_id', c_int64),
('status', c_int),
('status', RPacketStatus),
('id', c_int64),
('chi_th', c_double),
('chi_cont', c_double),
Expand Down Expand Up @@ -84,7 +85,7 @@ class StorageModel(Structure):
('t_electrons', POINTER(c_double)),
('l_pop', POINTER(c_double)),
('l_pop_r', POINTER(c_double)),
('cont_status', c_int),
('cont_status', ContinuumProcessesStatus),
('virt_packet_nus', POINTER(c_double)),
('virt_packet_energies', POINTER(c_double)),
('virt_packet_last_interaction_in_nu', POINTER(c_double)),
Expand Down
22 changes: 10 additions & 12 deletions tardis/montecarlo/tests/test_cmontecarlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

from tardis import __path__ as path
from tardis.montecarlo.struct import RPacket, StorageModel, RKState
from tardis.montecarlo.enum import TardisError, RPacketStatus, ContinuumProcessesStatus

# Wrap the shared object containing tests for C methods, written in C.
# TODO: Shift all tests here in Python and completely remove this test design.
Expand Down Expand Up @@ -78,7 +79,7 @@ def packet():
current_continuum_id=1,
virtual_packet_flag=1,
virtual_packet=0,
status=0,
status=RPacketStatus.IN_PROCESS,
id=0
)

Expand Down Expand Up @@ -140,6 +141,7 @@ def model():

l_pop=(c_double * 20000)(*([2.0] * 20000)),
l_pop_r=(c_double * 20000)(*([3.0] * 20000)),
cont_status=ContinuumProcessesStatus.OFF
)


Expand Down Expand Up @@ -201,16 +203,16 @@ def test_compute_distance2boundary(packet_params, expected_params, packet, model
@pytest.mark.parametrize(
['packet_params', 'expected_params'],
[({'nu_line': 0.1, 'next_line_id': 0, 'last_line': 1},
{'tardis_error': 0, 'd_line': 1e+99}),
{'tardis_error': TardisError.OK, 'd_line': 1e+99}),
({'nu_line': 0.2, 'next_line_id': 1, 'last_line': 0},
{'tardis_error': 0, 'd_line': 7.792353908000001e+17}),
{'tardis_error': TardisError.OK, 'd_line': 7.792353908000001e+17}),
({'nu_line': 0.5, 'next_line_id': 1, 'last_line': 0},
{'tardis_error': 2, 'd_line': 0.0}),
{'tardis_error': TardisError.COMOV_NU_LESS_THAN_NU_LINE, 'd_line': 0.0}),
({'nu_line': 0.6, 'next_line_id': 0, 'last_line': 0},
{'tardis_error': 2, 'd_line': 0.0})]
{'tardis_error': TardisError.COMOV_NU_LESS_THAN_NU_LINE, 'd_line': 0.0})]
)
def test_compute_distance2line(packet_params, expected_params, packet, model):
packet.nu_line = packet_params['nu_line']
Expand All @@ -226,19 +228,15 @@ def test_compute_distance2line(packet_params, expected_params, packet, model):


@pytest.mark.parametrize(
['packet_params', 'model_params', 'expected_params'],
['packet_params', 'expected_params'],
[({'virtual_packet': 0},
{'cont_status': 0},
{'chi_cont': 6.652486e-16, 'd_cont': 4.359272608766106e+28}),
{'chi_cont': 6.652486e-16, 'd_cont': 4.359272608766106e+28}),
({'virtual_packet': 1},
{'cont_status': 0},
{'chi_cont': 6.652486e-16, 'd_cont': 1e+99})]
)
def test_compute_distance2continuum(packet_params, model_params,
expected_params, packet, model):
def test_compute_distance2continuum(packet_params, expected_params, packet, model):
packet.virtual_packet = packet_params['virtual_packet']
model.cont_status = model_params['cont_status']

cmontecarlo_methods.compute_distance2continuum(byref(packet), byref(model))

Expand Down

0 comments on commit a442cfc

Please sign in to comment.