diff --git a/cmd/api/handlers/error.go b/cmd/api/handlers/error.go index 9a09de8d6..d4aa24aec 100644 --- a/cmd/api/handlers/error.go +++ b/cmd/api/handlers/error.go @@ -4,6 +4,7 @@ import ( "errors" "net/http" + "github.com/baking-bad/bcdhub/internal/bcd/ast" "github.com/baking-bad/bcdhub/internal/logger" "github.com/baking-bad/bcdhub/internal/models" sentrygin "github.com/getsentry/sentry-go/gin" @@ -39,6 +40,9 @@ func getErrorCode(err error, repo models.GeneralRepository) int { if repo.IsRecordNotFound(err) { return http.StatusNotFound } + if errors.Is(err, ast.ErrValidation) { + return http.StatusBadRequest + } return http.StatusInternalServerError } diff --git a/internal/bcd/ast/address.go b/internal/bcd/ast/address.go index 170854f1a..20565cb51 100644 --- a/internal/bcd/ast/address.go +++ b/internal/bcd/ast/address.go @@ -115,8 +115,7 @@ func (a *Address) Distinguish(x Distinguishable) (*MiguelNode, error) { // FromJSONSchema - func (a *Address) FromJSONSchema(data map[string]interface{}) error { - setOptimizedJSONSchema(&a.Default, data, forge.UnforgeContract) - return nil + return setOptimizedJSONSchema(&a.Default, data, forge.UnforgeContract, AddressValidator) } // FindByName - diff --git a/internal/bcd/ast/baker_hash.go b/internal/bcd/ast/baker_hash.go index f4431a3c6..4eb842f56 100644 --- a/internal/bcd/ast/baker_hash.go +++ b/internal/bcd/ast/baker_hash.go @@ -82,8 +82,7 @@ func (s *BakerHash) Distinguish(x Distinguishable) (*MiguelNode, error) { // FromJSONSchema - func (s *BakerHash) FromJSONSchema(data map[string]interface{}) error { - setOptimizedJSONSchema(&s.Default, data, forge.UnforgeBakerHash) - return nil + return setOptimizedJSONSchema(&s.Default, data, forge.UnforgeBakerHash, BakerHashValidator) } // FindByName - diff --git a/internal/bcd/ast/chain_id.go b/internal/bcd/ast/chain_id.go index 053d39168..4ea6acfb0 100644 --- a/internal/bcd/ast/chain_id.go +++ b/internal/bcd/ast/chain_id.go @@ -89,8 +89,7 @@ func (c *ChainID) Distinguish(x Distinguishable) (*MiguelNode, error) { // FromJSONSchema - func (c *ChainID) FromJSONSchema(data map[string]interface{}) error { - setOptimizedJSONSchema(&c.Default, data, forge.UnforgeChainID) - return nil + return setOptimizedJSONSchema(&c.Default, data, forge.UnforgeChainID, ChainIDValidator) } // FindByName - diff --git a/internal/bcd/ast/default.go b/internal/bcd/ast/default.go index c159b6f1f..93abc914b 100644 --- a/internal/bcd/ast/default.go +++ b/internal/bcd/ast/default.go @@ -194,12 +194,11 @@ func (d *Default) ToJSONSchema() (*JSONSchema, error) { // FromJSONSchema - func (d *Default) FromJSONSchema(data map[string]interface{}) error { - for key := range data { - if key == d.GetName() { - d.Value = data[key] - d.ValueKind = valueKindString - break - } + key := d.GetName() + + if value, ok := data[key]; ok { + d.Value = value + d.ValueKind = valueKindString } return nil } diff --git a/internal/bcd/ast/jsonschema.go b/internal/bcd/ast/jsonschema.go index 787c6d76a..95cc3ab36 100644 --- a/internal/bcd/ast/jsonschema.go +++ b/internal/bcd/ast/jsonschema.go @@ -40,6 +40,9 @@ type JSONSchema struct { XOptions map[string]interface{} `json:"x-options,omitempty"` } +// Optimizer - +type Optimizer func(string) (string, error) + func getStringJSONSchema(d Default) *JSONSchema { return wrapObject(&JSONSchema{ Prim: d.Prim, @@ -70,49 +73,57 @@ func getAddressJSONSchema(d Default) *JSONSchema { } func setIntJSONSchema(d *Default, data map[string]interface{}) { - for key := range data { - if key == d.GetName() { - switch v := data[key].(type) { - case float64: - i := big.NewInt(0) - i, _ = big.NewFloat(v).Int(i) - d.Value = &types.BigInt{Int: i} - case string: - d.Value = types.NewBigIntFromString(v) - } - d.ValueKind = valueKindInt - break + key := d.GetName() + if value, ok := data[key]; ok { + switch v := value.(type) { + case float64: + i := big.NewInt(0) + i, _ = big.NewFloat(v).Int(i) + d.Value = &types.BigInt{Int: i} + case string: + d.Value = types.NewBigIntFromString(v) } + d.ValueKind = valueKindInt } } func setBytesJSONSchema(d *Default, data map[string]interface{}) error { - for key := range data { - if key == d.GetName() { - if _, err := hex.DecodeString(data[key].(string)); err != nil { - return errors.Errorf("invalid bytes string: %s=%v", key, data[key]) + key := d.GetName() + if value, ok := data[key]; ok { + if str, ok := value.(string); ok { + if err := BytesValidator(str); err != nil { + return err + } + if _, err := hex.DecodeString(str); err != nil { + return errors.Errorf("bytes decoding error: %s=%v", key, value) } - d.Value = data[key] + d.Value = value d.ValueKind = valueKindBytes - return nil } } return nil } -func setOptimizedJSONSchema(d *Default, data map[string]interface{}, optimizer func(string) (string, error)) { - for key, value := range data { - if key == d.GetName() { - val, err := optimizer(value.(string)) +func setOptimizedJSONSchema(d *Default, data map[string]interface{}, optimizer Optimizer, validator Validator[string]) error { + key := d.GetName() + if value, ok := data[key]; ok { + if str, ok := value.(string); ok { + if validator != nil { + if err := validator(str); err != nil { + return err + } + } + + val, err := optimizer(str) if err != nil { - val = value.(string) + val = str } d.ValueKind = valueKindString d.Value = val - break } } + return nil } type mergeFields struct { diff --git a/internal/bcd/ast/key.go b/internal/bcd/ast/key.go index 0b4f23d9e..33b6b9367 100644 --- a/internal/bcd/ast/key.go +++ b/internal/bcd/ast/key.go @@ -83,8 +83,7 @@ func (k *Key) Distinguish(x Distinguishable) (*MiguelNode, error) { // FromJSONSchema - func (k *Key) FromJSONSchema(data map[string]interface{}) error { - setOptimizedJSONSchema(&k.Default, data, forge.UnforgePublicKey) - return nil + return setOptimizedJSONSchema(&k.Default, data, forge.UnforgePublicKey, PublicKeyValidator) } // FindByName - diff --git a/internal/bcd/ast/key_hash.go b/internal/bcd/ast/key_hash.go index 082ad339e..750986f19 100644 --- a/internal/bcd/ast/key_hash.go +++ b/internal/bcd/ast/key_hash.go @@ -83,8 +83,7 @@ func (k *KeyHash) Distinguish(x Distinguishable) (*MiguelNode, error) { // FromJSONSchema - func (k *KeyHash) FromJSONSchema(data map[string]interface{}) error { - setOptimizedJSONSchema(&k.Default, data, forge.UnforgeAddress) - return nil + return setOptimizedJSONSchema(&k.Default, data, forge.UnforgeAddress, nil) } // FindByName - diff --git a/internal/bcd/ast/signature.go b/internal/bcd/ast/signature.go index bfbbaafb5..7ced8e89c 100644 --- a/internal/bcd/ast/signature.go +++ b/internal/bcd/ast/signature.go @@ -82,8 +82,7 @@ func (s *Signature) Distinguish(x Distinguishable) (*MiguelNode, error) { // FromJSONSchema - func (s *Signature) FromJSONSchema(data map[string]interface{}) error { - setOptimizedJSONSchema(&s.Default, data, forge.UnforgeSignature) - return nil + return setOptimizedJSONSchema(&s.Default, data, forge.UnforgeSignature, SignatureValidator) } // FindByName - diff --git a/internal/bcd/ast/timestamp.go b/internal/bcd/ast/timestamp.go index 3c39b229f..c3aa35a9e 100644 --- a/internal/bcd/ast/timestamp.go +++ b/internal/bcd/ast/timestamp.go @@ -77,21 +77,20 @@ func (t *Timestamp) ToBaseNode(optimized bool) (*base.Node, error) { // FromJSONSchema - func (t *Timestamp) FromJSONSchema(data map[string]interface{}) error { - for key := range data { - if key == t.GetName() { - t.ValueKind = valueKindInt - switch val := data[key].(type) { - case string: - ts, err := time.Parse(time.RFC3339, val) - if err != nil { - return err - } - t.Value = types.NewBigInt(ts.UTC().Unix()) - case float64: - t.Value = types.NewBigInt(int64(val)) + key := t.GetName() + if value, ok := data[key]; ok { + t.ValueKind = valueKindInt + switch val := value.(type) { + case string: + ts, err := time.Parse(time.RFC3339, val) + if err != nil { + return errors.Wrapf(ErrValidation, "time should be in RFC3339 %s=%s", key, val) } - break + t.Value = types.NewBigInt(ts.UTC().Unix()) + case float64: + t.Value = types.NewBigInt(int64(val)) } + } return nil } diff --git a/internal/bcd/ast/validators.go b/internal/bcd/ast/validators.go new file mode 100644 index 000000000..89a234d38 --- /dev/null +++ b/internal/bcd/ast/validators.go @@ -0,0 +1,136 @@ +package ast + +import ( + "regexp" + "strings" + + "github.com/baking-bad/bcdhub/internal/bcd" + "github.com/baking-bad/bcdhub/internal/bcd/encoding" + "github.com/pkg/errors" +) + +// errors +var ( + ErrValidation = errors.New("validation error") +) + +// ValidatorConstraint - +type ValidatorConstraint interface { + ~string +} + +// Validator - +type Validator[T ValidatorConstraint] func(T) error + +var ( + hexRegex = regexp.MustCompile("^[0-9a-fA-F]+$") +) + +// AddressValidator - +func AddressValidator(value string) error { + switch len(value) { + case 44, 42: + if !hexRegex.MatchString(value) { + return errors.Wrapf(ErrValidation, "address '%s' should be hexademical without prefixes", value) + } + case 36: + if !bcd.IsAddressLazy(value) { + return errors.Wrapf(ErrValidation, "invalid address '%s'", value) + } + if !bcd.IsAddress(value) { + return errors.Wrapf(ErrValidation, "invalid address '%s'", value) + } + default: + return errors.Wrap(ErrValidation, "invalid address length") + } + + return nil +} + +// BakerHashValidator - +func BakerHashValidator(value string) error { + switch len(value) { + case 40: + if !hexRegex.MatchString(value) { + return errors.Wrapf(ErrValidation, "baker hash '%s' should be hexademical without prefixes", value) + } + case 36: + if !bcd.IsBakerHash(value) { + return errors.Wrapf(ErrValidation, "invalid baker hash '%s'", value) + } + default: + return errors.Wrap(ErrValidation, "invalid baker hash length") + } + + return nil +} + +// PublicKeyValidator - +func PublicKeyValidator(value string) error { + switch len(value) { + case 68, 66: + if !hexRegex.MatchString(value) { + return errors.Wrapf(ErrValidation, "public key '%s' should be hexademical without prefixes", value) + } + case 55, 54: + if strings.HasPrefix(value, encoding.PrefixED25519PublicKey) || + strings.HasPrefix(value, encoding.PrefixP256PublicKey) || + strings.HasPrefix(value, encoding.PrefixSecp256k1PublicKey) { + return nil + } + return errors.Wrapf(ErrValidation, "invalid public key '%s'", value) + default: + return errors.Wrap(ErrValidation, "invalid public key length") + } + + return nil +} + +// BytesValidator - +func BytesValidator(value string) error { + if len(value)%2 > 0 { + return errors.Wrapf(ErrValidation, "invalid bytes in hex length '%s'", value) + } + if !hexRegex.MatchString(value) { + return errors.Wrapf(ErrValidation, "bytes '%s' should be hexademical without prefixes", value) + } + return nil +} + +// ChainIDValidator - +func ChainIDValidator(value string) error { + switch len(value) { + case 8: + if !hexRegex.MatchString(value) { + return errors.Wrapf(ErrValidation, "chain id '%s' should be hexademical without prefixes", value) + } + case 15: + if strings.HasPrefix(value, encoding.PrefixChainID) { + return nil + } + return errors.Wrapf(ErrValidation, "invalid chain id '%s'", value) + default: + return errors.Wrap(ErrValidation, "invalid chain id length") + } + + return nil +} + +// SignatureValidator - +func SignatureValidator(value string) error { + switch len(value) { + case 128: + if !hexRegex.MatchString(value) { + return errors.Wrapf(ErrValidation, "signature '%s' should be hexademical without prefixes", value) + } + case 96: + if strings.HasPrefix(value, encoding.PrefixGenericSignature) { + return nil + } + return errors.Wrapf(ErrValidation, "invalid signature '%s'", value) + default: + return errors.Wrap(ErrValidation, "invalid signature length") + } + + return nil +} diff --git a/internal/bcd/ast/validators_test.go b/internal/bcd/ast/validators_test.go new file mode 100644 index 000000000..bd822b931 --- /dev/null +++ b/internal/bcd/ast/validators_test.go @@ -0,0 +1,195 @@ +package ast + +import ( + "testing" +) + +func TestAddressValidator(t *testing.T) { + tests := []struct { + name string + value string + wantErr bool + }{ + { + name: "test 1", + value: "016e4943f7a23ab9cbe56f48ff72f6c27e8956762400", + wantErr: false, + }, { + name: "test 2", + value: "KT1JdufSdfg3WyxWJcCRNsBFV9V3x9TQBkJ2", + wantErr: false, + }, { + name: "test 3", + value: "tz1KfEsrtDaA1sX7vdM4qmEPWuSytuqCDp5j", + wantErr: false, + }, { + name: "test 4", + value: "tz1KfEsrtDaA1sX7vdM4qmEPWuSytuqCDp5", + wantErr: true, + }, { + name: "test 5", + value: "0x6e4943f7a23ab9cbe56f48ff72f6c27e8956762400", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := AddressValidator(tt.value); (err != nil) != tt.wantErr { + t.Errorf("AddressValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestBakerHashValidator(t *testing.T) { + tests := []struct { + name string + value string + wantErr bool + }{ + { + name: "test 1", + value: "94697e9229c88fac7d19d62e139ca6735f9569dd", + wantErr: false, + }, { + name: "test 2", + value: "SG1d1wsgMKvSstzZQ8L4WoskCesdWGzVt5k4", + wantErr: false, + }, { + name: "test 3", + value: "SG1d1wsgMKvSstzZQ8L4WoskCesdWGzVt5k", + wantErr: true, + }, { + name: "test 4", + value: "0x697e9229c88fac7d19d62e139ca6735f9569dd", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := BakerHashValidator(tt.value); (err != nil) != tt.wantErr { + t.Errorf("BakerHashValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestPublicKeyValidator(t *testing.T) { + tests := []struct { + name string + value string + wantErr bool + }{ + { + name: "test 1", + value: "sppk7bMuoa8w2LSKz3XEuPsKx1WavsMLCWgbWG9CZNAsJg9eTmkXRPd", + wantErr: false, + }, { + name: "test 2", + value: "030ed412d33412ab4b71df0aaba07df7ddd2a44eb55c87bf81868ba09a358bc0e0", + wantErr: false, + }, { + name: "test 3", + value: "0x0ed412d33412ab4b71df0aaba07df7ddd2a44eb55c87bf81868ba09a358bc0e0", + wantErr: true, + }, { + name: "test 4", + value: "spsk7bMuoa8w2LSKz3XEuPsKx1WavsMLCWgbWG9CZNAsJg9eTmkXRPd", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := PublicKeyValidator(tt.value); (err != nil) != tt.wantErr { + t.Errorf("PublicKeyValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestBytesValidator(t *testing.T) { + tests := []struct { + name string + value string + wantErr bool + }{ + { + name: "test 1", + value: "sppk7bMuoa8w2LSKz3XEuPsKx1WavsMLCWgbWG9CZNAsJg9eTmkXRPd", + wantErr: true, + }, { + name: "test 2", + value: "030ed412d33412ab4b71df0aaba07df7ddd2a44eb55c87bf81868ba09a358bc0e0", + wantErr: false, + }, { + name: "test 3", + value: "0x030ed412d33412ab4b71df0aaba07df7ddd2a44eb55c87bf81868ba09a358bc0e0", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := BytesValidator(tt.value); (err != nil) != tt.wantErr { + t.Errorf("BytesValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestChainIDValidator(t *testing.T) { + tests := []struct { + name string + value string + wantErr bool + }{ + { + name: "test 1", + value: "7a06a770", + wantErr: false, + }, { + name: "test 2", + value: "NetXdQprcVkpaWU", + wantErr: false, + }, { + name: "test 3", + value: "0x06a770", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := ChainIDValidator(tt.value); (err != nil) != tt.wantErr { + t.Errorf("ChainIDValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func TestSignatureValidator(t *testing.T) { + tests := []struct { + name string + value string + wantErr bool + }{ + { + name: "test 1", + value: "a04991b4e938cc42d6c01c42be3649a81a9f80d244d9b90e7ec4edf8e0a7b68b6c212da2fef076e48fed66802fa83442b960a36afdb3e60c3cf14d4010f41f03", + wantErr: false, + }, { + name: "test 2", + value: "sigixZejtj1GfDpyiWAQAmvbtnNmCXKyADqVvCaXJH9xHyhSnYYV8696Z3kkns5DNV7oMnMPfNzo3qm84DfEx1XG6saZmHiA", + wantErr: false, + }, { + name: "test 3", + value: "0x4991b4e938cc42d6c01c42be3649a81a9f80d244d9b90e7ec4edf8e0a7b68b6c212da2fef076e48fed66802fa83442b960a36afdb3e60c3cf14d4010f41f03", + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if err := SignatureValidator(tt.value); (err != nil) != tt.wantErr { + t.Errorf("SignatureValidator() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} diff --git a/internal/bcd/literal.go b/internal/bcd/literal.go index 845b790f7..a76ebb70f 100644 --- a/internal/bcd/literal.go +++ b/internal/bcd/literal.go @@ -32,8 +32,9 @@ func IsAddressLazy(str string) bool { } var ( - addressRegex = regexp.MustCompile("(tz1|tz2|tz3|KT1)[0-9A-Za-z]{33}") - contractRegex = regexp.MustCompile("(KT1)[0-9A-Za-z]{33}") + addressRegex = regexp.MustCompile("(tz1|tz2|tz3|KT1)[0-9A-Za-z]{33}") + contractRegex = regexp.MustCompile("(KT1)[0-9A-Za-z]{33}") + bakerHashRegex = regexp.MustCompile("(SG1)[0-9A-Za-z]{33}") ) // IsAddress - @@ -45,3 +46,8 @@ func IsAddress(str string) bool { func IsContract(str string) bool { return contractRegex.MatchString(str) } + +// IsBakerHash - +func IsBakerHash(str string) bool { + return bakerHashRegex.MatchString(str) +} diff --git a/internal/bcd/literal_test.go b/internal/bcd/literal_test.go index 15c3cd2f3..bd027ea0c 100644 --- a/internal/bcd/literal_test.go +++ b/internal/bcd/literal_test.go @@ -30,3 +30,54 @@ func TestIsContract(t *testing.T) { }) } } + +func TestIsAddress(t *testing.T) { + tests := []struct { + name string + address string + want bool + }{ + { + name: "KT1Ap287P1NzsnToSJdA4aqSNjPomRaHBZSr", + address: "KT1Ap287P1NzsnToSJdA4aqSNjPomRaHBZSr", + want: true, + }, { + name: "tz1dMH7tW7RhdvVMR4wKVFF1Ke8m8ZDvrTTE", + address: "tz1dMH7tW7RhdvVMR4wKVFF1Ke8m8ZDvrTTE", + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsAddress(tt.address); got != tt.want { + t.Errorf("IsAddress() = %v, want %v", got, tt.want) + + } + }) + } +} + +func TestIsBakerHash(t *testing.T) { + tests := []struct { + name string + str string + want bool + }{ + { + name: "SG1d1wsgMKvSstzZQ8L4WoskCesdWGzVt5k4", + str: "SG1d1wsgMKvSstzZQ8L4WoskCesdWGzVt5k4", + want: true, + }, { + name: "SG1d1wsgMKvSstzZQ8L4WoskCesdWGzVt5k", + str: "SG1d1wsgMKvSstzZQ8L4WoskCesdWGzVt5k", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := IsBakerHash(tt.str); got != tt.want { + t.Errorf("IsBakerHash() = %v, want %v", got, tt.want) + } + }) + } +}