Skip to content

Commit

Permalink
[crypto] Speed up computation of R^2 in RSA.
Browse files Browse the repository at this point in the history
Implements the same algorithm we used in the specialized RSA-3072
implementation to speed up computation of R^2.

Also fixes an outdated comment in rsa_1024_enc_test.

Signed-off-by: Jade Philipoom <[email protected]>
  • Loading branch information
jadephilipoom authored and sameo committed Jun 19, 2024
1 parent ed4c1e7 commit a6b37ac
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 145 deletions.
271 changes: 127 additions & 144 deletions sw/otbn/crypto/montmul.s
Original file line number Diff line number Diff line change
Expand Up @@ -92,69 +92,82 @@ m0inv:
ret

/**
* Constant time conditional subtraction of modulus from a bigint
* Doubles a number and reduces modulo M in-place.
*
* Returns C <= C-s*M
* with C being a bigint of length 256..4096 bit
* M being the modulus of length 256..4096 bit
* s being a boolean value [0,1]
* Returns: C = (A + A) mod M
*
* Conditionally subtracts the modulus located in dmem from the bigint
* located in a buffer in the wide regfile (starting at w5). The subtracted
* value is selected when FG1.C equals 1, otherwise the unmodified value is
* selected.
* Requires that A < M < 2^(256*N). Writes output to the A buffer in DMEM.
*
* Note that the interpretation of the subtrahend as a modulus is only
* contextual. In theory, it can be any bigint. However, the subtrahend is
* expected in dmem at a location that is reserved for the modulus according
* to the calling conventions within this library.
* This routine runs in constant time.
*
* Flags: When leaving this subroutine, flags of FG0 depend on a
* potentially discarded value and therefore are not usable after
* return.
* FG1 is not modified in this subroutine.
* Flags: Flags have no meaning beyond the scope of this subroutine.
*
* @param[in] x16: dptr_m, pointer to 1st limb of modulus M
* @param[in] x30: N, number of 256 bit limbs in modulus and bigint
* @param[in] x16: dmem pointer to first limb of modulus M
* @param[in] x30: N, number of limbs
* @param[in] [w4:w(4+N-1)]: operand A
* @param[in] w31: all-zero
* @param[in] FG1.C: s, selection flag
* @param[out] [w[5+N-1]:w5]: new bigint value
* @param[in] FG0.C: needs to be set to 0
* @param[out] [w4:w(4+N-1)]: result C
*
* clobbered registers: x8, x10, x11, x16, w2, w3, w4, w5 to w[5+N-1]
* clobbered flag groups: FG0
* clobbered registers: x2, x3, x8, x10 to x13
* w2, w3, w4 to w(4+N-1), w24, w29, w30
* clobbered Flag Groups: FG0, FG1
*/
cond_sub_mod:

/* setup pointers */
li x8, 5
li x10, 3
li x11, 2

/* reset flags for FG0 */
bn.add w31, w31, w31

/* iterate over all limbs for limb-wise subtraction + conditional selection*/
double_and_reduce:
/* Clear carry flags. */
bn.sub w31, w31, w31
bn.sub w31, w31, w31, FG1

/* Double the input and compare the sum to the modulus.
[w4:w(4+N-1)] <= (A+A) mod 2^(256*N)
FG1.C <= (A+A-M) < 0 */
li x2, 2
li x3, 3
li x10, 4
addi x11, x16, 0
loop x30, 5
/* w3 <= a[i] */
bn.movr x3, x10
/* FG0.C, w3 <= w3 + w3 + FG0.C */
bn.addc w3, w3, w3
/* w2 <= M[i] */
bn.lid x2, 0(x11++)
/* FG1.C <= (w3 - M[i] - FG1.C) < 0 */
bn.cmpb w3, w2, FG1
/* w[4+i] <= w3 */
bn.movr x10++, x3

/* Now, FG0.C is 1 if (A + A) >= 2^(256*N) and 0 otherwise, and FG1.C is 1 if
(A + A) mod 2^(256*N) < M. So we have the following cases:
1) FG0.C is 0, FG1.C is 0 : A+A < 2^(256*N) and A + A >= M
2) FG0.C is 0, FG1.C is 1 : A+A < 2^(256*N) and A + A < M
3) FG0.C is 1, FG1.C is 0 : A+A >= 2^(256*N) and (A + A) mod 2^(256*N) >= M
4) FG0.C is 1, FG1.C is 1 : A+A >= 2^(256*N) and (A + A) mod 2^(256*N) < M
Case (3) is impossible given the bounds on A and M, because it would
require that A + A > 2^(256*N) + M. Case (2) is the only one in which we
don't need to subtract the modulus, since A + A < M. In cases (1) and (4)
we need to subtract the modulus. */

/* Clear FG0.C, and set FG1.C so that it is 1 if and only if FG0.C and FG1.C
match.
FG0.C <= 0
FG1.C <= (FG0.C ^ FG1.C) <? 1 */
bn.addc w2, w31, w31
bn.addc w3, w31, w31, FG1
bn.xor w2, w2, w3
bn.subi w2, w2, 1, FG1

/* Conditionally subtract M.
[w4:w(4+N-1)] <= [w4:w(4+N-1)] - FG1.C * M = (A + A) mod M */
li x8, 4
addi x10, x16, 0
jal x1, cond_sub_to_reg

/* load a limb of modulus from dmem to w3 */
bn.lid x10, 0(x16++)

/* load the limb of bigint buffer to w2 */
bn.movr x11, x8

/* subtract the current limb of the modulus from current limb of bigint */
bn.subb w4, w2, w3

/* conditionally select subtraction result or unmodified limb */
bn.sel w3, w4, w2, FG1.C

/* move back result from w3 to bigint buffer */
bn.movr x8++, x10
/* Restore modulus pointer (clobbered by cond_sub_to_reg). */
addi x16, x10, 0

ret


/**
* Compute square of Montgomery modulus
*
Expand All @@ -171,107 +184,74 @@ cond_sub_mod:
* not usable after return.
*
* @param[in] x16: dptr_M, pointer to first limb of modulus in dmem
* @param[in] x17: dptr_m0d, dmem pointer to Montgomery Constant m0'
* @param[in] x18: dptr_RR: dmem pointer to first limb of output buffer for RR
* @param[in] x30: N, number of limbs
* @param[in] x31: N-1, number of limbs minus 1
* @param[in] w31: all-zero
* @param[out] dmem[dptr_RR+N*32:dptr_RR]: computed RR
*
* clobbered registers: x3, x8, x10, x11, x22
* clobbered registers: x3, x8, x10, x11
* w0, w2, w3, w4, w5 to w20 depending on N
* clobbered flag groups: FG0, FG1
*/
compute_rr:
/* save pointer to modulus */
addi x22, x16, 0

/* zeroize w3 */
bn.xor w3, w3, w3

/* compute full length of current bigint size in bits
N*w = x24 = N*256 = N*2^8 = x30 << 8 */
slli x24, x30, 8

/* reg pointers */
li x8, 5
li x10, 3

/* zeroize w3 */
bn.xor w3, w3, w3
/* Prepare all-zero register and clear FG0.C. */
bn.sub w31, w31, w31

/* Initialize the buffer with R mod M = 2^(256*N) - M. Because of the bounds
on M, the subtraction will never underflow.
[w4:w(4+N-1)] <= (0 - M) mod 2^(256*N) = R mod M */
addi x10, x16, 0
li x11, 4
li x3, 3
loop x30, 3
/* w3 <= M[i] */
bn.lid x3, 0(x10++)
/* FG0.C, w3 <= (0 - M[i] - FG0.C) */
bn.subb w3, w31, w3
/* w[4+i] <= w3 */
bn.movr x11++, x3

/* Repeatedly double R until 5 squarings is enough to get R^2; that is, we
compute T = (2^(256*N / 32) * R) mod M. We could use different cutoffs for
switching from doubling to squaring, but this cutoff is empirically
fastest for RSA-3072.
[w4:w(4+N-1)] = (2^(8*N) * [w4:w(4+N-1)]) mod M = T */
slli x10, x30, 3
loop x10, 2
jal x1, double_and_reduce
nop

/* Store T in output buffer (in preparation for montmul).
dmem[dptr_RR] <= [w4:w(4+N-1)] = T */
li x8, 4
addi x21, x18, 0
loop x30, 2
bn.sid x8, 0(x21++)
addi x8, x8, 1

/* zeroize all limbs of bigint in regfile */
loop x30, 1
bn.movr x8++, x10
/* Prepare pointers to temp regs for montmul. */
li x9, 3
li x10, 4
li x11, 2

/* compute R-M
since R = 2^(N*w), this can be computed as R-M = unsigned(0-M) */
bn.addi w0, w31, 1
bn.sub w3, w31, w0, FG1
addi x16, x22, 0
jal x1, cond_sub_mod

/* Compute R^2 mod M = R*2^(N*w) mod M.
=> R^2 mod M can be computed by performing N*w duplications of R.
We directly perform a modulo reduction in each step such that the
final result will already be reduced. */
loop x24, 18
/* reset pointer */
li x8, 5

/* zeroize w3 reset flags of FG1 */
bn.sub w3, w3, w3, FG1

/* Duplicate the intermediate bigint result. This can overflow such that
bit 2^(N*w) (represented by the carry bit after the final loop cycle)
is set. */
loop x30, 3
/* copy current limb of bigint to w2 */
bn.movr x11, x8

/* perform the doubling */
bn.addc w2, w2, w2, FG1

/* copy result back to bigint in regfile */
bn.movr x8++, x11

/* Conditionally subtract the modulus from the current bigint Y if there
was an overflow. Again, just considering the lowest N*w bits is
sufficient, since (in case of an overflow) we can write
2*Y as 2^(N*w) + X with M > X >= 0.
Then, 2*Y - M = 2^(N*w) + X - M = X + unsigned(0-M) */
addi x16, x22, 0
jal x1, cond_sub_mod

/* reset pointer to 1st limb of bigint in regfile */
li x8, 5

/* reset pointer to modulus in dmem */
addi x16, x22, 0

/* reset flags of FG1 */
bn.sub w3, w3, w3, FG1

/* compare intermediate bigint y with modulus
subtract modulus if Y > M */
loop x30, 3
bn.lid x10, 0(x16++)
bn.movr x11, x8++
bn.cmpb w3, w2, FG1
addi x16, x22, 0
jal x1, cond_sub_mod

li x0, 0

/* reset pointer to 1st limb of bigint in regfile */
li x8, 5

/* reset pointer to modulus */
addi x16, x22, 0
/* Prepare a pointer to the w4 register for storing the result. */
li x8, 4

/* store computed RR in dmem */
addi x3, x18, 0
loop x30, 2
bn.sid x8, 0(x3++)
addi x8, x8, 1
/* Five montgomery squares to compute RR = (T^(2^5) * R) mod M. */
loopi 5,9
/* [w4:w(4+N-1)] <= montmul(dmem[rr], dmem[rr]) */
addi x19, x18, 0
addi x20, x18, 0
jal x1, montmul
/* Store result: dmem[rr] <= [w4:w(4+N-1)] */
addi x2, x18, 0
addi x3, x8, 0
loop x30, 2
bn.sid x3, 0(x2++)
addi x3, x3, 1
nop

ret

Expand Down Expand Up @@ -368,7 +348,7 @@ mul256_w30xw2:
* @param[in] x30: number of limbs
* @param[in] FG0.C: needs to be set to 0
*
* clobbered registers: x8, x16, w24, w29, w30, w[x8] to w[x8+N-1]
* clobbered registers: x8, x12, x13, x16, w24, w29, w30, w[x8] to w[x8+N-1]
* clobbered Flag Groups: FG0
*/
cond_sub_to_reg:
Expand All @@ -378,7 +358,7 @@ cond_sub_to_reg:
li x13, 24

/* iterate over all limbs for conditional limb-wise subtraction */
loop x30, 6
loop x30, 5
/* load limb of subtrahend (input B) to w24 */
bn.lid x13, 0(x16++)

Expand All @@ -388,8 +368,6 @@ cond_sub_to_reg:
/* perform subtraction for a limb */
bn.subb w29, w30, w24

bn.movr x8, x13

/* conditionally select subtraction result or unmodified limb */
bn.sel w24, w29, w30, FG1.C

Expand Down Expand Up @@ -591,7 +569,7 @@ mont_loop:
* @param[in] x11: pointer to temp reg, must be set to 2
* @param[out] [w[4+N-1]:w4]: result C
*
* clobbered registers: x5, x6, x7, x8, x10, x12, x13, x20, x22
* clobbered registers: x5 to x9, x12, x13, x20, x22
* w2, w3, w4 to w[4+N-1], w24 to w30
* clobbered Flag Groups: FG0, FG1
*/
Expand Down Expand Up @@ -626,6 +604,7 @@ montmul:
/* restore pointers */
li x8, 4
li x10, 4
li x11, 2

ret

Expand All @@ -652,6 +631,10 @@ modload:
li x8, 28
bn.lid x8, 0(x16)

/* x31 <= N - 1 */
li x2, 1
sub x31, x30, x2

/* Compute Montgomery constant */
jal x1, m0inv

Expand Down
2 changes: 1 addition & 1 deletion sw/otbn/crypto/tests/rsa_1024_enc_test.s
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
* .data segment in this file.
*
* Copies the encrypted message to wide registers for comparison (starting at
* w0). See comment at the end of the file for expected values.
* w0).
*/
run_rsa_1024_enc:
/* Init all-zero register. */
Expand Down

0 comments on commit a6b37ac

Please sign in to comment.