Skip to content

Commit

Permalink
Added SQL tests
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinXPN authored Dec 2, 2023
1 parent b64930a commit 9776177
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 31 deletions.
75 changes: 55 additions & 20 deletions coderunners/executors.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,33 +83,68 @@ def run(self, test: TestCase, **kwargs) -> RunResult:
"""
import pandas as pd
cursor = self.db.cursor()
cursor.executescript(test.input)

try:
cursor.executescript(test.input)
self.db.commit()
except sqlite3.Error as e:
cursor.close()
return RunResult(
status=Status.RUNTIME_ERROR, memory=0, time=0, return_code=0, outputs=None,
errors=str(e),
)

for filename, content in (test.input_files or {}).items():
# Load the content of the file into a dataframe and then load it into the db (filename)
print('Creating table:', filename)
csv_data = StringIO(content)
df = pd.read_csv(csv_data)
print(df.head())
df.to_sql(filename, self.db, if_exists='replace', index=False)
print('--- Done ---')
try:
print('Creating table:', filename)
csv_data = StringIO(content)
df = pd.read_csv(csv_data)
print(df.head())
df.to_sql(filename, self.db, if_exists='replace', index=False)
print('--- Done ---')
except (sqlite3.Error, pd.errors.ParserError, pd.errors.DatabaseError, ValueError) as e:
cursor.close()
return RunResult(
status=Status.RUNTIME_ERROR, memory=0, time=0, return_code=0, outputs=None,
errors=str(e),
)

self.db.commit()

# Execute the self.script as a single command and get the output
print('Executing script:', self.script)
res = pd.read_sql_query(self.script, self.db).to_string(index=False)
print('Result:', res)

r = RunResult(
status=Status.OK, memory=0, time=0, return_code=0, outputs=res,
output_files={
filename: pd.read_sql_query(f'SELECT * FROM {filename}', self.db).to_string(index=False)
for filename in (test.target_files or {}).keys()
})

cursor.close()
return r
try:
print('Executing script:', self.script)
if self.script.strip().upper().startswith('SELECT'):
res = pd.read_sql_query(self.script, self.db).to_csv(index=False)
else:
cursor.executescript(self.script)
self.db.commit()
res = ''
print('Result:', res)
except (sqlite3.Error, pd.errors.ParserError, pd.errors.DatabaseError, ValueError) as e:
cursor.close()
return RunResult(
status=Status.RUNTIME_ERROR, memory=0, time=0, return_code=0, outputs=None,
errors=str(e),
)

# Read output files into the result
try:
r = RunResult(
status=Status.OK, memory=0, time=0, return_code=0, outputs=res,
output_files={
filename: pd.read_sql_query(f'SELECT * FROM {filename}', self.db).to_csv(index=False)
for filename in (test.target_files or {}).keys()
})
cursor.close()
return r
except (sqlite3.Error, pd.errors.ParserError, pd.errors.DatabaseError, ValueError) as e:
cursor.close()
return RunResult(
status=Status.RUNTIME_ERROR, memory=0, time=0, return_code=0, outputs=None,
errors=str(e),
)

def cleanup(self, test: TestCase) -> None:
""" Drops all the tables in the database """
Expand Down
1 change: 1 addition & 0 deletions coderunners/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def check(self) -> SubmissionResult:
self.test_cases[0],
time_limit=self.time_limit, memory_limit_mb=self.memory_limit, output_limit_mb=self.output_limit
)
executor.cleanup(self.test_cases[0])
print('Done')

# Process all tests
Expand Down
144 changes: 133 additions & 11 deletions tests/integration/coderunners/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,19 +6,19 @@


class TestSQLSubmissions:
test_cases = [
TestCase(
input=dedent('''
-- Initialization script goes here
'''),
target=dedent('''
'hello world'
hello world
''').strip()),
]

def test_echo(self):
request = SubmissionRequest(test_cases=self.test_cases, return_outputs=True, language='SQL', code={
test_cases = [
TestCase(
input=dedent('''
-- Initialization script goes here
'''),
target=dedent('''
'hello world'
hello world
''').strip()),
]
request = SubmissionRequest(test_cases=test_cases, return_outputs=True, language='SQL', code={
'main.sql': dedent('''
SELECT 'hello world'
''').strip(),
Expand All @@ -29,3 +29,125 @@ def test_echo(self):
assert res.overall.score == 100
assert len(res.test_results) == 1
assert res.test_results[0].status == Status.OK

def test_create_table(self):
test_cases = [
TestCase(
input=dedent('''
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL
);
'''),
target=dedent('''
COUNT(*)
3
''').strip(),
input_files={
'users': dedent('''
id,name
1,John
2,Jane
3,Martin
''').strip(),
},
target_files={
'users': dedent('''
id,name
1,John
2,Jane
3,Martin
''').strip(),
}),
]
request = SubmissionRequest(test_cases=test_cases, return_outputs=True, language='SQL', code={
'main.sql': dedent('''
SELECT COUNT(*) FROM users;
''').strip(),
}, comparison_mode='token')
res = CodeRunner.from_language(language=request.language).invoke(lambda_client, request=request)
print(res)
assert res.overall.status == Status.OK
assert res.overall.score == 100
assert len(res.test_results) == 1
assert res.test_results[0].status == Status.OK

def test_insert(self):
test_cases = [
TestCase(
input=dedent('''
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL
);
'''),
target='',
input_files={
'users': dedent('''
id,name
1,John
2,Jane
3,Martin
''').strip(),
},
target_files={
'users': dedent('''
id,name
1,John
2,Jane
3,Martin
4,Jack
''').strip(),
}),
]
request = SubmissionRequest(test_cases=test_cases, return_outputs=True, language='SQL', code={
'main.sql': dedent('''
INSERT INTO users (id, name) VALUES (4, 'Jack');
''').strip(),
}, comparison_mode='token')
res = CodeRunner.from_language(language=request.language).invoke(lambda_client, request=request)
print(res)
assert res.overall.status == Status.OK
assert res.overall.score == 100
assert len(res.test_results) == 1
assert res.test_results[0].status == Status.OK

def test_invalid_query(self):
test_cases = [
TestCase(
input=dedent('''
CREATE TABLE users (
id INTEGER PRIMARY KEY,
name TEXT NOT NULL
);
'''),
target='',
input_files={
'users': dedent('''
id,name
1,John
2,Jane
3,Martin
''').strip(),
},
target_files={
'users': dedent('''
id,name
1,John
2,Jane
3,Martin
''').strip(),
}),
]
request = SubmissionRequest(test_cases=test_cases, return_outputs=True, language='SQL', code={
# Command should result in an error
'main.sql': dedent('''
SELECT * FROM random_table;
''').strip(),
}, comparison_mode='token')
res = CodeRunner.from_language(language=request.language).invoke(lambda_client, request=request)
print(res)
assert res.overall.status == Status.RUNTIME_ERROR
assert res.overall.score == 0
assert len(res.test_results) == 1
assert res.test_results[0].status == Status.RUNTIME_ERROR

0 comments on commit 9776177

Please sign in to comment.