Skip to content

Commit

Permalink
Simplifying source_registry (#1180)
Browse files Browse the repository at this point in the history
  • Loading branch information
mistercrunch authored Sep 23, 2016
1 parent aed473d commit fc921d6
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 38 deletions.
10 changes: 7 additions & 3 deletions caravel/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from flask_appbuilder.baseviews import expose
from flask_cache import Cache
from flask_migrate import Migrate
from caravel import source_registry
from caravel.source_registry import SourceRegistry
from werkzeug.contrib.fixers import ProxyFix


Expand Down Expand Up @@ -96,7 +96,11 @@ def index(self):

sm = appbuilder.sm

src_registry = source_registry.SourceRegistry()

get_session = appbuilder.get_session

# Registering sources
module_datasource_map = app.config.get("DEFAULT_MODULE_DS_MAP")
module_datasource_map.update(app.config.get("ADDITIONAL_MODULE_DS_MAP"))
SourceRegistry.register_sources(module_datasource_map)

from caravel import views, config # noqa
8 changes: 0 additions & 8 deletions caravel/bin/caravel
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,6 @@ config = app.config

manager = Manager(app)
manager.add_command('db', MigrateCommand)
module_datasource_map = config.get("DEFAULT_MODULE_DS_MAP")
module_datasource_map.update(config.get("ADDITIONAL_MODULE_DS_MAP"))

datasources = {}
for module in module_datasource_map:
datasources[module] = __import__(module, fromlist=module_datasource_map[module])

utils.register_sources(datasources, module_datasource_map, caravel.src_registry)


@manager.option(
Expand Down
8 changes: 4 additions & 4 deletions caravel/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
from flask_appbuilder import Model
from flask_appbuilder.models.mixins import AuditMixin
from flask_appbuilder.models.decorators import renders
from flask_appbuilder.security.sqla.models import Role, PermissionView
from flask_babel import lazy_gettext as _

from pydruid.client import PyDruid
Expand All @@ -50,7 +49,8 @@
from werkzeug.datastructures import ImmutableMultiDict

import caravel
from caravel import app, db, get_session, utils, sm, src_registry
from caravel import app, db, get_session, utils, sm
from caravel.source_registry import SourceRegistry
from caravel.viz import viz_types
from caravel.utils import flasher, MetricPermException, DimSelector

Expand Down Expand Up @@ -172,7 +172,7 @@ def __repr__(self):

@property
def cls_model(self):
return src_registry.sources[self.datasource_type]
return SourceRegistry.sources[self.datasource_type]

@property
def datasource(self):
Expand Down Expand Up @@ -2028,7 +2028,7 @@ class DatasourceAccessRequest(Model, AuditMixinNullable):

@property
def cls_model(self):
return src_registry.sources[self.datasource_type]
return SourceRegistry.sources[self.datasource_type]

@property
def username(self):
Expand Down
15 changes: 7 additions & 8 deletions caravel/source_registry.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from flask import flash


class SourceRegistry(object):
""" Central Registry for all available datasource engines"""

sources = {}

def add_source(self, ds_type, cls_model):
if ds_type not in self.sources:
self.sources[ds_type] = cls_model
if self.sources[ds_type] is not cls_model:
raise Exception(
'source type: {} is already associated with Model: {}'.format(
ds_type, self.sources[ds_type]))
@classmethod
def register_sources(cls, datasource_config):
for module_name, class_names in datasource_config.items():
module_obj = __import__(module_name, fromlist=class_names)
for class_name in class_names:
source_class = getattr(module_obj, class_name)
cls.sources[source_class.type] = source_class
8 changes: 0 additions & 8 deletions caravel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,14 +426,6 @@ def readfile(filepath):
return content


def register_sources(datasources, module_datasource_map, registry):
for m in datasources:
datasource_list = module_datasource_map[m]
for ds in datasource_list:
ds_class = getattr(datasources[m], ds)
registry.add_source(ds_class.type, ds_class)


def generic_find_constraint_name(table, columns, referenced, db):
"""Utility to find a constraint name in alembic migrations"""
t = sa.Table(table, db.metadata, autoload=True, autoload_with=db.engine)
Expand Down
11 changes: 6 additions & 5 deletions caravel/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@
import caravel
from caravel import (
appbuilder, cache, db, models, viz, utils, app,
sm, ascii_art, sql_lab, src_registry
sm, ascii_art, sql_lab
)
from caravel.source_registry import SourceRegistry
from caravel.models import DatasourceAccessRequest as DAR

config = app.config
Expand Down Expand Up @@ -748,9 +749,9 @@ def add(self):
if not widget:
return redirect(self.get_redirect())

sources = src_registry.sources
sources = SourceRegistry.sources
for source in sources:
ds = db.session.query(src_registry.sources[source]).first()
ds = db.session.query(SourceRegistry.sources[source]).first()
if ds is not None:
url = "/{}/list/".format(ds.baselink)
msg = _("Click on a {} link to create a Slice".format(source))
Expand Down Expand Up @@ -1056,7 +1057,7 @@ def approve(self):
role_to_extend = request.args.get('role_to_extend')

session = db.session
datasource_class = src_registry.sources[datasource_type]
datasource_class = SourceRegistry.sources[datasource_type]
datasource = session.query(datasource_class).filter_by(
id=datasource_id).first()

Expand Down Expand Up @@ -1119,7 +1120,7 @@ def approve(self):
@log_this
def explore(self, datasource_type, datasource_id, slice_id=None):
error_redirect = '/slicemodelview/list/'
datasource_class = src_registry.sources[datasource_type]
datasource_class = SourceRegistry.sources[datasource_type]
datasources = db.session.query(datasource_class).all()
datasources = sorted(datasources, key=lambda ds: ds.full_name)
datasource = [ds for ds in datasources if int(datasource_id) == ds.id]
Expand Down
5 changes: 3 additions & 2 deletions tests/core_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,8 @@
from flask_appbuilder.security.sqla import models as ab_models

import caravel
from caravel import app, db, models, utils, appbuilder, sm, src_registry
from caravel import app, db, models, utils, appbuilder, sm
from caravel.source_registry import SourceRegistry
from caravel.models import DruidDatasource

from .base_tests import CaravelTestCase
Expand Down Expand Up @@ -243,7 +244,7 @@ def test_approve(self):
self.login('admin')

def prepare_request(ds_type, ds_name, role):
ds_class = src_registry.sources[ds_type]
ds_class = SourceRegistry.sources[ds_type]
# TODO: generalize datasource names
if ds_type == 'table':
ds = session.query(ds_class).filter(
Expand Down

0 comments on commit fc921d6

Please sign in to comment.