Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Added column level access #135

Merged
merged 8 commits into from
Mar 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions postgresql/helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ var allowedPrivileges = map[string][]string{
"type": {"ALL", "USAGE"},
"foreign_data_wrapper": {"ALL", "USAGE"},
"foreign_server": {"ALL", "USAGE"},
"column": {"ALL", "SELECT", "INSERT", "UPDATE", "REFERENCES"},
}

// validatePrivileges checks that privileges to apply are allowed for this object type.
Expand Down Expand Up @@ -284,6 +285,22 @@ func setToPgIdentList(schema string, idents *schema.Set) string {
return strings.Join(quotedIdents, ",")
}

func setToPgIdentListWithoutSchema(idents *schema.Set) string {
quotedIdents := make([]string, idents.Len())
for i, ident := range idents.List() {
quotedIdents[i] = pq.QuoteIdentifier(ident.(string))
}
return strings.Join(quotedIdents, ",")
}

func setToPgIdentSimpleList(idents *schema.Set) string {
quotedIdents := make([]string, idents.Len())
for i, ident := range idents.List() {
quotedIdents[i] = ident.(string)
}
return strings.Join(quotedIdents, ",")
}

// startTransaction starts a new DB transaction on the specified database.
// If the database is specified and different from the one configured in the provider,
// it will create a new connection pool if needed.
Expand Down
162 changes: 156 additions & 6 deletions postgresql/resource_postgresql_grant.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ var allowedObjectTypes = []string{
"table",
"foreign_data_wrapper",
"foreign_server",
"column",
}

var objectTypes = map[string]string{
Expand All @@ -36,8 +37,9 @@ var objectTypes = map[string]string{
func resourcePostgreSQLGrant() *schema.Resource {
return &schema.Resource{
Create: PGResourceFunc(resourcePostgreSQLGrantCreate),
// As create revokes and grants we can use it to update too
Update: PGResourceFunc(resourcePostgreSQLGrantCreate),
// Since all of this resource's arguments force a recreation
// there's no need for an Update function
// Update:
Read: PGResourceFunc(resourcePostgreSQLGrantRead),
Delete: PGResourceFunc(resourcePostgreSQLGrantDelete),

Expand Down Expand Up @@ -75,9 +77,18 @@ func resourcePostgreSQLGrant() *schema.Resource {
Set: schema.HashString,
Description: "The specific objects to grant privileges on for this role (empty means all objects of the requested type)",
},
"columns": {
Type: schema.TypeSet,
Optional: true,
ForceNew: true,
Elem: &schema.Schema{Type: schema.TypeString},
Set: schema.HashString,
Description: "The specific columns to grant privileges on for this role",
},
"privileges": {
Type: schema.TypeSet,
Required: true,
ForceNew: true,
Elem: &schema.Schema{Type: schema.TypeString},
Set: schema.HashString,
Description: "The list of privileges to grant",
Expand Down Expand Up @@ -130,6 +141,18 @@ func resourcePostgreSQLGrantCreate(db *DBConnection, d *schema.ResourceData) err
if d.Get("objects").(*schema.Set).Len() > 0 && (objectType == "database" || objectType == "schema") {
return fmt.Errorf("cannot specify `objects` when `object_type` is `database` or `schema`")
}
if d.Get("columns").(*schema.Set).Len() > 0 && (objectType != "column") {
return fmt.Errorf("cannot specify `columns` when `object_type` is not `column`")
}
if d.Get("columns").(*schema.Set).Len() == 0 && (objectType == "column") {
return fmt.Errorf("must specify `columns` when `object_type` is `column`")
}
if d.Get("privileges").(*schema.Set).Len() != 1 && (objectType == "column") {
return fmt.Errorf("must specify exactly 1 `privileges` when `object_type` is `column`")
}
if (d.Get("objects").(*schema.Set).Len() != 1) && (objectType == "column") {
return fmt.Errorf("must specify exactly 1 table in the `objects` field when `object_type` is `column`")
}
if d.Get("objects").(*schema.Set).Len() != 1 && (objectType == "foreign_data_wrapper" || objectType == "foreign_server") {
return fmt.Errorf("one element must be specified in `objects` when `object_type` is `foreign_data_wrapper` or `foreign_server`")
}
Expand Down Expand Up @@ -310,6 +333,84 @@ WHERE grantee = $2
return nil
}

func readColumnRolePrivileges(txn *sql.Tx, d *schema.ResourceData) error {
objects := d.Get("objects").(*schema.Set)

missingColumns := d.Get("columns").(*schema.Set) // Getting columns from state.
// If the query returns a column, it is a removed from the missingColumns.

var rows *sql.Rows

// The attacl column of pg_attribute contains information only about explicit column grants
query := `
SELECT relname AS table_name, attname AS column_name, array_agg(privilege_type) AS column_privileges
FROM (SELECT relname, attname, (aclexplode(attacl)).*
FROM pg_class
JOIN pg_namespace ON pg_class.relnamespace = pg_namespace.oid
JOIN pg_attribute ON pg_class.oid = attrelid
WHERE nspname = $2
AND relname = $3
AND relkind = $4)
AS col_privs
JOIN pg_roles ON pg_roles.oid = col_privs.grantee
WHERE rolname = $1
AND privilege_type = $5
GROUP BY col_privs.relname, col_privs.attname, col_privs.privilege_type
ORDER BY col_privs.attname
;`
rows, err := txn.Query(
query, d.Get("role").(string), d.Get("schema"), objects.List()[0], objectTypes["table"], d.Get("privileges").(*schema.Set).List()[0],
)

if err != nil {
return err
}

for rows.Next() {
var objName string
var colName string
var privileges pq.ByteaArray

if err := rows.Scan(&objName, &colName, &privileges); err != nil {
return err
}

if objects.Len() > 0 && !objects.Contains(objName) {
continue
}

if missingColumns.Contains(colName) {
missingColumns.Remove(colName)
}

privilegesSet := pgArrayToSet(privileges)

if !privilegesSet.Equal(d.Get("privileges").(*schema.Set)) {
// If any object doesn't have the same privileges as saved in the state,
// we return its privileges to force an update.
log.Printf(
"[DEBUG] %s %s has not the expected privileges %v for role %s",
strings.ToTitle("column"), objName, privileges, d.Get("role"),
)
d.Set("privileges", privilegesSet)
break
}
}

if missingColumns.Len() > 0 {
// If missingColumns is not empty by the end of the result processing loop
// it means that a column is missing
remainingColumns := d.Get("columns").(*schema.Set).Difference(missingColumns)
log.Printf(
"[DEBUG] Role %s does not have the expected privileges on columns",
d.Get("role"),
)
d.Set("columns", remainingColumns)
}

return nil
}

func readRolePrivileges(txn *sql.Tx, d *schema.ResourceData) error {
role := d.Get("role").(string)
objectType := d.Get("object_type").(string)
Expand Down Expand Up @@ -356,6 +457,9 @@ GROUP BY pg_proc.proname
query, roleOID, d.Get("schema"),
)

case "column":
return readColumnRolePrivileges(txn, d)

default:
query = `
SELECT pg_class.relname, array_remove(array_agg(privilege_type), NULL)
Expand Down Expand Up @@ -448,6 +552,15 @@ func createGrantQuery(d *schema.ResourceData, privileges []string) string {
pq.QuoteIdentifier(srvName.(string)),
pq.QuoteIdentifier(d.Get("role").(string)),
)
case "COLUMN":
objects := d.Get("objects").(*schema.Set)
query = fmt.Sprintf(
"GRANT %s (%s) ON TABLE %s TO %s",
strings.Join(privileges, ","),
setToPgIdentListWithoutSchema(d.Get("columns").(*schema.Set)),
setToPgIdentList(d.Get("schema").(string), objects),
pq.QuoteIdentifier(d.Get("role").(string)),
)
case "TABLE", "SEQUENCE", "FUNCTION", "PROCEDURE", "ROUTINE":
objects := d.Get("objects").(*schema.Set)
if objects.Len() > 0 {
Expand Down Expand Up @@ -506,15 +619,44 @@ func createRevokeQuery(d *schema.ResourceData) string {
pq.QuoteIdentifier(srvName.(string)),
pq.QuoteIdentifier(d.Get("role").(string)),
)
case "TABLE", "SEQUENCE", "FUNCTION", "PROCEDURE", "ROUTINE":
case "COLUMN":
objects := d.Get("objects").(*schema.Set)
if objects.Len() > 0 {
columns := d.Get("columns").(*schema.Set)
privileges := d.Get("privileges").(*schema.Set)
if privileges.Len() == 0 || columns.Len() == 0 {
// No privileges to revoke, so don't revoke anything
query = "SELECT NULL"
kda-jt marked this conversation as resolved.
Show resolved Hide resolved
} else {
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON %s %s FROM %s",
strings.ToUpper(d.Get("object_type").(string)),
"REVOKE %s (%s) ON TABLE %s FROM %s",
kda-jt marked this conversation as resolved.
Show resolved Hide resolved
setToPgIdentSimpleList(privileges),
setToPgIdentListWithoutSchema(columns),
setToPgIdentList(d.Get("schema").(string), objects),
pq.QuoteIdentifier(d.Get("role").(string)),
)
}
case "TABLE", "SEQUENCE", "FUNCTION", "PROCEDURE", "ROUTINE":
objects := d.Get("objects").(*schema.Set)
privileges := d.Get("privileges").(*schema.Set)
kda-jt marked this conversation as resolved.
Show resolved Hide resolved
if objects.Len() > 0 {
if privileges.Len() > 0 {
// Revoking specific privileges instead of all privileges
// to avoid messing with column level grants
query = fmt.Sprintf(
"REVOKE %s ON %s %s FROM %s",
setToPgIdentSimpleList(privileges),
strings.ToUpper(d.Get("object_type").(string)),
setToPgIdentList(d.Get("schema").(string), objects),
pq.QuoteIdentifier(d.Get("role").(string)),
)
} else {
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON %s %s FROM %s",
strings.ToUpper(d.Get("object_type").(string)),
setToPgIdentList(d.Get("schema").(string), objects),
pq.QuoteIdentifier(d.Get("role").(string)),
)
}
} else {
query = fmt.Sprintf(
"REVOKE ALL PRIVILEGES ON ALL %sS IN SCHEMA %s FROM %s",
Expand Down Expand Up @@ -547,6 +689,10 @@ func grantRolePrivileges(txn *sql.Tx, d *schema.ResourceData) error {

func revokeRolePrivileges(txn *sql.Tx, d *schema.ResourceData) error {
query := createRevokeQuery(d)
if len(query) == 0 {
// Query is empty, don't run anything
return nil
}
if _, err := txn.Exec(query); err != nil {
return fmt.Errorf("could not execute revoke query: %w", err)
}
Expand Down Expand Up @@ -621,6 +767,10 @@ func generateGrantID(d *schema.ResourceData) string {
parts = append(parts, object.(string))
}

for _, column := range d.Get("columns").(*schema.Set).List() {
parts = append(parts, column.(string))
}

return strings.Join(parts, "_")
}

Expand Down
Loading