diff --git a/sm2/sm2.go b/sm2/sm2.go index 20b2794a..16a0280e 100644 --- a/sm2/sm2.go +++ b/sm2/sm2.go @@ -27,13 +27,17 @@ import ( "io" "math/big" + "golang.org/x/crypto/cryptobyte" + cbasn1 "golang.org/x/crypto/cryptobyte/asn1" + "github.com/tjfoc/gmsm/sm3" ) var ( default_uid = []byte{0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38} - C1C3C2=0 - C1C2C3=1 + + C1C3C2 = 0 + C1C2C3 = 1 ) type PublicKey struct { @@ -46,9 +50,6 @@ type PrivateKey struct { D *big.Int } -type sm2Signature struct { - R, S *big.Int -} type sm2Cipher struct { XCoordinate *big.Int YCoordinate *big.Int @@ -71,16 +72,28 @@ func (priv *PrivateKey) Sign(random io.Reader, msg []byte, signer crypto.SignerO if err != nil { return nil, err } - return asn1.Marshal(sm2Signature{r, s}) -} - -func (pub *PublicKey) Verify(msg []byte, sign []byte) bool { - var sm2Sign sm2Signature - _, err := asn1.Unmarshal(sign, &sm2Sign) - if err != nil { + var b cryptobyte.Builder + b.AddASN1(cbasn1.SEQUENCE, func(b *cryptobyte.Builder) { + b.AddASN1BigInt(r) + b.AddASN1BigInt(s) + }) + return b.Bytes() +} + +func (pub *PublicKey) Verify(msg []byte, sig []byte) bool { + var ( + r, s = &big.Int{}, &big.Int{} + inner cryptobyte.String + ) + input := cryptobyte.String(sig) + if !input.ReadASN1(&inner, cbasn1.SEQUENCE) || + !input.Empty() || + !inner.ReadASN1Integer(r) || + !inner.ReadASN1Integer(s) || + !inner.Empty() { return false } - return Sm2Verify(pub, msg, default_uid, sm2Sign.R, sm2Sign.S) + return Sm2Verify(pub, msg, default_uid, r, s) } func (pub *PublicKey) Sm3Digest(msg, uid []byte) ([]byte, error) { @@ -244,7 +257,7 @@ func Verify(pub *PublicKey, hash []byte, r, s *big.Int) bool { * hash * CipherText */ -func Encrypt(pub *PublicKey, data []byte, random io.Reader,mode int) ([]byte, error) { +func Encrypt(pub *PublicKey, data []byte, random io.Reader, mode int) ([]byte, error) { length := len(data) for { c := []byte{} @@ -287,42 +300,40 @@ func Encrypt(pub *PublicKey, data []byte, random io.Reader,mode int) ([]byte, er for i := 0; i < length; i++ { c[96+i] ^= data[i] } - switch mode{ - + switch mode { + case C1C3C2: return append([]byte{0x04}, c...), nil case C1C2C3: c1 := make([]byte, 64) - c2 := make([]byte, len(c) - 96) + c2 := make([]byte, len(c)-96) c3 := make([]byte, 32) - copy(c1, c[:64])//x1,y1 - copy(c3, c[64:96])//hash - copy(c2, c[96:])//密文 + copy(c1, c[:64]) //x1,y1 + copy(c3, c[64:96]) //hash + copy(c2, c[96:]) //密文 ciphertext := []byte{} ciphertext = append(ciphertext, c1...) ciphertext = append(ciphertext, c2...) ciphertext = append(ciphertext, c3...) return append([]byte{0x04}, ciphertext...), nil - default: + default: return append([]byte{0x04}, c...), nil + } } } -} - - -func Decrypt(priv *PrivateKey, data []byte,mode int) ([]byte, error) { +func Decrypt(priv *PrivateKey, data []byte, mode int) ([]byte, error) { switch mode { case C1C3C2: data = data[1:] - case C1C2C3: + case C1C2C3: data = data[1:] c1 := make([]byte, 64) - c2 := make([]byte, len(data) - 96) + c2 := make([]byte, len(data)-96) c3 := make([]byte, 32) - copy(c1, data[:64])//x1,y1 - copy(c2, data[64:len(data) - 32])//密文 - copy(c3, data[len(data) - 32:])//hash + copy(c1, data[:64]) //x1,y1 + copy(c2, data[64:len(data)-32]) //密文 + copy(c3, data[len(data)-32:]) //hash c := []byte{} c = append(c, c1...) c = append(c, c3...) @@ -362,8 +373,6 @@ func Decrypt(priv *PrivateKey, data []byte,mode int) ([]byte, error) { return c, nil } - - // keyExchange 为SM2密钥交换算法的第二部和第三步复用部分,协商的双方均调用此函数计算共同的字节串 // klen: 密钥长度 // ida, idb: 协商双方的标识,ida为密钥协商算法发起方标识,idb为响应方标识 @@ -479,7 +488,7 @@ func zeroByteSlice() []byte { sm2加密,返回asn.1编码格式的密文内容 */ func EncryptAsn1(pub *PublicKey, data []byte, rand io.Reader) ([]byte, error) { - cipher, err := Encrypt(pub, data, rand,C1C3C2) + cipher, err := Encrypt(pub, data, rand, C1C3C2) if err != nil { return nil, err } @@ -494,7 +503,7 @@ func DecryptAsn1(pub *PrivateKey, data []byte) ([]byte, error) { if err != nil { return nil, err } - return Decrypt(pub, cipher,C1C3C2) + return Decrypt(pub, cipher, C1C3C2) } /* @@ -670,5 +679,5 @@ func getLastBit(a *big.Int) uint { // crypto.Decrypter func (priv *PrivateKey) Decrypt(_ io.Reader, msg []byte, _ crypto.DecrypterOpts) (plaintext []byte, err error) { - return Decrypt(priv, msg,C1C3C2) + return Decrypt(priv, msg, C1C3C2) } diff --git a/sm2/utils.go b/sm2/utils.go index 03d92270..a2ecc1ca 100644 --- a/sm2/utils.go +++ b/sm2/utils.go @@ -20,7 +20,7 @@ func Decompress(a []byte) *PublicKey { y2 := sm2P256ToBig(&xx3) y := new(big.Int).ModSqrt(y2, sm2P256.P) - if getLastBit(y)!= uint(a[0]) { + if getLastBit(y) != uint(a[0]) { y.Sub(sm2P256.P, y) } return &PublicKey{ @@ -41,7 +41,9 @@ func Compress(a *PublicKey) []byte { return buf } - +type sm2Signature struct { + R, S *big.Int +} func SignDigitToSignData(r, s *big.Int) ([]byte, error) { return asn1.Marshal(sm2Signature{r, s}) @@ -56,4 +58,3 @@ func SignDataToSignDigit(sign []byte) (*big.Int, *big.Int, error) { } return sm2Sign.R, sm2Sign.S, nil } -