diff --git a/src/encoder.nr b/src/encoder.nr index 79429ce..3187981 100644 --- a/src/encoder.nr +++ b/src/encoder.nr @@ -59,7 +59,29 @@ impl Base64EncodeBE { // 240 bits fits 40 6-bit chunks and 30 8-bit chunks // we pack 40 base64 values into a field element and convert into 30 bytes // TODO: once we support arithmetic ops on generics, derive OutputBytes from InputBytes + // Calculate the number of padding characters and the length of the output without padding + let rem = InputBytes % 3; + let num_padding_chars = if rem == 1 { + 2 + } else if rem == 2 { + 1 + } else { + 0 + }; + let non_padded_output_length = OutputElements - num_padding_chars; + + // Assert that the output length is correct + if self.pad { + assert((InputBytes + 2) / 3 * 4 == OutputElements, "invalid output length"); + } else { + assert( + ((InputBytes + 2) / 3 * 4) - num_padding_chars == OutputElements, + "invalid output length", + ); + } + let mut result: [u8; OutputElements] = [0; OutputElements]; + let BASE64_ELEMENTS_PER_CHUNK: u32 = 40; let BYTES_PER_CHUNK: u32 = 30; let num_chunks = @@ -103,7 +125,10 @@ impl Base64EncodeBE { // extract the 6-bit values from the Field element let slice_base64_chunks: [u8; 40] = slice.to_be_radix(64); for i in 0..num_elements_in_final_chunk { - result[base64_offset + i] = slice_base64_chunks[i]; + // If padding is enabled, we can skip setting the last `non_padded_output_length` chars + if (!self.pad | base64_offset + i < non_padded_output_length) { + result[base64_offset + i] = slice_base64_chunks[i]; + } } result = self.encode_elements(result); @@ -112,10 +137,10 @@ impl Base64EncodeBE { // handle padding let rem = InputBytes % 3; if (rem == 1) { - result[base64_offset + num_elements_in_final_chunk - 1] = BASE64_PADDING_CHAR; - result[base64_offset + num_elements_in_final_chunk - 2] = BASE64_PADDING_CHAR; + result[OutputElements - 1] = BASE64_PADDING_CHAR; + result[OutputElements - 2] = BASE64_PADDING_CHAR; } else if (rem == 2) { - result[base64_offset + num_elements_in_final_chunk - 1] = BASE64_PADDING_CHAR; + result[OutputElements - 1] = BASE64_PADDING_CHAR; } } }