diff --git a/core/vm/contracts.go b/core/vm/contracts.go index f5fdb9d44c2b..a84d2ea4c4a7 100644 --- a/core/vm/contracts.go +++ b/core/vm/contracts.go @@ -1126,13 +1126,15 @@ func (c *bls12381MapG2) Run(accessibleState PrecompileAccessibleState, caller co type tomlConfigOptions struct { Oracle struct { - Mode string - OracleDBAddress string + Mode string + OracleDBAddress string + RequireRetryCount uint8 } Zk struct { Verify bool VerifyRPCAddress string + VerifyRetryCount uint8 } Tfhe struct { @@ -1294,6 +1296,10 @@ func fheEncryptToUserKey(value uint64, userAddress common.Address) ([]byte, erro return ct, nil } +func exitProcess() { + os.Exit(1) +} + type verifyCiphertext struct{} func (e *verifyCiphertext) RequiredGas(input []byte) uint64 { @@ -1301,24 +1307,33 @@ func (e *verifyCiphertext) RequiredGas(input []byte) uint64 { return 8 } -func verifyZkProof(input []byte) ([]byte, error) { - req, err := http.NewRequest(http.MethodPost, tomlConfig.Zk.VerifyRPCAddress, bytes.NewReader(input)) - if err != nil { - return nil, err - } - req.Header.Add("Content-Type", "application/msgpack") - resp, err := zkHttpClient.Do(req) - if err != nil { - return nil, err - } - if resp.StatusCode != 200 { - return nil, fmt.Errorf("failure HTTP status code on ZK verify: %d", resp.StatusCode) - } - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return nil, errors.New("failed reading ZK verification response body") +// Returns the verified ciphertext on success or nil on invalid ZK proof. +// Exits the process on errors. +func verifyZkProof(input []byte) []byte { + for try := uint8(1); try <= tomlConfig.Zk.VerifyRetryCount+1; try++ { + req, err := http.NewRequest(http.MethodPost, tomlConfig.Zk.VerifyRPCAddress, bytes.NewReader(input)) + if err != nil { + continue + } + req.Header.Add("Content-Type", "application/msgpack") + resp, err := zkHttpClient.Do(req) + if err != nil { + continue + } + // The ZKPoK service returns 406 if the proof is incorrect. + if resp.StatusCode == 406 { + return nil + } else if resp.StatusCode != 200 { + continue + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + continue + } + return body } - return body, nil + exitProcess() + return nil } func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) { @@ -1331,10 +1346,9 @@ func (e *verifyCiphertext) Run(accessibleState PrecompileAccessibleState, caller // For testing: if input size <= `ciphertextSize`, treat the whole input as ciphertext. ctBytes = input[0:minInt(ciphertextSize, len(input))] } else { - var err error - ctBytes, err = verifyZkProof(input) - if err != nil { - return nil, err + ctBytes = verifyZkProof(input) + if ctBytes == nil { + return nil, fmt.Errorf("invalid ZK Proof") } } ct := new(tfheCiphertext) @@ -1425,80 +1439,82 @@ func requireURL(key *string) string { return tomlConfig.Oracle.OracleDBAddress + "/require/" + *key } -func putRequire(ciphertext []byte, value bool) error { +// Puts the given ciphertext as a require to the oracle DB or exits the process on errors. +// Returns the require value. +func putRequire(ct *tfheCiphertext) bool { + ciphertext := ct.serialize() + value := (ct.decrypt() != 0) key := requireKey(ciphertext) j, err := json.Marshal(requireMessage{value, signRequire(ciphertext, value)}) if err != nil { - return err + exitProcess() } - req, err := http.NewRequest(http.MethodPut, requireURL(&key), bytes.NewReader(j)) - if err != nil { - return err - } - resp, err := requireHttpClient.Do(req) - if err != nil { - return err - } - if resp.StatusCode != 200 { - return fmt.Errorf("failure HTTP status code on require PUT: %d", resp.StatusCode) + for try := uint8(1); try <= tomlConfig.Oracle.RequireRetryCount+1; try++ { + req, err := http.NewRequest(http.MethodPut, requireURL(&key), bytes.NewReader(j)) + if err != nil { + continue + } + resp, err := requireHttpClient.Do(req) + if err != nil { + continue + } + if resp.StatusCode != 200 { + continue + } + return value } - return nil + exitProcess() + return value } -func getRequire(ciphertext []byte) (bool, error) { +// Gets the given require from the oracle DB and returns its value. +// Exits the process on errors or signature verification failure. +func getRequire(ct *tfheCiphertext) bool { + ciphertext := ct.serialize() key := requireKey(ciphertext) - req, err := http.NewRequest(http.MethodGet, requireURL(&key), http.NoBody) - if err != nil { - return false, nil - } - resp, err := requireHttpClient.Do(req) - if err != nil { - return false, err - } - if resp.StatusCode != 200 { - return false, fmt.Errorf("require: failure HTTP status code on require GET: %d", resp.StatusCode) - } - body, err := ioutil.ReadAll(resp.Body) - if err != nil { - return false, errors.New("failed reading response body") - } - msg := requireMessage{} - if err := json.Unmarshal(body, &msg); err != nil { - return false, err - } - b := requireBytesToSign(ciphertext, msg.Value) - s, err := hex.DecodeString(msg.Signature) - if err != nil { - return false, err - } - if ed25519.Verify(publicSignatureKey, b, s) { - return msg.Value, nil + for try := uint8(1); try <= tomlConfig.Oracle.RequireRetryCount+1; try++ { + req, err := http.NewRequest(http.MethodGet, requireURL(&key), http.NoBody) + if err != nil { + continue + } + resp, err := requireHttpClient.Do(req) + if err != nil { + continue + } + if resp.StatusCode != 200 { + continue + } + body, err := ioutil.ReadAll(resp.Body) + if err != nil { + continue + } + msg := requireMessage{} + if err := json.Unmarshal(body, &msg); err != nil { + continue + } + b := requireBytesToSign(ciphertext, msg.Value) + s, err := hex.DecodeString(msg.Signature) + if err != nil { + continue + } + if !ed25519.Verify(publicSignatureKey, b, s) { + continue + } + return msg.Value } - return false, errors.New("invalid require signature") + exitProcess() + return false } func evaluateRequire(ct *tfheCiphertext) bool { switch mode := strings.ToLower(tomlConfig.Oracle.Mode); mode { case "oracle": - requireValue := ct.decrypt() - if err := putRequire(ct.serialize(), requireValue != 0); err != nil { - panic(err) - } - if requireValue == 0 { - return false - } - return true + return putRequire(ct) case "node": - requireValue, err := getRequire(ct.serialize()) - if err != nil { - panic(err) - } - if !requireValue { - return false - } - return true + return getRequire(ct) } - panic(errors.New("unimplemented require mode")) + exitProcess() + return false } func (e *require) Run(accessibleState PrecompileAccessibleState, caller common.Address, addr common.Address, input []byte, readOnly bool) ([]byte, error) {