Skip to content

Commit

Permalink
Convert ir_data to a dataclass
Browse files Browse the repository at this point in the history
This converts `ir_data` over to a `dataclasses.dataclass` and adds
various `FieldSpec` helpers to support that conversion. The `builder`,
`reader`, `IrDataSerializer`, `copy`, and `update` stubs are fully
implemented to support dataclasses as well.

This change results in a 38% speedup against a 75KB test file.

Fixes google#118.
  • Loading branch information
EricRahm committed Jun 12, 2024
1 parent 457234d commit c4d0956
Show file tree
Hide file tree
Showing 14 changed files with 2,348 additions and 700 deletions.
3 changes: 1 addition & 2 deletions compiler/front_end/constraints_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,10 @@ class ConstraintsTest(unittest.TestCase):
def test_error_on_missing_inner_array_size(self):
ir = _make_ir_from_emb("struct Foo:\n"
" 0 [+1] UInt:8[][1] one_byte\n")
error_array = ir.module[0].type[0].structure.field[0].type.array_type
self.assertEqual([[
error.error(
"m.emb",
error_array.base_type.array_type.element_count.source_location,
None, # This is probably a latent bug
"Array dimensions can only be omitted for the outermost dimension.")
]], error.filter_errors(constraints.check_constraints(ir)))

Expand Down
11 changes: 6 additions & 5 deletions compiler/front_end/module_ir_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from compiler.front_end import parser
from compiler.front_end import tokenizer
from compiler.util import ir_data
from compiler.util import ir_data_fields
from compiler.util import ir_data_utils
from compiler.util import test_util

Expand Down Expand Up @@ -4089,29 +4090,29 @@ def _check_all_source_locations(proto, path="", min_start=None, max_end=None):
child_end = None
# Only check the source_location value if this proto message actually has a
# source_location field.
if "source_location" in proto.raw_fields:
if proto.HasField("source_location"):
errors.extend(_check_source_location(proto.source_location,
path + "source_location",
min_start, max_end))
child_start = proto.source_location.start
child_end = proto.source_location.end

for name, spec in proto.field_specs.items():
for name, spec in ir_data_fields.field_specs(proto).items():
if name == "source_location":
continue
if not proto.HasField(name):
continue
field_path = "{}{}".format(path, name)
if isinstance(spec, ir_pb2.Repeated):
if issubclass(spec.type, ir_pb2.Message):
if spec.is_sequence:
if spec.is_dataclass:
index = 0
for i in getattr(proto, name):
item_path = "{}[{}]".format(field_path, index)
index += 1
errors.extend(
_check_all_source_locations(i, item_path, child_start, child_end))
else:
if issubclass(spec.type, ir_data.Message):
if spec.is_dataclass:
errors.extend(_check_all_source_locations(getattr(proto, name),
field_path, child_start,
child_end))
Expand Down
2 changes: 1 addition & 1 deletion compiler/front_end/symbol_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,7 @@ def _find_target_of_reference(reference, table, current_scope, visible_scopes,
name = reference.source_name[0].text
for scope in visible_scopes:
scoped_table = table[scope.module_file]
for path_element in scope.object_path:
for path_element in scope.object_path or []:
scoped_table = scoped_table[path_element]
if (name in scoped_table and
(scope == current_scope or
Expand Down
10 changes: 5 additions & 5 deletions compiler/front_end/synthetics.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ def _mark_as_synthetic(proto):
return
if hasattr(proto, "source_location"):
ir_data_utils.builder(proto).source_location.is_synthetic = True
for name, value in proto.raw_fields.items():
if name != "source_location":
if isinstance(value, ir_data.TypedScopedList):
for i in range(len(value)):
_mark_as_synthetic(value[i])
for spec, value in ir_data_utils.get_set_fields(proto):
if spec.name != "source_location" and spec.is_dataclass:
if spec.is_sequence:
for i in value:
_mark_as_synthetic(i)
else:
_mark_as_synthetic(value)

Expand Down
14 changes: 14 additions & 0 deletions compiler/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,22 @@ py_library(
name = "ir_data",
srcs = [
"ir_data.py",
"ir_data_fields.py",
"ir_data_utils.py",
],
deps = [],
)

py_test(
name = "ir_data_fields_test",
srcs = ["ir_data_fields_test.py"],
deps = [":ir_data"],
)

py_test(
name = "ir_data_utils_test",
srcs = ["ir_data_utils_test.py"],
deps = [":expression_parser", ":ir_data"],
)

py_library(
Expand Down
13 changes: 10 additions & 3 deletions compiler/util/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
]
"""

from compiler.util import ir_data_utils
from compiler.util import parser_types

# Error levels; represented by the strings that will be included in messages.
Expand Down Expand Up @@ -65,20 +66,26 @@
BOLD = "\033[0;1m"
RESET = "\033[0m"

def _copy(location):
location = ir_data_utils.copy(location)
if not location:
location = parser_types.make_location((0,0), (0,0))
return location


def error(source_file, location, message):
"""Returns an object representing an error message."""
return _Message(source_file, location, ERROR, message)
return _Message(source_file, _copy(location), ERROR, message)


def warn(source_file, location, message):
"""Returns an object representing a warning."""
return _Message(source_file, location, WARNING, message)
return _Message(source_file, _copy(location), WARNING, message)


def note(source_file, location, message):
"""Returns and object representing an informational note."""
return _Message(source_file, location, NOTE, message)
return _Message(source_file, _copy(location), NOTE, message)


class _Message(object):
Expand Down
Loading

0 comments on commit c4d0956

Please sign in to comment.