Skip to content

Commit

Permalink
Introduce API key (#220)
Browse files Browse the repository at this point in the history
  • Loading branch information
ridenaio authored Dec 11, 2019
1 parent b35c4bc commit 60bc4d1
Show file tree
Hide file tree
Showing 14 changed files with 170 additions and 22 deletions.
34 changes: 34 additions & 0 deletions config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

const (
datadirPrivateKey = "nodekey" // Path within the datadir to the node's private key
apiKeyFileName = "api.key"
LowPowerProfile = "lowpower"
)

Expand Down Expand Up @@ -133,6 +134,33 @@ func (c *Config) KeyStoreDataDir() (string, error) {
return instanceDir, nil
}

func (c *Config) SetApiKey() error {
shouldSaveKey := true
if c.RPC.APIKey == "" {
apiKeyFile := filepath.Join(c.DataDir, apiKeyFileName)
data, _ := ioutil.ReadFile(apiKeyFile)
key := string(data)
if key == "" {
randomKey, _ := crypto.GenerateKey()
key = hex.EncodeToString(crypto.FromECDSA(randomKey)[:16])
} else {
shouldSaveKey = false
}
c.RPC.APIKey = key
}

if shouldSaveKey {
f, err := os.OpenFile(filepath.Join(c.DataDir, apiKeyFileName), os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0666)
if err != nil {
return err
}
defer f.Close()
_, err = f.WriteString(c.RPC.APIKey)
return err
}
return nil
}

func MakeMobileConfig(path string, cfg string) (*Config, error) {
conf := getDefaultConfig(filepath.Join(path, DefaultDataDir))

Expand Down Expand Up @@ -301,6 +329,12 @@ func applyRpcFlags(ctx *cli.Context, cfg *Config) {
if ctx.IsSet(RpcPortFlag.Name) {
cfg.RPC.HTTPPort = ctx.Int(RpcPortFlag.Name)
}
if ctx.IsSet(ApiKeyFlag.Name) {
cfg.RPC.APIKey = ctx.String(ApiKeyFlag.Name)
if cfg.RPC.APIKey != "" {
cfg.RPC.UseApiKey = true
}
}
}

func applyGenesisFlags(ctx *cli.Context, cfg *Config) {
Expand Down
4 changes: 4 additions & 0 deletions config/flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,4 +102,8 @@ var (
Name: "ipfsportstatic",
Usage: "Enable static ipfs port",
}
ApiKeyFlag = cli.StringFlag{
Name: "apikey",
Usage: "Set RPC api key",
}
)
1 change: 1 addition & 0 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func main() {
config.ForceFullSyncFlag,
config.ProfileFlag,
config.IpfsPortStaticFlag,
config.ApiKeyFlag,
}

app.Action = func(context *cli.Context) error {
Expand Down
18 changes: 15 additions & 3 deletions node/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ import (
"github.com/idena-network/idena-go/rpc"
"github.com/idena-network/idena-go/secstore"
"github.com/idena-network/idena-go/stats/collector"
"github.com/pkg/errors"
"net"
"os"
"path/filepath"
Expand Down Expand Up @@ -140,6 +141,11 @@ func NewNodeWithInjections(config *config.Config, bus eventbus.Bus, blockStatsCo
return nil, err
}

err = config.SetApiKey()
if err != nil {
return nil, errors.Wrap(err, "cannot set API key")
}

ipfsProxy, err := ipfs.NewIpfsProxy(config.IpfsConf)
if err != nil {
return nil, err
Expand Down Expand Up @@ -280,7 +286,13 @@ func (node *Node) startRPC() error {
// Gather all the possible APIs to surface
apis := node.apis()

if err := node.startHTTP(node.config.RPC.HTTPEndpoint(), apis, node.config.RPC.HTTPModules, node.config.RPC.HTTPCors, node.config.RPC.HTTPVirtualHosts, node.config.RPC.HTTPTimeouts); err != nil {
// TODO: remove later
apiKey := node.config.RPC.APIKey
if !node.config.RPC.UseApiKey {
apiKey = ""
}

if err := node.startHTTP(node.config.RPC.HTTPEndpoint(), apis, node.config.RPC.HTTPModules, node.config.RPC.HTTPCors, node.config.RPC.HTTPVirtualHosts, node.config.RPC.HTTPTimeouts, apiKey); err != nil {
return err
}

Expand All @@ -289,12 +301,12 @@ func (node *Node) startRPC() error {
}

// startHTTP initializes and starts the HTTP RPC endpoint.
func (node *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts) error {
func (node *Node) startHTTP(endpoint string, apis []rpc.API, modules []string, cors []string, vhosts []string, timeouts rpc.HTTPTimeouts, apiKey string) error {
// Short circuit if the HTTP endpoint isn't being exposed
if endpoint == "" {
return nil
}
listener, handler, err := rpc.StartHTTPEndpoint(endpoint, apis, modules, cors, vhosts, timeouts)
listener, handler, err := rpc.StartHTTPEndpoint(endpoint, apis, modules, cors, vhosts, timeouts, apiKey)
if err != nil {
return err
}
Expand Down
2 changes: 1 addition & 1 deletion rpc/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,7 @@ func TestClientReconnect(t *testing.T) {
}

func newTestServer(serviceName string, service interface{}) *Server {
server := NewServer()
server := NewServer("")
if err := server.RegisterName(serviceName, service); err != nil {
panic(err)
}
Expand Down
3 changes: 3 additions & 0 deletions rpc/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ type Config struct {
// default zero value is/ valid and will pick a port number randomly (useful
// for ephemeral nodes).
HTTPPort int `toml:",omitempty"`

APIKey string
UseApiKey bool
}

func (c *Config) HTTPEndpoint() string {
Expand Down
8 changes: 4 additions & 4 deletions rpc/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@ import (
)

// StartHTTPEndpoint starts the HTTP RPC endpoint, configured with cors/vhosts/modules
func StartHTTPEndpoint(endpoint string, apis []API, modules []string, cors []string, vhosts []string, timeouts HTTPTimeouts) (net.Listener, *Server, error) {
func StartHTTPEndpoint(endpoint string, apis []API, modules []string, cors []string, vhosts []string, timeouts HTTPTimeouts, apiKey string) (net.Listener, *Server, error) {
// Generate the whitelist based on the allowed modules
whitelist := make(map[string]bool)
for _, module := range modules {
whitelist[module] = true
}
// Register all the APIs exposed by the services
handler := NewServer()
handler := NewServer(apiKey)
for _, api := range apis {
if whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) {
if err := handler.RegisterName(api.Namespace, api.Service); err != nil {
Expand Down Expand Up @@ -60,7 +60,7 @@ func StartWSEndpoint(endpoint string, apis []API, modules []string, wsOrigins []
whitelist[module] = true
}
// Register all the APIs exposed by the services
handler := NewServer()
handler := NewServer("")
for _, api := range apis {
if exposeAll || whitelist[api.Namespace] || (len(whitelist) == 0 && api.Public) {
if err := handler.RegisterName(api.Namespace, api.Service); err != nil {
Expand All @@ -85,7 +85,7 @@ func StartWSEndpoint(endpoint string, apis []API, modules []string, wsOrigins []
// StartIPCEndpoint starts an IPC endpoint.
func StartIPCEndpoint(ipcEndpoint string, apis []API) (net.Listener, *Server, error) {
// Register all the APIs exposed by the services.
handler := NewServer()
handler := NewServer("")
for _, api := range apis {
if err := handler.RegisterName(api.Namespace, api.Service); err != nil {
return nil, nil, err
Expand Down
7 changes: 7 additions & 0 deletions rpc/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,10 @@ type shutdownError struct{}
func (e *shutdownError) ErrorCode() int { return -32000 }

func (e *shutdownError) Error() string { return "server is shutting down" }

// invalid api key
type invalidApiKeyError struct{}

func (e *invalidApiKeyError) ErrorCode() int { return -32800 }

func (e *invalidApiKeyError) Error() string { return "the provided API key is invalid" }
17 changes: 9 additions & 8 deletions rpc/json.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ const (
)

type jsonRequest struct {
Key string `json:"key"`
Method string `json:"method"`
Version string `json:"jsonrpc"`
Id json.RawMessage `json:"id,omitempty"`
Expand Down Expand Up @@ -181,7 +182,7 @@ func parseRequest(incomingMsg json.RawMessage) ([]rpcRequest, bool, Error) {

// subscribe are special, they will always use `subscribeMethod` as first param in the payload
if strings.HasSuffix(in.Method, subscribeMethodSuffix) {
reqs := []rpcRequest{{id: &in.Id, isPubSub: true}}
reqs := []rpcRequest{{id: &in.Id, isPubSub: true, key: in.Key}}
if len(in.Payload) > 0 {
// first param must be subscription name
var subscribeMethod [1]string
Expand All @@ -199,7 +200,7 @@ func parseRequest(incomingMsg json.RawMessage) ([]rpcRequest, bool, Error) {

if strings.HasSuffix(in.Method, unsubscribeMethodSuffix) {
return []rpcRequest{{id: &in.Id, isPubSub: true,
method: in.Method, params: in.Payload}}, false, nil
method: in.Method, params: in.Payload, key: in.Key}}, false, nil
}

elems := strings.Split(in.Method, serviceMethodSeparator)
Expand All @@ -209,10 +210,10 @@ func parseRequest(incomingMsg json.RawMessage) ([]rpcRequest, bool, Error) {

// regular RPC call
if len(in.Payload) == 0 {
return []rpcRequest{{service: elems[0], method: elems[1], id: &in.Id}}, false, nil
return []rpcRequest{{service: elems[0], method: elems[1], id: &in.Id, key: in.Key}}, false, nil
}

return []rpcRequest{{service: elems[0], method: elems[1], id: &in.Id, params: in.Payload}}, false, nil
return []rpcRequest{{service: elems[0], method: elems[1], id: &in.Id, params: in.Payload, key: in.Key}}, false, nil
}

// parseBatchRequest will parse a batch request into a collection of requests from the given RawMessage, an indication
Expand All @@ -233,7 +234,7 @@ func parseBatchRequest(incomingMsg json.RawMessage) ([]rpcRequest, bool, Error)

// subscribe are special, they will always use `subscriptionMethod` as first param in the payload
if strings.HasSuffix(r.Method, subscribeMethodSuffix) {
requests[i] = rpcRequest{id: id, isPubSub: true}
requests[i] = rpcRequest{id: id, isPubSub: true, key: r.Key}
if len(r.Payload) > 0 {
// first param must be subscription name
var subscribeMethod [1]string
Expand All @@ -251,14 +252,14 @@ func parseBatchRequest(incomingMsg json.RawMessage) ([]rpcRequest, bool, Error)
}

if strings.HasSuffix(r.Method, unsubscribeMethodSuffix) {
requests[i] = rpcRequest{id: id, isPubSub: true, method: r.Method, params: r.Payload}
requests[i] = rpcRequest{id: id, isPubSub: true, method: r.Method, params: r.Payload, key: r.Key}
continue
}

if len(r.Payload) == 0 {
requests[i] = rpcRequest{id: id, params: nil}
requests[i] = rpcRequest{id: id, params: nil, key: r.Key}
} else {
requests[i] = rpcRequest{id: id, params: r.Payload}
requests[i] = rpcRequest{id: id, params: r.Payload, key: r.Key}
}
if elem := strings.Split(r.Method, serviceMethodSeparator); len(elem) == 2 {
requests[i].service, requests[i].method = elem[0], elem[1]
Expand Down
2 changes: 1 addition & 1 deletion rpc/json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ func (rwc *RWC) Close() error {
}

func TestJSONRequestParsing(t *testing.T) {
server := NewServer()
server := NewServer("")
service := new(Service)

if err := server.RegisterName("calc", service); err != nil {
Expand Down
8 changes: 7 additions & 1 deletion rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,9 @@ const (
)

// NewServer will create a new server instance with no registered handlers.
func NewServer() *Server {
func NewServer(apiKey string) *Server {
server := &Server{
apiKey: apiKey,
services: make(serviceRegistry),
codecs: mapset.NewSet(),
run: 1,
Expand Down Expand Up @@ -389,6 +390,11 @@ func (s *Server) readRequest(codec ServerCodec) ([]*serverRequest, bool, Error)
continue
}

if s.apiKey != "" && r.key != s.apiKey {
requests[i] = &serverRequest{id: r.id, err: &invalidApiKeyError{}}
continue
}

if r.isPubSub && strings.HasSuffix(r.method, unsubscribeMethodSuffix) {
requests[i] = &serverRequest{id: r.id, isUnsubscribe: true}
argTypes := []reflect.Type{reflect.TypeOf("")} // expect subscription id as first arg
Expand Down
82 changes: 80 additions & 2 deletions rpc/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ func (s *Service) Subscription(ctx context.Context) (*Subscription, error) {
}

func TestServerRegisterName(t *testing.T) {
server := NewServer()
server := NewServer("")
service := new(Service)

if err := server.RegisterName("calc", service); err != nil {
Expand All @@ -102,7 +102,7 @@ func TestServerRegisterName(t *testing.T) {
}

func testServerMethodExecution(t *testing.T, method string) {
server := NewServer()
server := NewServer("")
service := new(Service)

if err := server.RegisterName("test", service); err != nil {
Expand Down Expand Up @@ -160,3 +160,81 @@ func TestServerMethodExecution(t *testing.T) {
func TestServerMethodWithCtx(t *testing.T) {
testServerMethodExecution(t, "echoWithCtx")
}

func TestServerMethodExecutionWithApiKeyProvided(t *testing.T) {
testApiKey(t, true)
}

func TestServerMethodExecutionWithoutApiKeyProvided(t *testing.T) {
testApiKey(t, false)
}

func testApiKey(t *testing.T, sendKey bool) {
apiKey := "tempKey"
server := NewServer(apiKey)
service := new(Service)

if err := server.RegisterName("test", service); err != nil {
t.Fatalf("%v", err)
}

stringArg := "string arg"
intArg := 1122
argsArg := &Args{"abcde"}
params := []interface{}{stringArg, intArg, argsArg}

request := map[string]interface{}{
"id": 12345,
"method": "test_echo",
"version": "2.0",
"params": params,
}

if sendKey {
request["key"] = apiKey
}

clientConn, serverConn := net.Pipe()
defer clientConn.Close()

go server.ServeCodec(NewJSONCodec(serverConn), OptionMethodInvocation)

out := json.NewEncoder(clientConn)
in := json.NewDecoder(clientConn)

if err := out.Encode(request); err != nil {
t.Fatal(err)
}

if sendKey {
response := jsonSuccessResponse{Result: &Result{}}
if err := in.Decode(&response); err != nil {
t.Fatal(err)
}

if result, ok := response.Result.(*Result); ok {
if result.String != stringArg {
t.Errorf("expected %s, got : %s\n", stringArg, result.String)
}
if result.Int != intArg {
t.Errorf("expected %d, got %d\n", intArg, result.Int)
}
if !reflect.DeepEqual(result.Args, argsArg) {
t.Errorf("expected %v, got %v\n", argsArg, result)
}
} else {
t.Fatalf("invalid response: expected *Result - got: %T", response.Result)
}

} else {
response := jsonErrResponse{Error: jsonError{}}
if err := in.Decode(&response); err != nil {
t.Fatal(err)
}

invalidKeyError := &invalidApiKeyError{}
if response.Error.Message != invalidKeyError.Error() {
t.Errorf("expected %v, got %v\n", invalidKeyError.Error(), response.Error.Message)
}
}
}
Loading

0 comments on commit 60bc4d1

Please sign in to comment.