diff --git a/influxql/ast.go b/influxql/ast.go index 6afafead079..080e5e21774 100644 --- a/influxql/ast.go +++ b/influxql/ast.go @@ -594,6 +594,11 @@ func (s *DropRetentionPolicyStatement) RequiredPrivileges() (ExecutionPrivileges return ExecutionPrivileges{{Admin: false, Name: s.Database, Privilege: WritePrivilege}}, nil } +// DefaultDatabase returns the default database from the statement. +func (s *DropRetentionPolicyStatement) DefaultDatabase() string { + return s.Database +} + // CreateUserStatement represents a command for creating a new user. type CreateUserStatement struct { // Name of the user to be created. @@ -704,6 +709,11 @@ func (s *GrantStatement) RequiredPrivileges() (ExecutionPrivileges, error) { return ExecutionPrivileges{{Admin: true, Name: "", Privilege: AllPrivileges}}, nil } +// DefaultDatabase returns the default database from the statement. +func (s *GrantStatement) DefaultDatabase() string { + return s.On +} + // GrantAdminStatement represents a command for granting admin privilege. type GrantAdminStatement struct { // Who to grant the privilege to. @@ -802,6 +812,11 @@ func (s *RevokeStatement) RequiredPrivileges() (ExecutionPrivileges, error) { return ExecutionPrivileges{{Admin: true, Name: "", Privilege: AllPrivileges}}, nil } +// DefaultDatabase returns the default database from the statement. +func (s *RevokeStatement) DefaultDatabase() string { + return s.On +} + // RevokeAdminStatement represents a command to revoke admin privilege from a user. type RevokeAdminStatement struct { // Who to revoke admin privilege from. @@ -868,6 +883,11 @@ func (s *CreateRetentionPolicyStatement) RequiredPrivileges() (ExecutionPrivileg return ExecutionPrivileges{{Admin: true, Name: "", Privilege: AllPrivileges}}, nil } +// DefaultDatabase returns the default database from the statement. +func (s *CreateRetentionPolicyStatement) DefaultDatabase() string { + return s.Database +} + // AlterRetentionPolicyStatement represents a command to alter an existing retention policy. type AlterRetentionPolicyStatement struct { // Name of policy to alter. @@ -924,6 +944,11 @@ func (s *AlterRetentionPolicyStatement) RequiredPrivileges() (ExecutionPrivilege return ExecutionPrivileges{{Admin: true, Name: "", Privilege: AllPrivileges}}, nil } +// DefaultDatabase returns the default database from the statement. +func (s *AlterRetentionPolicyStatement) DefaultDatabase() string { + return s.Database +} + // FillOption represents different options for filling aggregate windows. type FillOption int @@ -2824,6 +2849,11 @@ func (s *DropContinuousQueryStatement) RequiredPrivileges() (ExecutionPrivileges return ExecutionPrivileges{{Admin: false, Name: "", Privilege: WritePrivilege}}, nil } +// DefaultDatabase returns the default database from the statement. +func (s *DropContinuousQueryStatement) DefaultDatabase() string { + return s.Database +} + // ShowMeasurementsStatement represents a command for listing measurements. type ShowMeasurementsStatement struct { // Database to query. If blank, use the default database. @@ -3043,6 +3073,11 @@ func (s *CreateSubscriptionStatement) RequiredPrivileges() (ExecutionPrivileges, return ExecutionPrivileges{{Admin: true, Name: "", Privilege: AllPrivileges}}, nil } +// DefaultDatabase returns the default database from the statement. +func (s *CreateSubscriptionStatement) DefaultDatabase() string { + return s.Database +} + // DropSubscriptionStatement represents a command to drop a subscription to the incoming data stream. type DropSubscriptionStatement struct { Name string @@ -3060,6 +3095,11 @@ func (s *DropSubscriptionStatement) RequiredPrivileges() (ExecutionPrivileges, e return ExecutionPrivileges{{Admin: true, Name: "", Privilege: AllPrivileges}}, nil } +// DefaultDatabase returns the default database from the statement. +func (s *DropSubscriptionStatement) DefaultDatabase() string { + return s.Database +} + // ShowSubscriptionsStatement represents a command to show a list of subscriptions. type ShowSubscriptionsStatement struct { } diff --git a/influxql/ast_test.go b/influxql/ast_test.go index a1c3457f116..fb77f4aed89 100644 --- a/influxql/ast_test.go +++ b/influxql/ast_test.go @@ -2,6 +2,7 @@ package influxql_test import ( "fmt" + "go/importer" "reflect" "strings" "testing" @@ -1697,6 +1698,102 @@ func TestParse_Errors(t *testing.T) { } } +// This test checks to ensure that we have given thought to the database +// context required for security checks. If a new statement is added, this +// test will fail until it is categorized into the correct bucket below. +func Test_EnforceHasDefaultDatabase(t *testing.T) { + pkg, err := importer.Default().Import("github.com/influxdata/influxdb/influxql") + if err != nil { + fmt.Printf("error: %s\n", err.Error()) + return + } + statements := []string{} + + // this is a list of statements that do not have a database context + exemptStatements := []string{ + "CreateDatabaseStatement", + "CreateUserStatement", + "DeleteSeriesStatement", + "DeleteStatement", + "DropDatabaseStatement", + "DropMeasurementStatement", + "DropSeriesStatement", + "DropShardStatement", + "DropUserStatement", + "GrantAdminStatement", + "KillQueryStatement", + "RevokeAdminStatement", + "SelectStatement", + "SetPasswordUserStatement", + "ShowContinuousQueriesStatement", + "ShowDatabasesStatement", + "ShowDiagnosticsStatement", + "ShowFieldKeysStatement", + "ShowGrantsForUserStatement", + "ShowMeasurementsStatement", + "ShowQueriesStatement", + "ShowRetentionPoliciesStatement", + "ShowSeriesStatement", + "ShowShardGroupsStatement", + "ShowShardsStatement", + "ShowStatsStatement", + "ShowSubscriptionsStatement", + "ShowTagKeysStatement", + "ShowTagValuesStatement", + "ShowUsersStatement", + } + + exists := func(stmt string) bool { + switch stmt { + // These are functions with the word statement in them, and can be ignored + case "Statement", "MustParseStatement", "ParseStatement", "RewriteStatement": + return true + default: + // check the exempt statements + for _, s := range exemptStatements { + if s == stmt { + return true + } + } + // check the statements that passed the interface test for HasDefaultDatabase + for _, s := range statements { + if s == stmt { + return true + } + } + return false + } + } + + needsHasDefault := []interface{}{ + &influxql.AlterRetentionPolicyStatement{}, + &influxql.CreateContinuousQueryStatement{}, + &influxql.CreateRetentionPolicyStatement{}, + &influxql.CreateSubscriptionStatement{}, + &influxql.DropContinuousQueryStatement{}, + &influxql.DropRetentionPolicyStatement{}, + &influxql.DropSubscriptionStatement{}, + &influxql.GrantStatement{}, + &influxql.RevokeStatement{}, + } + + for _, stmt := range needsHasDefault { + statements = append(statements, strings.TrimPrefix(fmt.Sprintf("%T", stmt), "*influxql.")) + if _, ok := stmt.(influxql.HasDefaultDatabase); !ok { + t.Errorf("%T was expected to declare DefaultDatabase method", stmt) + } + + } + + for _, declName := range pkg.Scope().Names() { + if strings.HasSuffix(declName, "Statement") { + if !exists(declName) { + t.Errorf("unchecked statement %s. please update this test to determine if this statement needs to declare 'DefaultDatabase'", declName) + } + } + } +} + // Valuer represents a simple wrapper around a map to implement the influxql.Valuer interface. type Valuer map[string]interface{}