Skip to content

Commit

Permalink
Make emitted egal code more loopy
Browse files Browse the repository at this point in the history
The strategy here is to look at (data, padding) pairs and RLE
them into loops, so that repeated adjacent patterns use a loop
rather than getting unrolled. On the test case from #54109,
this makes compilation essentially instant, while also being
faster at runtime (turns out LLVM spends a massive amount of time
AND the answer is bad).

There's some obvious further enhancements possible here:
1. The `memcmp` constant is small. LLVM has a pass to inline these
   with better code. However, we don't have it turned on. We should
   consider vendoring it, though we may want to add some shorcutting
   to it to avoid having it iterate through each function.
2. This only does one level of sequence matching. It could be recursed
   to turn things into nested loops.

However, this solves the immediate issue, so hopefully it's a useful
start. Fixes #54109.
  • Loading branch information
Keno committed Apr 24, 2024
1 parent 7ba1b33 commit b5a5ea1
Show file tree
Hide file tree
Showing 4 changed files with 214 additions and 11 deletions.
137 changes: 136 additions & 1 deletion src/codegen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3358,6 +3358,58 @@ static Value *emit_bitsunion_compare(jl_codectx_t &ctx, const jl_cgval_t &arg1,
return phi;
}

struct egal_desc {
size_t offset;
size_t nrepeats;
size_t data_bytes;
size_t padding_bytes;
};

template <typename callback>
static size_t emit_masked_bits_compare(callback &emit_desc, jl_datatype_t *aty, egal_desc &current_desc)
{
// Memcmp, but with masked padding
size_t data_bytes = 0;
size_t padding_bytes = 0;
size_t nfields = jl_datatype_nfields(aty);
size_t total_size = jl_datatype_size(aty);
for (size_t i = 0; i < nfields; ++i) {
size_t offset = jl_field_offset(aty, i);
size_t fend = i == nfields - 1 ? total_size : jl_field_offset(aty, i + 1);
size_t fsz = jl_field_size(aty, i);
jl_datatype_t *fty = (jl_datatype_t*)jl_field_type(aty, i);
if (jl_field_isptr(aty, i) || !fty->layout->flags.haspadding) {
// The field has no internal padding
data_bytes += fsz;
if (offset + fsz == fend) {
// The field has no padding after. Merge this into the current
// comparison range and go to next field.
} else {
padding_bytes = fend - offset - fsz;
// Found padding. Either merge this into the current comparison
// range, or emit the old one and start a new one.
if (current_desc.data_bytes == data_bytes &&
current_desc.padding_bytes == padding_bytes) {
// Same as the previous range, just note that down, so we
// emit this as a loop.
current_desc.nrepeats += 1;
} else {
if (current_desc.nrepeats != 0)
emit_desc(current_desc);
current_desc.nrepeats = 1;
current_desc.data_bytes = data_bytes;
current_desc.padding_bytes = padding_bytes;
}
data_bytes = 0;
}
} else {
// The field may have internal padding. Recurse this.
data_bytes += emit_masked_bits_compare(emit_desc, fty, current_desc);
}
}
return data_bytes;
}

static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t arg2)
{
++EmittedBitsCompares;
Expand Down Expand Up @@ -3396,7 +3448,7 @@ static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t a
if (at->isAggregateType()) { // Struct or Array
jl_datatype_t *sty = (jl_datatype_t*)arg1.typ;
size_t sz = jl_datatype_size(sty);
if (sz > 512 && !sty->layout->flags.haspadding) {
if (sz > 512 && !sty->layout->flags.haspadding && sty->layout->flags.isbitsegal) {
Value *varg1 = arg1.ispointer() ? data_pointer(ctx, arg1) :
value_to_pointer(ctx, arg1).V;
Value *varg2 = arg2.ispointer() ? data_pointer(ctx, arg2) :
Expand Down Expand Up @@ -3433,6 +3485,89 @@ static Value *emit_bits_compare(jl_codectx_t &ctx, jl_cgval_t arg1, jl_cgval_t a
}
return ctx.builder.CreateICmpEQ(answer, ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0));
}
else if (sz > 512 && jl_struct_try_layout(sty) && sty->layout->flags.isbitsegal) {
Type *TInt8 = getInt8Ty(ctx.builder.getContext());
Type *TpInt8 = getInt8PtrTy(ctx.builder.getContext());
Type *TInt1 = getInt1Ty(ctx.builder.getContext());
Value *varg1 = arg1.ispointer() ? data_pointer(ctx, arg1) :
value_to_pointer(ctx, arg1).V;
Value *varg2 = arg2.ispointer() ? data_pointer(ctx, arg2) :
value_to_pointer(ctx, arg2).V;
varg1 = emit_pointer_from_objref(ctx, varg1);
varg2 = emit_pointer_from_objref(ctx, varg2);
varg1 = emit_bitcast(ctx, varg1, TpInt8);
varg2 = emit_bitcast(ctx, varg2, TpInt8);

Value *answer = nullptr;
auto emit_desc = [&](egal_desc desc) {
Value *ptr1 = varg1;
Value *ptr2 = varg2;
if (desc.offset != 0) {
ptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr1, desc.offset);
ptr2 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr2, desc.offset);
}

Value *new_ptr1 = ptr1;
Value *endptr1 = nullptr;
BasicBlock *postBB = nullptr;
BasicBlock *loopBB = nullptr;
PHINode *answerphi = nullptr;
if (desc.nrepeats != 1) {
// Set up loop
endptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, ptr1, desc.nrepeats * (desc.data_bytes + desc.padding_bytes));;

BasicBlock *currBB = ctx.builder.GetInsertBlock();
loopBB = BasicBlock::Create(ctx.builder.getContext(), "egal_loop", ctx.f);
postBB = BasicBlock::Create(ctx.builder.getContext(), "post", ctx.f);
ctx.builder.CreateBr(loopBB);

ctx.builder.SetInsertPoint(loopBB);
answerphi = ctx.builder.CreatePHI(TInt1, 2);
answerphi->addIncoming(answer ? answer : ConstantInt::get(TInt1, 1), currBB);
answer = answerphi;

PHINode *itr1 = ctx.builder.CreatePHI(ptr1->getType(), 2);
PHINode *itr2 = ctx.builder.CreatePHI(ptr2->getType(), 2);

new_ptr1 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, itr1, desc.data_bytes + desc.padding_bytes);
itr1->addIncoming(ptr1, currBB);
itr1->addIncoming(new_ptr1, loopBB);

Value *new_ptr2 = ctx.builder.CreateConstInBoundsGEP1_32(TInt8, itr2, desc.data_bytes + desc.padding_bytes);
itr2->addIncoming(ptr2, currBB);
itr2->addIncoming(new_ptr2, loopBB);

ptr1 = itr1;
ptr2 = itr2;
}

// Emit memcmp. TODO: LLVM has a pass to expand this for additional
// performance.
Value *this_answer = ctx.builder.CreateCall(prepare_call(memcmp_func),
{ ptr1,
ptr2,
ConstantInt::get(ctx.types().T_size, desc.data_bytes) });
this_answer = ctx.builder.CreateICmpEQ(this_answer, ConstantInt::get(getInt32Ty(ctx.builder.getContext()), 0));
answer = answer ? ctx.builder.CreateAnd(answer, this_answer) : this_answer;
if (endptr1) {
answerphi->addIncoming(answer, loopBB);
Value *loopend = ctx.builder.CreateICmpEQ(new_ptr1, endptr1);
ctx.builder.CreateCondBr(loopend, postBB, loopBB);
ctx.builder.SetInsertPoint(postBB);
}
};
egal_desc current_desc = {0};
size_t trailing_data_bytes = emit_masked_bits_compare(emit_desc, sty, current_desc);
assert(current_desc.nrepeats != 0);
emit_desc(current_desc);
if (trailing_data_bytes != 0) {
current_desc.nrepeats = 1;
current_desc.data_bytes = trailing_data_bytes;
current_desc.padding_bytes = 0;
emit_desc(current_desc);
}
return answer;
}
else {
jl_svec_t *types = sty->types;
Value *answer = ConstantInt::get(getInt1Ty(ctx.builder.getContext()), 1);
Expand Down
32 changes: 23 additions & 9 deletions src/datatype.c
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ static jl_datatype_layout_t *jl_get_layout(uint32_t sz,
uint32_t npointers,
uint32_t alignment,
int haspadding,
int isbitsegal,
int arrayelem,
jl_fielddesc32_t desc[],
uint32_t pointers[]) JL_NOTSAFEPOINT
Expand Down Expand Up @@ -226,6 +227,7 @@ static jl_datatype_layout_t *jl_get_layout(uint32_t sz,
flddesc->nfields = nfields;
flddesc->alignment = alignment;
flddesc->flags.haspadding = haspadding;
flddesc->flags.isbitscomparable = isbitscomparable;
flddesc->flags.fielddesc_type = fielddesc_type;
flddesc->flags.arrayelem_isboxed = arrayelem == 1;
flddesc->flags.arrayelem_isunion = arrayelem == 2;
Expand Down Expand Up @@ -504,6 +506,7 @@ void jl_get_genericmemory_layout(jl_datatype_t *st)
int isunboxed = jl_islayout_inline(eltype, &elsz, &al) && (kind != (jl_value_t*)jl_atomic_sym || jl_is_datatype(eltype));
int isunion = isunboxed && jl_is_uniontype(eltype);
int haspadding = 1; // we may want to eventually actually compute this more precisely
int isbitsegal = 0;
int nfields = 0; // aka jl_is_layout_opaque
int npointers = 1;
int zi;
Expand Down Expand Up @@ -562,7 +565,7 @@ void jl_get_genericmemory_layout(jl_datatype_t *st)
else
arrayelem = 0;
assert(!st->layout);
st->layout = jl_get_layout(elsz, nfields, npointers, al, haspadding, arrayelem, NULL, pointers);
st->layout = jl_get_layout(elsz, nfields, npointers, al, haspadding, isbitsegal, arrayelem, NULL, pointers);
st->zeroinit = zi;
//st->has_concrete_subtype = 1;
//st->isbitstype = 0;
Expand Down Expand Up @@ -673,6 +676,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
size_t alignm = 1;
int zeroinit = 0;
int haspadding = 0;
int isbitsegal = 1;
int homogeneous = 1;
int needlock = 0;
uint32_t npointers = 0;
Expand All @@ -687,19 +691,30 @@ void jl_compute_field_offsets(jl_datatype_t *st)
throw_ovf(should_malloc, desc, st, fsz);
desc[i].isptr = 0;
if (jl_is_uniontype(fld)) {
haspadding = 1;
fsz += 1; // selector byte
zeroinit = 1;
// TODO: Some unions could be bits comparable.
isbitsegal = 0;
}
else {
uint32_t fld_npointers = ((jl_datatype_t*)fld)->layout->npointers;
if (((jl_datatype_t*)fld)->layout->flags.haspadding)
haspadding = 1;
if (!((jl_datatype_t*)fld)->layout->flags.isbitsegal)
isbitsegal = 0;
if (i >= nfields - st->name->n_uninitialized && fld_npointers &&
fld_npointers * sizeof(void*) != fsz) {
// field may be undef (may be uninitialized and contains pointer),
// and contains non-pointer fields of non-zero sizes.
haspadding = 1;
// For field types that contain pointers, we allow inlinealloc
// as long as the field type itself is always fully initialized.
// In such a case, we use the first pointer in the inlined field
// as the #undef marker (if it is zero, we treat the whole inline
// struct as #undef). However, we do not zero-initialize the whole
// struct, so the non-pointer parts of the inline allocation may
// be arbitrary, but still need to compare egal (because all #undef)
// representations are egal. Because of this, we cannot bitscompare
// them.
// TODO: Consider zero-initializing the whole struct.
isbitsegal = 0;
}
if (!zeroinit)
zeroinit = ((jl_datatype_t*)fld)->zeroinit;
Expand All @@ -715,8 +730,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
zeroinit = 1;
npointers++;
if (!jl_pointer_egal(fld)) {
// this somewhat poorly named flag says whether some of the bits can be non-unique
haspadding = 1;
isbitsegal = 0;
}
}
if (isatomic && fsz > MAX_ATOMIC_SIZE)
Expand Down Expand Up @@ -777,7 +791,7 @@ void jl_compute_field_offsets(jl_datatype_t *st)
}
}
assert(ptr_i == npointers);
st->layout = jl_get_layout(sz, nfields, npointers, alignm, haspadding, 0, desc, pointers);
st->layout = jl_get_layout(sz, nfields, npointers, alignm, haspadding, isbitsegal, 0, desc, pointers);
if (should_malloc) {
free(desc);
if (npointers)
Expand Down Expand Up @@ -931,7 +945,7 @@ JL_DLLEXPORT jl_datatype_t *jl_new_primitivetype(jl_value_t *name, jl_module_t *
bt->ismutationfree = 1;
bt->isidentityfree = 1;
bt->isbitstype = (parameters == jl_emptysvec);
bt->layout = jl_get_layout(nbytes, 0, 0, alignm, 0, 0, NULL, NULL);
bt->layout = jl_get_layout(nbytes, 0, 0, alignm, 0, 1, 0, NULL, NULL);
bt->instance = NULL;
return bt;
}
Expand Down
5 changes: 4 additions & 1 deletion src/julia.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,10 @@ typedef struct {
// metadata bit only for GenericMemory eltype layout
uint16_t arrayelem_isboxed : 1;
uint16_t arrayelem_isunion : 1;
uint16_t padding : 11;
// If set, this type's egality can be determined entirely by comparing
// the non-padding bits of this datatype.
uint16_t isbitsegal : 1;
uint16_t padding : 10;
} flags;
// union {
// jl_fielddesc8_t field8[nfields];
Expand Down
51 changes: 51 additions & 0 deletions test/compiler/codegen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -873,3 +873,54 @@ if Sys.ARCH === :x86_64
end
end
end

# #54109 - Excessive LLVM time for egal
struct DefaultOr54109{T}
x::T
default::Bool
end

@eval struct Torture1_54109
$((Expr(:(::), Symbol("x$i"), DefaultOr54109{Float64}) for i = 1:897)...)
end
Torture1_54109() = Torture1_54109((DefaultOr54109(1.0, false) for i = 1:897)...)

@eval struct Torture2_54109
$((Expr(:(::), Symbol("x$i"), DefaultOr54109{Float64}) for i = 1:400)...)
$((Expr(:(::), Symbol("x$(i+400)"), DefaultOr54109{Int16}) for i = 1:400)...)
end
Torture2_54109() = Torture2_54109((DefaultOr54109(1.0, false) for i = 1:400)..., (DefaultOr54109(Int16(1), false) for i = 1:400)...)

@noinline egal_any54109(x, @nospecialize(y::Any)) = x === Base.compilerbarrier(:type, y)

let ir1 = get_llvm(egal_any54109, Tuple{Torture1_54109, Any}),
ir2 = get_llvm(egal_any54109, Tuple{Torture2_54109, Any})

# We can't really do timing on CI, so instead, let's look at the length of
# the optimized IR. The original version had tens of thousands of lines and
# was slower, so just check here that we only have < 500 lines. If somebody,
# implements a better comparison that's larger than that, just re-benchmark
# this and adjust the threshold.

@test count(==('\n'), ir1) < 500
@test count(==('\n'), ir2) < 500
end

## Regression test for egal of a struct of this size without padding, but with
## non-bitsegal, to make sure that it doesn't accidentally go down the accelerated
## path.
@eval struct BigStructAnyInt
$((Expr(:(::), Symbol("x$i"), Tuple{Any, Int}) for i = 1:33)...)
end
BigStructAnyInt() = BigStructAnyInt(((Union{Base.inferencebarrier(Float64), Int}, i) for i = 1:33)...)
@test egal_any54109(BigStructAnyInt(), BigStructAnyInt())

## For completeness, also test correctness, since we don't have a lot of
## large-struct tests.

# The two allocations of the same struct will likely have different padding,
# we want to make sure we find them egal anyway - a naive memcmp would
# accidentally look at it.
@test egal_any54109(Torture1_54109(), Torture1_54109())
@test egal_any54109(Torture2_54109(), Torture2_54109())
@test !egal_any54109(Torture1_54109(), Torture1_54109((DefaultOr54109(2.0, false) for i = 1:897)...))

0 comments on commit b5a5ea1

Please sign in to comment.