Skip to content

Commit

Permalink
Specify builder capacities in incremental generation to avoid wasted …
Browse files Browse the repository at this point in the history
…scratch arrays. (#62285)

* Specify builder capacities in incremental generation to avoid wasted scratch arrays.

* Add comment

* Use linq
  • Loading branch information
CyrusNajmabadi authored Jul 1, 2022
1 parent 5510ef9 commit 412dc4c
Show file tree
Hide file tree
Showing 5 changed files with 59 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,15 @@ public static int Count<T>(this ImmutableArray<T> items, Func<T, bool> predicate
return count;
}

public static int Sum<T>(this ImmutableArray<T> items, Func<T, int> selector)
{
var sum = 0;
foreach (var item in items)
sum += selector(item);

return sum;
}

internal static Dictionary<K, ImmutableArray<T>> ToDictionary<K, T>(this ImmutableArray<T> items, Func<T, K> keySelector, IEqualityComparer<K>? comparer = null)
where K : notnull
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ public CombineNode(IIncrementalGeneratorNode<TInput1> input1, IIncrementalGenera
return previousTable;
}

var builder = graphState.CreateTableBuilder(previousTable, _name, _comparer);
var totalEntryItemCount = input1Table.GetTotalEntryItemCount();
var builder = graphState.CreateTableBuilder(previousTable, _name, _comparer, totalEntryItemCount);

// Semantics of a join:
//
Expand Down Expand Up @@ -75,6 +76,7 @@ public CombineNode(IIncrementalGeneratorNode<TInput1> input1, IIncrementalGenera
}
}

Debug.Assert(builder.Count == totalEntryItemCount);
return builder.ToImmutableAndFree();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,10 @@ public NodeStateTable<T> GetLatestStateTableForNode<T>(IIncrementalGeneratorNode
return newTable;
}

public NodeStateTable<T>.Builder CreateTableBuilder<T>(NodeStateTable<T> previousTable, string? stepName, IEqualityComparer<T>? equalityComparer)
public NodeStateTable<T>.Builder CreateTableBuilder<T>(
NodeStateTable<T> previousTable, string? stepName, IEqualityComparer<T>? equalityComparer, int? tableCapacity = null)
{
return previousTable.ToBuilder(stepName, DriverState.TrackIncrementalSteps, equalityComparer);
return previousTable.ToBuilder(stepName, DriverState.TrackIncrementalSteps, equalityComparer, tableCapacity);
}

public DriverStateTable ToImmutable()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ private NodeStateTable(ImmutableArray<TableEntry> states, ImmutableArray<Increme

public ImmutableArray<IncrementalGeneratorRunStep> Steps { get; }

public int GetTotalEntryItemCount()
=> _states.Sum(static e => e.Count);

public IEnumerator<NodeStateEntry<T>> GetEnumerator()
{
for (int i = 0; i < _states.Length; i++)
Expand Down Expand Up @@ -131,8 +134,8 @@ public NodeStateTable<T> AsCached()
return (_states[^1].GetItem(0), HasTrackedSteps ? Steps[^1] : null);
}

public Builder ToBuilder(string? stepName, bool stepTrackingEnabled, IEqualityComparer<T>? equalityComparer = null)
=> new(this, stepName, stepTrackingEnabled, equalityComparer);
public Builder ToBuilder(string? stepName, bool stepTrackingEnabled, IEqualityComparer<T>? equalityComparer = null, int? tableCapacity = null)
=> new(this, stepName, stepTrackingEnabled, equalityComparer, tableCapacity);

public NodeStateTable<T> CreateCachedTableWithUpdatedSteps<TInput>(NodeStateTable<TInput> inputTable, string? stepName, IEqualityComparer<T> equalityComparer)
{
Expand Down Expand Up @@ -160,9 +163,23 @@ public sealed class Builder
[MemberNotNullWhen(true, nameof(_steps))]
public bool TrackIncrementalSteps => _steps is not null;

internal Builder(NodeStateTable<T> previous, string? name, bool stepTrackingEnabled, IEqualityComparer<T>? equalityComparer)
#if DEBUG
private readonly int? _requestedTableCapacity;
#endif

internal Builder(
NodeStateTable<T> previous,
string? name,
bool stepTrackingEnabled,
IEqualityComparer<T>? equalityComparer,
int? tableCapacity)
{
_states = ArrayBuilder<TableEntry>.GetInstance(previous.Count);
#if DEBUG
_requestedTableCapacity = tableCapacity;
#endif
// If the caller specified a desired capacity, then use that. Otherwise, use the previous table's total
// entry count as a reasonable approximation for what we will need.
_states = ArrayBuilder<TableEntry>.GetInstance(tableCapacity ?? previous.GetTotalEntryItemCount());
_previous = previous;
_name = name;
_equalityComparer = equalityComparer ?? EqualityComparer<T>.Default;
Expand All @@ -172,6 +189,8 @@ internal Builder(NodeStateTable<T> previous, string? name, bool stepTrackingEnab
}
}

public int Count => _states.Count;

public bool TryRemoveEntries(TimeSpan elapsedTime, ImmutableArray<(IncrementalGeneratorRunStep InputStep, int OutputIndex)> stepInputs)
{
if (_previous._states.Length <= _states.Count)
Expand Down Expand Up @@ -367,13 +386,25 @@ public NodeStateTable<T> ToImmutableAndFree()
return NodeStateTable<T>.Empty;
}

#if DEBUG
// If the caller requested a specific capacity, then we should have added exactly that amount of items.
Debug.Assert(_requestedTableCapacity == null || _requestedTableCapacity == _states.Count);
#endif

// if we added the exact same entries as before, then we can directly embed previous' entry array,
// avoiding a costly allocation of the same data.
var finalStates = _states.Count == _previous.Count && _states.SequenceEqual(_previous._states, (e1, e2) => e1.Matches(e2, _equalityComparer))
? _previous._states
: _states.ToImmutable();

_states.Free();
ImmutableArray<TableEntry> finalStates;
if (_states.Count == _previous.Count && _states.SequenceEqual(_previous._states, (e1, e2) => e1.Matches(e2, _equalityComparer)))
{
finalStates = _previous._states;
_states.Free();
}
else
{
// Important to use ToImmutableAndFree so that we will MoveToImmutable when the requested capacity
// equals the count.
finalStates = _states.ToImmutableAndFree();
}

return new NodeStateTable<T>(
finalStates,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,8 @@ public NodeStateTable<TOutput> UpdateStateTable(DriverStateTable.Builder builder
// - Added: perform transform and add
// - Modified: perform transform and do element wise comparison with previous results

var newTable = builder.CreateTableBuilder(previousTable, _name, _comparer);
var totalEntryItemCount = sourceTable.GetTotalEntryItemCount();
var newTable = builder.CreateTableBuilder(previousTable, _name, _comparer, totalEntryItemCount);

foreach (var entry in sourceTable)
{
Expand All @@ -77,6 +78,8 @@ public NodeStateTable<TOutput> UpdateStateTable(DriverStateTable.Builder builder
}
}
}

Debug.Assert(newTable.Count == totalEntryItemCount);
return newTable.ToImmutableAndFree();
}

Expand Down

0 comments on commit 412dc4c

Please sign in to comment.