-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
c1e6aa2
commit 1ab79f2
Showing
2 changed files
with
242 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,241 @@ | ||
import { describe, it, expect, vi, beforeEach } from 'vitest' | ||
import { applyRLS, loadPolicies } from './index' | ||
import { DataSource, QueryResult } from '../types' | ||
import { StarbaseDBConfiguration } from '../handler' | ||
|
||
const mockDataSource = { | ||
source: 'internal', | ||
rpc: { | ||
executeQuery: vi.fn(), | ||
}, | ||
context: { sub: 'user123' }, | ||
} as any | ||
|
||
const mockConfig: StarbaseDBConfiguration = { | ||
outerbaseApiKey: 'mock-api-key', | ||
role: 'client', | ||
features: { allowlist: true, rls: true, rest: true }, | ||
} | ||
|
||
describe('loadPolicies - Policy Fetching and Parsing', () => { | ||
it('should load and parse policies correctly', async () => { | ||
vi.mocked(mockDataSource.rpc.executeQuery).mockResolvedValue([ | ||
{ | ||
actions: 'SELECT', | ||
schema: 'public', | ||
table: 'users', | ||
column: 'user_id', | ||
value: 'context.id()', | ||
value_type: 'string', | ||
operator: '=', | ||
}, | ||
] as any) | ||
|
||
const policies = await loadPolicies(mockDataSource) | ||
|
||
expect(mockDataSource.rpc.executeQuery).toHaveBeenCalledTimes(1) | ||
expect(policies).toEqual([ | ||
{ | ||
action: 'SELECT', | ||
condition: { | ||
type: 'binary_expr', | ||
operator: '=', | ||
left: { | ||
type: 'column_ref', | ||
table: 'public.users', | ||
column: 'user_id', | ||
}, | ||
right: { | ||
type: 'string', | ||
value: '__CONTEXT_ID__', | ||
}, | ||
}, | ||
}, | ||
]) | ||
}) | ||
|
||
it('should return an empty array if an error occurs', async () => { | ||
const consoleErrorSpy = vi | ||
.spyOn(console, 'error') | ||
.mockImplementation(() => {}) | ||
vi.mocked(mockDataSource.rpc.executeQuery).mockRejectedValue( | ||
new Error('Database error') | ||
) | ||
|
||
const policies = await loadPolicies(mockDataSource) | ||
|
||
expect(policies).toEqual([]) | ||
}) | ||
}) | ||
|
||
describe('applyRLS - Query Modification', () => { | ||
beforeEach(() => { | ||
vi.resetAllMocks() | ||
mockDataSource.context.sub = 'user123' | ||
vi.mocked(mockDataSource.rpc.executeQuery).mockResolvedValue([ | ||
{ | ||
actions: 'SELECT', | ||
schema: 'public', | ||
table: 'users', | ||
column: 'user_id', | ||
value: 'context.id()', | ||
value_type: 'string', | ||
operator: '=', | ||
}, | ||
]) | ||
}) | ||
|
||
it('should modify SELECT queries with WHERE conditions', async () => { | ||
const sql = 'SELECT * FROM users' | ||
const modifiedSql = await applyRLS({ | ||
sql, | ||
isEnabled: true, | ||
dataSource: mockDataSource, | ||
config: mockConfig, | ||
}) | ||
|
||
console.log('Final SQL:', modifiedSql) | ||
expect(modifiedSql).toContain("WHERE `user_id` = 'user123'") | ||
}) | ||
it('should modify DELETE queries by adding policy-based WHERE clause', async () => { | ||
const sql = "DELETE FROM users WHERE name = 'Alice'" | ||
const modifiedSql = await applyRLS({ | ||
sql, | ||
isEnabled: true, | ||
dataSource: mockDataSource, | ||
config: mockConfig, | ||
}) | ||
|
||
expect(modifiedSql).toContain("WHERE `name` = 'Alice'") | ||
}) | ||
|
||
it('should modify UPDATE queries with additional WHERE clause', async () => { | ||
const sql = "UPDATE users SET name = 'Bob' WHERE age = 25" | ||
const modifiedSql = await applyRLS({ | ||
Check failure on line 114 in src/rls/index.test.ts
|
||
sql, | ||
isEnabled: true, | ||
dataSource: mockDataSource, | ||
config: mockConfig, | ||
}) | ||
|
||
expect(modifiedSql).toContain("`name` = 'Bob' WHERE `age` = 25") | ||
}) | ||
|
||
it('should modify INSERT queries to enforce column values', async () => { | ||
const sql = "INSERT INTO users (user_id, name) VALUES (1, 'Alice')" | ||
const modifiedSql = await applyRLS({ | ||
sql, | ||
isEnabled: true, | ||
dataSource: mockDataSource, | ||
config: mockConfig, | ||
}) | ||
|
||
expect(modifiedSql).toContain("VALUES (1,'Alice')") | ||
}) | ||
}) | ||
|
||
describe('applyRLS - Edge Cases', () => { | ||
it('should not modify SQL if RLS is disabled', async () => { | ||
const sql = 'SELECT * FROM users' | ||
const modifiedSql = await applyRLS({ | ||
sql, | ||
isEnabled: false, | ||
dataSource: mockDataSource, | ||
config: mockConfig, | ||
}) | ||
|
||
expect(modifiedSql).toBe(sql) | ||
}) | ||
|
||
it('should not modify SQL if user is admin', async () => { | ||
mockConfig.role = 'admin' | ||
|
||
const sql = 'SELECT * FROM users' | ||
const modifiedSql = await applyRLS({ | ||
sql, | ||
isEnabled: true, | ||
dataSource: mockDataSource, | ||
config: mockConfig, | ||
}) | ||
|
||
expect(modifiedSql).toBe(sql) | ||
}) | ||
}) | ||
|
||
describe('applyRLS - Multi-Table Queries', () => { | ||
beforeEach(() => { | ||
vi.mocked(mockDataSource.rpc.executeQuery).mockResolvedValue([ | ||
{ | ||
actions: 'SELECT', | ||
schema: 'public', | ||
table: 'users', | ||
column: 'user_id', | ||
value: 'context.id()', | ||
value_type: 'string', | ||
operator: '=', | ||
}, | ||
{ | ||
actions: 'SELECT', | ||
schema: 'public', | ||
table: 'orders', | ||
column: 'user_id', | ||
value: 'context.id()', | ||
value_type: 'string', | ||
operator: '=', | ||
}, | ||
] as any) | ||
}) | ||
|
||
it('should apply RLS policies to tables in JOIN conditions', async () => { | ||
const sql = ` | ||
SELECT users.name, orders.total | ||
FROM users | ||
JOIN orders ON users.id = orders.user_id | ||
` | ||
|
||
const modifiedSql = await applyRLS({ | ||
sql, | ||
isEnabled: true, | ||
dataSource: mockDataSource, | ||
config: mockConfig, | ||
}) | ||
|
||
expect(modifiedSql).toContain("WHERE `users.user_id` = 'user123'") | ||
expect(modifiedSql).toContain("AND `orders.user_id` = 'user123'") | ||
}) | ||
|
||
it('should apply RLS policies to multiple tables in a JOIN', async () => { | ||
const sql = ` | ||
SELECT users.name, orders.total | ||
FROM users | ||
JOIN orders ON users.id = orders.user_id | ||
` | ||
|
||
const modifiedSql = await applyRLS({ | ||
sql, | ||
isEnabled: true, | ||
dataSource: mockDataSource, | ||
config: mockConfig, | ||
}) | ||
|
||
expect(modifiedSql).toContain("WHERE (users.user_id = 'user123')") | ||
expect(modifiedSql).toContain("AND (orders.user_id = 'user123')") | ||
}) | ||
|
||
it('should apply RLS policies to subqueries inside FROM clause', async () => { | ||
const sql = ` | ||
SELECT * FROM ( | ||
SELECT * FROM users WHERE age > 18 | ||
) AS adults | ||
` | ||
|
||
const modifiedSql = await applyRLS({ | ||
sql, | ||
isEnabled: true, | ||
dataSource: mockDataSource, | ||
config: mockConfig, | ||
Check failure on line 236 in src/rls/index.test.ts
|
||
}) | ||
|
||
expect(modifiedSql).toContain("WHERE `users.user_id` = 'user123'") | ||
}) | ||
}) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters