Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add @depositBits and @extractBits builtins #15285

Closed
wants to merge 18 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions doc/langref.html.in
Original file line number Diff line number Diff line change
Expand Up @@ -8292,6 +8292,33 @@ test "main" {
{#see_also|@cVaArg|@cVaCopy|@cVaEnd#}
{#header_close#}

{#header_open|@depositBits#}
<pre>{#syntax#}@depositBits(source: T, mask: T) T{#endsyntax#}</pre>
ominitay marked this conversation as resolved.
Show resolved Hide resolved
<p>
{#syntax#}T{#endsyntax#} must be an unsigned integer type, or a `comptime_int` (for which both parameters must be positive). `T` is determined by peer-type resolution.
</p>
<p>
Uses a mask to transfer contiguous lower bits in the {#syntax#}source{#endsyntax#} operand to the destination, transferring them to the corresponding bits in the destination that are set in the mask. All other bits in the destination are zeroed.
</p>
<p>
Currently, only x86 processors with BMI2 enabled support this in hardware. On processors without support for the instruction, it will be emulated. AMD processors before Zen 3 implement the corresponding instruction (PDEP) in microcode. It may be faster to use an alternative method in both of these cases.
</p>
<p>
Example:
</p>

{#code_begin|test|test_depositbits_builtin#}
const std = @import("std");

test "deposit bits" {
comptime {
try std.testing.expectEqual(@depositBits(0x00001234, 0xf0f0f0f0), 0x10203040);
}
}
{#code_end#}
{#see_also|@extractBits#}
{#header_close#}

{#header_open|@divExact#}
<pre>{#syntax#}@divExact(numerator: T, denominator: T) T{#endsyntax#}</pre>
<p>
Expand Down Expand Up @@ -8462,6 +8489,33 @@ export fn @"A function name that is a complete sentence."() void {}
{#see_also|@export#}
{#header_close#}

{#header_open|@extractBits#}
<pre>{#syntax#}@extractBits(source: T, mask: T) T{#endsyntax#}</pre>
<p>
{#syntax#}T{#endsyntax#} must be an unsigned integer type, or a `comptime_int` (for which both parameters must be positive). `T` is determined by peer-type resolution.
</p>
<p>
Uses a mask to transfer bits in the {#syntax#}source{#endsyntax#} operand to the destination, writing them as contiguous lower bits in the destination. The upper bits of the destination are zeroed.
</p>
<p>
Currently, only x86 processors with BMI2 enabled support this in hardware. On processors without support for the instruction, it will be emulated. AMD processors before Zen 3 implement the corresponding instruction (PEXT) in microcode. It may be faster to use an alternative method in both of these cases.
</p>
<p>
Example:
</p>

{#code_begin|test|test_depositbits_builtin#}
const std = @import("std");

test "extract bits" {
comptime {
try std.testing.expectEqual(@extractBits(0x12345678, 0xf0f0f0f0), 0x00001357);
}
}
{#code_end#}
{#see_also|@depositBits#}
{#header_close#}

{#header_open|@fence#}
<pre>{#syntax#}@fence(order: AtomicOrder) void{#endsyntax#}</pre>
<p>
Expand Down
92 changes: 92 additions & 0 deletions lib/std/math/big/int.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1732,6 +1732,98 @@ pub const Mutable = struct {
y.shiftRight(y.toConst(), norm_shift);
}

// TODO this function is quite inefficient and could be optimised
/// r = @depositBits(source, mask)
///
/// Asserts that `source` and `mask` are positive
pub fn depositBits(r: *Mutable, source: Const, mask: Const) void {
assert(source.positive);
assert(mask.positive);

r.positive = true;
@memset(r.limbs, 0);

var mask_limb: Limb = mask.limbs[0];
var mask_limb_index: Limb = 0;
var i: usize = 0;
outer: while (true) : (i += 1) {
// Find next bit in mask
const mask_limb_bit: Log2Limb = limb_bit: while (true) {
const mask_limb_tz = @ctz(mask_limb);
if (mask_limb_tz != @sizeOf(Limb) * 8) {
const cast_limb_bit = @intCast(Log2Limb, mask_limb_tz);
mask_limb ^= @as(Limb, 1) << cast_limb_bit;
break :limb_bit cast_limb_bit;
}

mask_limb_index += 1;
// No more limbs, we've finished iterating the mask
if (mask_limb_index >= mask.limbs.len) {
break :outer;
}

mask_limb = mask.limbs[mask_limb_index];
};

const i_limb_index = i / limb_bits;
const i_limb_bit = @truncate(Log2Limb, i);

if (i_limb_index >= source.limbs.len) break; // Stop when we reach the end of `source` (we can treat the rest as zeroes)

const source_bit_set = source.limbs[i_limb_index] & (@as(Limb, 1) << i_limb_bit) != 0;

r.limbs[mask_limb_index] |= @as(Limb, @intFromBool(source_bit_set)) << mask_limb_bit;
}

r.normalize(r.limbs.len);
}

// TODO this function is quite inefficient and could be optimised
/// r = @extractBits(source, mask)
///
/// Asserts that `source` and `mask` are positive
pub fn extractBits(r: *Mutable, source: Const, mask: Const) void {
assert(source.positive);
assert(mask.positive);

r.positive = true;
@memset(r.limbs, 0);

var mask_limb: Limb = mask.limbs[0];
var mask_limb_index: Limb = 0;
var i: usize = 0;
outer: while (true) : (i += 1) {
// Find next bit in mask
const mask_limb_bit: Log2Limb = limb_bit: while (true) {
const mask_limb_tz = @ctz(mask_limb);
if (mask_limb_tz != @sizeOf(Limb) * 8) {
const cast_limb_bit = @intCast(Log2Limb, mask_limb_tz);
mask_limb ^= @as(Limb, 1) << cast_limb_bit;
break :limb_bit cast_limb_bit;
}

mask_limb_index += 1;
// No more limbs, we've finished iterating the mask
if (mask_limb_index >= mask.limbs.len) {
break :outer;
}

mask_limb = mask.limbs[mask_limb_index];
};

const i_limb_index = i / limb_bits;
const i_limb_bit = @truncate(Log2Limb, i);

if (mask_limb_index >= source.limbs.len) break; // Stop when we reach the end of `source` (we can treat the rest as zeroes)

const source_bit_set = source.limbs[mask_limb_index] & (@as(Limb, 1) << mask_limb_bit) != 0;

r.limbs[i_limb_index] |= @as(Limb, @intFromBool(source_bit_set)) << i_limb_bit;
}

r.normalize(r.limbs.len);
}

/// If a is positive, this passes through to truncate.
/// If a is negative, then r is set to positive with the bit pattern ~(a - 1).
/// r may alias a.
Expand Down
48 changes: 48 additions & 0 deletions lib/std/math/big/int_test.zig
Original file line number Diff line number Diff line change
Expand Up @@ -2762,6 +2762,54 @@ fn popCountTest(val: *const Managed, bit_count: usize, expected: usize) !void {
try testing.expectEqual(expected, val.toConst().popCount(bit_count));
}

test "big int extractBits" {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some possible tests to have more confidence everything is correct:

  • min + max supported sized type
  • 2 and 3 limb sizes
  • random input on your pc for varying (in CI with constant) seed (for both mask and original number): extract bit sequence, negate mask to extract the other bit sequence, applying or of extracted bit sequences must be identical to original.
  • same principle: set random bit sequence, negate mask to set the other one, applying or must be identical to original or destination number. Not sure how to best create destination number.
  • unclear and open (not feasible to solve with this PR): how to get structured testing of UUM for bigint and bigfloat

Random input can be taken like this:

const RndGen = std.rand.DefaultPrng;
    var rnd = RndGen.init(42);
    var i: u32 = 0;
    while (i < 10_000) : (i += 1) {
        var rand_num = rnd.random().int(i64);
        try test__popcountdi2(rand_num);
    }

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah issue with creating the destination value is that we need to have a known-good implementation to generate it. I guess we could just write this in regular Zig though and converting to a bigint.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm actually a bit skeptical of this now, since the actual opportunities for mistakes here would be around having multiple limbs. Generating random u64s likely wouldn't help us catch these, would it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Generating random u64s likely wouldn't help us catch these, would it?

I would expect the main problems to be in either the bit mask logic or the limbs not getting properly initialized. My suggestion was more of an brute-force approach, but listing the edge cases should also work.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm. I think we should look at a more reasoned set of tests then. The current ones effectively just test the basic bitmask logic, it'd be good to add some more to ensure that the boundaries between limbs are handled correctly, and that a difference in sizes is handled correctly. I'm fairly confident in both of these, but definitely can't hurt to strengthen this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Uninitialised limbs aren't a worry though, since the output bigint is set to zero at the beginning.

try extractBitsTest(0x12345678, 0x0, 0x0);
try extractBitsTest(0x12345678, 0xf0f0f0f0, 0x1357);
try extractBitsTest(0x12345678, 0xff00ff00, 0x1256);
try extractBitsTest(0x12345678, 0xffff, 0x5678);

try extractBitsTest(0x12345678_90123456_78901234_56789012, 0xff << 64, 0x56);
try extractBitsTest(0x12345678_90123456_78901234_56789012, (0xff << 64) | 0xff00f, 0x56892);
}

fn extractBitsTest(comptime source: comptime_int, comptime mask: comptime_int, comptime expected: comptime_int) !void {
var source_bigint = try Managed.initSet(testing.allocator, source);
defer source_bigint.deinit();
var mask_bigint = try Managed.initSet(testing.allocator, mask);
defer mask_bigint.deinit();
const limbs = try testing.allocator.alloc(Limb, mask_bigint.limbs.len);
defer testing.allocator.free(limbs);
var result = Mutable{ .limbs = limbs, .positive = undefined, .len = undefined };

result.extractBits(source_bigint.toConst(), mask_bigint.toConst());

try testing.expectEqual(std.math.Order.eq, result.toConst().orderAgainstScalar(expected));
}

test "big int depositBits" {
try depositBitsTest(0x12345678, 0x0, 0x0);
try depositBitsTest(0x12345678, 0xf0f0f0f0, 0x50607080);
try depositBitsTest(0x12345678, 0xff00ff00, 0x56007800);
try depositBitsTest(0x12345678, 0xffff, 0x5678);

try depositBitsTest(0x1234, 0xff << 64, 0x34_00000000_00000000);
try depositBitsTest(0x12345678, (0xff << 64) | 0xff00f, 0x45_00000000_00067008);
}

fn depositBitsTest(comptime source: comptime_int, comptime mask: comptime_int, comptime expected: comptime_int) !void {
var source_bigint = try Managed.initSet(testing.allocator, source);
defer source_bigint.deinit();
var mask_bigint = try Managed.initSet(testing.allocator, mask);
defer mask_bigint.deinit();
const limbs = try testing.allocator.alloc(Limb, mask_bigint.limbs.len);
defer testing.allocator.free(limbs);
var result = Mutable{ .limbs = limbs, .positive = undefined, .len = undefined };

result.depositBits(source_bigint.toConst(), mask_bigint.toConst());

try testing.expectEqual(std.math.Order.eq, result.toConst().orderAgainstScalar(expected));
}

test "big int conversion read/write twos complement" {
var a = try Managed.initSet(testing.allocator, (1 << 493) - 1);
defer a.deinit();
Expand Down
11 changes: 11 additions & 0 deletions src/Air.zig
Original file line number Diff line number Diff line change
Expand Up @@ -822,6 +822,13 @@ pub const Inst = struct {
/// Operand is unused and set to Ref.none
work_group_id,

/// Implements @depositBits builtin.
/// Uses the `bin_op` field.
deposit_bits,
/// Implements @extractBits builtin.
/// Uses the `bin_op` field.
extract_bits,

pub fn fromCmpOp(op: std.math.CompareOperator, optimized: bool) Tag {
switch (op) {
.lt => return if (optimized) .cmp_lt_optimized else .cmp_lt,
Expand Down Expand Up @@ -1232,6 +1239,8 @@ pub fn typeOfIndex(air: Air, inst: Air.Inst.Index, ip: *const InternPool) Type {
.div_exact_optimized,
.rem_optimized,
.mod_optimized,
.deposit_bits,
.extract_bits,
=> return air.typeOf(datas[inst].bin_op.lhs, ip),

.sqrt,
Expand Down Expand Up @@ -1742,6 +1751,8 @@ pub fn mustLower(air: Air, inst: Air.Inst.Index, ip: *const InternPool) bool {
.work_item_id,
.work_group_size,
.work_group_id,
.deposit_bits,
.extract_bits,
=> false,

.assembly => @truncate(u1, air.extraData(Air.Asm, data.ty_pl.payload).data.flags >> 31) != 0,
Expand Down
21 changes: 21 additions & 0 deletions src/AstGen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -8699,6 +8699,9 @@ fn builtinCall(
});
return rvalue(gz, ri, result, node);
},

.deposit_bits => return depositExtractBits(gz, scope, ri, node, params, .deposit_bits),
.extract_bits => return depositExtractBits(gz, scope, ri, node, params, .extract_bits),
}
}

Expand Down Expand Up @@ -8966,6 +8969,24 @@ fn overflowArithmetic(
return rvalue(gz, ri, result, node);
}

fn depositExtractBits(
gz: *GenZir,
scope: *Scope,
ri: ResultInfo,
node: Ast.Node.Index,
params: []const Ast.Node.Index,
tag: Zir.Inst.Extended,
) InnerError!Zir.Inst.Ref {
const lhs = try expr(gz, scope, .{ .rl = .none }, params[0]);
const rhs = try expr(gz, scope, .{ .rl = .none }, params[1]);
const result = try gz.addExtendedPayload(tag, Zir.Inst.BinNode{
.node = gz.nodeIndexToRelative(node),
.lhs = lhs,
.rhs = rhs,
});
return rvalue(gz, ri, result, node);
}

fn callExpr(
gz: *GenZir,
scope: *Scope,
Expand Down
15 changes: 15 additions & 0 deletions src/BuiltinFn.zig
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub const Tag = enum {
c_va_copy,
c_va_end,
c_va_start,
deposit_bits,
div_exact,
div_floor,
div_trunc,
Expand All @@ -46,6 +47,7 @@ pub const Tag = enum {
err_set_cast,
@"export",
@"extern",
extract_bits,
fence,
field,
field_parent_ptr,
Expand Down Expand Up @@ -396,6 +398,12 @@ pub const list = list: {
.param_count = 0,
},
},
.{
"@depositBits", .{
.tag = .deposit_bits,
.param_count = 2,
},
},
.{
"@divExact",
.{
Expand Down Expand Up @@ -474,6 +482,13 @@ pub const list = list: {
.param_count = 2,
},
},
.{
"@extractBits",
.{
.tag = .extract_bits,
.param_count = 2,
},
},
.{
"@fence",
.{
Expand Down
4 changes: 4 additions & 0 deletions src/Liveness.zig
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,8 @@ pub fn categorizeOperand(
.cmp_gte_optimized,
.cmp_gt_optimized,
.cmp_neq_optimized,
.deposit_bits,
.extract_bits,
=> {
const o = air_datas[inst].bin_op;
if (o.lhs == operand_ref) return matchOperandSmallIndex(l, inst, 0, .none);
Expand Down Expand Up @@ -942,6 +944,8 @@ fn analyzeInst(
.memset,
.memset_safe,
.memcpy,
.deposit_bits,
.extract_bits,
=> {
const o = inst_datas[inst].bin_op;
return analyzeOperands(a, pass, data, inst, .{ o.lhs, o.rhs, .none });
Expand Down
2 changes: 2 additions & 0 deletions src/Liveness/Verify.zig
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,8 @@ fn verifyBody(self: *Verify, body: []const Air.Inst.Index) Error!void {
.memset,
.memset_safe,
.memcpy,
.deposit_bits,
.extract_bits,
=> {
const bin_op = data[inst].bin_op;
try self.verifyInstOperands(inst, .{ bin_op.lhs, bin_op.rhs, .none });
Expand Down
Loading