Skip to content

Commit

Permalink
Only update IntervalTrees when we need them
Browse files Browse the repository at this point in the history
  • Loading branch information
jdorn-gt committed Jul 31, 2024
1 parent 6c11d2e commit 79372d7
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 41 deletions.
39 changes: 20 additions & 19 deletions python/gtirb/byteinterval.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import itertools
import typing
from uuid import UUID

from intervaltree import IntervalTree
from sortedcontainers import SortedDict

from .block import ByteBlock, CodeBlock, DataBlock
from .lazyintervaltree import LazyIntervalTree
from .node import Node, _NodeMessage
from .proto import ByteInterval_pb2, SymbolicExpression_pb2
from .symbolicexpression import SymAddrAddr, SymAddrConst, SymbolicExpression
Expand Down Expand Up @@ -162,15 +161,19 @@ def __init__(
raise ValueError("initialized_size must be <= size!")

super().__init__(uuid=uuid)
self._interval_tree: "IntervalTree[int, ByteBlock]" = IntervalTree()
self._section: typing.Optional["Section"] = None
self.address = address
self.size = size
self.contents = bytearray(contents)
self.initialized_size = initialized_size
self.blocks: SetWrapper[ByteBlock] = ByteInterval._BlockSet(
self, blocks

# Both blocks and _interval_tree must exist before adding any blocks.
self.blocks: SetWrapper[ByteBlock] = ByteInterval._BlockSet(self)
self._interval_tree = LazyIntervalTree[int, ByteBlock](
self.blocks, _offset_interval
)
self.blocks.update(blocks)

self._symbolic_expressions = ByteInterval._SymbolicExprDict(
self, symbolic_expressions
)
Expand All @@ -186,20 +189,14 @@ def _index_add_multiple(
old_blocks: typing.Collection[ByteBlock],
new_blocks: typing.Collection[ByteBlock],
) -> None:
if len(old_blocks) < len(new_blocks):
self._interval_tree = IntervalTree(
_offset_interval(block)
for block in itertools.chain(old_blocks, new_blocks)
)
else:
for block in new_blocks:
self._index_add(block)
for block in new_blocks:
self._interval_tree.add(block)

def _index_add(self, block: ByteBlock) -> None:
self._interval_tree.add(_offset_interval(block))
self._interval_tree.add(block)

def _index_discard(self, block: ByteBlock) -> None:
self._interval_tree.discard(_offset_interval(block))
self._interval_tree.discard(block)

@property
def initialized_size(self) -> int:
Expand Down Expand Up @@ -444,7 +441,7 @@ def byte_blocks_on(
return ()

return _nodes_on_interval_tree(
self._interval_tree, addrs, -self.address
self._interval_tree.get(), addrs, -self.address
)

def byte_blocks_at(
Expand All @@ -460,7 +457,7 @@ def byte_blocks_at(
return ()

return _nodes_at_interval_tree(
self._interval_tree, addrs, -self.address
self._interval_tree.get(), addrs, -self.address
)

def code_blocks_on(
Expand Down Expand Up @@ -524,7 +521,9 @@ def byte_blocks_on_offset(
:param offsets: Either a ``range`` object or a single offset.
"""

return _nodes_on_interval_tree_offset(self._interval_tree, offsets)
return _nodes_on_interval_tree_offset(
self._interval_tree.get(), offsets
)

def byte_blocks_at_offset(
self, offsets: typing.Union[int, range]
Expand All @@ -535,7 +534,9 @@ def byte_blocks_at_offset(
:param offsets: Either a ``range`` object or a single offset.
"""

return _nodes_at_interval_tree_offset(self._interval_tree, offsets)
return _nodes_at_interval_tree_offset(
self._interval_tree.get(), offsets
)

def code_blocks_on_offset(
self, offsets: typing.Union[int, range]
Expand Down
118 changes: 118 additions & 0 deletions python/gtirb/lazyintervaltree.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
"""
Implements a simple wrapper that lazily initializes and updates an
IntervalTree.
GTIRB uses IntervalTrees to accelerate certain operations. However, these
operations are not always needed for a given GTIRB object or by a given GTIRB
analysis. To prevent scripts that do not need the IntervalTrees from wasting
time updating the data structures, the LazyIntervalTree in this module delays
instantiating or updating the tree. Instead, it queues the updates so they can
be rapidly applied when the script invokes an operation that requires an
up-to-date tree.
"""

import enum
from typing import (
Collection,
Generic,
Iterator,
List,
Optional,
Protocol,
Tuple,
TypeVar,
)

from intervaltree import Interval, IntervalTree

_K = TypeVar("_K")
_Kco = TypeVar("_Kco", covariant=True)
_V = TypeVar("_V")


class _EventType(enum.Enum):
"""Whether an interval is to be added or discarded."""

ADDED = enum.auto()
DISCARDED = enum.auto()


class IntervalBuilder(Protocol[_Kco, _V]):
"""Gets an interval for certain values.
If no interval is available for a particular value, returns None instead.
"""

def __call__(self, value: _V, /) -> Optional["Interval[_Kco, _V]"]:
...


class LazyIntervalTree(Generic[_K, _V]):
"""Simple wrapper to lazily initialize and update an IntervalTree.
The underlying IntervalTree can be retrieved by calling get(). This will
ensure that the tree is up-to-date with all intermediate modifications
before returning it.
In many algorithms, the tree may receive large numbers of modifications,
adding and removing the same intervals several times before querying. In
these cases, it may be faster to rebuild the tree from scratch rather than
perform all of the intermediate modifications. For this reason, get() is
not guaranteed to always return the same tree object. That is, the tree
returned by get() should not be cached; calling get() may return a new tree
rather than updating the tree it returned previously.
"""

def __init__(
self,
values: Collection[_V],
make_interval: IntervalBuilder[_K, _V],
):
"""Create a new lazy tree.
:param values: collection of values from which the tree can be rebuilt
:param make_interval: callable to get an interval for a value
"""
self._interval_index: Optional["IntervalTree[_K, _V]"] = None
self._interval_events: List[Tuple[_EventType, "Interval[_K, _V]"]] = []
self._value_collection = values
self._make_interval = make_interval

def add(self, value: _V) -> None:
"""Add a value to the tree."""
interval = self._make_interval(value)
if interval is not None:
self._interval_events.append((_EventType.ADDED, interval))

def discard(self, value: _V) -> None:
"""Remove a value from the tree.
Does nothing if the interval with that value is not present.
"""
interval = self._make_interval(value)
if interval is not None:
self._interval_events.append((_EventType.DISCARDED, interval))

def get(self) -> "IntervalTree[_K, _V]":
"""Get the most up-to-date tree reflecting all pending updates."""

def intervals() -> Iterator["Interval[_K, _V]"]:
for value in self._value_collection:
interval = self._make_interval(value)
if interval:
yield interval

if self._interval_index is None:
self._interval_index = IntervalTree(intervals())
elif len(self._value_collection) <= len(self._interval_events):
# Constructing a new tree involves one update for each value.
self._interval_index = IntervalTree(intervals())
else:
# There are fewer updates than constructing a new tree would use.
for event, interval in self._interval_events:
if event == _EventType.ADDED:
self._interval_index.add(interval)
else:
self._interval_index.discard(interval)
self._interval_events.clear()
return self._interval_index
36 changes: 20 additions & 16 deletions python/gtirb/section.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
from enum import Enum
from uuid import UUID

from intervaltree import IntervalTree

from .block import ByteBlock, CodeBlock, DataBlock
from .byteinterval import ByteInterval, SymbolicExpressionElement
from .lazyintervaltree import LazyIntervalTree
from .node import Node, _NodeMessage
from .proto import Section_pb2
from .util import (
Expand Down Expand Up @@ -102,24 +101,27 @@ def __init__(
"""

super().__init__(uuid)
self._interval_index: "IntervalTree[int,ByteInterval]" = IntervalTree()
self._module: typing.Optional["Module"] = None
self.name = name
self.byte_intervals = Section._ByteIntervalSet(self, byte_intervals)

# Both byte_intervals and _interval_index must exist before adding any
# intervals.
self.byte_intervals = Section._ByteIntervalSet(self)
self._interval_index = LazyIntervalTree[int, ByteInterval](
self.byte_intervals, _address_interval
)
self.byte_intervals.update(byte_intervals)

self.flags = set(flags)

# Use the property setter to ensure correct invariants.
self.module = module

def _index_add(self, byte_interval: ByteInterval) -> None:
address_interval = _address_interval(byte_interval)
if address_interval:
self._interval_index.add(address_interval)
self._interval_index.add(byte_interval)

def _index_discard(self, byte_interval: ByteInterval) -> None:
address_interval = _address_interval(byte_interval)
if address_interval:
self._interval_index.discard(address_interval)
self._interval_index.discard(byte_interval)

@classmethod
def _decode_protobuf(
Expand Down Expand Up @@ -233,8 +235,9 @@ def address(self) -> typing.Optional[int]:
size, so it will be ``None`` in that case.
"""

if 0 < len(self._interval_index) == len(self.byte_intervals):
return self._interval_index.begin()
index = self._interval_index.get()
if 0 < len(index) == len(self.byte_intervals):
return index.begin()

return None

Expand All @@ -251,8 +254,9 @@ def size(self) -> typing.Optional[int]:
it has no address or size, so it will be ``None`` in that case.
"""

if 0 < len(self._interval_index) == len(self.byte_intervals):
return self._interval_index.span() - 1
index = self._interval_index.get()
if 0 < len(index) == len(self.byte_intervals):
return index.span() - 1

return None

Expand All @@ -265,7 +269,7 @@ def byte_intervals_on(
:param addrs: Either a ``range`` object or a single address.
"""

return _nodes_on_interval_tree(self._interval_index, addrs)
return _nodes_on_interval_tree(self._interval_index.get(), addrs)

def byte_intervals_at(
self, addrs: typing.Union[int, range]
Expand All @@ -276,7 +280,7 @@ def byte_intervals_at(
:param addrs: Either a ``range`` object or a single address.
"""

return _nodes_at_interval_tree(self._interval_index, addrs)
return _nodes_at_interval_tree(self._interval_index.get(), addrs)

def byte_blocks_on(
self, addrs: typing.Union[int, range]
Expand Down
4 changes: 2 additions & 2 deletions python/stubs/intervaltree/interval.pyi
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Generic, TypeVar

PointT = TypeVar("PointT")
DataT = TypeVar("DataT")
PointT = TypeVar("PointT", covariant=True)
DataT = TypeVar("DataT", covariant=True)

class Interval(Generic[PointT, DataT]):
begin: PointT
Expand Down
12 changes: 8 additions & 4 deletions python/tests/test_blocks_at_offset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,28 +10,32 @@ class BlocksAtOffsetTests(unittest.TestCase):
def test_blocks_at_offset_simple(self):
ir, m, s, bi = create_interval_etc(address=None, size=4)

# Ensure we always have a couple blocks in the index beyond what we
# are querying so that we don't just rebuild the tree from scratch
# every time.
code_block = gtirb.CodeBlock(offset=0, size=1, byte_interval=bi)
code_block2 = gtirb.CodeBlock(offset=1, size=1, byte_interval=bi)
code_block3 = gtirb.CodeBlock(offset=2, size=1, byte_interval=bi)

found = set(bi.byte_blocks_at_offset(0))
self.assertEqual(found, {code_block})

# Change the offset to verify we update the index
code_block.offset = 2
code_block.offset = 3
found = set(bi.byte_blocks_at_offset(0))
self.assertEqual(found, set())

found = set(bi.byte_blocks_at_offset(2))
found = set(bi.byte_blocks_at_offset(3))
self.assertEqual(found, {code_block})

# Discard the block to verify we update the index
bi.blocks.discard(code_block)
found = set(bi.byte_blocks_at_offset(2))
found = set(bi.byte_blocks_at_offset(3))
self.assertEqual(found, set())

# Now add it back to verify we update the index
bi.blocks.add(code_block)
found = set(bi.byte_blocks_at_offset(2))
found = set(bi.byte_blocks_at_offset(3))
self.assertEqual(found, {code_block})

def test_blocks_at_offset_overlapping(self):
Expand Down

0 comments on commit 79372d7

Please sign in to comment.