From f303d5b99ec92759861b881cfed7157540800f65 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dominik=20Miedzi=C5=84ski?= Date: Mon, 30 Nov 2015 11:26:49 +0100 Subject: [PATCH] Add SQLAlchemy storage --- oauth2client/contrib/sql_alchemy.py | 86 +++++++++++++++++++++ tests/contrib/test_sqlalchemy.py | 111 ++++++++++++++++++++++++++++ tox.ini | 1 + 3 files changed, 198 insertions(+) create mode 100644 oauth2client/contrib/sql_alchemy.py create mode 100644 tests/contrib/test_sqlalchemy.py diff --git a/oauth2client/contrib/sql_alchemy.py b/oauth2client/contrib/sql_alchemy.py new file mode 100644 index 000000000..f4b94c3fb --- /dev/null +++ b/oauth2client/contrib/sql_alchemy.py @@ -0,0 +1,86 @@ +"""OAuth 2.0 utilities for SQLAlchemy. +Utilities for using OAuth 2.0 in conjunction with SQLAlchemy. +Heavily inspired by equivalent Django module. +""" + +from oauth2client.client import Storage as BaseStorage +from sqlalchemy.types import PickleType + + +__author__ = 'dominik@mdzn.pl (Dominik Miedzinski)' + + +class CredentialsType(PickleType): + pass + + +class FlowType(PickleType): + pass + + +class Storage(BaseStorage): + """Store and retrieve a single credential to and from the SQLAlchemy. + This Storage helper presumes the Credentials + have been stored as a Credentials column + on a db model class. + """ + + def __init__(self, session, model_class, key_name, + key_value, property_name): + """Constructor for Storage. + Args: + session: sqlalchemy.orm.Session + model_class: SQLAlchemy mapping + key_name: string, key name for the entity that has the credentials + key_value: string, key value for the entity that has the + credentials + property_name: string, name of the property that is an + CredentialsProperty + """ + self.session = session + self.model_class = model_class + self.key_name = key_name + self.key_value = key_value + self.property_name = property_name + + def locked_get(self): + """Retrieve stored credential. + + Returns: + oauth2client.Credentials + """ + credential = None + + session = self.session + query = {self.key_name: self.key_value} + + entity = session.query(self.model_class).filter_by(**query).first() + if entity: + credential = getattr(entity, self.property_name) + if credential and hasattr(credential, 'set_store'): + credential.set_store(self) + + return credential + + def locked_put(self, credentials): + """Write a Credentials to the SQLAlchemy datastore. + + Args: + credentials: Credentials, the credentials to store. + """ + session = self.session + query = {self.key_name: self.key_value} + + entity = session.query(self.model_class).filter_by(**query).first() + if not entity: + entity = self.model_class(**query) + + setattr(entity, self.property_name, credentials) + session.add(entity) + + def locked_delete(self): + """Delete Credentials from the SQLAlchemy datastore.""" + + session = self.session + query = {self.key_name: self.key_value} + session.query(self.model_class).filter_by(**query).delete() diff --git a/tests/contrib/test_sqlalchemy.py b/tests/contrib/test_sqlalchemy.py new file mode 100644 index 000000000..bde989f1f --- /dev/null +++ b/tests/contrib/test_sqlalchemy.py @@ -0,0 +1,111 @@ +import datetime +import unittest + +from sqlalchemy import Column, create_engine, Integer +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker + +from oauth2client import GOOGLE_TOKEN_URI +from oauth2client.client import OAuth2Credentials +from oauth2client.contrib.sql_alchemy import CredentialsType, FlowType, Storage + +Base = declarative_base() + + +class DummyModel(Base): + __tablename__ = 'dummy' + + id = Column(Integer, primary_key=True) + key = Column(Integer) # we will query against this, because of ROWID + credentials = Column(CredentialsType) + flow = Column(FlowType) + + +class TestSQLAlchemyStorage(unittest.TestCase): + engine = create_engine('sqlite://') + + @classmethod + def setUpClass(cls): + Base.metadata.create_all(cls.engine) + cls.session = sessionmaker(bind=cls.engine) + + def setUp(self): + self.credentials = OAuth2Credentials( + access_token='token', + client_id='client_id', + client_secret='client_secret', + refresh_token='refresh_token', + token_expiry=datetime.datetime.utcnow(), + token_uri=GOOGLE_TOKEN_URI, + user_agent='DummyAgent', + ) + + def tearDown(self): + session = self.session() + session.query(DummyModel).filter_by(id=1).delete() + session.commit() + + def compare_credentials(self, result): + self.assertEqual(result.access_token, self.credentials.access_token) + self.assertEqual(result.client_id, self.credentials.client_id) + self.assertEqual(result.client_secret, self.credentials.client_secret) + self.assertEqual(result.refresh_token, self.credentials.refresh_token) + self.assertEqual(result.token_expiry, self.credentials.token_expiry) + self.assertEqual(result.token_uri, self.credentials.token_uri) + self.assertEqual(result.user_agent, self.credentials.user_agent) + + def test_locked_get(self): + session = self.session() + session.add(DummyModel( + key=1, + credentials=self.credentials, + )) + session.commit() + + ret = Storage( + session=session, + model_class=DummyModel, + key_name='key', + key_value=1, + property_name='credentials', + ).locked_get() + + self.compare_credentials(ret) + + def test_locked_put(self): + session = self.session() + Storage( + session=session, + model_class=DummyModel, + key_name='key', + key_value=1, + property_name='credentials', + ).locked_put(self.credentials) + session.commit() + + ret = session.query(DummyModel).filter_by(key=1).first() + self.compare_credentials(ret.credentials) + + def test_locked_delete(self): + session = self.session() + session.add(DummyModel( + key=1, + credentials=self.credentials, + )) + session.commit() + + q = session.query(DummyModel).filter_by(key=1) + self.assertIsNotNone(q.first()) + Storage( + session=session, + model_class=DummyModel, + key_name='key', + key_value=1, + property_name='credentials', + ).locked_delete() + session.commit() + self.assertIsNone(q.first()) + + +if __name__ == '__main__': # pragma: NO COVER + unittest.main() diff --git a/tox.ini b/tox.ini index 8e94c012b..ac068544b 100644 --- a/tox.ini +++ b/tox.ini @@ -10,6 +10,7 @@ basedeps = mock>=1.3.0 nose flask unittest2 + sqlalchemy deps = {[testenv]basedeps} django keyring