From 2855692468af411fd2123b5f0a377a48fd82f59b Mon Sep 17 00:00:00 2001
From: Xin Wang <dram.wang@gmail.com>
Date: Sun, 9 Sep 2018 18:38:55 +0800
Subject: [PATCH] Add `keccak_NNN` family functions

KECCAK-256 is needed by Ethereum, see following discussion for detail:

https://github.com/ethereum/EIPs/issues/59
---
 sources/sha3.sig | 5 +++++
 sources/sha3.sml | 9 +++++++++
 tests/main.sml   | 8 ++++++++
 3 files changed, 22 insertions(+)

diff --git a/sources/sha3.sig b/sources/sha3.sig
index ccf2916..febd6c2 100644
--- a/sources/sha3.sig
+++ b/sources/sha3.sig
@@ -3,6 +3,11 @@ signature SHA3 = sig
         : (int * int * Word8Vector.vector * Word8.word * int)
           -> Word8Vector.vector
 
+    val keccak_224 : Word8Vector.vector -> Word8Vector.vector
+    val keccak_256 : Word8Vector.vector -> Word8Vector.vector
+    val keccak_384 : Word8Vector.vector -> Word8Vector.vector
+    val keccak_512 : Word8Vector.vector -> Word8Vector.vector
+
     val shake_128 : (Word8Vector.vector * int) -> Word8Vector.vector
     val shake_256 : (Word8Vector.vector * int) -> Word8Vector.vector
 
diff --git a/sources/sha3.sml b/sources/sha3.sml
index 1d24e0f..fc813c7 100644
--- a/sources/sha3.sml
+++ b/sources/sha3.sml
@@ -252,6 +252,15 @@ fun keccak (rate : int,
     end
 end
 
+fun keccak_224 (inputBytes : Word8Vector.vector) : Word8Vector.vector =
+    keccak (1152, 448, inputBytes, 0wx01, 224 div 8)
+fun keccak_256 (inputBytes : Word8Vector.vector) : Word8Vector.vector =
+    keccak (1088, 512, inputBytes, 0wx01, 256 div 8)
+fun keccak_384 (inputBytes : Word8Vector.vector) : Word8Vector.vector =
+    keccak (832, 768, inputBytes, 0wx01, 384 div 8)
+fun keccak_512 (inputBytes : Word8Vector.vector) : Word8Vector.vector =
+    keccak (576, 1024, inputBytes, 0wx01, 512 div 8)
+
 fun shake_128 (inputBytes : Word8Vector.vector,
                outputByteLen : int) : Word8Vector.vector =
     keccak (1344, 256, inputBytes, 0wx1f, outputByteLen)
diff --git a/tests/main.sml b/tests/main.sml
index 8f8d8fe..d6365f5 100644
--- a/tests/main.sml
+++ b/tests/main.sml
@@ -15,6 +15,10 @@ fun main _ =
         val () = (printResult "value1"
                   o bytesToString) value1
 
+        val () = (printResult "value1 KECCAK-256"
+                  o bytesToString
+                  o Sha3.keccak_256) value1
+
         val () = (printResult "value1 SHA3-256"
                   o bytesToString
                   o Sha3.sha3_256) value1
@@ -24,6 +28,10 @@ fun main _ =
         val () = (printResult "value2"
                   o bytesToString) value2
 
+        val () = (printResult "value2 KECCAK-256"
+                  o bytesToString
+                  o Sha3.keccak_256) value2
+
         val () = (printResult "value2 SHA3-256"
                   o bytesToString
                   o Sha3.sha3_256) value2