diff --git a/aws_xray_sdk/core/models/entity.py b/aws_xray_sdk/core/models/entity.py index 293f220b..34378d4a 100644 --- a/aws_xray_sdk/core/models/entity.py +++ b/aws_xray_sdk/core/models/entity.py @@ -280,13 +280,16 @@ def to_dict(self): subsegments.append(subsegment.to_dict()) entity_dict[key] = subsegments elif key == 'cause': - entity_dict[key] = {} - entity_dict[key]['working_directory'] = self.cause['working_directory'] - # exceptions are stored as List - throwables = [] - for throwable in value['exceptions']: - throwables.append(throwable.to_dict()) - entity_dict[key]['exceptions'] = throwables + if isinstance(self.cause, dict): + entity_dict[key] = {} + entity_dict[key]['working_directory'] = self.cause['working_directory'] + # exceptions are stored as List + throwables = [] + for throwable in value['exceptions']: + throwables.append(throwable.to_dict()) + entity_dict[key]['exceptions'] = throwables + else: + entity_dict[key] = self.cause elif key == 'metadata': entity_dict[key] = metadata_to_dict(value) elif key != 'sampled' and key != ORIGIN_TRACE_HEADER_ATTR_KEY: diff --git a/tests/test_serialize_entities.py b/tests/test_serialize_entities.py index da5d50a2..d46d737b 100644 --- a/tests/test_serialize_entities.py +++ b/tests/test_serialize_entities.py @@ -260,7 +260,7 @@ class TestException(Exception): def __init__(self, message): super(TestException, self).__init__(message) - segment = Segment('test') + segment_one = Segment('test') stack_one = [ ('/path/to/test.py', 10, 'module', 'another_function()'), @@ -275,18 +275,18 @@ def __init__(self, message): exception_one = TestException('test message one') exception_two = TestException('test message two') - segment.add_exception(exception_one, stack_one, True) - segment.add_exception(exception_two, stack_two, False) + segment_one.add_exception(exception_one, stack_one, True) + segment_one.add_exception(exception_two, stack_two, False) - segment.close() + segment_one.close() - expected_segment_dict = { - "id": segment.id, + expected_segment_one_dict = { + "id": segment_one.id, "name": "test", - "start_time": segment.start_time, + "start_time": segment_one.start_time, "in_progress": False, "cause": { - "working_directory": segment.cause['working_directory'], + "working_directory": segment_one.cause['working_directory'], "exceptions": [ { "id": exception_one._cause_id, @@ -326,14 +326,39 @@ def __init__(self, message): } ] }, - "trace_id": segment.trace_id, + "trace_id": segment_one.trace_id, "fault": True, - "end_time": segment.end_time + "end_time": segment_one.end_time } - actual_segment_dict = entity_to_dict(segment) + segment_two = Segment('test') + subsegment = Subsegment('test', 'local', segment_two) + + subsegment.add_exception(exception_one, stack_one, True) + subsegment.add_exception(exception_two, stack_two, False) + subsegment.close() - assert expected_segment_dict == actual_segment_dict + # will record cause id instead as same exception already recorded in its subsegment + segment_two.add_exception(exception_one, stack_one, True) + + segment_two.close() + + expected_segment_two_dict = { + "id": segment_two.id, + "name": "test", + "start_time": segment_two.start_time, + "in_progress": False, + "cause": exception_one._cause_id, + "trace_id": segment_two.trace_id, + "fault": True, + "end_time": segment_two.end_time + } + + actual_segment_one_dict = entity_to_dict(segment_one) + actual_segment_two_dict = entity_to_dict(segment_two) + + assert expected_segment_one_dict == actual_segment_one_dict + assert expected_segment_two_dict == actual_segment_two_dict def test_serialize_subsegment():