Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: using base_filters with FilterEqualFunction not working for relation fields #2011

Merged
merged 9 commits into from
Apr 28, 2023
Empty file.
4 changes: 2 additions & 2 deletions examples/extendsecurity/testdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
import logging
import random

from .app import appbuilder, db, create_app
from .app.models import ContactGroup, Gender, Contact, Company
from app import appbuilder, db, create_app
from app.models import ContactGroup, Gender, Contact, Company


log = logging.getLogger(__name__)
Expand Down
6 changes: 4 additions & 2 deletions flask_appbuilder/models/sqla/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,8 +344,10 @@ def get_inner_filters(self, filters: Optional[Filters]) -> Filters:
if not is_column_dotted(flt.column_name):
_filters.append((flt.column_name, flt.__class__, value))
elif self.is_relation_many_to_one(
flt.column_name
) or self.is_relation_one_to_one(flt.column_name):
get_column_root_relation(flt.column_name)
) or self.is_relation_one_to_one(
get_column_root_relation(flt.column_name)
):
_filters.append((flt.column_name, flt.__class__, value))
inner_filters.add_filter_list(_filters)
return inner_filters
Expand Down
50 changes: 48 additions & 2 deletions flask_appbuilder/tests/test_mvc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,11 @@
from flask_appbuilder.models.generic import PSSession
from flask_appbuilder.models.generic.interface import GenericInterface
from flask_appbuilder.models.group import aggregate_avg, aggregate_count, aggregate_sum
from flask_appbuilder.models.sqla.filters import FilterEqual, FilterStartsWith
from flask_appbuilder.models.sqla.filters import (
FilterEqual,
FilterEqualFunction,
FilterStartsWith,
)
from flask_appbuilder.models.sqla.interface import SQLAInterface
from flask_appbuilder.views import CompactCRUDMixin, MasterDetailView, ModelView
from flask_wtf import CSRFProtect
Expand Down Expand Up @@ -566,6 +570,28 @@ class Model1Filtered2View(ModelView):
datamodel = SQLAInterface(Model1)
base_filters = [["field_integer", FilterEqual, 0]]

def get_model1_by_name(datamodel, name):
model = (
datamodel.session.query(Model1)
.filter_by(field_string=name)
.one_or_none()
)
return model

class Model2FilterEqualFunctionView(ModelView):
datamodel = SQLAInterface(Model2)
base_filters = [
[
"group",
FilterEqualFunction,
lambda: get_model1_by_name(
Model2FilterEqualFunctionView.datamodel, "test1"
),
]
]
list_columns = ["group"]
search_columns = ["field_integer"]

class Model2ChartView(ChartView):
datamodel = SQLAInterface(Model2)
chart_title = "Test Model1 Chart"
Expand Down Expand Up @@ -668,6 +694,11 @@ def enabled(self):
self.appbuilder.add_view(
Model1Filtered2View, "Model1Filtered2", category="Model1"
)
self.appbuilder.add_view(
Model2FilterEqualFunctionView,
"Model2FilterEqualFunction",
category="Model2",
)
self.appbuilder.add_view(
Model1FormattedView, "Model1FormattedView", category="Model1FormattedView"
)
Expand Down Expand Up @@ -713,7 +744,7 @@ def test_fab_views(self):
"""
Test views creation and registration
"""
self.assertEqual(len(self.appbuilder.baseviews), 37)
self.assertEqual(len(self.appbuilder.baseviews), 38)

def test_back(self):
"""
Expand Down Expand Up @@ -1309,6 +1340,21 @@ def test_model_base_filter(self):
self.assertIn("test0", data)
self.assertNotIn("test1", data)

def test_filterequalfunction_with_relation(self):
"""
Test FilterEqualFunction
"""
client = self.app.test_client()
self.browser_login(client, USERNAME_ADMIN, PASSWORD_ADMIN)

# Base filter string starts with
rv = client.get("/model2filterequalfunctionview/list/")
self.assertEqual(rv.status_code, 200)
data = rv.data.decode("utf-8")
self.assertIn("test1", data)
self.assertNotIn("test0", data)
self.assertNotIn("test2", data)

def test_model_list_method_field(self):
"""
Tests a model's field has a method
Expand Down