Skip to content

Commit

Permalink
提供Oracle相关支持 (#688)
Browse files Browse the repository at this point in the history
1. 支持SELECT COUNT(*)查询, 不再限制limit;

2. 同时还支持执行以下SQL:
   (a)创建存储过程、函数、包、包体、触发器等对象的定义;
   (b)执行PLSQL可执行匿名块;

   暂定该类SQL以delimiter $$作为起始分隔符,以$$作为结束分隔符;
   每一个对象定义,前后套用一个起始和结束分隔符;
   为了保留该类SQL的注释,该类SQL没有做注释的过滤,提高SQL的可读性

3. 如果视图定义SQL也需保留注释,也可以delimiter $$作为起始分隔符,以$$作为结束分隔符
  • Loading branch information
ericruan-cn authored Mar 27, 2020
1 parent bf41b4a commit dd0665b
Show file tree
Hide file tree
Showing 3 changed files with 259 additions and 20 deletions.
24 changes: 24 additions & 0 deletions sql/engines/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,25 @@
import json


class SqlItem:

def __init__(self, id=0, statement='', stmt_type='SQL', object_owner='', object_type='', object_name=''):
'''
:param id: SQL序号,从0开始
:param statement: SQL Statement
:param stmt_type: SQL类型(SQL, PLSQL), 默认为SQL
:param object_owner: PLSQL Object Owner
:param object_type: PLSQL Object Type
:param object_name: PLSQL Object Name
'''
self.id = id
self.statement = statement
self.stmt_type = stmt_type
self.object_owner = object_owner
self.object_type = object_type
self.object_name = object_name


class ReviewResult:
"""审核的单条结果"""

Expand Down Expand Up @@ -42,6 +61,11 @@ def __init__(self, inception_result=None, **kwargs):
self.backup_time = kwargs.get('backup_time', '')
self.actual_affected_rows = kwargs.get('actual_affected_rows', '')

self.stmt_type = kwargs.get('stmt_type', 'SQL')
self.object_owner = kwargs.get('object_owner', '')
self.object_type = kwargs.get('object_type', '')
self.object_name = kwargs.get('object_name', '')


class ReviewSet:
"""review和执行后的结果集, rows中是review result, 有设定好的字段"""
Expand Down
80 changes: 60 additions & 20 deletions sql/engines/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import traceback
import re
import sqlparse
import simplejson as json

from common.config import SysConfig
from common.utils.timer import FuncTimer
from sql.utils.sql_utils import get_syntax_type
from sql.utils.sql_utils import get_syntax_type, get_full_sqlitem_list, get_exec_sqlitem_list
from . import EngineBase
import cx_Oracle
from .models import ResultSet, ReviewSet, ReviewResult
Expand Down Expand Up @@ -149,7 +150,10 @@ def query_check(self, db_name=None, sql=''):
def filter_sql(self, sql='', limit_num=0):
sql_lower = sql.lower()
# 对查询sql增加limit限制
if re.match(r"^select", sql_lower):
if re.match(r"^\s*select", sql_lower):
# 针对select count(*) from之类的SQL,不做limit限制
if re.match(r"^\s*select\s+count\s*\(\s*[\*|\d]\s*\)\s+from", sql_lower, re.I):
return sql.rstrip(';')
if sql_lower.find(' rownum ') == -1:
if sql_lower.find('where') == -1:
return f"{sql.rstrip(';')} WHERE ROWNUM <= {limit_num}"
Expand Down Expand Up @@ -208,33 +212,41 @@ def execute_check(self, db_name=None, sql=''):
critical_ddl_regex = config.get('critical_ddl_regex', '')
p = re.compile(critical_ddl_regex)
check_result.syntax_type = 2 # TODO 工单类型 0、其他 1、DDL,2、DML
for statement in sqlparse.split(sql):
statement = sqlparse.format(statement, strip_comments=True)

# 把所有SQL转换成SqlItem List。 如有多行(内部有多个;)执行块,约定以delimiter $$作为开始, 以$$结束
# 需要在函数里实现单条SQL做sqlparse.format(sql, strip_comments=True)
sqlitemList = get_full_sqlitem_list(sql, db_name)

for sqlitem in sqlitemList:
# 禁用语句
if re.match(r"^select", statement.lower()):
if re.match(r"^\s*select", sqlitem.statement.lower(), re.I):
check_result.is_critical = True
result = ReviewResult(id=line, errlevel=2,
stagestatus='驳回不支持语句',
errormessage='仅支持DML和DDL语句,查询语句请使用SQL查询功能!',
sql=statement)
sql=sqlitem.statement)
# 高危语句
elif critical_ddl_regex and p.match(statement.strip().lower()):
elif critical_ddl_regex and p.match(sqlitem.statement.strip().lower()):
check_result.is_critical = True
result = ReviewResult(id=line, errlevel=2,
stagestatus='驳回高危SQL',
errormessage='禁止提交匹配' + critical_ddl_regex + '条件的语句!',
sql=statement)
sql=sqlitem.statement)

# 正常语句
else:
result = ReviewResult(id=line, errlevel=0,
stagestatus='Audit completed',
errormessage='None',
sql=statement,
sql=sqlitem.statement,
stmt_type=sqlitem.stmt_type,
object_owner=sqlitem.object_owner,
object_type=sqlitem.object_type,
object_name=sqlitem.object_name,
affected_rows=0,
execute_time=0, )
# 判断工单类型
if get_syntax_type(statement) == 'DDL':
if get_syntax_type(sqlitem.statement) == 'DDL':
check_result.syntax_type = 1
check_result.rows += [result]

Expand All @@ -246,30 +258,58 @@ def execute_check(self, db_name=None, sql=''):
return check_result

def execute_workflow(self, workflow, close_conn=True):
"""执行上线单,返回Review set"""
"""执行上线单,返回Review set
原来的逻辑是根据 sql_content简单来分割SQL,进而再执行这些SQL
新的逻辑变更为根据审核结果中记录的sql来执行,
如果是PLSQL存储过程等对象定义操作,还需检查确认新建对象是否编译通过!
"""
review_content = workflow.sqlworkflowcontent.review_content
review_result = json.loads(review_content)
sqlitemList = get_exec_sqlitem_list(review_result, workflow.db_name)

sql = workflow.sqlworkflowcontent.sql_content
execute_result = ReviewSet(full_sql=sql)
# 删除注释语句,切分语句,将切换CURRENT_SCHEMA语句增加到切分结果中
sql = sqlparse.format(sql, strip_comments=True)
split_sql = [f"ALTER SESSION SET CURRENT_SCHEMA = {workflow.db_name};"] + sqlparse.split(sql)

line = 1
statement = None
try:
conn = self.get_connection()
cursor = conn.cursor()
# 逐条执行切分语句,追加到执行结果中
for statement in split_sql:
statement = statement.rstrip(';')
for sqlitem in sqlitemList:
statement = sqlitem.statement
if sqlitem.stmt_type == "SQL":
statement = statement.rstrip(';')
with FuncTimer() as t:
cursor.execute(statement)
conn.commit()

rowcount = cursor.rowcount
stagestatus = "Execute Successfully"
if sqlitem.stmt_type == "PLSQL" and sqlitem.object_name and sqlitem.object_name != 'ANONYMOUS' and sqlitem.object_name != '':
query_obj_sql = f"""SELECT OBJECT_NAME, STATUS, TO_CHAR(LAST_DDL_TIME, 'YYYY-MM-DD HH24:MI:SS') FROM ALL_OBJECTS
WHERE OWNER = '{sqlitem.object_owner}'
AND OBJECT_NAME = '{sqlitem.object_name}'
"""
cursor.execute(query_obj_sql)
row = cursor.fetchone()
if row:
status = row[1]
if status and status == "INVALID":
stagestatus = "Compile Failed. Object " + sqlitem.object_owner + "." + sqlitem.object_name + " is invalid."
else:
stagestatus = "Compile Failed. Object " + sqlitem.object_owner + "." + sqlitem.object_name + " doesn't exist."

if stagestatus != "Execute Successfully":
raise Exception(stagestatus)

execute_result.rows.append(ReviewResult(
id=line,
errlevel=0,
stagestatus='Execute Successfully',
stagestatus=stagestatus,
errormessage='None',
sql=statement,
affected_rows=cursor.rowcount,
affected_rows=rowcount,
execute_time=t.cost,
))
line += 1
Expand All @@ -288,13 +328,13 @@ def execute_workflow(self, workflow, close_conn=True):
))
line += 1
# 报错语句后面的语句标记为审核通过、未执行,追加到执行结果中
for statement in split_sql[line - 1:]:
for sqlitem in sqlitemList[line - 1:]:
execute_result.rows.append(ReviewResult(
id=line,
errlevel=0,
stagestatus='Audit completed',
errormessage=f'前序语句失败, 未执行',
sql=statement,
sql=sqlitem.statement,
affected_rows=0,
execute_time=0,
))
Expand Down
175 changes: 175 additions & 0 deletions sql/utils/sql_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import mybatis_mapper2sql
import sqlparse

from sql.engines.models import SqlItem
from sql.utils.extract_tables import extract_tables as extract_tables_by_sql_parse

__author__ = 'hhyo'
Expand Down Expand Up @@ -124,3 +125,177 @@ def generate_sql(text):
row = {"sql_id": num, "sql": statement}
rows.append(row)
return rows


def get_base_sqlitem_list(full_sql):
''' 把参数 full_sql 转变为 SqlItem列表
:param full_sql: 完整sql字符串, 每个SQL以分号;间隔, 不包含plsql执行块和plsql对象定义块
:return: SqlItem对象列表
'''
list = []
for statement in sqlparse.split(full_sql):
statement = sqlparse.format(statement, strip_comments=True)
if len(statement) <= 0:
continue
item = SqlItem(statement=statement)
list.append(item)
return list


def get_full_sqlitem_list(full_sql, db_name):
''' 获取Sql对应的SqlItem列表, 包括PLSQL部分
PLSQL语句块由delimiter $$作为开始间隔符,以$$作为结束间隔符
:param full_sql: 全部sql内容
:return: SqlItem 列表
'''
list = []

# 定义开始分隔符,两端用括号,是为了re.split()返回列表包含分隔符
regex_delimiter = r'(delimiter\s*\$\$)'
# 注意:必须把package body置于package之前,否则将永远匹配不上package body
regex_objdefine = r'create\s+or\s+replace\s+(function|procedure|trigger|package\s+body|package|view)\s+("?\w+"?\.)?"?\w+"?[\s+|\(]'
# 对象命名,两端有双引号
regex_objname = r'^".+"$'

sql_list = re.split(pattern=regex_delimiter, string=full_sql, flags=re.I)

# delimiter_flag => 分隔符标记, 0:不是, 1:是
# 遇到分隔符标记为1, 则本块SQL要去判断是否有PLSQL内容
# PLSQL内容存在判定依据, 本块SQL包含'$$'

delimiter_flag = 0
for sql in sql_list:
# 截去首尾空格和多余空字符
sql = sql.strip()

# 如果字符串长度为0, 跳过该字符串
if len(sql) <= 0:
continue

# 表示这一行是分隔符, 跳过该字符串
if re.match(regex_delimiter, sql):
delimiter_flag = 1
continue

if delimiter_flag == 1:
# 表示SQL块为delimiter $$标记之后的内容

# 查找是否存在'$$'结束符
pos = sql.find("$$")
length = len(sql)
if pos > -1:
# 该sqlitem包含结束符$$
# 处理PLSQL语句块, 这里需要先去判定语句块的类型
plsql_block = sql[0:pos].strip()
# 如果plsql_area字符串最后一个字符为/,则把/给去掉
while True:
if plsql_block[-1:] == '/':
plsql_block = plsql_block[:-1].strip()
else:
break

search_result = re.search(regex_objdefine, plsql_block, flags=re.I)

# 检索关键字, 分为两个情况
# 情况1:plsql block 为对象定义执行块
# 情况2:plsql block 为匿名执行块

if search_result:

# 检索到关键字, 属于情况1

str_plsql_match = search_result.group()
str_plsql_type = search_result.groups()[0]

idx = str_plsql_match.index(str_plsql_type)
nm_str = str_plsql_match[idx + len(str_plsql_type):].strip()

if nm_str[-1:] == '(':
nm_str = nm_str[:-1]
nm_list = nm_str.split('.')

if len(nm_list) > 1:
# 带有属主的对象名, 形如object_owner.object_name

# 获取object_owner
if re.match(regex_objname, nm_list[0]):
# object_owner两端带有双引号
object_owner = nm_list[0].strip().strip('"')
else:
# object_owner两端不带有双引号
object_owner = nm_list[0].upper().strip().strip("'")

# 获取object_name
if re.match(regex_objname, nm_list[1]):
# object_name两端带有双引号
object_name = nm_list[1].strip().strip('"')
else:
# object_name两端不带有双引号
object_name = nm_list[1].upper().strip()
else:
# 不带属主
object_owner = db_name
if re.match(regex_objname, nm_list[0]):
# object_name两端带有双引号
object_name = nm_list[0].strip().strip('"')
else:
# object_name两端不带有双引号
object_name = nm_list[0].upper().strip()

tmp_object_type = str_plsql_type.upper()
tmp_stmt_type = 'PLSQL'
if tmp_object_type == 'VIEW':
tmp_stmt_type = 'SQL'

item = SqlItem(statement=plsql_block,
stmt_type=tmp_stmt_type,
object_owner=object_owner,
object_type=tmp_object_type,
object_name=object_name)
list.append(item)
else:
# 未检索到关键字, 属于情况2, 匿名可执行块 it's ANONYMOUS
item = SqlItem(statement=plsql_block.strip(),
stmt_type='PLSQL',
object_owner=db_name,
object_type='ANONYMOUS',
object_name='ANONYMOUS')
list.append(item)

if length > pos + 2:
# 处理$$之后的那些语句, 默认为单条可执行SQL的集合
sql_area = sql[pos + 2:].strip()
if len(sql_area) > 0:
tmp_list = get_base_sqlitem_list(sql_area)
list.extend(tmp_list)

else:
# 没有匹配到$$标记, 默认为单条可执行SQL集合
tmp_list = get_base_sqlitem_list(sql)
list.extend(tmp_list)

# 处理完本次delimiter标记的内容,把delimiter_flag重置
delimiter_flag = 0
else:
# 表示当前为以;结尾的正常sql
tmp_list = get_base_sqlitem_list(sql)
list.extend(tmp_list)
return list


def get_exec_sqlitem_list(reviewResult, db_name):
""" 根据审核结果生成新的SQL列表
:param reviewResult: SQL审核结果列表
:param db_name:
:return:
"""
list = []
list.append(SqlItem(statement=f"ALTER SESSION SET CURRENT_SCHEMA = {db_name}"))

for item in reviewResult:
list.append(SqlItem(statement=item['sql'],
stmt_type=item['stmt_type'],
object_owner=item['object_owner'],
object_type=item['object_type'],
object_name=item['object_name']))
return list

0 comments on commit dd0665b

Please sign in to comment.