diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000..9db33d6b64 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +import-order-style = google +max-line-length = 90 +exclude = .tox,build,docs,bin,examples,flask_appbuilder/templates,flask_appbuilder/static,venv +ignore = E203,W503,W605 diff --git a/.travis.yml b/.travis.yml index d18eab16ef..c38835a10c 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,7 +5,8 @@ matrix: include: - python: 3.6 env: TOXENV=flake8 - + - python: 3.6 + env: TOXENV=black - python: 3.6 env: TOXENV=mysql services: @@ -14,7 +15,6 @@ matrix: - mysql -u root -e "DROP DATABASE IF EXISTS app; CREATE DATABASE app DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci" - mysql -u root -e "CREATE USER 'mysqluser'@'localhost' IDENTIFIED BY 'mysqluserpassword';" - mysql -u root -e "GRANT ALL ON app.* TO 'mysqluser'@'localhost';" - - python: 3.6 env: TOXENV=mssql services: diff --git a/flask_appbuilder/api/convert.py b/flask_appbuilder/api/convert.py index 73597e7820..89d47b7f8d 100644 --- a/flask_appbuilder/api/convert.py +++ b/flask_appbuilder/api/convert.py @@ -19,7 +19,7 @@ class Tree: """ def __init__(self): - self.root = TreeNode('+') + self.root = TreeNode("+") def add(self, data): node = TreeNode(data) @@ -45,18 +45,14 @@ def __repr__(self): def columns2Tree(columns): tree = Tree() for column in columns: - if '.' in column: - tree.add_child( - column.split('.')[0], - column.split('.')[1] - ) + if "." in column: + tree.add_child(column.split(".")[0], column.split(".")[1]) else: tree.add(column) return tree class BaseModel2SchemaConverter(object): - def __init__(self, datamodel, validators_columns): """ :param datamodel: SQLAInterface @@ -95,18 +91,22 @@ def _meta_schema_factory(self, columns, model, class_mixin): """ _model = model if columns: + class MetaSchema(ModelSchema, class_mixin): class Meta: model = _model fields = columns strict = True sqla_session = self.datamodel.session + else: + class MetaSchema(ModelSchema, class_mixin): class Meta: model = _model strict = True sqla_session = self.datamodel.session + return MetaSchema def _column2field(self, datamodel, column, nested=True, enum_dump_by_name=False): @@ -124,11 +124,7 @@ def _column2field(self, datamodel, column, nested=True, enum_dump_by_name=False) required = not datamodel.is_nullable(column.data) nested_model = datamodel.get_related_model(column.data) lst = [item.data for item in column.childs] - nested_schema = self.convert( - lst, - nested_model, - nested=False - ) + nested_schema = self.convert(lst, nested_model, nested=False) if datamodel.is_relation_many_to_one(column.data): many = False elif datamodel.is_relation_many_to_many(column.data): @@ -141,9 +137,10 @@ def _column2field(self, datamodel, column, nested=True, enum_dump_by_name=False) return field # Handle bug on marshmallow-sqlalchemy #163 elif datamodel.is_relation(column.data): - if (datamodel.is_relation_many_to_many(column.data) or - datamodel.is_relation_one_to_many(column.data)): - if datamodel.get_info(column.data).get('required', False): + if datamodel.is_relation_many_to_many( + column.data + ) or datamodel.is_relation_one_to_many(column.data): + if datamodel.get_info(column.data).get("required", False): required = True else: required = False @@ -157,8 +154,7 @@ def _column2field(self, datamodel, column, nested=True, enum_dump_by_name=False) elif datamodel.is_enum(column.data): required = not datamodel.is_nullable(column.data) enum_class = datamodel.list_columns[column.data].info.get( - 'enum_class', - datamodel.list_columns[column.data].type + "enum_class", datamodel.list_columns[column.data].type ) if enum_dump_by_name: enum_dump_by = EnumField.NAME @@ -168,10 +164,10 @@ def _column2field(self, datamodel, column, nested=True, enum_dump_by_name=False) field.unique = datamodel.is_unique(column.data) return field # is custom property method field? - if hasattr(getattr(_model, column.data), 'fget'): + if hasattr(getattr(_model, column.data), "fget"): return fields.Raw(dump_only=True) # is a normal model field not a function? - if not hasattr(getattr(_model, column.data), '__call__'): + if not hasattr(getattr(_model, column.data), "__call__"): field = field_for(_model, column.data) field.unique = datamodel.is_unique(column.data) if column.data in self.validators_columns: @@ -180,13 +176,13 @@ def _column2field(self, datamodel, column, nested=True, enum_dump_by_name=False) @staticmethod def get_column_child_model(column): - if '.' in column: - return column.split('.')[0] + if "." in column: + return column.split(".")[0] return column @staticmethod def is_column_dotted(column): - return '.' in column + return "." in column def convert(self, columns, model=None, nested=True, enum_dump_by_name=False): """ @@ -198,11 +194,7 @@ def convert(self, columns, model=None, nested=True, enum_dump_by_name=False): :param nested: Generate relation with nested schemas :return: ModelSchema object """ - super(Model2SchemaConverter, self).convert( - columns, - model=model, - nested=nested - ) + super(Model2SchemaConverter, self).convert(columns, model=model, nested=nested) class SchemaMixin: pass @@ -217,10 +209,7 @@ class SchemaMixin: for column in tree_columns.root.childs: # Get child model is column is dotted notation ma_sqla_fields_override[column.data] = self._column2field( - _datamodel, - column, - nested, - enum_dump_by_name + _datamodel, column, nested, enum_dump_by_name ) _columns.append(column.data) for k, v in ma_sqla_fields_override.items(): diff --git a/flask_appbuilder/api/manager.py b/flask_appbuilder/api/manager.py index a5ad24c92c..73efbbcaea 100644 --- a/flask_appbuilder/api/manager.py +++ b/flask_appbuilder/api/manager.py @@ -9,10 +9,10 @@ class OpenApi(BaseApi): - route_base = '/api' + route_base = "/api" allow_browser_login = True - @expose('//_openapi') + @expose("//_openapi") @protect() @safe def get(self, version): @@ -56,28 +56,28 @@ def _create_api_spec(version): openapi_version="3.0.2", info=dict(description=current_app.appbuilder.app_name), plugins=[MarshmallowPlugin()], - servers=[{'url': "/api/{}".format(version)}] + servers=[{"url": "/api/{}".format(version)}], ) class SwaggerView(BaseView): - default_view = 'ui' - openapi_uri = '/api/{}/_openapi' + default_view = "ui" + openapi_uri = "/api/{}/_openapi" - @expose('/') + @expose("/") @has_access def show(self, version): return self.render_template( - 'appbuilder/swagger/swagger.html', - openapi_uri=self.openapi_uri.format(version) + "appbuilder/swagger/swagger.html", + openapi_uri=self.openapi_uri.format(version), ) class OpenApiManager(BaseManager): def register_views(self): - if not self.appbuilder.app.config.get('FAB_ADD_SECURITY_VIEWS', True): + if not self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEWS", True): return - if self.appbuilder.get_app.config.get('FAB_API_SWAGGER_UI', False): + if self.appbuilder.get_app.config.get("FAB_API_SWAGGER_UI", False): self.appbuilder.add_api(OpenApi) self.appbuilder.add_view_no_menu(SwaggerView) diff --git a/flask_appbuilder/babel/views.py b/flask_appbuilder/babel/views.py index 91cec67e77..c42a85cb5f 100644 --- a/flask_appbuilder/babel/views.py +++ b/flask_appbuilder/babel/views.py @@ -5,13 +5,13 @@ class LocaleView(BaseView): - route_base = '/lang' + route_base = "/lang" - default_view = 'index' + default_view = "index" - @expose('/') + @expose("/") def index(self, locale): - session['locale'] = locale + session["locale"] = locale refresh() self.update_redirect() return redirect(self.get_redirect()) diff --git a/flask_appbuilder/base.py b/flask_appbuilder/base.py index 5f86e00778..3378e2edbd 100644 --- a/flask_appbuilder/base.py +++ b/flask_appbuilder/base.py @@ -14,7 +14,7 @@ LOGMSG_ERR_FAB_ADDON_PROCESS, LOGMSG_INF_FAB_ADD_VIEW, LOGMSG_INF_FAB_ADDON_ADDED, - LOGMSG_WAR_FAB_VIEW_EXISTS + LOGMSG_WAR_FAB_VIEW_EXISTS, ) from .filters import TemplateFilters from .menu import Menu, MenuApiManager @@ -109,7 +109,7 @@ def __init__( static_folder="static/appbuilder", static_url_path="/appbuilder", security_manager_class=None, - update_perms=True + update_perms=True, ): """ AppBuilder constructor @@ -167,42 +167,34 @@ def init_app(self, app, session): self.app = app - self.base_template = app.config.get( - "FAB_BASE_TEMPLATE", - self.base_template, - ) - self.static_folder = app.config.get( - "FAB_STATIC_FOLDER", - self.static_folder, - ) + self.base_template = app.config.get("FAB_BASE_TEMPLATE", self.base_template) + self.static_folder = app.config.get("FAB_STATIC_FOLDER", self.static_folder) self.static_url_path = app.config.get( - "FAB_STATIC_URL_PATH", - self.static_url_path, + "FAB_STATIC_URL_PATH", self.static_url_path ) - _index_view = app.config.get('FAB_INDEX_VIEW', None) + _index_view = app.config.get("FAB_INDEX_VIEW", None) if _index_view is not None: - self.indexview = dynamic_class_import( - _index_view - ) + self.indexview = dynamic_class_import(_index_view) else: self.indexview = self.indexview or IndexView - _menu = app.config.get('FAB_MENU', None) + _menu = app.config.get("FAB_MENU", None) if _menu is not None: - self.menu = dynamic_class_import( - _menu - ) + self.menu = dynamic_class_import(_menu) else: self.menu = self.menu or Menu() if self.update_perms: # default is True, if False takes precedence from config - self.update_perms = app.config.get('FAB_UPDATE_PERMS', True) - _security_manager_class_name = app.config.get('FAB_SECURITY_MANAGER_CLASS', None) + self.update_perms = app.config.get("FAB_UPDATE_PERMS", True) + _security_manager_class_name = app.config.get( + "FAB_SECURITY_MANAGER_CLASS", None + ) if _security_manager_class_name is not None: self.security_manager_class = dynamic_class_import( _security_manager_class_name ) if self.security_manager_class is None: from flask_appbuilder.security.sqla.manager import SecurityManager + self.security_manager_class = SecurityManager self._addon_managers = app.config["ADDON_MANAGERS"] @@ -635,7 +627,7 @@ def _process_inner_views(self): for inner_class in view.get_uninit_inner_views(): for v in self.baseviews: if ( - isinstance(v, inner_class) and - v not in view.get_init_inner_views() + isinstance(v, inner_class) + and v not in view.get_init_inner_views() ): view.get_init_inner_views().append(v) diff --git a/flask_appbuilder/charts/jsontools.py b/flask_appbuilder/charts/jsontools.py index c112fefd41..cebb875bc9 100644 --- a/flask_appbuilder/charts/jsontools.py +++ b/flask_appbuilder/charts/jsontools.py @@ -19,24 +19,24 @@ def dict_to_json(xcol, ycols, labels, value_columns): # pragma: no cover """ json_data = dict() - json_data['cols'] = [{'id': xcol, - 'label': as_unicode(labels[xcol]), - 'type': 'string'}] + json_data["cols"] = [ + {"id": xcol, "label": as_unicode(labels[xcol]), "type": "string"} + ] for ycol in ycols: - json_data['cols'].append({'id': ycol, - 'label': as_unicode(labels[ycol]), - 'type': 'number'}) - json_data['rows'] = [] + json_data["cols"].append( + {"id": ycol, "label": as_unicode(labels[ycol]), "type": "number"} + ) + json_data["rows"] = [] for value in value_columns: - row = {'c': []} + row = {"c": []} if isinstance(value[xcol], datetime.date): - row['c'].append({'v': (str(value[xcol]))}) + row["c"].append({"v": (str(value[xcol]))}) else: - row['c'].append({'v': (value[xcol])}) + row["c"].append({"v": (value[xcol])}) for ycol in ycols: if value[ycol]: - row['c'].append({'v': (value[ycol])}) + row["c"].append({"v": (value[ycol])}) else: - row['c'].append({'v': 0}) - json_data['rows'].append(row) + row["c"].append({"v": 0}) + json_data["rows"].append(row) return json_data diff --git a/flask_appbuilder/charts/widgets.py b/flask_appbuilder/charts/widgets.py index 3d9ef74f52..dfd4dfbcdb 100644 --- a/flask_appbuilder/charts/widgets.py +++ b/flask_appbuilder/charts/widgets.py @@ -2,12 +2,12 @@ class ChartWidget(RenderTemplateWidget): - template = 'appbuilder/general/widgets/chart.html' + template = "appbuilder/general/widgets/chart.html" class DirectChartWidget(RenderTemplateWidget): - template = 'appbuilder/general/widgets/direct_chart.html' + template = "appbuilder/general/widgets/direct_chart.html" class MultipleChartWidget(RenderTemplateWidget): - template = 'appbuilder/general/widgets/multiple_chart.html' + template = "appbuilder/general/widgets/multiple_chart.html" diff --git a/flask_appbuilder/cli.py b/flask_appbuilder/cli.py index d0a96857b6..88ee1486a6 100644 --- a/flask_appbuilder/cli.py +++ b/flask_appbuilder/cli.py @@ -8,9 +8,7 @@ from flask import current_app from flask.cli import with_appcontext -from . const import ( - AUTH_DB, AUTH_LDAP, AUTH_OAUTH, AUTH_OID, AUTH_REMOTE_USER -) +from .const import AUTH_DB, AUTH_LDAP, AUTH_OAUTH, AUTH_OID, AUTH_REMOTE_USER SQLA_REPO_URL = ( @@ -35,7 +33,7 @@ def fab(): pass -@fab.command('create-admin') +@fab.command("create-admin") @click.option("--username", default="admin", prompt="Username") @click.option("--firstname", default="admin", prompt="User first name") @click.option("--lastname", default="user", prompt="User last name") @@ -118,7 +116,7 @@ def version(): click.style( "F.A.B Version: {0}.".format(current_app.appbuilder.version), bg="blue", - fg="white" + fg="white", ) ) @@ -134,7 +132,9 @@ def security_cleanup(): @fab.command("security-converge") -@click.option('--dry-run', '-d', is_flag=True, help="Dry run & print state transitions.") +@click.option( + "--dry-run", "-d", is_flag=True, help="Dry run & print state transitions." +) @with_appcontext def security_converge(dry_run=False): """ @@ -144,16 +144,16 @@ def security_converge(dry_run=False): if dry_run: click.echo(click.style("Computed security converge:", fg="green")) click.echo(click.style("Add to Roles:", fg="green")) - for _from, _to in state_transitions['add'].items(): + for _from, _to in state_transitions["add"].items(): click.echo(f"Where {_from} add {_to}") click.echo(click.style("Del from Roles:", fg="green")) - for pvm in state_transitions['del_role_pvm']: + for pvm in state_transitions["del_role_pvm"]: click.echo(pvm) click.echo(click.style("Remove views:", fg="green")) - for views in state_transitions['del_views']: + for views in state_transitions["del_views"]: click.echo(views) click.echo(click.style("Remove permissions:", fg="green")) - for perms in state_transitions['del_perms']: + for perms in state_transitions["del_perms"]: click.echo(perms) else: click.echo(click.style("Finished security converge", fg="green")) @@ -293,7 +293,7 @@ def collect_static(static_folder): click.echo( click.style( "Static folder does not exist creating: %s" % app_static_path, - fg="green" + fg="green", ) ) os.makedirs(app_static_path) @@ -304,8 +304,7 @@ def collect_static(static_folder): except Exception: click.echo( click.style( - "Appbuilder static folder already exists on your project", - fg="red" + "Appbuilder static folder already exists on your project", fg="red" ) ) diff --git a/flask_appbuilder/console.py b/flask_appbuilder/console.py index 99dab2a8cb..a8a26ca6f1 100644 --- a/flask_appbuilder/console.py +++ b/flask_appbuilder/console.py @@ -22,8 +22,13 @@ # Fall back to Python 2's urllib2 from urllib2 import urlopen -click.echo(click.style("fabmanager is going to be deprecated in 2.2.X, you can use " - "the same commands on the improved 'flask fab '", fg="red")) +click.echo( + click.style( + "fabmanager is going to be deprecated in 2.2.X, you can use " + "the same commands on the improved 'flask fab '", + fg="red", + ) +) SQLA_REPO_URL = ( "https://github.com/dpgaspar/Flask-AppBuilder-Skeleton/archive/master.zip" diff --git a/flask_appbuilder/const.py b/flask_appbuilder/const.py index 882efd9fb1..9a9fcc6681 100644 --- a/flask_appbuilder/const.py +++ b/flask_appbuilder/const.py @@ -37,18 +37,10 @@ LOGMSG_WAR_SEC_DEL_PERMVIEW = ( "Refused to delete permission view, assoc with role exists {}.{} {}" ) -LOGMSG_WAR_SEC_DEL_PERMISSION = ( - "Refused to delete, permission {} does not exist" -) -LOGMSG_WAR_SEC_DEL_VIEWMENU = ( - "Refused to delete, view menu {} does not exist" -) -LOGMSG_WAR_SEC_DEL_PERM_PVM = ( - "Refused to delete permission {}, PVM exists {}" -) -LOGMSG_WAR_SEC_DEL_VIEWMENU_PVM = ( - "Refused to delete view menu {}, PVM exists {}" -) +LOGMSG_WAR_SEC_DEL_PERMISSION = "Refused to delete, permission {} does not exist" +LOGMSG_WAR_SEC_DEL_VIEWMENU = "Refused to delete, view menu {} does not exist" +LOGMSG_WAR_SEC_DEL_PERM_PVM = "Refused to delete permission {}, PVM exists {}" +LOGMSG_WAR_SEC_DEL_VIEWMENU_PVM = "Refused to delete view menu {}, PVM exists {}" LOGMSG_ERR_SEC_ADD_PERMROLE = "Add Permission to Role Error: {0}" """ Error adding permission to role, format with err message """ LOGMSG_ERR_SEC_DEL_PERMROLE = "Remove Permission to Role Error: {0}" diff --git a/flask_appbuilder/exceptions.py b/flask_appbuilder/exceptions.py index d6c279888f..99eebaaeec 100644 --- a/flask_appbuilder/exceptions.py +++ b/flask_appbuilder/exceptions.py @@ -1,18 +1,22 @@ class FABException(Exception): """Base FAB Exception""" + pass class InvalidColumnFilterFABException(FABException): """Invalid column for filter""" + pass class InvalidOperationFilterFABException(FABException): """Invalid operation for filter""" + pass class InvalidOrderByColumnFABException(FABException): """Invalid order by column""" + pass diff --git a/flask_appbuilder/filemanager.py b/flask_appbuilder/filemanager.py index 402f070c8c..1147d408fc 100644 --- a/flask_appbuilder/filemanager.py +++ b/flask_appbuilder/filemanager.py @@ -56,8 +56,8 @@ def is_file_allowed(self, filename): if not self.allowed_extensions: return True return ( - "." in filename and - filename.rsplit(".", 1)[1].lower() in self.allowed_extensions + "." in filename + and filename.rsplit(".", 1)[1].lower() in self.allowed_extensions ) def generate_name(self, obj, file_data): diff --git a/flask_appbuilder/filters.py b/flask_appbuilder/filters.py index 98a2b48f5f..22cbc48393 100755 --- a/flask_appbuilder/filters.py +++ b/flask_appbuilder/filters.py @@ -160,7 +160,7 @@ def is_item_visible(self, permission: str, item: str) -> bool: if PERMISSION_PREFIX in permission: method = permission.split(PERMISSION_PREFIX)[1] else: - if hasattr(_view, 'actions') and _view.actions.get(permission): + if hasattr(_view, "actions") and _view.actions.get(permission): permission_name = _view.get_action_permission_name(permission) if permission_name not in _view.base_permissions: return False diff --git a/flask_appbuilder/forms.py b/flask_appbuilder/forms.py index 50dbba7dde..994ad2e875 100644 --- a/flask_appbuilder/forms.py +++ b/flask_appbuilder/forms.py @@ -9,7 +9,7 @@ FloatField, IntegerField, StringField, - TextAreaField + TextAreaField, ) from wtforms import validators @@ -20,14 +20,14 @@ DatePickerWidget, DateTimePickerWidget, Select2ManyWidget, - Select2Widget + Select2Widget, ) from .models.mongoengine.fields import MongoFileField, MongoImageField from .upload import ( BS3FileUploadFieldWidget, BS3ImageUploadFieldWidget, FileUploadField, - ImageUploadField + ImageUploadField, ) from .validators import Unique @@ -46,22 +46,26 @@ class FieldConverter(object): it has a conversion table with type method checks from model interfaces, these methods are invoked with a column name """ - conversion_table = (('is_image', ImageUploadField, BS3ImageUploadFieldWidget), - ('is_file', FileUploadField, BS3FileUploadFieldWidget), - ('is_gridfs_file', MongoFileField, BS3FileUploadFieldWidget), - ('is_gridfs_image', MongoImageField, BS3ImageUploadFieldWidget), - ('is_text', TextAreaField, BS3TextAreaFieldWidget), - ('is_binary', TextAreaField, BS3TextAreaFieldWidget), - ('is_string', StringField, BS3TextFieldWidget), - ('is_integer', IntegerField, BS3TextFieldWidget), - ('is_numeric', DecimalField, BS3TextFieldWidget), - ('is_float', FloatField, BS3TextFieldWidget), - ('is_boolean', BooleanField, None), - ('is_date', DateField, DatePickerWidget), - ('is_datetime', DateTimeField, DateTimePickerWidget), - ) - def __init__(self, datamodel, colname, label, description, validators, default=None): + conversion_table = ( + ("is_image", ImageUploadField, BS3ImageUploadFieldWidget), + ("is_file", FileUploadField, BS3FileUploadFieldWidget), + ("is_gridfs_file", MongoFileField, BS3FileUploadFieldWidget), + ("is_gridfs_image", MongoImageField, BS3ImageUploadFieldWidget), + ("is_text", TextAreaField, BS3TextAreaFieldWidget), + ("is_binary", TextAreaField, BS3TextAreaFieldWidget), + ("is_string", StringField, BS3TextFieldWidget), + ("is_integer", IntegerField, BS3TextFieldWidget), + ("is_numeric", DecimalField, BS3TextFieldWidget), + ("is_float", FloatField, BS3TextFieldWidget), + ("is_boolean", BooleanField, None), + ("is_date", DateField, DatePickerWidget), + ("is_datetime", DateTimeField, DateTimePickerWidget), + ) + + def __init__( + self, datamodel, colname, label, description, validators, default=None + ): self.datamodel = datamodel self.colname = colname self.label = label @@ -72,29 +76,35 @@ def __init__(self, datamodel, colname, label, description, validators, default=N def convert(self): # sqlalchemy.types.Enum inherits from String, therefore `is_enum` must be # checked before checking for `is_string`: - if getattr(self.datamodel, 'is_enum')(self.colname): + if getattr(self.datamodel, "is_enum")(self.colname): col_type = self.datamodel.list_columns[self.colname].type - return EnumField(enum_class=col_type.enum_class, - enums=col_type.enums, - label=self.label, - description=self.description, - validators=self.validators, - widget=Select2Widget(), - default=self.default) + return EnumField( + enum_class=col_type.enum_class, + enums=col_type.enums, + label=self.label, + description=self.description, + validators=self.validators, + widget=Select2Widget(), + default=self.default, + ) for type_marker, field, widget in self.conversion_table: if getattr(self.datamodel, type_marker)(self.colname): if widget: - return field(self.label, - description=self.description, - validators=self.validators, - widget=widget(), - default=self.default) + return field( + self.label, + description=self.description, + validators=self.validators, + widget=widget(), + default=self.default, + ) else: - return field(self.label, - description=self.description, - validators=self.validators, - default=self.default) - log.error('Column %s Type not supported' % self.colname) + return field( + self.label, + description=self.description, + validators=self.validators, + default=self.default, + ) + log.error("Column %s Type not supported" % self.colname) class GeneralModelConverter(object): @@ -129,13 +139,19 @@ def _get_related_query_func(self, col_name, filter_rel_fields): return lambda: self.datamodel.get_related_interface(col_name).query()[1] def _get_related_pk_func(self, col_name): - return lambda obj: self.datamodel.get_related_interface( - col_name - ).get_pk_value(obj) + return lambda obj: self.datamodel.get_related_interface(col_name).get_pk_value( + obj + ) - def _convert_many_to_one(self, col_name, label, description, - lst_validators, filter_rel_fields, - form_props): + def _convert_many_to_one( + self, + col_name, + label, + description, + lst_validators, + filter_rel_fields, + form_props, + ): """ Creates a WTForm field for many to one related fields, will use a Select box based on a query. Will only @@ -150,19 +166,26 @@ def _convert_many_to_one(self, col_name, label, description, allow_blank = False else: lst_validators.append(validators.Optional()) - form_props[col_name] = \ - QuerySelectField(label, - description=description, - query_func=query_func, - get_pk_func=get_pk_func, - allow_blank=allow_blank, - validators=lst_validators, - widget=Select2Widget(extra_classes=extra_classes)) + form_props[col_name] = QuerySelectField( + label, + description=description, + query_func=query_func, + get_pk_func=get_pk_func, + allow_blank=allow_blank, + validators=lst_validators, + widget=Select2Widget(extra_classes=extra_classes), + ) return form_props - def _convert_many_to_many(self, col_name, label, description, - lst_validators, filter_rel_fields, - form_props): + def _convert_many_to_many( + self, + col_name, + label, + description, + lst_validators, + filter_rel_fields, + form_props, + ): query_func = self._get_related_query_func(col_name, filter_rel_fields) get_pk_func = self._get_related_pk_func(col_name) allow_blank = True @@ -173,7 +196,7 @@ def _convert_many_to_many(self, col_name, label, description, get_pk_func=get_pk_func, allow_blank=allow_blank, validators=lst_validators, - widget=Select2ManyWidget() + widget=Select2ManyWidget(), ) return form_props @@ -198,44 +221,59 @@ def _convert_simple(self, col_name, label, description, lst_validators, form_pro label, description, lst_validators, - default=default_value + default=default_value, ) form_props[col_name] = fc.convert() return form_props - def _convert_col(self, col_name, - label, description, - lst_validators, filter_rel_fields, - form_props): + def _convert_col( + self, + col_name, + label, + description, + lst_validators, + filter_rel_fields, + form_props, + ): if self.datamodel.is_relation(col_name): - if self.datamodel.is_relation_many_to_one(col_name) or \ - self.datamodel.is_relation_one_to_one(col_name): - return self._convert_many_to_one(col_name, label, - description, - lst_validators, - filter_rel_fields, - form_props) - elif self.datamodel.is_relation_many_to_many(col_name) or \ - self.datamodel.is_relation_one_to_many(col_name): - return self._convert_many_to_many(col_name, label, - description, - lst_validators, - filter_rel_fields, - form_props) + if self.datamodel.is_relation_many_to_one( + col_name + ) or self.datamodel.is_relation_one_to_one(col_name): + return self._convert_many_to_one( + col_name, + label, + description, + lst_validators, + filter_rel_fields, + form_props, + ) + elif self.datamodel.is_relation_many_to_many( + col_name + ) or self.datamodel.is_relation_one_to_many(col_name): + return self._convert_many_to_many( + col_name, + label, + description, + lst_validators, + filter_rel_fields, + form_props, + ) else: log.warning("Relation {0} not supported".format(col_name)) else: return self._convert_simple( - col_name, - label, - description, - lst_validators, - form_props + col_name, label, description, lst_validators, form_props ) - def create_form(self, label_columns=None, inc_columns=None, - description_columns=None, validators_columns=None, - extra_fields=None, filter_rel_fields=None): + def create_form( + self, + label_columns=None, + inc_columns=None, + description_columns=None, + validators_columns=None, + extra_fields=None, + filter_rel_fields=None, + ): """ Converts a model to a form given @@ -270,11 +308,15 @@ def create_form(self, label_columns=None, inc_columns=None, if col_name in extra_fields: form_props[col_name] = extra_fields.get(col_name) else: - self._convert_col(col_name, self._get_label(col_name, label_columns), - self._get_description(col_name, description_columns), - self._get_validators(col_name, validators_columns), - filter_rel_fields, form_props) - return type('DynamicForm', (DynamicForm,), form_props) + self._convert_col( + col_name, + self._get_label(col_name, label_columns), + self._get_description(col_name, description_columns), + self._get_validators(col_name, validators_columns), + filter_rel_fields, + form_props, + ) + return type("DynamicForm", (DynamicForm,), form_props) class DynamicForm(FlaskForm): diff --git a/flask_appbuilder/menu.py b/flask_appbuilder/menu.py index 4bb16631cd..cca3983689 100644 --- a/flask_appbuilder/menu.py +++ b/flask_appbuilder/menu.py @@ -64,24 +64,28 @@ def get_data(self, menu=None): ) for i, item in enumerate(menu): - if item.name == '-' and not i == len(menu) - 1: - ret_list.append('-') + if item.name == "-" and not i == len(menu) - 1: + ret_list.append("-") elif item.name not in allowed_menus: continue elif item.childs: - ret_list.append({ - "name": item.name, - "icon": item.icon, - "label": str(item.label), - "childs": self.get_data(menu=item.childs) - }) + ret_list.append( + { + "name": item.name, + "icon": item.icon, + "label": str(item.label), + "childs": self.get_data(menu=item.childs), + } + ) else: - ret_list.append({ - "name": item.name, - "icon": item.icon, - "label": str(item.label), - "url": item.get_url() - }) + ret_list.append( + { + "name": item.name, + "icon": item.icon, + "label": str(item.label), + "url": item.get_url(), + } + ) return ret_list def find(self, name, menu=None): @@ -158,9 +162,9 @@ def add_separator(self, category=""): class MenuApi(BaseApi): resource_name = "menu" - @expose('/', methods=["GET"]) + @expose("/", methods=["GET"]) @protect(allow_browser_login=True) - @permission_name('get') + @permission_name("get") def get_menu_data(self): """An endpoint for retreiving the menu. --- @@ -198,5 +202,5 @@ def get_menu_data(self): class MenuApiManager(BaseManager): def register_views(self): - if self.appbuilder.app.config.get('FAB_ADD_MENU_API', True): + if self.appbuilder.app.config.get("FAB_ADD_MENU_API", True): self.appbuilder.add_api(MenuApi) diff --git a/flask_appbuilder/messages.py b/flask_appbuilder/messages.py index 69fb6b8933..6d24962db0 100644 --- a/flask_appbuilder/messages.py +++ b/flask_appbuilder/messages.py @@ -11,5 +11,5 @@ _("Save"), _("This field is required."), _("Not a valid date value"), - _("No records found") + _("No records found"), ] diff --git a/flask_appbuilder/models/decorators.py b/flask_appbuilder/models/decorators.py index e1ef947282..abb9b55f66 100644 --- a/flask_appbuilder/models/decorators.py +++ b/flask_appbuilder/models/decorators.py @@ -20,7 +20,7 @@ class MyModelView(ModelView): """ def wrap(f): - if not hasattr(f, '_col_name'): + if not hasattr(f, "_col_name"): f._col_name = col_name return f diff --git a/flask_appbuilder/models/filters.py b/flask_appbuilder/models/filters.py index 8dbf206b68..b9928c7c0b 100644 --- a/flask_appbuilder/models/filters.py +++ b/flask_appbuilder/models/filters.py @@ -176,9 +176,7 @@ def rest_add_filters(self, data): raise InvalidColumnFilterFABException( f"Filter column: {col} not allowed to filter" ) - elif not self._rest_check_valid_filter_operation( - col, opr - ): + elif not self._rest_check_valid_filter_operation(col, opr): raise InvalidOperationFilterFABException( f"Filter operation: {opr} not allowed on column: {col}" ) diff --git a/flask_appbuilder/models/generic/__init__.py b/flask_appbuilder/models/generic/__init__.py index e5aa62402c..a1e1432e97 100644 --- a/flask_appbuilder/models/generic/__init__.py +++ b/flask_appbuilder/models/generic/__init__.py @@ -369,7 +369,7 @@ def all(self): items = self._order_by(items, self._order_by_cmd) total_length = len(items) if self._limit != 0: - items = items[self._offset: self._offset + self._limit] + items = items[self._offset : self._offset + self._limit] return total_length, items def add(self, model): diff --git a/flask_appbuilder/models/mongoengine/interface.py b/flask_appbuilder/models/mongoengine/interface.py index 6abfacdcff..68eb1bf18e 100644 --- a/flask_appbuilder/models/mongoengine/interface.py +++ b/flask_appbuilder/models/mongoengine/interface.py @@ -11,7 +11,7 @@ ListField, ObjectIdField, ReferenceField, - StringField + StringField, ) from . import filters @@ -20,7 +20,7 @@ from ...const import ( LOGMSG_ERR_DBI_ADD_GENERIC, LOGMSG_ERR_DBI_DEL_GENERIC, - LOGMSG_ERR_DBI_EDIT_GENERIC + LOGMSG_ERR_DBI_EDIT_GENERIC, ) log = logging.getLogger(__name__) @@ -84,7 +84,7 @@ def query( log.warn("Retrieving %s %s items from DB" % (count, str(self.obj))) else: # get data segment for paginated page offset = (page or 0) * page_size - objs = objs[offset: offset + page_size] + objs = objs[offset : offset + page_size] return count, objs @@ -138,8 +138,9 @@ def is_gridfs_image(self, col_name): def is_relation(self, col_name): try: - return (isinstance(self.obj._fields[col_name], ReferenceField) or - isinstance(self.obj._fields[col_name], ListField)) + return isinstance(self.obj._fields[col_name], ReferenceField) or isinstance( + self.obj._fields[col_name], ListField + ) except Exception: return False diff --git a/flask_appbuilder/models/sqla/interface.py b/flask_appbuilder/models/sqla/interface.py index cd8c164cf9..aa2f62d0c4 100644 --- a/flask_appbuilder/models/sqla/interface.py +++ b/flask_appbuilder/models/sqla/interface.py @@ -169,10 +169,10 @@ def query( # MSSQL exception page/limit must have an order by if ( - page - and page_size - and not order_column - and self.session.bind.dialect.name == "mssql" + page + and page_size + and not order_column + and self.session.bind.dialect.name == "mssql" ): pk_name = self.get_pk_name() query = query.order_by(pk_name) diff --git a/flask_appbuilder/security/decorators.py b/flask_appbuilder/security/decorators.py index 92736a271b..a0134b531c 100644 --- a/flask_appbuilder/security/decorators.py +++ b/flask_appbuilder/security/decorators.py @@ -1,15 +1,7 @@ import functools import logging -from flask import ( - current_app, - flash, - jsonify, - make_response, - redirect, - request, - url_for -) +from flask import current_app, flash, jsonify, make_response, redirect, request, url_for from flask_jwt_extended import verify_jwt_in_request from flask_login import current_user @@ -17,7 +9,7 @@ from ..const import ( FLAMSG_ERR_SEC_ACCESS_DENIED, LOGMSG_ERR_SEC_ACCESS_DENIED, - PERMISSION_PREFIX + PERMISSION_PREFIX, ) log = logging.getLogger(__name__) @@ -48,7 +40,7 @@ def do_something_else(self): """ def _protect(f): - if hasattr(f, '_permission_name'): + if hasattr(f, "_permission_name"): permission_str = f._permission_name else: permission_str = f.__name__ @@ -64,29 +56,25 @@ def wraps(self, *args, **kwargs): if permission_str not in self.base_permissions: return self.response_401() if current_app.appbuilder.sm.is_item_public( - permission_str, - class_permission_name + permission_str, class_permission_name ): return f(self, *args, **kwargs) if not (self.allow_browser_login or allow_browser_login): verify_jwt_in_request() if current_app.appbuilder.sm.has_access( - permission_str, - class_permission_name + permission_str, class_permission_name ): return f(self, *args, **kwargs) elif self.allow_browser_login or allow_browser_login: if not current_user.is_authenticated: verify_jwt_in_request() if current_app.appbuilder.sm.has_access( - permission_str, - class_permission_name + permission_str, class_permission_name ): return f(self, *args, **kwargs) log.warning( LOGMSG_ERR_SEC_ACCESS_DENIED.format( - permission_str, - class_permission_name + permission_str, class_permission_name ) ) return self.response_401() @@ -104,7 +92,7 @@ def has_access(f): By default the permission's name is the methods name. """ - if hasattr(f, '_permission_name'): + if hasattr(f, "_permission_name"): permission_str = f._permission_name else: permission_str = f.__name__ @@ -115,24 +103,21 @@ def wraps(self, *args, **kwargs): _permission_name = self.method_permission_name.get(f.__name__) if _permission_name: permission_str = "{}{}".format(PERMISSION_PREFIX, _permission_name) - if (permission_str in self.base_permissions and - self.appbuilder.sm.has_access( - permission_str, - self.class_permission_name - )): + if permission_str in self.base_permissions and self.appbuilder.sm.has_access( + permission_str, self.class_permission_name + ): return f(self, *args, **kwargs) else: log.warning( LOGMSG_ERR_SEC_ACCESS_DENIED.format( - permission_str, - self.__class__.__name__ + permission_str, self.__class__.__name__ ) ) flash(as_unicode(FLAMSG_ERR_SEC_ACCESS_DENIED), "danger") return redirect( url_for( self.appbuilder.sm.auth_view.__class__.__name__ + ".login", - next=request.url + next=request.url, ) ) @@ -149,7 +134,7 @@ def has_access_api(f): this will return a message and HTTP 401 is case of unauthorized access. """ - if hasattr(f, '_permission_name'): + if hasattr(f, "_permission_name"): permission_str = f._permission_name else: permission_str = f.__name__ @@ -160,29 +145,23 @@ def wraps(self, *args, **kwargs): _permission_name = self.method_permission_name.get(f.__name__) if _permission_name: permission_str = "{}{}".format(PERMISSION_PREFIX, _permission_name) - if (permission_str in self.base_permissions and - self.appbuilder.sm.has_access( - permission_str, - self.class_permission_name - )): + if permission_str in self.base_permissions and self.appbuilder.sm.has_access( + permission_str, self.class_permission_name + ): return f(self, *args, **kwargs) else: log.warning( LOGMSG_ERR_SEC_ACCESS_DENIED.format( - permission_str, - self.__class__.__name__ + permission_str, self.__class__.__name__ ) ) response = make_response( jsonify( - { - 'message': str(FLAMSG_ERR_SEC_ACCESS_DENIED), - 'severity': 'danger' - } + {"message": str(FLAMSG_ERR_SEC_ACCESS_DENIED), "severity": "danger"} ), - 401 + 401, ) - response.headers['Content-Type'] = "application/json" + response.headers["Content-Type"] = "application/json" return response f._permission_name = permission_str diff --git a/flask_appbuilder/security/manager.py b/flask_appbuilder/security/manager.py index 390ff9ba09..078a4bf870 100644 --- a/flask_appbuilder/security/manager.py +++ b/flask_appbuilder/security/manager.py @@ -17,7 +17,7 @@ from .registerviews import ( RegisterUserDBView, RegisterUserOAuthView, - RegisterUserOIDView + RegisterUserOIDView, ) from .views import ( AuthDBView, @@ -38,7 +38,7 @@ UserOIDModelView, UserRemoteUserModelView, UserStatsChartView, - ViewMenuModelView + ViewMenuModelView, ) from ..basemanager import BaseManager from ..const import ( @@ -52,7 +52,7 @@ LOGMSG_WAR_SEC_LOGIN_FAILED, LOGMSG_WAR_SEC_NO_USER, LOGMSG_WAR_SEC_NOLDAP_OBJ, - PERMISSION_PREFIX + PERMISSION_PREFIX, ) log = logging.getLogger(__name__) @@ -294,7 +294,7 @@ def create_jwt_manager(self, app) -> JWTManager: return jwt_manager def create_builtin_roles(self): - return self.appbuilder.get_app.config.get('FAB_ROLES', {}) + return self.appbuilder.get_app.config.get("FAB_ROLES", {}) @property def get_url_for_registeruser(self): @@ -455,9 +455,7 @@ def wraps(provider, response=None): if not type(ret) == dict: log.error( "OAuth user info decorated function " - "did not returned a dict, but: {0}".format( - type(ret) - ) + "did not returned a dict, but: {0}".format(type(ret)) ) return {} return ret @@ -590,7 +588,7 @@ def _azure_jwt_token_parse(self, id_token): return jwt_decoded_payload def register_views(self): - if not self.appbuilder.app.config.get('FAB_ADD_SECURITY_VIEWS', True): + if not self.appbuilder.app.config.get("FAB_ADD_SECURITY_VIEWS", True): return # Security APIs self.appbuilder.add_api(self.security_api) @@ -669,9 +667,7 @@ def register_views(self): category="Security", ) self.appbuilder.menu.add_separator("Security") - if self.appbuilder.app.config.get( - "FAB_ADD_SECURITY_PERMISSION_VIEW", True - ): + if self.appbuilder.app.config.get("FAB_ADD_SECURITY_PERMISSION_VIEW", True): self.appbuilder.add_view( self.permissionmodelview, "Base Permissions", @@ -687,7 +683,9 @@ def register_views(self): label=_("Views/Menus"), category="Security", ) - if self.appbuilder.app.config.get("FAB_ADD_SECURITY_PERMISSION_VIEWS_VIEW", True): + if self.appbuilder.app.config.get( + "FAB_ADD_SECURITY_PERMISSION_VIEWS_VIEW", True + ): self.appbuilder.add_view( self.permissionviewmodelview, "Permission on Views/Menus", @@ -700,7 +698,7 @@ def create_db(self): """ Setups the DB, creates admin and public roles if they don't exist. """ - roles_mapping = self.appbuilder.get_app.config.get('FAB_ROLES_MAPPING', {}) + roles_mapping = self.appbuilder.get_app.config.get("FAB_ROLES_MAPPING", {}) for pk, name in roles_mapping.items(): self.update_role(pk, name) for role_name in self.builtin_roles: @@ -946,7 +944,7 @@ def auth_user_ldap(self, username, password): except ldap.LDAPError as e: msg = None if isinstance(e, dict): - msg = getattr(e, 'message', None) + msg = getattr(e, "message", None) if msg is not None and "desc" in msg: log.error(LOGMSG_ERR_SEC_AUTH_LDAP.format(e.message["desc"])) return None @@ -986,7 +984,7 @@ def auth_user_remote_user(self, username): username=username, first_name=username, last_name="-", - email=username + '@email.notfound', + email=username + "@email.notfound", role=self.find_role(self.auth_user_registration_role), ) @@ -1062,10 +1060,7 @@ def is_item_public(self, permission_name, view_name): return False def _has_access_builtin_roles( - self, - role, - permission_name: str, - view_name: str + self, role, permission_name: str, view_name: str ) -> bool: """ Checks permission on builtin role @@ -1074,13 +1069,14 @@ def _has_access_builtin_roles( for pvm in builtin_pvms: _view_name = pvm[0] _permission_name = pvm[1] - if (re.match(_view_name, view_name) and - re.match(_permission_name, permission_name)): + if re.match(_view_name, view_name) and re.match( + _permission_name, permission_name + ): return True return False def _has_view_access( - self, user: object, permission_name: str, view_name: str + self, user: object, permission_name: str, view_name: str ) -> bool: roles = user.roles db_role_ids = list() @@ -1088,27 +1084,16 @@ def _has_view_access( # because no database query is needed for role in roles: if role.name in self.builtin_roles: - if self._has_access_builtin_roles( - role, - permission_name, - view_name - ): + if self._has_access_builtin_roles(role, permission_name, view_name): return True else: db_role_ids.append(role.id) # Then check against database-stored roles - return self.exist_permission_on_roles( - view_name, - permission_name, - db_role_ids, - ) + return self.exist_permission_on_roles(view_name, permission_name, db_role_ids) def _get_user_permission_view_menus( - self, - user: object, - permission_name: str, - view_menus_name: List[str] + self, user: object, permission_name: str, view_menus_name: List[str] ) -> Set[str]: """ Return a set of view menu names with a certain permission name @@ -1128,9 +1113,7 @@ def _get_user_permission_view_menus( if role.name in self.builtin_roles: for view_menu_name in view_menus_name: if self._has_access_builtin_roles( - role, - permission_name, - view_menu_name + role, permission_name, view_menu_name ): result.add(view_menu_name) else: @@ -1138,7 +1121,9 @@ def _get_user_permission_view_menus( # Then check against database-stored roles pvms_names = [ pvm.view_menu.name - for pvm in self.find_roles_permission_view_menus(permission_name, db_role_ids) + for pvm in self.find_roles_permission_view_menus( + permission_name, db_role_ids + ) ] result.update(pvms_names) return result @@ -1157,13 +1142,16 @@ def has_access(self, permission_name, view_name): def get_user_menu_access(self, menu_names: List[str] = None) -> Set[str]: if current_user.is_authenticated: return self._get_user_permission_view_menus( - g.user, "menu_access", view_menus_name=menu_names) + g.user, "menu_access", view_menus_name=menu_names + ) elif current_user_jwt: return self._get_user_permission_view_menus( - current_user_jwt, "menu_access", view_menus_name=menu_names) + current_user_jwt, "menu_access", view_menus_name=menu_names + ) else: return self._get_user_permission_view_menus( - None, "menu_access", view_menus_name=menu_names) + None, "menu_access", view_menus_name=menu_names + ) def add_permissions_view(self, base_permissions, view_menu): """ @@ -1206,8 +1194,10 @@ def add_permissions_view(self, base_permissions, view_menu): for role in roles: self.del_permission_role(role, perm) self.del_permission_view_menu(perm_view.permission.name, view_menu) - elif (self.auth_role_admin not in self.builtin_roles and - perm_view not in role_admin.permissions): + elif ( + self.auth_role_admin not in self.builtin_roles + and perm_view not in role_admin.permissions + ): # Role Admin must have all permissions self.add_permission_role(role_admin, perm_view) @@ -1262,42 +1252,43 @@ def _get_new_old_permissions(baseview) -> Dict: method_name ) # Actions do not get prefix when normally defined - if (hasattr(baseview, 'actions') and - baseview.actions.get(old_permission_name)): - permission_prefix = '' + if hasattr(baseview, "actions") and baseview.actions.get( + old_permission_name + ): + permission_prefix = "" else: permission_prefix = PERMISSION_PREFIX if old_permission_name: if PERMISSION_PREFIX + permission_name not in ret: - ret[ - PERMISSION_PREFIX + permission_name - ] = {permission_prefix + old_permission_name, } + ret[PERMISSION_PREFIX + permission_name] = { + permission_prefix + old_permission_name + } else: - ret[ - PERMISSION_PREFIX + permission_name - ].add(permission_prefix + old_permission_name) + ret[PERMISSION_PREFIX + permission_name].add( + permission_prefix + old_permission_name + ) return ret @staticmethod def _add_state_transition( - state_transition: Dict, - old_view_name: str, - old_perm_name: str, - view_name: str, - perm_name: str + state_transition: Dict, + old_view_name: str, + old_perm_name: str, + view_name: str, + perm_name: str, ) -> None: - old_pvm = state_transition['add'].get((old_view_name, old_perm_name)) + old_pvm = state_transition["add"].get((old_view_name, old_perm_name)) if old_pvm: - state_transition['add'][(old_view_name, old_perm_name)].add( + state_transition["add"][(old_view_name, old_perm_name)].add( (view_name, perm_name) ) else: - state_transition['add'][(old_view_name, old_perm_name)] = { + state_transition["add"][(old_view_name, old_perm_name)] = { (view_name, perm_name) } - state_transition['del_role_pvm'].add((old_view_name, old_perm_name)) - state_transition['del_views'].add(old_view_name) - state_transition['del_perms'].add(old_perm_name) + state_transition["del_role_pvm"].add((old_view_name, old_perm_name)) + state_transition["del_views"].add(old_view_name) + state_transition["del_perms"].add(old_perm_name) @staticmethod def _update_del_transitions(state_transitions: Dict, baseviews: List) -> None: @@ -1311,15 +1302,12 @@ def _update_del_transitions(state_transitions: Dict, baseviews: List) -> None: :return: """ for baseview in baseviews: - state_transitions['del_views'].discard(baseview.class_permission_name) + state_transitions["del_views"].discard(baseview.class_permission_name) for permission in baseview.base_permissions: - state_transitions['del_role_pvm'].discard( - ( - baseview.class_permission_name, - permission - ) + state_transitions["del_role_pvm"].discard( + (baseview.class_permission_name, permission) ) - state_transitions['del_perms'].discard(permission) + state_transitions["del_perms"].discard(permission) def create_state_transitions(self, baseviews: List, menus: List) -> Dict: """ @@ -1337,10 +1325,10 @@ def create_state_transitions(self, baseviews: List, menus: List) -> Dict: :return: Dict with state transitions """ state_transitions = { - 'add': {}, - 'del_role_pvm': set(), - 'del_views': set(), - 'del_perms': set() + "add": {}, + "del_role_pvm": set(), + "del_views": set(), + "del_perms": set(), } for baseview in baseviews: add_all_flag = False @@ -1362,7 +1350,7 @@ def create_state_transitions(self, baseviews: List, menus: List) -> Dict: old_view_name, old_perm_name, new_view_name, - new_perm_name + new_perm_name, ) else: old_perm_names = permission_mapping.get(new_perm_name) or set() @@ -1372,7 +1360,7 @@ def create_state_transitions(self, baseviews: List, menus: List) -> Dict: old_view_name, old_perm_name, new_view_name, - new_perm_name + new_perm_name, ) self._update_del_transitions(state_transitions, baseviews) return state_transitions @@ -1400,7 +1388,7 @@ def security_converge(self, baseviews: List, menus: List, dry=False) -> Dict: for role in roles: permissions = list(role.permissions) for pvm in permissions: - new_pvm_states = state_transitions['add'].get( + new_pvm_states = state_transitions["add"].get( (pvm.view_menu.name, pvm.permission.name) ) if not new_pvm_states: @@ -1411,16 +1399,17 @@ def security_converge(self, baseviews: List, menus: List, dry=False) -> Dict: ) self.add_permission_role(role, new_pvm) if (pvm.view_menu.name, pvm.permission.name) in state_transitions[ - 'del_role_pvm' + "del_role_pvm" ]: self.del_permission_role(role, pvm) - for pvm in state_transitions['del_role_pvm']: + for pvm in state_transitions["del_role_pvm"]: self.del_permission_view_menu(pvm[1], pvm[0], cascade=False) - for view_name in state_transitions['del_views']: + for view_name in state_transitions["del_views"]: self.del_view_menu(view_name) - for permission_name in state_transitions['del_perms']: + for permission_name in state_transitions["del_perms"]: self.del_permission(permission_name) return state_transitions + """ --------------------------- INTERFACE ABSTRACT METHODS @@ -1532,17 +1521,12 @@ def find_permission(self, name): raise NotImplementedError def find_roles_permission_view_menus( - self, - permission_name: str, - role_ids: List[int], + self, permission_name: str, role_ids: List[int] ): raise NotImplementedError def exist_permission_on_roles( - self, - view_name: str, - permission_name: str, - role_ids: List[int], + self, view_name: str, permission_name: str, role_ids: List[int] ) -> bool: """ Finds and returns permission views for a group of roles diff --git a/flask_appbuilder/security/mongoengine/manager.py b/flask_appbuilder/security/mongoengine/manager.py index 511304fc4f..506b0e30ab 100644 --- a/flask_appbuilder/security/mongoengine/manager.py +++ b/flask_appbuilder/security/mongoengine/manager.py @@ -194,10 +194,7 @@ def find_permission(self, name): return self.permission_model.objects(name=name).first() def exist_permission_on_roles( - self, - view_name: str, - permission_name: str, - role_ids: List[int], + self, view_name: str, permission_name: str, role_ids: List[int] ) -> bool: for role_id in role_ids: role = self.role_model.objects(id=role_id).first() @@ -205,7 +202,7 @@ def exist_permission_on_roles( if permissions: for permission in permissions: if (view_name == permission.view_menu.name) and ( - permission_name == permission.permission.name + permission_name == permission.permission.name ): return True return False @@ -323,10 +320,7 @@ def add_permission_view_menu(self, permission_name, view_menu_name): """ if not (permission_name and view_menu_name): return None - pv = self.find_permission_view_menu( - permission_name, - view_menu_name - ) + pv = self.find_permission_view_menu(permission_name, view_menu_name) if pv: return pv vm = self.add_view_menu(view_menu_name) diff --git a/flask_appbuilder/security/mongoengine/models.py b/flask_appbuilder/security/mongoengine/models.py index 30a03d6176..283b3d5537 100644 --- a/flask_appbuilder/security/mongoengine/models.py +++ b/flask_appbuilder/security/mongoengine/models.py @@ -8,7 +8,7 @@ IntField, ListField, ReferenceField, - StringField + StringField, ) from ..._compat import as_unicode diff --git a/flask_appbuilder/security/sqla/manager.py b/flask_appbuilder/security/sqla/manager.py index bf6930fabb..93bc01b977 100644 --- a/flask_appbuilder/security/sqla/manager.py +++ b/flask_appbuilder/security/sqla/manager.py @@ -299,10 +299,7 @@ def find_permission(self, name): ) def exist_permission_on_roles( - self, - view_name: str, - permission_name: str, - role_ids: List[int], + self, view_name: str, permission_name: str, role_ids: List[int] ) -> bool: """ Method to efficiently check if a certain permission exists @@ -318,8 +315,10 @@ def exist_permission_on_roles( .join( assoc_permissionview_role, and_( - (self.permissionview_model.id == - assoc_permissionview_role.c.permission_view_id), + ( + self.permissionview_model.id + == assoc_permissionview_role.c.permission_view_id + ) ), ) .join(self.role_model) @@ -337,14 +336,18 @@ def exist_permission_on_roles( return self.appbuilder.get_session.query(literal(True)).filter(q).scalar() return self.appbuilder.get_session.query(q).scalar() - def find_roles_permission_view_menus(self, permission_name: str, role_ids: List[int]): + def find_roles_permission_view_menus( + self, permission_name: str, role_ids: List[int] + ): return ( self.appbuilder.get_session.query(self.permissionview_model) .join( assoc_permissionview_role, and_( - (self.permissionview_model.id == - assoc_permissionview_role.c.permission_view_id), + ( + self.permissionview_model.id + == assoc_permissionview_role.c.permission_view_id + ) ), ) .join(self.role_model) @@ -352,7 +355,8 @@ def find_roles_permission_view_menus(self, permission_name: str, role_ids: List[ .join(self.viewmenu_model) .filter( self.permission_model.name == permission_name, - self.role_model.id.in_(role_ids)) + self.role_model.id.in_(role_ids), + ) ).all() def add_permission(self, name): @@ -387,9 +391,11 @@ def del_permission(self, name: str) -> bool: log.warning(c.LOGMSG_WAR_SEC_DEL_PERMISSION.format(name)) return False try: - pvms = self.get_session.query(self.permissionview_model).filter( - self.permissionview_model.permission == perm - ).all() + pvms = ( + self.get_session.query(self.permissionview_model) + .filter(self.permissionview_model.permission == perm) + .all() + ) if pvms: log.warning(c.LOGMSG_WAR_SEC_DEL_PERM_PVM.format(perm, pvms)) return False @@ -447,9 +453,11 @@ def del_view_menu(self, name: str) -> bool: log.warning(c.LOGMSG_WAR_SEC_DEL_VIEWMENU.format(name)) return False try: - pvms = self.get_session.query(self.permissionview_model).filter( - self.permissionview_model.view_menu == view_menu - ).all() + pvms = ( + self.get_session.query(self.permissionview_model) + .filter(self.permissionview_model.view_menu == view_menu) + .all() + ) if pvms: log.warning(c.LOGMSG_WAR_SEC_DEL_VIEWMENU_PVM.format(view_menu, pvms)) return False @@ -504,10 +512,7 @@ def add_permission_view_menu(self, permission_name, view_menu_name): """ if not (permission_name and view_menu_name): return None - pv = self.find_permission_view_menu( - permission_name, - view_menu_name - ) + pv = self.find_permission_view_menu(permission_name, view_menu_name) if pv: return pv vm = self.add_view_menu(view_menu_name) @@ -529,9 +534,11 @@ def del_permission_view_menu(self, permission_name, view_menu_name, cascade=True pv = self.find_permission_view_menu(permission_name, view_menu_name) if not pv: return - roles_pvs = self.get_session.query(self.role_model).filter( - self.role_model.permissions.contains(pv) - ).first() + roles_pvs = ( + self.get_session.query(self.role_model) + .filter(self.role_model.permissions.contains(pv)) + .first() + ) if roles_pvs: log.warning( c.LOGMSG_WAR_SEC_DEL_PERMVIEW.format( diff --git a/flask_appbuilder/security/views.py b/flask_appbuilder/security/views.py index cbb7150a09..ae716783ad 100644 --- a/flask_appbuilder/security/views.py +++ b/flask_appbuilder/security/views.py @@ -400,18 +400,18 @@ class UserStatsChartView(DirectByChartView): class RoleListWidget(ListWidget): - template = 'appbuilder/general/widgets/roles/list.html' + template = "appbuilder/general/widgets/roles/list.html" def __init__(self, **kwargs): - kwargs['appbuilder'] = current_app.appbuilder + kwargs["appbuilder"] = current_app.appbuilder super().__init__(**kwargs) class RoleShowWidget(ShowWidget): - template = 'appbuilder/general/widgets/roles/show.html' + template = "appbuilder/general/widgets/roles/show.html" def __init__(self, **kwargs): - kwargs['appbuilder'] = current_app.appbuilder + kwargs["appbuilder"] = current_app.appbuilder super().__init__(**kwargs) diff --git a/flask_appbuilder/tests/mongoengine/models.py b/flask_appbuilder/tests/mongoengine/models.py index 1849f9eef3..7aa34b846a 100644 --- a/flask_appbuilder/tests/mongoengine/models.py +++ b/flask_appbuilder/tests/mongoengine/models.py @@ -5,7 +5,7 @@ ImageField, IntField, ReferenceField, - StringField + StringField, ) from mongoengine import Document diff --git a/flask_appbuilder/tests/sqla/models.py b/flask_appbuilder/tests/sqla/models.py index c0572d2659..7838f64aa5 100644 --- a/flask_appbuilder/tests/sqla/models.py +++ b/flask_appbuilder/tests/sqla/models.py @@ -194,6 +194,7 @@ def insert_model2(session, i=0, model1_collection=None): model.group = model1 import random + year = random.choice(range(1900, 2012)) month = random.choice(range(1, 12)) day = random.choice(range(1, 28)) diff --git a/flask_appbuilder/tests/test_fab_cli.py b/flask_appbuilder/tests/test_fab_cli.py index d19115f32f..76a6c356dd 100644 --- a/flask_appbuilder/tests/test_fab_cli.py +++ b/flask_appbuilder/tests/test_fab_cli.py @@ -3,7 +3,11 @@ from click.testing import CliRunner from flask_appbuilder.cli import ( - create_app, create_permissions, create_user, list_users, list_views, + create_app, + create_permissions, + create_user, + list_users, + list_views, ) from .base import FABTestCase @@ -60,8 +64,6 @@ def test_list_views(self): os.environ["FLASK_APP"] = "app:app" runner = CliRunner() with runner.isolated_filesystem(): - result = runner.invoke( - list_views, [] - ) + result = runner.invoke(list_views, []) self.assertIn("List of registered views", result.output) self.assertIn(" Route:/api/v1/security", result.output) diff --git a/flask_appbuilder/upload.py b/flask_appbuilder/upload.py index 1c372c90ee..f67f7495c8 100644 --- a/flask_appbuilder/upload.py +++ b/flask_appbuilder/upload.py @@ -152,9 +152,9 @@ def process_on_store(self, obj, byte_stream): def pre_validate(self, form): if ( - self.data and - isinstance(self.data, FileStorage) and - not self.filemanager.is_file_allowed(self.data.filename) + self.data + and isinstance(self.data, FileStorage) + and not self.filemanager.is_file_allowed(self.data.filename) ): raise ValidationError(gettext("Invalid file extension")) @@ -205,9 +205,9 @@ def __init__(self, label=None, validators=None, imagemanager=None, **kwargs): def pre_validate(self, form): if ( - self.data and - isinstance(self.data, FileStorage) and - not self.imagemanager.is_file_allowed(self.data.filename) + self.data + and isinstance(self.data, FileStorage) + and not self.imagemanager.is_file_allowed(self.data.filename) ): raise ValidationError(gettext("Invalid file extension")) diff --git a/tox.ini b/tox.ini index 872e380522..833a897be9 100644 --- a/tox.ini +++ b/tox.ini @@ -1,28 +1,5 @@ [flake8] accept-encodings = utf-8 -exclude = - .tox - build - docs - bin - examples - flask_appbuilder/templates - flask_appbuilder/static - venv -ignore = - FI12 - FI15 - FI16 - FI17 - FI50 - FI51 - FI53 - FI54 - W503 - W504 - W605 -import-order-style = google -max-line-length = 90 require-code = true [testenv:flake8] @@ -63,6 +40,10 @@ commands = commands = nosetests -v --with-coverage --cover-package=flask_appbuilder flask_appbuilder/tests/test_mongoengine.py +[testenv:black] +commands = + black --check setup.py flask_appbuilder + [tox] envlist = flake8