diff --git a/djstripe/event_handlers.py b/djstripe/event_handlers.py index 881b03329d..4e1639541a 100644 --- a/djstripe/event_handlers.py +++ b/djstripe/event_handlers.py @@ -5,6 +5,7 @@ .. moduleauthor:: Bill Huneke (@wahuneke) .. moduleauthor:: Alex Kavanaugh (@akavanau) +.. moduleauthor:: Lee Skillen (@lskillen) Implement webhook event handlers for all the models that need to respond to webhook events. @@ -17,92 +18,164 @@ from . import webhooks from .models import Charge, Customer, Card, Subscription, Plan, Transfer, Invoice, InvoiceItem -STRIPE_CRUD_EVENTS = ["created", "updated", "deleted"] - -# --------------------------- -# Charge model events -# --------------------------- -@webhooks.handler(['charge']) -def charge_webhook_handler(event, event_data, event_type, event_subtype): - versioned_charge_data = Charge(stripe_id=event_data["object"]["id"]).api_retrieve() - Charge.sync_from_stripe_data(versioned_charge_data) - - -# --------------------------- -# Customer model events -# --------------------------- @webhooks.handler_all def customer_event_attach(event, event_data, event_type, event_subtype): + """ Makes the related customer available on the event for all handlers. """ - if event_type == "customer" and event_subtype in STRIPE_CRUD_EVENTS: - stripe_customer_id = event_data["object"]["id"] + event.customer = None + crud_type = CrudType.determine(event_subtype, exact=True) + + if event_type == "customer" and crud_type.valid: + customer_stripe_id = event_data["object"]["id"] else: - stripe_customer_id = event_data["object"].get("customer", None) + customer_stripe_id = event_data["object"].get("customer", None) - if stripe_customer_id: + if customer_stripe_id: try: - event.customer = Customer.objects.get(stripe_id=stripe_customer_id) + event.customer = Customer.objects.get(stripe_id=customer_stripe_id) except Customer.DoesNotExist: pass -@webhooks.handler(['customer']) +@webhooks.handler("customer") def customer_webhook_handler(event, event_data, event_type, event_subtype): + """ Handles updates for customer objects. """ + + crud_type = CrudType.determine(event_subtype, exact=True) + if crud_type.valid and event.customer: + # As customers are tied to local users, djstripe will not create + # customers that do not already exist locally. + _handle_crud_type_event(target_cls=Customer, event_data=event_data, event_subtype=event_subtype, crud_type=crud_type) + + +@webhooks.handler("customer.source") +def customer_source_webhook_handler( + event, event_data, event_type, event_subtype): + """ Handles updates for customer source objects. """ + + source_type = event_data["object"]["object"] + + # TODO: other sources + if source_type == "card": + _handle_crud_type_event(target_cls=Card, event_data=event_data, event_subtype=event_subtype, customer=event.customer) + + +@webhooks.handler("customer.subscription") +def customer_subscription_webhook_handler(event, event_data, event_type, event_subtype): + """ Handles updates for customer subscription objects. """ + + _handle_crud_type_event(target_cls=Subscription, event_data=event_data, event_subtype=event_subtype, customer=event.customer) + + +@webhooks.handler(["transfer", "charge", "invoice", "invoiceitem", "plan"]) +def other_object_webhook_handler(event, event_data, event_type, event_subtype): + """ Handles updates for transfer, charge, invoice, invoiceitem and plan objects. """ + + target_cls = { + "charge": Charge, + "invoice": Invoice, + "invoiceitem": InvoiceItem, + "plan": Plan, + "transfer": Transfer + }.get(event_type) + + _handle_crud_type_event(target_cls=target_cls, event_data=event_data, event_subtype=event_subtype, customer=event.customer) + + +# +# Helpers +# + +class CrudType(object): + """ Helper object to determine CRUD-like event state. """ - customer = event.customer - if customer: - if event_subtype in STRIPE_CRUD_EVENTS: - versioned_customer_data = Customer(stripe_id=event_data["object"]["id"]).api_retrieve() - Customer.sync_from_stripe_data(versioned_customer_data) - - if event_subtype == "deleted": - customer.purge() -# elif event_subtype.startswith("discount."): -# pass # TODO - elif event_subtype.startswith("source."): - source_type = event_data["object"]["object"] - - # TODO: other sources - if source_type == "card": - versioned_card_data = Card(stripe_id=event_data["object"]["id"], customer=customer).api_retrieve() - Card.sync_from_stripe_data(versioned_card_data) - elif event_subtype.startswith("subscription."): - versioned_subscription_data = Subscription(stripe_id=event_data["object"]["id"], customer=customer).api_retrieve() - Subscription.sync_from_stripe_data(versioned_subscription_data) - - -# --------------------------- -# Transfer model events -# --------------------------- -@webhooks.handler(["transfer"]) -def transfer_webhook_handler(event, event_data, event_type, event_subtype): - versioned_transfer_data = Transfer(stripe_id=event_data["object"]["id"]).api_retrieve() - Transfer.sync_from_stripe_data(versioned_transfer_data) - - -# --------------------------- -# Invoice model events -# --------------------------- -@webhooks.handler(['invoice']) -def invoice_webhook_handler(event, event_data, event_type, event_subtype): - versioned_invoice_data = Invoice(stripe_id=event_data["object"]["id"]).api_retrieve() - Invoice.sync_from_stripe_data(versioned_invoice_data) - - -# --------------------------- -# InvoiceItem model events -# --------------------------- -@webhooks.handler(['invoiceitem']) -def invoiceitem_webhook_handler(event, event_data, event_type, event_subtype): - versioned_invoiceitem_data = InvoiceItem(stripe_id=event_data["object"]["id"]).api_retrieve() - InvoiceItem.sync_from_stripe_data(versioned_invoiceitem_data) - - -# --------------------------- -# Plan model events -# --------------------------- -@webhooks.handler(['plan']) -def plan_webhook_handler(event, event_data, event_type, event_subtype): - versioned_plan_data = Plan(stripe_id=event_data["object"]["id"]).api_retrieve() - Plan.sync_from_stripe_data(versioned_plan_data) + created = False + updated = False + deleted = False + + def __init__(self, **kwargs): + for k, v in kwargs.items(): + setattr(self, k, v) + + @property + def valid(self): + """ Returns True if this is a CRUD-like event. """ + + return self.created or self.updated or self.deleted + + @classmethod + def determine(cls, event_subtype, exact=False): + """ + Determines if the event subtype is a crud_type (without the 'R') event. + + :param event_subtype: The event subtype to examine. + :type event_subtype: string (``str``/`unicode``) + :param exact: If True, match crud_type to event subtype string exactly. + :param type: ``bool`` + :returns: The CrudType state object. + :rtype: ``CrudType`` + """ + + def check(crud_type_event): + if exact: + return event_subtype == crud_type_event + else: + return event_subtype.endswith(crud_type_event) + + created = updated = deleted = False + + if check("updated"): + updated = True + elif check("created"): + created = True + elif check("deleted"): + deleted = True + + return cls(created=created, updated=updated, deleted=deleted) + + +def _handle_crud_type_event(target_cls, event_data, event_subtype, stripe_id=None, customer=None, crud_type=None): + """ + Helper to process crud_type-like events for objects. + + Non-deletes (creates, updates and "anything else" events) are treated as + update_or_create events - The object will be retrieved locally, then it is + synchronised with the Stripe API for parity. + + Deletes only occur for delete events and cause the object to be deleted + from the local database, if it existed. If it doesn't exist then it is + ignored (but the event processing still succeeds). + + :param target_cls: The djstripe model being handled. + :type: ``djstripe.stripe_objects.StripeObject`` + :param event_data: The event object data received from the Stripe API. + :param event_subtype: The event subtype string. + :param stripe_id: The object Stripe ID - If not provided then this is + retrieved from the event object data by "object.id" key. + :param customer: The customer object which is passed on object creation. + :param crud_type: The CrudType object - If not provided it is determined + based on the event subtype string. + :returns: The object (if any) and the event CrudType. + :rtype: ``tuple(obj, CrudType)`` + """ + + crud_type = crud_type or CrudType.determine(event_subtype) + stripe_id = stripe_id or event_data["object"]["id"] + obj = None + + if crud_type.deleted: + try: + obj = target_cls.objects.get(stripe_id=stripe_id) + obj.delete() + except target_cls.DoesNotExist: + pass + else: + # Any other event type (creates, updates, etc.) + kwargs = {"stripe_id": stripe_id} + if customer: + kwargs["customer"] = customer + data = target_cls(**kwargs).api_retrieve() + obj = target_cls.sync_from_stripe_data(data) + + return obj, crud_type diff --git a/djstripe/fields.py b/djstripe/fields.py index 0ea352e809..861c7ef93e 100644 --- a/djstripe/fields.py +++ b/djstripe/fields.py @@ -129,8 +129,12 @@ class StripeIdField(StripeCharField): """A field with enough space to hold any stripe ID.""" def __init__(self, *args, **kwargs): + # As per: https://stripe.com/docs/upgrades + # You can safely assume object IDs we generate will never exceed 255 + # characters, but you should be able to handle IDs of up to that + # length. defaults = { - 'max_length': 50, + 'max_length': 255, 'blank': False, 'null': False, } diff --git a/djstripe/migrations/0014_auto_20160625_1851.py b/djstripe/migrations/0014_auto_20160625_1851.py new file mode 100644 index 0000000000..db950f7a18 --- /dev/null +++ b/djstripe/migrations/0014_auto_20160625_1851.py @@ -0,0 +1,33 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.9.1 on 2016-06-25 17:51 +from __future__ import unicode_literals + +import django.core.validators +from django.db import migrations, models +import django.db.models.deletion +import djstripe.fields + + +class Migration(migrations.Migration): + + dependencies = [ + ('djstripe', '0013_sync_cleanup'), + ] + + operations = [ + migrations.AlterField( + model_name='charge', + name='receipt_sent', + field=models.BooleanField(default=False, help_text='Whether or not a receipt was sent for this charge.'), + ), + migrations.AlterField( + model_name='charge', + name='source', + field=models.ForeignKey(help_text='The source used for this charge.', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='charges', to='djstripe.StripeSource'), + ), + migrations.AlterField( + model_name='subscription', + name='application_fee_percent', + field=djstripe.fields.StripePercentField(decimal_places=2, help_text=b'A positive decimal that represents the fee percentage of the subscription invoice amount that will be transferred to the application owner\xe2\x80\x99s Stripe account each billing period.', max_digits=5, null=True, validators=[django.core.validators.MinValueValidator(1.0), django.core.validators.MaxValueValidator(100.0)]), + ), + ] diff --git a/djstripe/migrations/0015_upcoming_invoices.py b/djstripe/migrations/0015_upcoming_invoices.py new file mode 100644 index 0000000000..78727fbaca --- /dev/null +++ b/djstripe/migrations/0015_upcoming_invoices.py @@ -0,0 +1,32 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.9.1 on 2016-06-25 17:51 +from __future__ import unicode_literals + +import django.core.validators +from django.db import migrations, models +import django.db.models.deletion + + +class Migration(migrations.Migration): + + dependencies = [ + ('djstripe', '0014_auto_20160625_1851'), + ] + + operations = [ + migrations.CreateModel( + name='UpcomingInvoice', + fields=[ + ('invoice_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='djstripe.Invoice')), + ], + options={ + 'abstract': False, + }, + bases=('djstripe.invoice',), + ), + migrations.AlterField( + model_name='invoiceitem', + name='invoice', + field=models.ForeignKey(help_text='The invoice to which this invoiceitem is attached.', null=True, on_delete=django.db.models.deletion.CASCADE, related_name='invoiceitems', to='djstripe.Invoice'), + ), + ] diff --git a/djstripe/migrations/0016_stripe_id_255_length.py b/djstripe/migrations/0016_stripe_id_255_length.py new file mode 100644 index 0000000000..18e32b6da8 --- /dev/null +++ b/djstripe/migrations/0016_stripe_id_255_length.py @@ -0,0 +1,86 @@ +# -*- coding: utf-8 -*- +# Generated by Django 1.9.1 on 2016-06-25 22:40 +from __future__ import unicode_literals + +from django.db import migrations +import djstripe.fields + + +class Migration(migrations.Migration): + + dependencies = [ + ('djstripe', '0015_upcoming_invoices'), + ] + + operations = [ + migrations.AlterField( + model_name='account', + name='stripe_id', + field=djstripe.fields.StripeIdField(max_length=255, unique=True), + ), + migrations.AlterField( + model_name='charge', + name='source_stripe_id', + field=djstripe.fields.StripeIdField(help_text=b'The payment source id.', max_length=255, null=True), + ), + migrations.AlterField( + model_name='charge', + name='stripe_id', + field=djstripe.fields.StripeIdField(max_length=255, unique=True), + ), + migrations.AlterField( + model_name='customer', + name='stripe_id', + field=djstripe.fields.StripeIdField(max_length=255, unique=True), + ), + migrations.AlterField( + model_name='event', + name='stripe_id', + field=djstripe.fields.StripeIdField(max_length=255, unique=True), + ), + migrations.AlterField( + model_name='invoice', + name='stripe_id', + field=djstripe.fields.StripeIdField(max_length=255, unique=True), + ), + migrations.AlterField( + model_name='invoiceitem', + name='stripe_id', + field=djstripe.fields.StripeIdField(max_length=255, unique=True), + ), + migrations.AlterField( + model_name='plan', + name='stripe_id', + field=djstripe.fields.StripeIdField(max_length=255, unique=True), + ), + migrations.AlterField( + model_name='stripesource', + name='stripe_id', + field=djstripe.fields.StripeIdField(max_length=255, unique=True), + ), + migrations.AlterField( + model_name='subscription', + name='stripe_id', + field=djstripe.fields.StripeIdField(max_length=255, unique=True), + ), + migrations.AlterField( + model_name='transfer', + name='destination', + field=djstripe.fields.StripeIdField(help_text=b'ID of the bank account, card, or Stripe account the transfer was sent to.', max_length=255), + ), + migrations.AlterField( + model_name='transfer', + name='destination_payment', + field=djstripe.fields.StripeIdField(help_text=b'If the destination is a Stripe account, this will be the ID of the payment that the destination account received for the transfer.', max_length=255, null=True), + ), + migrations.AlterField( + model_name='transfer', + name='source_transaction', + field=djstripe.fields.StripeIdField(help_text=b'ID of the charge (or other transaction) that was used to fund the transfer. If null, the transfer was funded from the available balance.', max_length=255, null=True), + ), + migrations.AlterField( + model_name='transfer', + name='stripe_id', + field=djstripe.fields.StripeIdField(max_length=255, unique=True), + ), + ] diff --git a/djstripe/models.py b/djstripe/models.py index c5af145e95..11f6f5ef65 100644 --- a/djstripe/models.py +++ b/djstripe/models.py @@ -19,9 +19,11 @@ from django.core.mail import EmailMessage from django.db import models from django.template.loader import render_to_string -from django.utils import timezone +from django.utils import six, timezone from django.utils.encoding import python_2_unicode_compatible, smart_text +from django.utils.functional import cached_property from doc_inherit import class_doc_inherit +from mock_django.query import QuerySetMock from model_utils.models import TimeStampedModel from stripe.error import StripeError, InvalidRequestError @@ -125,7 +127,9 @@ class Customer(StripeCustomer): # account = models.ForeignKey(Account, related_name="customers") - default_source = models.ForeignKey(StripeSource, null=True, related_name="customers") + default_source = models.ForeignKey( + StripeSource, null=True, related_name="customers", + on_delete=models.SET_NULL) subscriber = models.OneToOneField(getattr(settings, 'DJSTRIPE_SUBSCRIBER_MODEL', settings.AUTH_USER_MODEL), null=True) date_purged = models.DateTimeField(null=True, editable=False) @@ -331,7 +335,6 @@ def retry_unpaid_invoices(self): def has_valid_source(self): """ Check whether the customer has a valid payment source.""" - return self.default_source is not None def add_card(self, source, set_default=True): @@ -345,6 +348,16 @@ def add_card(self, source, set_default=True): return new_card + def upcoming_invoice(self, **kwargs): + """ Gets the upcoming preview invoice (singular) for this customer. + + See `StripeInvoice.upcoming() <#djstripe.stripe_objects.StripeInvoice.upcoming>`__ + The ``customer`` argument to the ``upcoming()`` call is automatically set by this method. + """ + + kwargs['customer'] = self + return Invoice.upcoming(**kwargs) + def _attach_objects_hook(self, cls, data): # TODO: other sources if data["default_source"] and data["default_source"]["object"] == "card": @@ -402,7 +415,7 @@ def validate(self): self.valid = self.webhook_message == self.api_retrieve()["data"] self.save() - def process(self, force=False): + def process(self, force=False, raise_exception=False): """ Invokes any webhook handlers that have been registered for this event based on event type or event sub-type. @@ -413,27 +426,32 @@ def process(self, force=False): :param force: If True, force the event to be processed by webhook handlers, even if the event has already been processed previously. :type force: bool + :param raise_exception: If True, any Stripe errors raised during + processing will be raised to the caller after logging the exception. + :type raise_exception: bool :returns: True if the webhook was processed successfully or was previously processed successfully. :rtype: bool """ + if not self.valid: return False if not self.processed or force: - event_type, event_subtype = self.type.split(".", 1) + exc_value = None try: # TODO: would it make sense to wrap the next 4 lines in a transaction.atomic context? Yes it would, # except that some webhook handlers can have side effects outside of our local database, meaning that # even if we rollback on our database, some updates may have been sent to Stripe, etc in resposne to # webhooks... - webhooks.call_handlers(self, self.message, event_type, event_subtype) + webhooks.call_handlers(self, self.message, self.event_type, self.event_subtype) self._send_signal() self.processed = True except StripeError as exc: # TODO: What if we caught all exceptions or a broader range of exceptions here? How about DoesNotExist # exceptions, for instance? or how about TypeErrors, KeyErrors, ValueErrors, etc? + exc_value = exc self.processed = False EventProcessingException.log( data=exc.http_body, @@ -451,6 +469,9 @@ def process(self, force=False): # an event handle was broken. self.save() + if exc_value and raise_exception: + six.reraise(StripeError, exc_value) + return self.processed def _send_signal(self): @@ -458,6 +479,21 @@ def _send_signal(self): if signal: return signal.send(sender=Event, event=self) + @cached_property + def parts(self): + """ Gets the event type/subtype as a list of parts. """ + return str(self.type).split(".") + + @cached_property + def event_type(self): + """ Gets the event type string. """ + return self.parts[0] + + @cached_property + def event_subtype(self): + """ Gets the event subtype string. """ + return ".".join(self.parts[1:]) + @class_doc_inherit class Transfer(StripeTransfer): @@ -541,6 +577,84 @@ def _attach_objects_hook(self, cls, data): if subscription: self.subscription = subscription + def _attach_objects_post_save_hook(self, cls, data): + # InvoiceItems need a saved invoice because they're associated via a + # RelatedManager, so this must be done as part of the post save hook. + cls._stripe_object_to_invoice_items(target_cls=InvoiceItem, data=data, invoice=self) + + @classmethod + def upcoming(cls, **kwargs): + upcoming_stripe_invoice = StripeInvoice.upcoming(**kwargs) + + if upcoming_stripe_invoice: + return UpcomingInvoice._create_from_stripe_object(upcoming_stripe_invoice, save=False) + + @property + def plan(self): + """ Gets the associated plan for this invoice. + + In order to provide a consistent view of invoices, the plan object + should be taken from the first invoice item that has one, rather than + using the plan associated with the subscription. + + Subscriptions (and their associated plan) are updated by the customer + and represent what is current, but invoice items are immutable within + the invoice and stay static/unchanged. + + In other words, a plan retrieved from an invoice item will represent + the plan as it was at the time an invoice was issued. The plan + retrieved from the subscription will be the currently active plan. + + :returns: The associated plan for the invoice. + :rtype: ``djstripe.models.Plan`` + """ + + for invoiceitem in self.invoiceitems.all(): + if invoiceitem.plan: + return invoiceitem.plan + + if self.subscription: + return self.subscription.plan + + +@class_doc_inherit +class UpcomingInvoice(Invoice): + __doc__ = getattr(Invoice, "__doc__") + + def __init__(self, *args, **kwargs): + super(UpcomingInvoice, self).__init__(*args, **kwargs) + self._invoiceitems = [] + + def _attach_objects_hook(self, cls, data): + super(UpcomingInvoice, self)._attach_objects_hook(cls, data) + self._invoiceitems = cls._stripe_object_to_invoice_items(target_cls=InvoiceItem, data=data, invoice=self) + + @property + def invoiceitems(self): + """ Gets the invoice items associated with this upcoming invoice. + + This differs from normal (non-upcoming) invoices, in that upcoming + invoices are in-memory and do not persist to the database. Therefore, + all of the data comes from the Stripe API itself. + + Instead of returning a normal queryset for the invoiceitems, this will + return a mock of a queryset, but with the data fetched from Stripe - It + will act like a normal queryset, but mutation will silently fail. + """ + + return QuerySetMock(InvoiceItem, *self._invoiceitems) + + @property + def stripe_id(self): + return None + + @stripe_id.setter + def stripe_id(self, value): + return # noop + + def save(self, *args, **kwargs): + return # noop + @class_doc_inherit class InvoiceItem(StripeInvoiceItem): @@ -548,13 +662,17 @@ class InvoiceItem(StripeInvoiceItem): # account = models.ForeignKey(Account, related_name="invoiceitems") customer = models.ForeignKey(Customer, related_name="invoiceitems", help_text="The customer associated with this invoiceitem.") - invoice = models.ForeignKey(Invoice, related_name="invoiceitems", help_text="The invoice to which this invoiceitem is attached.") + invoice = models.ForeignKey(Invoice, null=True, related_name="invoiceitems", help_text="The invoice to which this invoiceitem is attached.") plan = models.ForeignKey("Plan", null=True, related_name="invoiceitems", help_text="If the invoice item is a proration, the plan of the subscription for which the proration was computed.") subscription = models.ForeignKey("Subscription", null=True, related_name="invoiceitems", help_text="The subscription that this invoice item has been created for, if any.") def _attach_objects_hook(self, cls, data): - self.customer = cls._stripe_object_to_customer(target_cls=Customer, data=data) - self.invoice = cls._stripe_object_to_invoice(target_cls=Invoice, data=data) + customer = cls._stripe_object_to_customer(target_cls=Customer, data=data) + + invoice = cls._stripe_object_to_invoice(target_cls=Invoice, data=data) + if invoice: + self.invoice = invoice + customer = customer or invoice.customer plan = cls._stripe_object_to_plan(target_cls=Plan, data=data) if plan: @@ -563,6 +681,9 @@ def _attach_objects_hook(self, cls, data): subscription = cls._stripe_object_to_subscription(target_cls=Subscription, data=data) if subscription: self.subscription = subscription + customer = customer or subscription.customer + + self.customer = customer @class_doc_inherit diff --git a/djstripe/settings.py b/djstripe/settings.py index e2a2d79153..fae38c7053 100644 --- a/djstripe/settings.py +++ b/djstripe/settings.py @@ -6,10 +6,47 @@ from django.apps import apps as django_apps from django.conf import settings from django.core.exceptions import ImproperlyConfigured +from django.utils import six +from django.utils.module_loading import import_string PY3 = sys.version > "3" -subscriber_request_callback = getattr(settings, "DJSTRIPE_SUBSCRIBER_MODEL_REQUEST_CALLBACK", (lambda request: request.user)) + +def get_callback_function(setting_name, default=None): + """ + Resolves a callback function based on a setting name. + + If the setting value isn't set, default is returned. If the setting value + is already a callable function, that value is used - If the setting value + is a string, an attempt is made to import it. Anything else will result in + a failed import causing ImportError to be raised. + + :param setting_name: The name of the setting to resolve a callback from. + :type setting_name: string (``str``/``unicode``) + :param default: The default to return if setting isn't populated. + :type default: ``bool`` + :returns: The resolved callback function (if any). + :type: ``callable`` + """ + + func = getattr(settings, setting_name, None) + if not func: + return default + + if callable(func): + return func + + if isinstance(func, six.string_types): + func = import_string(func) + + if not callable(func): + raise ImproperlyConfigured("{name} must be callable.".format(name=setting_name)) + + return func + + +subscriber_request_callback = get_callback_function("DJSTRIPE_SUBSCRIBER_MODEL_REQUEST_CALLBACK", + default=(lambda request: request.user)) INVOICE_FROM_EMAIL = getattr(settings, "DJSTRIPE_INVOICE_FROM_EMAIL", "billing@example.com") PAYMENTS_PLANS = getattr(settings, "DJSTRIPE_PLANS", {}) @@ -33,13 +70,17 @@ # Try to find the new settings variable first. If that fails, revert to the # old variable. -trial_period_for_subscriber_callback = getattr(settings, - "DJSTRIPE_TRIAL_PERIOD_FOR_SUBSCRIBER_CALLBACK", - getattr(settings, "DJSTRIPE_TRIAL_PERIOD_FOR_USER_CALLBACK", None) -) +trial_period_for_subscriber_callback = ( + get_callback_function("DJSTRIPE_TRIAL_PERIOD_FOR_SUBSCRIBER_CALLBACK") or + get_callback_function("DJSTRIPE_TRIAL_PERIOD_FOR_USER_CALLBACK")) DJSTRIPE_WEBHOOK_URL = getattr(settings, "DJSTRIPE_WEBHOOK_URL", r"^webhook/$") +# Webhook event callbacks allow an application to take control of what happens +# when an event from Stripe is received. One suggestion is to put the event +# onto a task queue (such as celery) for asynchronous processing. +WEBHOOK_EVENT_CALLBACK = get_callback_function("DJSTRIPE_WEBHOOK_EVENT_CALLBACK") + def _check_subscriber_for_email_address(subscriber_model, message): """Ensure the custom model has an ``email`` field or property.""" @@ -82,10 +123,10 @@ def get_subscriber_model(): _check_subscriber_for_email_address(subscriber_model, "DJSTRIPE_SUBSCRIBER_MODEL must have an email attribute.") # Custom user model detected. Make sure the callback is configured. - if hasattr(settings, "DJSTRIPE_SUBSCRIBER_MODEL_REQUEST_CALLBACK"): - if not callable(getattr(settings, "DJSTRIPE_SUBSCRIBER_MODEL_REQUEST_CALLBACK")): - raise ImproperlyConfigured("DJSTRIPE_SUBSCRIBER_MODEL_REQUEST_CALLBACK must be callable.") - else: - raise ImproperlyConfigured("DJSTRIPE_SUBSCRIBER_MODEL_REQUEST_CALLBACK must be implemented if a DJSTRIPE_SUBSCRIBER_MODEL is defined.") + func = get_callback_function("DJSTRIPE_SUBSCRIBER_MODEL_REQUEST_CALLBACK") + if not func: + raise ImproperlyConfigured( + "DJSTRIPE_SUBSCRIBER_MODEL_REQUEST_CALLBACK must be implemented " + "if a DJSTRIPE_SUBSCRIBER_MODEL is defined.") return subscriber_model diff --git a/djstripe/stripe_objects.py b/djstripe/stripe_objects.py index bd13254dc0..31ee1ebf02 100644 --- a/djstripe/stripe_objects.py +++ b/djstripe/stripe_objects.py @@ -5,6 +5,7 @@ .. moduleauthor:: Bill Huneke (@wahuneke) .. moduleauthor:: Alex Kavanaugh (@kavdev) +.. moduleauthor:: Lee Skillen (@lskillen) This module is an effort to isolate (as much as possible) the API dependent code in one place. Primarily this is: @@ -19,14 +20,16 @@ from copy import deepcopy import decimal +import sys from django.conf import settings from django.db import models -from django.utils import timezone +from django.utils import dateformat, six, timezone from django.utils.encoding import python_2_unicode_compatible, smart_text from model_utils.models import TimeStampedModel from polymorphic.models import PolymorphicModel import stripe +from stripe.error import InvalidRequestError from djstripe.exceptions import CustomerDoesNotExistLocallyException @@ -125,7 +128,7 @@ def _api_delete(self, api_key=settings.STRIPE_SECRET_KEY, **kwargs): :type api_key: string """ - return self.api_retrieve(api_key).delete(**kwargs) + return self.api_retrieve(api_key=api_key).delete(**kwargs) def str_parts(self): """ @@ -173,33 +176,73 @@ def _attach_objects_hook(self, cls, data): """ Gets called by this object's create and sync methods just before save. Use this to populate fields before the model is saved. + + :param cls: The target class for the instantiated object. + :param data: The data dictionary received from the Stripe API. + :type data: dict + """ + + pass + + def _attach_objects_post_save_hook(self, cls, data): + """ + Gets called by this object's create and sync methods just after save. + Use this to populate fields after the model is saved. + + :param cls: The target class for the instantiated object. + :param data: The data dictionary received from the Stripe API. + :type data: dict """ pass @classmethod - def _create_from_stripe_object(cls, data): + def _create_from_stripe_object(cls, data, save=True): """ - Create a model instance using the given data object from Stripe + Instantiates a model instance using the provided data object received + from Stripe, and saves it to the database if specified. + + :param data: The data dictionary received from the Stripe API. :type data: dict + :param save: If True, the object is saved after instantiation. + :type save: bool + :returns: The instantiated object. """ + instance = cls(**cls._stripe_object_to_record(data)) instance._attach_objects_hook(cls, data) - instance.save() + + if save: + instance.save() + + instance._attach_objects_post_save_hook(cls, data) return instance @classmethod - def _get_or_create_from_stripe_object(cls, data, field_name="id"): + def _get_or_create_from_stripe_object(cls, data, field_name="id", refetch=True, save=True): + field = data.get(field_name) + + if isinstance(field, six.string_types): + # A field like {"subscription": "sub_6lsC8pt7IcFpjA", ...} + stripe_id = field + elif field: + # A field like {"subscription": {"id": sub_6lsC8pt7IcFpjA", ...}} + data = field + stripe_id = field.get("id") + else: + # An empty field - We need to return nothing here because there is + # no way of knowing what needs to be fetched! + return None, False + try: - return cls.stripe_objects.get_by_json(data, field_name), False + return cls.stripe_objects.get(stripe_id=stripe_id), False except cls.DoesNotExist: - # Grab the stripe data for a nested object - if field_name != "id": - cls_instance = cls(stripe_id=data[field_name]) + if refetch and field_name != "id": + cls_instance = cls(stripe_id=stripe_id) data = cls_instance.api_retrieve() - return cls._create_from_stripe_object(data), True + return cls._create_from_stripe_object(data, save=save), True @classmethod def _stripe_object_to_customer(cls, target_cls, data): @@ -261,6 +304,57 @@ def _stripe_object_to_invoice(cls, target_cls, data): return target_cls._get_or_create_from_stripe_object(data, "invoice")[0] + @classmethod + def _stripe_object_to_invoice_items(cls, target_cls, data, invoice): + """ + Retrieves InvoiceItems for an invoice. + + If the invoice item doesn't exist already then it is created. + + If the invoice is an upcoming invoice that doesn't persist to the + database (i.e. ephemeral) then the invoice items are also not saved. + + :param target_cls: The target class to instantiate per invoice item. + :type target_cls: ``StripeInvoiceItem`` + :param data: The data dictionary received from the Stripe API. + :type data: dict + :param invoice: The invoice object that should hold the invoice items. + :type invoice: ``djstripe.models.Invoice`` + """ + + lines = data.get("lines") + if not lines: + return [] + + invoiceitems = [] + for line in lines.get("data", []): + if invoice.stripe_id: + save = True + line.setdefault("invoice", invoice.stripe_id) + + if line.get("type") == "subscription": + # Lines for subscriptions need to be keyed based on invoice and + # subscription, because their id is *just* the subscription + # when received from Stripe. This means that future updates to + # a subscription will change previously saved invoices - Doing + # the composite key avoids this. + if not line["id"].startswith(invoice.stripe_id): + line["id"] = "{invoice_id}-{subscription_id}".format( + invoice_id=invoice.stripe_id, + subscription_id=line["id"]) + else: + # Don't save invoice items for ephemeral invoices + save = False + + line.setdefault("customer", invoice.customer.stripe_id) + line.setdefault("date", int(dateformat.format(invoice.date, 'U'))) + + item, _ = target_cls._get_or_create_from_stripe_object( + line, refetch=False, save=save) + invoiceitems.append(item) + + return invoiceitems + @classmethod def _stripe_object_to_subscription(cls, target_cls, data): """ @@ -294,6 +388,7 @@ def sync_from_stripe_data(cls, data): instance._sync(cls._stripe_object_to_record(data)) instance._attach_objects_hook(cls, data) instance.save() + instance._attach_objects_post_save_hook(cls, data) return instance @@ -932,7 +1027,7 @@ def api_retrieve(self, api_key=settings.STRIPE_SECRET_KEY): # Cards must be manipulated through a customer or account. # TODO: When managed accounts are supported, this method needs to check if either a customer or account is supplied to determine the correct object to use. - return self.customer.api_retrieve().sources.retrieve(id=self.stripe_id, api_key=api_key, expand=self.expand_fields) + return self.customer.api_retrieve(api_key=api_key).sources.retrieve(self.stripe_id, expand=self.expand_fields) @staticmethod def _get_customer_from_kwargs(**kwargs): @@ -1097,6 +1192,70 @@ def _stripe_object_to_charge(cls, target_cls, data): if "charge" in data and data["charge"]: return target_cls._get_or_create_from_stripe_object(data, "charge")[0] + @classmethod + def upcoming(cls, api_key=settings.STRIPE_SECRET_KEY, customer=None, coupon=None, subscription=None, + subscription_plan=None, subscription_prorate=None, subscription_proration_date=None, + subscription_quantity=None, subscription_trial_end=None, **kwargs): + """ + Gets the upcoming preview invoice (singular) for a customer. + + As per the Stripe docs: "At any time, you can preview the upcoming + invoice for a customer. This will show you all the charges that are + pending, including subscription renewal charges, invoice item charges, + etc. It will also show you any discount that is applicable to the + customer." + + See for details: https://stripe.com/docs/api#upcoming_invoice + + :param customer: The identifier of the customer whose upcoming invoice + you'd like to retrieve. + :type customer: ``djstripe.models.Customer`` + :param coupon: The code of the coupon to apply. + :type customer: ``str`` + :param subscription: The identifier of the subscription to retrieve an + invoice for. + :type customer: ``djstripe.models.Subscription`` or ``str`` (id) + :param subscription_plan: If set, the invoice returned will preview + updating the subscription given to this plan, or creating a new + subscription to this plan if no subscription is given. + :type subscription_plan: ``djstripe.models.Subscription`` or ``str`` (id) + :param subscription_prorate: If previewing an update to a subscription, + this decides whether the preview will show the result of applying + prorations or not. + :type subscription_prorate: ``bool`` + :param subscription_proration_date: If previewing an update to a + subscription, and doing proration, subscription_proration_date forces + the proration to be calculated as though the update was done at the + specified time. + :type subscription_proration_date: ``datetime`` + :param subscription_quantity: If provided, the invoice returned will + preview updating or creating a subscription with that quantity. + :type subscription_proration_quantity: ``int`` + :param subscription_trial_end: If provided, the invoice returned will + preview updating or creating a subscription with that trial end. + :type subscription_trial_end: ``datetime`` + :returns: The upcoming preview invoice. + :rtype: ``djstripe.models.UpcomingInvoice`` + """ + try: + upcoming_stripe_invoice = cls._api().upcoming( + api_key=api_key, customer=customer, + coupon=coupon, subscription=subscription, + subscription_plan=subscription_plan, + subscription_prorate=subscription_prorate, + subscription_proration_date=subscription_proration_date, + subscription_quantity=subscription_quantity, + subscription_trial_end=subscription_trial_end, **kwargs) + except InvalidRequestError as exc: + if str(exc) != "Nothing to invoice for customer": + six.reraise(*sys.exc_info()) + return + + # Workaround for "id" being missing (upcoming invoices don't persist). + upcoming_stripe_invoice["id"] = "upcoming" + + return upcoming_stripe_invoice + def retry(self): """ Retry payment on this invoice if it isn't paid, closed, or forgiven.""" diff --git a/djstripe/views.py b/djstripe/views.py index 12cd2f60b5..0978ebf11b 100644 --- a/djstripe/views.py +++ b/djstripe/views.py @@ -14,10 +14,10 @@ from django.views.generic import DetailView, FormView, TemplateView, View from stripe.error import StripeError +from . import settings as djstripe_settings from .forms import PlanForm, CancelSubscriptionForm from .mixins import PaymentsContextMixin, SubscriptionMixin from .models import Customer, Event, EventProcessingException, Plan -from .settings import PRORATION_POLICY_FOR_UPGRADES, subscriber_request_callback from .sync import sync_subscriber @@ -41,7 +41,7 @@ def get_object(self): if hasattr(self, "customer"): return self.customer self.customer, _created = Customer.get_or_create( - subscriber=subscriber_request_callback(self.request)) + subscriber=djstripe_settings.subscriber_request_callback(self.request)) return self.customer def post(self, request, *args, **kwargs): @@ -84,7 +84,7 @@ class HistoryView(LoginRequiredMixin, SelectRelatedMixin, DetailView): def get_object(self): customer, _created = Customer.get_or_create( - subscriber=subscriber_request_callback(self.request)) + subscriber=djstripe_settings.subscriber_request_callback(self.request)) return customer @@ -97,7 +97,7 @@ def post(self, request, *args, **kwargs): return render( request, self.template_name, - {"customer": sync_subscriber(subscriber_request_callback(request))} + {"customer": sync_subscriber(djstripe_settings.subscriber_request_callback(request))} ) @@ -118,7 +118,7 @@ def get(self, request, *args, **kwargs): if not Plan.objects.filter(id=plan_id).exists(): return HttpResponseNotFound() - customer, _created = Customer.get_or_create(subscriber=subscriber_request_callback(self.request)) + customer, _created = Customer.get_or_create(subscriber=djstripe_settings.subscriber_request_callback(self.request)) if customer.subscription and str(customer.subscription.plan.id) == plan_id and customer.subscription.is_valid(): message = "You already subscribed to this plan" @@ -141,7 +141,7 @@ def post(self, request, *args, **kwargs): form = self.get_form(form_class) if form.is_valid(): try: - customer, _created = Customer.get_or_create(subscriber=subscriber_request_callback(self.request)) + customer, _created = Customer.get_or_create(subscriber=djstripe_settings.subscriber_request_callback(self.request)) customer.add_card(self.request.POST.get("stripe_token")) customer.subscribe(form.cleaned_data["plan"]) except StripeError as exc: @@ -171,7 +171,7 @@ class ChangePlanView(LoginRequiredMixin, FormValidMessageMixin, SubscriptionMixi def post(self, request, *args, **kwargs): form = PlanForm(request.POST) - customer, _created = Customer.get_or_create(subscriber=subscriber_request_callback(self.request)) + customer, _created = Customer.get_or_create(subscriber=djstripe_settings.subscriber_request_callback(self.request)) if not customer.subscription: form.add_error(None, "You must already be subscribed to a plan before you can change it.") @@ -184,7 +184,7 @@ def post(self, request, *args, **kwargs): # When a customer upgrades their plan, and DJSTRIPE_PRORATION_POLICY_FOR_UPGRADES is set to True, # we force the proration of the current plan and use it towards the upgraded plan, # no matter what DJSTRIPE_PRORATION_POLICY is set to. - if PRORATION_POLICY_FOR_UPGRADES: + if djstripe_settings.PRORATION_POLICY_FOR_UPGRADES: # Is it an upgrade? if selected_plan.amount > customer.subscription.plan.amount: customer.subscription.update(plan=selected_plan, prorate=True) @@ -206,7 +206,7 @@ class CancelSubscriptionView(LoginRequiredMixin, SubscriptionMixin, FormView): success_url = reverse_lazy("djstripe:account") def form_valid(self, form): - customer, _created = Customer.get_or_create(subscriber=subscriber_request_callback(self.request)) + customer, _created = Customer.get_or_create(subscriber=djstripe_settings.subscriber_request_callback(self.request)) subscription = customer.subscription.cancel() if subscription.status == subscription.STATUS_CANCELED: @@ -244,5 +244,10 @@ def post(self, request, *args, **kwargs): else: event = Event._create_from_stripe_object(data) event.validate() - event.process() + + if djstripe_settings.WEBHOOK_EVENT_CALLBACK: + djstripe_settings.WEBHOOK_EVENT_CALLBACK(event) + else: + event.process() + return HttpResponse() diff --git a/djstripe/webhooks.py b/djstripe/webhooks.py index 70d1cdf05b..47e75c0b4f 100644 --- a/djstripe/webhooks.py +++ b/djstripe/webhooks.py @@ -51,6 +51,7 @@ def handler(event_types): :param event_types: The event type(s) or sub-type(s) that should be handled. :type event_types: A sequence (`list`) or string (`str`/`unicode`). """ + if isinstance(event_types, six.string_types): event_types = [event_types] @@ -67,10 +68,12 @@ def handler_all(func=None): Decorator which registers a function as a webhook handler for ALL webhook events, regardless of event type or sub-type. """ + if not func: return functools.partial(handler_all) registrations_global.append(func) + return func @@ -87,20 +90,25 @@ def call_handlers(event, event_data, event_type, event_subtype): Handlers within each group are invoked in order of registration. :param event: The event model object. - :type event: `djstripe.models.Event` + :type event: ``djstripe.models.Event`` :param event_data: The raw data for the event. - :type event_data: `dict` + :type event_data: ``dict`` :param event_type: The event type, e.g. 'customer'. - :type event_type: string (`str`/`unicode`) + :type event_type: string (``str``/``unicode``) :param event_subtype: The event sub-type, e.g. 'updated'. - :type event_subtype: string (`str`/`unicode`) + :type event_subtype: string (``str``/`unicode``) """ - qualified_event_type = ( - "{event_type}.{event_subtype}".format( - event_type=event_type, event_subtype=event_subtype)) - - for handler_func in itertools.chain( - registrations_global, - registrations[event_type], - registrations[qualified_event_type]): + + chain = [registrations_global] + + # Build up a list of handlers with each qualified part of the event + # type and subtype. For example, "customer.subscription.created" creates: + # 1. "customer" + # 2. "customer.subscription" + # 3. "customer.subscription.created" + for index, _ in enumerate(event.parts): + qualified_event_type = ".".join(event.parts[:(index + 1)]) + chain.append(registrations[qualified_event_type]) + + for handler_func in itertools.chain(*chain): handler_func(event, event_data, event_type, event_subtype) diff --git a/docs/models.rst b/docs/models.rst index 6adee2a73e..399ba0b343 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -51,6 +51,7 @@ Customer .. automethod:: djstripe.models.Customer.retry_unpaid_invoices .. automethod:: djstripe.models.Customer.has_valid_source .. automethod:: djstripe.models.Customer.add_card + .. automethod:: djstripe.models.Customer.upcoming_invoice .. automethod:: djstripe.models.Customer.str_parts .. automethod:: djstripe.stripe_objects.StripeObject.sync_from_stripe_data @@ -123,8 +124,10 @@ Invoice .. autoattribute:: djstripe.models.Invoice.STATUS_CLOSED .. autoattribute:: djstripe.models.Invoice.STATUS_OPEN .. autoattribute:: djstripe.models.Invoice.status + .. autoattribute:: djstripe.models.Invoice.plan .. automethod:: djstripe.models.Invoice.retry + .. automethod:: djstripe.models.Invoice.upcoming .. automethod:: djstripe.models.Invoice.str_parts .. automethod:: djstripe.stripe_objects.StripeObject.sync_from_stripe_data diff --git a/docs/settings.rst b/docs/settings.rst index ce190d8ed2..22e5d2497c 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -222,7 +222,7 @@ DJSTRIPE_SUBSCRIBER_MODEL_REQUEST_CALLBACK (=None) ================================================== If you choose to use a custom subscriber model, you'll need a way to pull it from ``request``. That's where this callback comes in. -It must be a callable that takes a request object and returns an instance of DJSTRIPE_SUBSCRIBER_MODEL +It must be a callable or importable string to a callable that takes a request object and returns an instance of DJSTRIPE_SUBSCRIBER_MODEL Examples: @@ -263,7 +263,7 @@ DJSTRIPE_TRIAL_PERIOD_FOR_SUBSCRIBER_CALLBACK (=None) Used by ``djstripe.models.Customer`` only when creating stripe customers when you have a default plan set via ``DJSTRIPE_DEFAULT_PLAN``. -This is called to dynamically add a trial period to a subscriber's plan. It must be a callable that takes a subscriber object and returns the number of days the trial period should last. +This is called to dynamically add a trial period to a subscriber's plan. It must be a callable or importable string to a callable that takes a subscriber object and returns the number of days the trial period should last. Examples: @@ -295,6 +295,48 @@ This is where you can set *Stripe.com* to send webhook response. You can set thi As this is embedded in the URLConf, this must be a resolvable regular expression. +DJSTRIPE_WEBHOOK_EVENT_CALLBACK (=None) +======================================= + +Webhook event callbacks allow an application to take control of what happens when an event from Stripe is received. +It must be a callable or importable string to a callable that takes an event object. + +One suggestion is to put the event onto a task queue (such as celery) for asynchronous processing. + +Examples: + +`callbacks.py` + +.. code-block:: python + + def webhook_event_callback(event): + """ Dispatches the event to celery for processing. """ + from . import tasks + # Ansychronous hand-off to celery so that we can continue immediately + tasks.process_webhook_event.s(event).apply_async() + +`tasks.py` + +.. code-block:: python + + from stripe.error import StripeError + + @shared_task(bind=True) + def process_webhook_event(self, event): + """ Processes events from Stripe asynchronously. """ + log.debug("Processing Stripe event: %s", str(event)) + try: + event.process(raise_exception=True): + except StripeError as exc: + log.error("Failed to process Stripe event: %s", str(event)) + raise self.retry(exc=exc, countdown=60) # retry after 60 seconds + +`settings.py` + +.. code-block:: python + + DJSTRIPE_WEBHOOK_EVENT_CALLBACK = 'callbacks.webhook_event_callback' + DJSTRIPE_CURRENCIES (=(('usd', 'U.S. Dollars',), ('gbp', 'Pounds (GBP)',), ('eur', 'Euros',))) ============================================================================================== diff --git a/requirements.txt b/requirements.txt index a078e80023..7ebd289307 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,5 @@ jsonfield>=1.0.3 pytz>=2015.7 stripe~=1.35.0 tqdm>=4.7.4 -python-doc-inherit~=0.3.0 \ No newline at end of file +python-doc-inherit~=0.3.0 +mock-django~=0.6.10 diff --git a/tests/__init__.py b/tests/__init__.py index 730a3ebad5..d26bdd5354 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -3,6 +3,7 @@ :synopsis: dj-stripe test fakes .. moduleauthor:: Alex Kavanaugh (@kavdev) +.. moduleauthor:: Lee Skillen (@lskillen) A Fake or multiple fakes for each stripe object. @@ -542,7 +543,7 @@ def create(self, source, api_key=None): if fake_card["id"] == source: return fake_card - def retrieve(self, id, api_key, expand): + def retrieve(self, id, expand=None): for fake_card in self.card_fakes: if fake_card["id"] == id: return fake_card @@ -812,12 +813,71 @@ def pay(self): "webhooks_delivered_at": 1439426955, }) +FAKE_UPCOMING_INVOICE = InvoiceDict({ + "id": "in", + "object": "invoice", + "amount_due": 2000, + "application_fee": None, + "attempt_count": 1, + "attempted": False, + "charge": None, + "closed": False, + "currency": "usd", + "customer": FAKE_CUSTOMER["id"], + "date": 1439218864, + "description": None, + "discount": None, + "ending_balance": None, + "forgiven": False, + "lines": { + "data": [ + { + "id": FAKE_SUBSCRIPTION["id"], + "object": "line_item", + "amount": 2000, + "currency": "usd", + "description": None, + "discountable": True, + "livemode": True, + "metadata": {}, + "period": { + "start": 1441907581, + "end": 1444499581 + }, + "plan": deepcopy(FAKE_PLAN), + "proration": False, + "quantity": 1, + "subscription": None, + "type": "subscription", + } + ], + "total_count": 1, + "object": "list", + "url": "/v1/invoices/in_16YHls2eZvKYlo2CwwH968Mc/lines", + }, + "livemode": False, + "metadata": {}, + "next_payment_attempt": 1439218689, + "paid": False, + "period_end": 1439218689, + "period_start": 1439132289, + "receipt_number": None, + "starting_balance": 0, + "statement_descriptor": None, + "subscription": FAKE_SUBSCRIPTION["id"], + "subtotal": 2000, + "tax": None, + "tax_percent": None, + "total": 2000, + "webhooks_delivered_at": 1439218870, +}) + FAKE_INVOICEITEM = { "id": "ii_16XVTY2eZvKYlo2Cxz5n3RaS", "object": "invoiceitem", "amount": 2000, "currency": "usd", - "customer": "cus_4UbFSo9tl62jqj", + "customer": FAKE_CUSTOMER_II["id"], "date": 1439033216, "description": "One-time setup fee", "discountable": True, @@ -837,6 +897,31 @@ def pay(self): "subscription": None, } +FAKE_INVOICEITEM_II = { + "id": "ii_16XVTY2eZvKYlo2Cxz5n3RaS", + "object": "invoiceitem", + "amount": 2000, + "currency": "usd", + "customer": FAKE_CUSTOMER["id"], + "date": 1439033216, + "description": "One-time setup fee", + "discountable": True, + "invoice": FAKE_INVOICE["id"], + "livemode": False, + "metadata": { + "key1": "value1", + "key2": "value2" + }, + "period": { + "start": 1439033216, + "end": 1439033216, + }, + "plan": None, + "proration": False, + "quantity": None, + "subscription": None, +} + FAKE_TRANSFER = { "id": "tr_16Y9BK2eZvKYlo2CR0ySu1BA", "object": "transfer", @@ -1043,6 +1128,12 @@ def pay(self): "type": "customer.created", } +FAKE_EVENT_CUSTOMER_DELETED = deepcopy(FAKE_EVENT_CUSTOMER_CREATED) +FAKE_EVENT_CUSTOMER_DELETED.update({ + "id": "evt_38DHch3whaDvKYlo2jksfsFFxy", + "type": "customer.deleted" +}) + FAKE_EVENT_CUSTOMER_SOURCE_CREATED = { "id": "evt_DvKYlo38huDvKYlo2C7SXedrZk", "object": "event", @@ -1057,6 +1148,17 @@ def pay(self): "type": "customer.source.created", } +FAKE_EVENT_CUSTOMER_SOURCE_DELETED = deepcopy(FAKE_EVENT_CUSTOMER_SOURCE_CREATED) +FAKE_EVENT_CUSTOMER_SOURCE_DELETED.update({ + "id": "evt_DvKYlo38huDvKYlo2C7SXedrYk", + "type": "customer.source.deleted" +}) + +FAKE_EVENT_CUSTOMER_SOURCE_DELETED_DUPE = deepcopy(FAKE_EVENT_CUSTOMER_SOURCE_DELETED) +FAKE_EVENT_CUSTOMER_SOURCE_DELETED_DUPE.update({ + "id": "evt_DvKYlo38huDvKYlo2C7SXedzAk", +}) + FAKE_EVENT_CUSTOMER_SUBSCRIPTION_CREATED = { "id": "evt_38DHch3wHD2eZvKYlCT2oe5ff3", "object": "event", @@ -1071,6 +1173,11 @@ def pay(self): "type": "customer.subscription.created", } +FAKE_EVENT_CUSTOMER_SUBSCRIPTION_DELETED = deepcopy(FAKE_EVENT_CUSTOMER_SUBSCRIPTION_CREATED) +FAKE_EVENT_CUSTOMER_SUBSCRIPTION_DELETED.update({ + "id": "evt_38DHch3wHD2eZvKYlCT2oeryaf", + "type": "customer.subscription.deleted"}) + FAKE_EVENT_INVOICE_CREATED = { "id": "evt_187IHD2eZvKYlo2C6YKQi2eZ", "object": "event", @@ -1085,6 +1192,11 @@ def pay(self): "type": "invoice.created", } +FAKE_EVENT_INVOICE_DELETED = deepcopy(FAKE_EVENT_INVOICE_CREATED) +FAKE_EVENT_INVOICE_DELETED.update({ + "id": "evt_187IHD2eZvKYlo2Cjkjsr34H", + "type": "invoice.deleted"}) + FAKE_EVENT_INVOICEITEM_CREATED = { "id": "evt_187IHD2eZvKYlo2C7SXedrZk", "object": "event", @@ -1099,6 +1211,11 @@ def pay(self): "type": "invoiceitem.created", } +FAKE_EVENT_INVOICEITEM_DELETED = deepcopy(FAKE_EVENT_INVOICEITEM_CREATED) +FAKE_EVENT_INVOICEITEM_DELETED.update({ + "id": "evt_187IHD2eZvKYloJfdsnnfs34", + "type": "invoiceitem.deleted"}) + FAKE_EVENT_PLAN_CREATED = { "id": "evt_1877X72eZvKYlo2CLK6daFxu", "object": "event", @@ -1113,6 +1230,11 @@ def pay(self): "type": "plan.created", } +FAKE_EVENT_PLAN_DELETED = deepcopy(FAKE_EVENT_PLAN_CREATED) +FAKE_EVENT_PLAN_DELETED.update({ + "id": "evt_1877X72eZvKYl2jkds32jJFc", + "type": "plan.deleted"}) + FAKE_EVENT_TRANSFER_CREATED = { "id": "evt_16igNU2eZvKYlo2CYyMkYvet", "object": "event", @@ -1127,6 +1249,11 @@ def pay(self): "type": "transfer.created", } +FAKE_EVENT_TRANSFER_DELETED = deepcopy(FAKE_EVENT_TRANSFER_CREATED) +FAKE_EVENT_TRANSFER_DELETED.update({ + "id": "evt_16igNU2eZvKjklfsdjk232Mf", + "type": "transfer.deleted"}) + FAKE_TOKEN = { "id": "tok_16YDIe2eZvKYlo2CPvqprIJd", "object": "token", diff --git a/tests/test_customer.py b/tests/test_customer.py index e1a312c26d..7cdde1076d 100644 --- a/tests/test_customer.py +++ b/tests/test_customer.py @@ -5,6 +5,7 @@ .. moduleauthor:: Daniel Greenfeld (@pydanny) .. moduleauthor:: Alex Kavanaugh (@kavdev) .. moduleauthor:: Michael Thornhill (@mthornhill) +.. moduleauthor:: Lee Skillen (@lskillen) """ @@ -16,14 +17,14 @@ from django.contrib.auth import get_user_model from django.test import TestCase from django.utils import timezone -from mock import patch +from mock import patch, ANY from stripe.error import InvalidRequestError from djstripe.exceptions import MultipleSubscriptionException from djstripe.models import Account, Customer, Charge, Card, Subscription, Invoice, Plan from tests import (FAKE_CARD, FAKE_CHARGE, FAKE_CUSTOMER, FAKE_ACCOUNT, FAKE_INVOICE, FAKE_INVOICE_III, FAKE_INVOICEITEM, FAKE_PLAN, FAKE_SUBSCRIPTION, FAKE_SUBSCRIPTION_II, - StripeList, FAKE_CARD_V, FAKE_CUSTOMER_II) + StripeList, FAKE_CARD_V, FAKE_CUSTOMER_II, FAKE_UPCOMING_INVOICE) class TestCustomer(TestCase): @@ -99,7 +100,8 @@ def test_customer_purge_raises_customer_exception(self, customer_retrieve_mock): self.assertTrue(not customer.sources.all()) self.assertTrue(get_user_model().objects.filter(pk=self.user.pk).exists()) - customer_retrieve_mock.assert_called_with(id=self.customer.stripe_id, api_key=settings.STRIPE_SECRET_KEY, expand=['default_source']) + customer_retrieve_mock.assert_called_with(id=self.customer.stripe_id, api_key=settings.STRIPE_SECRET_KEY, + expand=['default_source']) self.assertEquals(2, customer_retrieve_mock.call_count) @patch("stripe.Customer.retrieve") @@ -109,7 +111,8 @@ def test_customer_delete_raises_unexpected_exception(self, customer_retrieve_moc with self.assertRaisesMessage(InvalidRequestError, "Unexpected Exception"): self.customer.purge() - customer_retrieve_mock.assert_called_once_with(id=self.customer.stripe_id, api_key=settings.STRIPE_SECRET_KEY, expand=['default_source']) + customer_retrieve_mock.assert_called_once_with(id=self.customer.stripe_id, api_key=settings.STRIPE_SECRET_KEY, + expand=['default_source']) def test_can_charge(self): self.assertTrue(self.customer.can_charge()) @@ -576,3 +579,27 @@ def test_add_invoice_item_djstripe_objects(self, invoiceitem_create_mock, invoic def test_add_invoice_item_bad_decimal(self): with self.assertRaisesMessage(ValueError, "You must supply a decimal value representing dollars."): self.customer.add_invoice_item(amount=5000, currency="usd") + + @patch("stripe.Plan.retrieve", return_value=deepcopy(FAKE_PLAN)) + @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) + @patch("stripe.Invoice.upcoming", return_value=deepcopy(FAKE_UPCOMING_INVOICE)) + def test_upcoming_invoice(self, invoice_upcoming_mock, subscription_retrieve_mock, plan_retrieve_mock): + invoice = self.customer.upcoming_invoice() + self.assertIsNotNone(invoice) + self.assertIsNone(invoice.stripe_id) + self.assertIsNone(invoice.save()) + + subscription_retrieve_mock.assert_called_once_with(api_key=ANY, expand=ANY, id=FAKE_SUBSCRIPTION["id"]) + plan_retrieve_mock.assert_not_called() + + items = invoice.invoiceitems.all() + self.assertEquals(1, len(items)) + self.assertEquals(FAKE_SUBSCRIPTION["id"], items[0].stripe_id) + + self.assertIsNotNone(invoice.plan) + self.assertEquals(FAKE_PLAN["id"], invoice.plan.stripe_id) + + invoice._invoiceitems = [] + items = invoice.invoiceitems.all() + self.assertEquals(0, len(items)) + self.assertIsNotNone(invoice.plan) diff --git a/tests/test_event.py b/tests/test_event.py index 4a675377de..1fed9e3f41 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -38,13 +38,22 @@ def test_str(self): ), str(event)) @patch('djstripe.models.EventProcessingException.log') - def test_stripe_error(self, event_exception_log_mock): + def test_process_event_with_log_stripe_error(self, event_exception_log_mock): event = self._create_event(FAKE_EVENT_TRANSFER_CREATED) self.call_handlers.side_effect = StripeError("Boom!") self.assertFalse(event.process()) self.assertTrue(event_exception_log_mock.called) self.assertFalse(event.processed) + @patch('djstripe.models.EventProcessingException.log') + def test_process_event_with_raise_stripe_error(self, event_exception_log_mock): + event = self._create_event(FAKE_EVENT_TRANSFER_CREATED) + self.call_handlers.side_effect = StripeError("Boom!") + with self.assertRaises(StripeError): + event.process(raise_exception=True) + self.assertTrue(event_exception_log_mock.called) + self.assertFalse(event.processed) + def test_process_event_when_invalid(self): event = self._create_event(FAKE_EVENT_TRANSFER_CREATED) event.valid = False diff --git a/tests/test_event_handlers.py b/tests/test_event_handlers.py index 8e3df68d13..580186a2be 100644 --- a/tests/test_event_handlers.py +++ b/tests/test_event_handlers.py @@ -3,6 +3,7 @@ :synopsis: dj-stripe Event Handler Tests. .. moduleauthor:: Alex Kavanaugh (@kavdev) +.. moduleauthor:: Lee Skillen (@lskillen) """ @@ -13,15 +14,39 @@ from django.test import TestCase from mock import patch -from djstripe.models import Event, Charge, Transfer, Account, Plan, Customer, InvoiceItem, Invoice, Card, Subscription -from tests import (FAKE_CUSTOMER, FAKE_CUSTOMER_II, FAKE_EVENT_CHARGE_SUCCEEDED, FAKE_EVENT_TRANSFER_CREATED, - FAKE_EVENT_PLAN_CREATED, FAKE_CHARGE, FAKE_CHARGE_II, FAKE_INVOICE_II, FAKE_EVENT_INVOICEITEM_CREATED, - FAKE_EVENT_INVOICE_CREATED, FAKE_EVENT_CUSTOMER_CREATED, FAKE_EVENT_CUSTOMER_SOURCE_CREATED, - FAKE_EVENT_CUSTOMER_SUBSCRIPTION_CREATED, FAKE_PLAN, FAKE_SUBSCRIPTION, FAKE_SUBSCRIPTION_III) from djstripe.exceptions import CustomerDoesNotExistLocallyException +from djstripe.models import Event, Charge, Transfer, Account, Plan, Customer, InvoiceItem, Invoice, Card, Subscription +from tests import (FAKE_CARD, FAKE_CHARGE, FAKE_CHARGE_II, FAKE_CUSTOMER, FAKE_CUSTOMER_II, + FAKE_EVENT_CHARGE_SUCCEEDED, FAKE_EVENT_CUSTOMER_CREATED, + FAKE_EVENT_CUSTOMER_DELETED, FAKE_EVENT_CUSTOMER_SOURCE_CREATED, + FAKE_EVENT_CUSTOMER_SOURCE_DELETED, FAKE_EVENT_CUSTOMER_SOURCE_DELETED_DUPE, + FAKE_EVENT_CUSTOMER_SUBSCRIPTION_CREATED, FAKE_EVENT_CUSTOMER_SUBSCRIPTION_DELETED, + FAKE_EVENT_INVOICE_CREATED, FAKE_EVENT_INVOICE_DELETED, FAKE_EVENT_INVOICEITEM_CREATED, + FAKE_EVENT_INVOICEITEM_DELETED, FAKE_EVENT_PLAN_CREATED, FAKE_EVENT_PLAN_DELETED, + FAKE_EVENT_TRANSFER_CREATED, FAKE_EVENT_TRANSFER_DELETED, FAKE_INVOICE, FAKE_INVOICE_II, + FAKE_INVOICEITEM, FAKE_PLAN, FAKE_SUBSCRIPTION, FAKE_SUBSCRIPTION_III, FAKE_TRANSFER) + + +class EventTestCase(TestCase): + # + # Helpers + # + + @patch('stripe.Event.retrieve') + def _create_event(self, event_data, event_retrieve_mock, patch_data=None): + event_data = deepcopy(event_data) + + if patch_data: + event_data.update(patch_data) + + event_retrieve_mock.return_value = event_data + event = Event.sync_from_stripe_data(event_data) + event.validate() + return event -class TestChargeEvents(TestCase): + +class TestChargeEvents(EventTestCase): def setUp(self): self.user = get_user_model().objects.create_user(username="pydanny", email="pydanny@gmail.com") @@ -48,17 +73,17 @@ def test_charge_created(self, event_retrieve_mock, charge_retrieve_mock, custome self.assertEquals(charge.status, fake_stripe_event["data"]["object"]["status"]) -class TestCustomerEvents(TestCase): +class TestCustomerEvents(EventTestCase): def setUp(self): self.user = get_user_model().objects.create_user(username="pydanny", email="pydanny@gmail.com") @patch("stripe.Customer.retrieve") @patch("stripe.Event.retrieve") - def test_customer_created(self, event_retrieve_mock, customer_retreive_mock): + def test_customer_created(self, event_retrieve_mock, customer_retrieve_mock): fake_stripe_event = deepcopy(FAKE_EVENT_CUSTOMER_CREATED) event_retrieve_mock.return_value = fake_stripe_event - customer_retreive_mock.return_value = fake_stripe_event["data"]["object"] + customer_retrieve_mock.return_value = fake_stripe_event["data"]["object"] Customer.objects.create(subscriber=self.user, stripe_id=FAKE_CUSTOMER["id"], currency="usd") @@ -73,10 +98,10 @@ def test_customer_created(self, event_retrieve_mock, customer_retreive_mock): @patch("stripe.Customer.retrieve") @patch("stripe.Event.retrieve") - def test_customer_created_no_customer_exists(self, event_retrieve_mock, customer_retreive_mock): + def test_customer_created_no_customer_exists(self, event_retrieve_mock, customer_retrieve_mock): fake_stripe_event = deepcopy(FAKE_EVENT_CUSTOMER_CREATED) event_retrieve_mock.return_value = fake_stripe_event - customer_retreive_mock.return_value = fake_stripe_event["data"]["object"] + customer_retrieve_mock.return_value = fake_stripe_event["data"]["object"] event = Event.sync_from_stripe_data(fake_stripe_event) @@ -85,24 +110,17 @@ def test_customer_created_no_customer_exists(self, event_retrieve_mock, customer self.assertFalse(Customer.objects.filter(stripe_id=fake_stripe_event["data"]["object"]["id"]).exists()) - @patch("stripe.Customer.retrieve") - @patch("stripe.Event.retrieve") - def test_customer_deleted(self, event_retrieve_mock, customer_retreive_mock): - fake_stripe_event = deepcopy(FAKE_EVENT_CUSTOMER_CREATED) - fake_stripe_event["type"] = "customer.deleted" - - event_retrieve_mock.return_value = fake_stripe_event - customer_retreive_mock.return_value = fake_stripe_event["data"]["object"] - + @patch("stripe.Customer.retrieve", return_value=FAKE_CUSTOMER) + def test_customer_deleted(self, customer_retrieve_mock): Customer.objects.create(subscriber=self.user, stripe_id=FAKE_CUSTOMER["id"], currency="usd") + event = self._create_event(FAKE_EVENT_CUSTOMER_CREATED) + self.assertTrue(event.process()) - event = Event.sync_from_stripe_data(fake_stripe_event) + event = self._create_event(FAKE_EVENT_CUSTOMER_DELETED) + self.assertTrue(event.process()) - event.validate() - event.process() - - customer = Customer.objects.get(stripe_id=fake_stripe_event["data"]["object"]["id"]) - self.assertNotEqual(None, customer.date_purged) + customer = Customer.objects.get(stripe_id=FAKE_CUSTOMER["id"]) + self.assertIsNotNone(customer.date_purged) @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) @patch("stripe.Event.retrieve") @@ -138,6 +156,38 @@ def test_customer_unknown_source_created(self, event_retrieve_mock, customer_ret self.assertFalse(Card.objects.filter(stripe_id=fake_stripe_event["data"]["object"]["id"]).exists()) + @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) + def test_customer_default_source_deleted(self, customer_retrieve_mock): + event = self._create_event(FAKE_EVENT_CUSTOMER_SOURCE_CREATED) + Customer.objects.create(subscriber=self.user, stripe_id=FAKE_CUSTOMER["id"], currency="usd") + self.assertTrue(event.process()) + + card = Card.objects.get(stripe_id=FAKE_CARD["id"]) + customer = Customer.objects.get(stripe_id=FAKE_CUSTOMER["id"]) + customer.default_source = card + customer.save() + self.assertIsNotNone(customer.default_source) + self.assertTrue(customer.has_valid_source()) + + event = self._create_event(FAKE_EVENT_CUSTOMER_SOURCE_DELETED) + self.assertTrue(event.process()) + + customer = Customer.objects.get(stripe_id=FAKE_CUSTOMER["id"]) + self.assertIsNone(customer.default_source) + self.assertFalse(customer.has_valid_source()) + + @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) + def test_customer_source_double_delete(self, customer_retrieve_mock): + event = self._create_event(FAKE_EVENT_CUSTOMER_SOURCE_CREATED) + Customer.objects.create(subscriber=self.user, stripe_id=FAKE_CUSTOMER["id"], currency="usd") + self.assertTrue(event.process()) + + event = self._create_event(FAKE_EVENT_CUSTOMER_SOURCE_DELETED) + self.assertTrue(event.process()) + + event = self._create_event(FAKE_EVENT_CUSTOMER_SOURCE_DELETED_DUPE) + self.assertTrue(event.process()) + @patch("stripe.Plan.retrieve", return_value=deepcopy(FAKE_PLAN)) @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) @@ -158,15 +208,32 @@ def test_customer_subscription_created(self, event_retrieve_mock, customer_retri self.assertEqual(subscription.status, fake_stripe_event["data"]["object"]["status"]) self.assertEqual(subscription.quantity, fake_stripe_event["data"]["object"]["quantity"]) + @patch("stripe.Plan.retrieve", return_value=deepcopy(FAKE_PLAN)) + @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) + @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) + def test_customer_subscription_deleted( + self, customer_retrieve_mock, subscription_retrieve_mock, plan_retrieve_mock): + Customer.objects.create(subscriber=self.user, stripe_id=FAKE_CUSTOMER["id"], currency="usd") + event = self._create_event(FAKE_EVENT_CUSTOMER_SUBSCRIPTION_CREATED) + self.assertTrue(event.process()) + + Subscription.objects.get(stripe_id=FAKE_SUBSCRIPTION["id"]) + + event = self._create_event(FAKE_EVENT_CUSTOMER_SUBSCRIPTION_DELETED) + self.assertTrue(event.process()) + + with self.assertRaises(Subscription.DoesNotExist): + Subscription.objects.get(stripe_id=FAKE_SUBSCRIPTION["id"]) + @patch("stripe.Customer.retrieve") @patch("stripe.Event.retrieve") - def test_customer_bogus_event_type(self, event_retrieve_mock, customer_retreive_mock): + def test_customer_bogus_event_type(self, event_retrieve_mock, customer_retrieve_mock): fake_stripe_event = deepcopy(FAKE_EVENT_CUSTOMER_CREATED) fake_stripe_event["data"]["object"]["customer"] = fake_stripe_event["data"]["object"]["id"] fake_stripe_event["type"] = "customer.praised" event_retrieve_mock.return_value = fake_stripe_event - customer_retreive_mock.return_value = fake_stripe_event["data"]["object"] + customer_retrieve_mock.return_value = fake_stripe_event["data"]["object"] Customer.objects.create(subscriber=self.user, stripe_id=FAKE_CUSTOMER["id"], currency="usd") @@ -179,7 +246,7 @@ def test_customer_bogus_event_type(self, event_retrieve_mock, customer_retreive_ self.assertEqual(None, customer.account_balance) -class TestInvoiceEvents(TestCase): +class TestInvoiceEvents(EventTestCase): @patch("djstripe.models.Charge.send_receipt", autospec=True) @patch("djstripe.models.Account.get_default_account") @@ -237,8 +304,32 @@ def test_invoice_created(self, event_retrieve_mock, invoice_retrieve_mock, charg self.assertEquals(invoice.amount_due, fake_stripe_event["data"]["object"]["amount_due"] / decimal.Decimal("100")) self.assertEquals(invoice.paid, fake_stripe_event["data"]["object"]["paid"]) + @patch("djstripe.models.Charge.send_receipt", autospec=True) + @patch("djstripe.models.Account.get_default_account") + @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) + @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) + @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) + @patch("stripe.Invoice.retrieve", return_value=deepcopy(FAKE_INVOICE)) + def test_invoice_deleted(self, invoice_retrieve_mock, charge_retrieve_mock, customer_retrieve_mock, + subscription_retrieve_mock, default_account_mock, send_receipt_mock): + default_account_mock.return_value = Account.objects.create() + + user = get_user_model().objects.create_user(username="pydanny", email="pydanny@gmail.com") + Customer.objects.create(subscriber=user, stripe_id=FAKE_CUSTOMER["id"], currency="usd") + + event = self._create_event(FAKE_EVENT_INVOICE_CREATED) + self.assertTrue(event.process()) -class TestInvoiceItemEvents(TestCase): + Invoice.objects.get(stripe_id=FAKE_INVOICE["id"]) + + event = self._create_event(FAKE_EVENT_INVOICE_DELETED) + self.assertTrue(event.process()) + + with self.assertRaises(Invoice.DoesNotExist): + Invoice.objects.get(stripe_id=FAKE_INVOICE["id"]) + + +class TestInvoiceItemEvents(EventTestCase): @patch("djstripe.models.Account.get_default_account") @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION_III)) @@ -267,8 +358,34 @@ def test_invoiceitem_created(self, event_retrieve_mock, invoiceitem_retrieve_moc invoiceitem = InvoiceItem.objects.get(stripe_id=fake_stripe_event["data"]["object"]["id"]) self.assertEquals(invoiceitem.amount, fake_stripe_event["data"]["object"]["amount"] / decimal.Decimal("100")) + @patch("djstripe.models.Account.get_default_account") + @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION_III)) + @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER_II)) + @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE_II)) + @patch("stripe.Invoice.retrieve", return_value=deepcopy(FAKE_INVOICE_II)) + @patch("stripe.InvoiceItem.retrieve", return_value=deepcopy(FAKE_INVOICEITEM)) + def test_invoiceitem_deleted( + self, invoiceitem_retrieve_mock, invoice_retrieve_mock, + charge_retrieve_mock, customer_retrieve_mock, + subscription_retrieve_mock, default_account_mock): + default_account_mock.return_value = Account.objects.create() + + user = get_user_model().objects.create_user(username="pydanny", email="pydanny@gmail.com") + Customer.objects.create(subscriber=user, stripe_id=FAKE_CUSTOMER_II["id"], currency="usd") + + event = self._create_event(FAKE_EVENT_INVOICEITEM_CREATED) + self.assertTrue(event.process()) + + InvoiceItem.objects.get(stripe_id=FAKE_INVOICEITEM["id"]) + + event = self._create_event(FAKE_EVENT_INVOICEITEM_DELETED) + self.assertTrue(event.process()) + + with self.assertRaises(InvoiceItem.DoesNotExist): + InvoiceItem.objects.get(stripe_id=FAKE_INVOICEITEM["id"]) + -class TestPlanEvents(TestCase): +class TestPlanEvents(EventTestCase): @patch('stripe.Plan.retrieve') @patch("stripe.Event.retrieve") @@ -285,8 +402,21 @@ def test_plan_created(self, event_retrieve_mock, plan_retrieve_mock): plan = Plan.objects.get(stripe_id=fake_stripe_event["data"]["object"]["id"]) self.assertEquals(plan.name, fake_stripe_event["data"]["object"]["name"]) + @patch('stripe.Plan.retrieve', return_value=FAKE_PLAN) + def test_plan_deleted(self, plan_retrieve_mock): + event = self._create_event(FAKE_EVENT_PLAN_CREATED) + self.assertTrue(event.process()) -class TestTransferEvents(TestCase): + Plan.objects.get(stripe_id=FAKE_PLAN["id"]) + + event = self._create_event(FAKE_EVENT_PLAN_DELETED) + self.assertTrue(event.process()) + + with self.assertRaises(Plan.DoesNotExist): + Plan.objects.get(stripe_id=FAKE_PLAN["id"]) + + +class TestTransferEvents(EventTestCase): @patch('stripe.Transfer.retrieve') @patch("stripe.Event.retrieve") @@ -303,3 +433,16 @@ def test_transfer_created(self, event_retrieve_mock, transfer_retrieve_mock): transfer = Transfer.objects.get(stripe_id=fake_stripe_event["data"]["object"]["id"]) self.assertEquals(transfer.amount, fake_stripe_event["data"]["object"]["amount"] / decimal.Decimal("100")) self.assertEquals(transfer.status, fake_stripe_event["data"]["object"]["status"]) + + @patch('stripe.Transfer.retrieve', return_value=FAKE_TRANSFER) + def test_transfer_deleted(self, transfer_retrieve_mock): + event = self._create_event(FAKE_EVENT_TRANSFER_CREATED) + self.assertTrue(event.process()) + + Transfer.objects.get(stripe_id=FAKE_TRANSFER["id"]) + + event = self._create_event(FAKE_EVENT_TRANSFER_DELETED) + self.assertTrue(event.process()) + + with self.assertRaises(Transfer.DoesNotExist): + Transfer.objects.get(stripe_id=FAKE_TRANSFER["id"]) diff --git a/tests/test_invoice.py b/tests/test_invoice.py index bf3807f60e..1a3f6b820d 100644 --- a/tests/test_invoice.py +++ b/tests/test_invoice.py @@ -3,6 +3,7 @@ :synopsis: dj-stripe Invoice Model Tests. .. moduleauthor:: Alex Kavanaugh (@kavdev) +.. moduleauthor:: Lee Skillen (@lskillen) """ @@ -11,24 +12,25 @@ from django.conf import settings from django.contrib.auth import get_user_model from django.test.testcases import TestCase -from mock import patch +from mock import patch, ANY -from djstripe.models import Customer, Invoice, Account -from tests import FAKE_INVOICE, FAKE_CHARGE, FAKE_CUSTOMER, FAKE_SUBSCRIPTION +from djstripe.models import Customer, Invoice, Account, UpcomingInvoice +from djstripe.models import InvalidRequestError + +from tests import FAKE_INVOICE, FAKE_CHARGE, FAKE_CUSTOMER, FAKE_SUBSCRIPTION, FAKE_PLAN, FAKE_INVOICEITEM_II, FAKE_UPCOMING_INVOICE class InvoiceTest(TestCase): def setUp(self): self.account = Account.objects.create() - user = get_user_model().objects.create_user(username="pydanny", email="pydanny@gmail.com") - Customer.objects.create(subscriber=user, stripe_id=FAKE_CUSTOMER["id"], currency="usd") + self.user = get_user_model().objects.create_user(username="pydanny", email="pydanny@gmail.com") + self.customer = Customer.objects.create(subscriber=self.user, stripe_id=FAKE_CUSTOMER["id"], currency="usd") @patch("djstripe.models.Account.get_default_account") @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) - @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) - def test_str(self, charge_retrieve_mock, customer_retrieve_mock, subscription_retrive_mock, default_account_mock): + def test_str(self, charge_retrieve_mock, subscription_retrive_mock, default_account_mock): default_account_mock.return_value = self.account invoice = Invoice.sync_from_stripe_data(deepcopy(FAKE_INVOICE)) @@ -42,9 +44,8 @@ def test_str(self, charge_retrieve_mock, customer_retrieve_mock, subscription_re @patch("stripe.Invoice.retrieve") @patch("djstripe.models.Account.get_default_account") @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) - @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) - def test_retry_true(self, charge_retrieve_mock, customer_retrieve_mock, subscription_retrieve_mock, default_account_mock, invoice_retrieve_mock): + def test_retry_true(self, charge_retrieve_mock, subscription_retrieve_mock, default_account_mock, invoice_retrieve_mock): default_account_mock.return_value = self.account fake_invoice = deepcopy(FAKE_INVOICE) @@ -54,15 +55,15 @@ def test_retry_true(self, charge_retrieve_mock, customer_retrieve_mock, subscrip invoice = Invoice.sync_from_stripe_data(fake_invoice) return_value = invoice.retry() - invoice_retrieve_mock.assert_called_once_with(id=invoice.stripe_id, api_key=settings.STRIPE_SECRET_KEY, expand=None) + invoice_retrieve_mock.assert_called_once_with(id=invoice.stripe_id, api_key=settings.STRIPE_SECRET_KEY, + expand=None) self.assertTrue(return_value) @patch("stripe.Invoice.retrieve") @patch("djstripe.models.Account.get_default_account") @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) - @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) - def test_retry_false(self, charge_retrieve_mock, customer_retrieve_mock, subscription_retrieve_mock, default_account_mock, invoice_retrieve_mock): + def test_retry_false(self, charge_retrieve_mock, subscription_retrieve_mock, default_account_mock, invoice_retrieve_mock): default_account_mock.return_value = self.account fake_invoice = deepcopy(FAKE_INVOICE) @@ -76,9 +77,8 @@ def test_retry_false(self, charge_retrieve_mock, customer_retrieve_mock, subscri @patch("djstripe.models.Account.get_default_account") @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) - @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) - def test_status_paid(self, charge_retrieve_mock, customer_retrieve_mock, subscription_retrieve_mock, default_account_mock): + def test_status_paid(self, charge_retrieve_mock, subscription_retrieve_mock, default_account_mock): default_account_mock.return_value = self.account invoice = Invoice.sync_from_stripe_data(deepcopy(FAKE_INVOICE)) @@ -87,9 +87,8 @@ def test_status_paid(self, charge_retrieve_mock, customer_retrieve_mock, subscri @patch("djstripe.models.Account.get_default_account") @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) - @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) - def test_status_open(self, charge_retrieve_mock, customer_retrieve_mock, subscription_retrieve_mock, default_account_mock): + def test_status_open(self, charge_retrieve_mock, subscription_retrieve_mock, default_account_mock): default_account_mock.return_value = self.account invoice_data = deepcopy(FAKE_INVOICE) @@ -100,9 +99,8 @@ def test_status_open(self, charge_retrieve_mock, customer_retrieve_mock, subscri @patch("djstripe.models.Account.get_default_account") @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) - @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) - def test_status_forgiven(self, charge_retrieve_mock, customer_retrieve_mock, subscription_retrieve_mock, default_account_mock): + def test_status_forgiven(self, charge_retrieve_mock, subscription_retrieve_mock, default_account_mock): default_account_mock.return_value = self.account invoice_data = deepcopy(FAKE_INVOICE) @@ -113,9 +111,8 @@ def test_status_forgiven(self, charge_retrieve_mock, customer_retrieve_mock, sub @patch("djstripe.models.Account.get_default_account") @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) - @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) - def test_status_closed(self, charge_retrieve_mock, customer_retrieve_mock, subscription_retrieve_mock, default_account_mock): + def test_status_closed(self, charge_retrieve_mock, subscription_retrieve_mock, default_account_mock): default_account_mock.return_value = self.account invoice_data = deepcopy(FAKE_INVOICE) @@ -128,9 +125,8 @@ def test_status_closed(self, charge_retrieve_mock, customer_retrieve_mock, subsc @patch("djstripe.models.Charge.send_receipt") @patch("djstripe.models.Account.get_default_account") @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) - @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) - def test_sync_send_emails_false(self, charge_retrieve_mock, customer_retrieve_mock, subscription_retrieve_mock, default_account_mock, send_receipt_mock, settings_fake): + def test_sync_send_emails_false(self, charge_retrieve_mock, subscription_retrieve_mock, default_account_mock, send_receipt_mock, settings_fake): default_account_mock.return_value = self.account settings_fake.SEND_INVOICE_RECEIPT_EMAILS = False @@ -140,10 +136,10 @@ def test_sync_send_emails_false(self, charge_retrieve_mock, customer_retrieve_mo self.assertFalse(send_receipt_mock.called) @patch("djstripe.models.Account.get_default_account") + @patch("stripe.Plan.retrieve", return_value=deepcopy(FAKE_PLAN)) @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) - @patch("stripe.Customer.retrieve", return_value=deepcopy(FAKE_CUSTOMER)) @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) - def test_sync_no_subscription(self, charge_retrieve_mock, customer_retrieve_mock, subscription_retrieve_mock, default_account_mock): + def test_sync_no_subscription(self, charge_retrieve_mock, subscription_retrieve_mock, plan_retrieve_mock, default_account_mock): default_account_mock.return_value = self.account invoice_data = deepcopy(FAKE_INVOICE) @@ -151,3 +147,119 @@ def test_sync_no_subscription(self, charge_retrieve_mock, customer_retrieve_mock invoice = Invoice.sync_from_stripe_data(invoice_data) self.assertEqual(None, invoice.subscription) + + charge_retrieve_mock.assert_called_once_with(api_key=ANY, expand=ANY, id=FAKE_CHARGE["id"]) + plan_retrieve_mock.assert_called_once_with(api_key=ANY, expand=ANY, id=FAKE_PLAN["id"]) + + subscription_retrieve_mock.assert_not_called() + + @patch("djstripe.models.Account.get_default_account") + @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) + @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) + def test_invoice_with_subscription_invoice_items(self, charge_retrieve_mock, subscription_retrieve_mock, default_account_mock): + default_account_mock.return_value = self.account + + invoice_data = deepcopy(FAKE_INVOICE) + invoice = Invoice.sync_from_stripe_data(invoice_data) + + items = invoice.invoiceitems.all() + self.assertEquals(1, len(items)) + item_id = "{invoice_id}-{subscription_id}".format(invoice_id=invoice.stripe_id, subscription_id=FAKE_SUBSCRIPTION["id"]) + self.assertEquals(item_id, items[0].stripe_id) + + @patch("djstripe.models.Account.get_default_account") + @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) + @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) + def test_invoice_with_no_invoice_items(self, charge_retrieve_mock, subscription_retrieve_mock, default_account_mock): + default_account_mock.return_value = self.account + + invoice_data = deepcopy(FAKE_INVOICE) + invoice_data["lines"] = [] + invoice = Invoice.sync_from_stripe_data(invoice_data) + + self.assertIsNotNone(invoice.plan) # retrieved from invoice item + self.assertEquals(FAKE_PLAN["id"], invoice.plan.stripe_id) + + @patch("djstripe.models.Account.get_default_account") + @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) + @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) + def test_invoice_with_non_subscription_invoice_items(self, charge_retrieve_mock, subscription_retrieve_mock, default_account_mock): + default_account_mock.return_value = self.account + + invoice_data = deepcopy(FAKE_INVOICE) + invoice_data["lines"]["data"].append(deepcopy(FAKE_INVOICEITEM_II)) + invoice_data["lines"]["total_count"] += 1 + invoice = Invoice.sync_from_stripe_data(invoice_data) + + self.assertIsNotNone(invoice) + self.assertEquals(2, len(invoice.invoiceitems.all())) + + @patch("djstripe.models.Account.get_default_account") + @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) + @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) + def test_invoice_plan_from_invoice_items(self, charge_retrieve_mock, subscription_retrieve_mock, default_account_mock): + default_account_mock.return_value = self.account + + invoice_data = deepcopy(FAKE_INVOICE) + invoice = Invoice.sync_from_stripe_data(invoice_data) + + self.assertIsNotNone(invoice.plan) # retrieved from invoice item + self.assertEquals(FAKE_PLAN["id"], invoice.plan.stripe_id) + + @patch("djstripe.models.Account.get_default_account") + @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) + @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) + def test_invoice_plan_from_subscription(self, charge_retrieve_mock, subscription_retrieve_mock, default_account_mock): + default_account_mock.return_value = self.account + + invoice_data = deepcopy(FAKE_INVOICE) + invoice_data["lines"]["data"][0]["plan"] = None + invoice = Invoice.sync_from_stripe_data(invoice_data) + self.assertIsNotNone(invoice.plan) # retrieved from subscription + self.assertEquals(FAKE_PLAN["id"], invoice.plan.stripe_id) + + @patch("djstripe.models.Account.get_default_account") + @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) + @patch("stripe.Charge.retrieve", return_value=deepcopy(FAKE_CHARGE)) + def test_invoice_without_plan(self, charge_retrieve_mock, subscription_retrieve_mock, default_account_mock): + default_account_mock.return_value = self.account + + invoice_data = deepcopy(FAKE_INVOICE) + invoice_data["lines"]["data"][0]["plan"] = None + invoice_data["subscription"] = None + invoice = Invoice.sync_from_stripe_data(invoice_data) + self.assertIsNone(invoice.plan) + + @patch("stripe.Plan.retrieve", return_value=deepcopy(FAKE_PLAN)) + @patch("stripe.Subscription.retrieve", return_value=deepcopy(FAKE_SUBSCRIPTION)) + @patch("stripe.Invoice.upcoming", return_value=deepcopy(FAKE_UPCOMING_INVOICE)) + def test_upcoming_invoice(self, invoice_upcoming_mock, subscription_retrieve_mock, plan_retrieve_mock): + invoice = UpcomingInvoice.upcoming() + self.assertIsNotNone(invoice) + self.assertIsNone(invoice.stripe_id) + self.assertIsNone(invoice.save()) + + subscription_retrieve_mock.assert_called_once_with(api_key=ANY, expand=ANY, id=FAKE_SUBSCRIPTION["id"]) + plan_retrieve_mock.assert_not_called() + + items = invoice.invoiceitems.all() + self.assertEquals(1, len(items)) + self.assertEquals(FAKE_SUBSCRIPTION["id"], items[0].stripe_id) + + self.assertIsNotNone(invoice.plan) + self.assertEquals(FAKE_PLAN["id"], invoice.plan.stripe_id) + + invoice._invoiceitems = [] + items = invoice.invoiceitems.all() + self.assertEquals(0, len(items)) + self.assertIsNotNone(invoice.plan) + + @patch("stripe.Invoice.upcoming", side_effect=InvalidRequestError("Nothing to invoice for customer", None)) + def test_no_upcoming_invoices(self, invoice_upcoming_mock): + invoice = Invoice.upcoming() + self.assertIsNone(invoice) + + @patch("stripe.Invoice.upcoming", side_effect=InvalidRequestError("Some other error", None)) + def test_upcoming_invoice_error(self, invoice_upcoming_mock): + with self.assertRaises(InvalidRequestError): + Invoice.upcoming() diff --git a/tests/test_plan.py b/tests/test_plan.py index 09aa541fe1..e48f4a4942 100644 --- a/tests/test_plan.py +++ b/tests/test_plan.py @@ -87,4 +87,5 @@ def test_str(self): @patch("stripe.Plan.retrieve", return_value="soup") def test_stripe_plan(self, plan_retrieve_mock): self.assertEqual("soup", self.plan.api_retrieve()) - plan_retrieve_mock.assert_called_once_with(id=self.plan_data["id"], api_key=settings.STRIPE_SECRET_KEY, expand=None) + plan_retrieve_mock.assert_called_once_with(id=self.plan_data["id"], api_key=settings.STRIPE_SECRET_KEY, + expand=None) diff --git a/tests/test_settings.py b/tests/test_settings.py index 327aec4141..c7b1e1ecdc 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -3,6 +3,7 @@ :synopsis: dj-stripe Settings Tests. .. moduleauthor:: Alex Kavanaugh (@kavdev) +.. moduleauthor:: Lee Skillen (@lskillen) """ @@ -10,8 +11,10 @@ from django.db.models.base import ModelBase from django.test import TestCase from django.test.utils import override_settings +from mock import patch -from djstripe.settings import get_subscriber_model +from djstripe import settings +from djstripe.settings import get_subscriber_model, get_callback_function class TestSubscriberModelRetrievalMethod(TestCase): @@ -49,3 +52,32 @@ def test_no_callback(self): @override_settings(DJSTRIPE_SUBSCRIBER_MODEL='testapp.Organization', DJSTRIPE_SUBSCRIBER_MODEL_REQUEST_CALLBACK=5) def test_bad_callback(self): self.assertRaisesMessage(ImproperlyConfigured, "DJSTRIPE_SUBSCRIBER_MODEL_REQUEST_CALLBACK must be callable.", get_subscriber_model) + + @override_settings(DJSTRIPE_TEST_CALLBACK=(lambda: "ok")) + def test_get_callback_function_with_valid_func_callable(self): + func = get_callback_function("DJSTRIPE_TEST_CALLBACK") + self.assertEquals("ok", func()) + + @override_settings(DJSTRIPE_TEST_CALLBACK='foo.valid_callback') + @patch.object(settings, 'import_string', return_value=(lambda: "ok")) + def test_get_callback_function_with_valid_string_callable(self, import_string_mock): + func = get_callback_function("DJSTRIPE_TEST_CALLBACK") + self.assertEquals("ok", func()) + import_string_mock.assert_called_with('foo.valid_callback') + + @override_settings(DJSTRIPE_TEST_CALLBACK='foo.non_existant_callback') + def test_get_callback_function_import_error(self): + with self.assertRaises(ImportError): + get_callback_function("DJSTRIPE_TEST_CALLBACK") + + @override_settings(DJSTRIPE_TEST_CALLBACK='foo.invalid_callback') + @patch.object(settings, 'import_string', return_value="not_callable") + def test_get_callback_function_with_non_callable_string(self, import_string_mock): + with self.assertRaises(ImproperlyConfigured): + get_callback_function("DJSTRIPE_TEST_CALLBACK") + import_string_mock.assert_called_with('foo.invalid_callback') + + @override_settings(DJSTRIPE_TEST_CALLBACK='foo.non_existant_callback') + def test_get_callback_function_(self): + with self.assertRaises(ImportError): + get_callback_function("DJSTRIPE_TEST_CALLBACK") diff --git a/tests/test_views.py b/tests/test_views.py index f4ed6545a4..b2ca310af9 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -278,7 +278,7 @@ def test_change_sub_no_proration(self, subscription_update_mock): subscription_update_mock.assert_called_once_with(subscription, plan=plan) - @patch("djstripe.views.PRORATION_POLICY_FOR_UPGRADES", return_value=True) + @patch("djstripe.views.djstripe_settings.PRORATION_POLICY_FOR_UPGRADES", return_value=True) @patch("djstripe.models.Subscription.update", autospec=True) def test_change_sub_with_proration_downgrade(self, subscription_update_mock, proration_policy_mock): Customer.objects.create(subscriber=self.user, stripe_id=FAKE_CUSTOMER["id"], currency="usd") @@ -293,7 +293,7 @@ def test_change_sub_with_proration_downgrade(self, subscription_update_mock, pro subscription_update_mock.assert_called_once_with(subscription, plan=plan) - @patch("djstripe.views.PRORATION_POLICY_FOR_UPGRADES", return_value=True) + @patch("djstripe.views.djstripe_settings.PRORATION_POLICY_FOR_UPGRADES", return_value=True) @patch("djstripe.models.Subscription.update", autospec=True) def test_change_sub_with_proration_upgrade(self, subscription_update_mock, proration_policy_mock): Customer.objects.create(subscriber=self.user, stripe_id=FAKE_CUSTOMER["id"], currency="usd") @@ -308,7 +308,7 @@ def test_change_sub_with_proration_upgrade(self, subscription_update_mock, prora subscription_update_mock.assert_called_once_with(subscription, plan=plan, prorate=True) - @patch("djstripe.views.PRORATION_POLICY_FOR_UPGRADES", return_value=True) + @patch("djstripe.views.djstripe_settings.PRORATION_POLICY_FOR_UPGRADES", return_value=True) @patch("djstripe.models.Subscription.update", autospec=True) def test_change_sub_with_proration_same_plan(self, subscription_update_mock, proration_policy_mock): Customer.objects.create(subscriber=self.user, stripe_id=FAKE_CUSTOMER["id"], currency="usd") diff --git a/tests/test_webhooks.py b/tests/test_webhooks.py index 9a0131d7b7..add7552f56 100644 --- a/tests/test_webhooks.py +++ b/tests/test_webhooks.py @@ -14,9 +14,9 @@ from django.core.urlresolvers import reverse from django.test import TestCase from django.test.client import Client -from mock import patch, Mock, ANY +from mock import call, patch, Mock, PropertyMock, ANY -from djstripe import webhooks +from djstripe import views, webhooks from djstripe.models import Event, EventProcessingException from djstripe.webhooks import handler, handler_all, call_handlers from tests import FAKE_EVENT_TRANSFER_CREATED, FAKE_TRANSFER @@ -38,6 +38,24 @@ def test_webhook_with_transfer_event(self, event_retrieve_mock, transfer_retriev self.assertEquals(resp.status_code, 200) self.assertTrue(Event.objects.filter(type="transfer.created").exists()) + @patch.object(views.djstripe_settings, 'WEBHOOK_EVENT_CALLBACK', return_value=(lambda event: event.process())) + @patch("stripe.Transfer.retrieve", return_value=deepcopy(FAKE_TRANSFER)) + @patch("stripe.Event.retrieve") + def test_webhook_with_custom_callback(self, + event_retrieve_mock, transfer_retrieve_mock, + webhook_event_callback_mock): + fake_event = deepcopy(FAKE_EVENT_TRANSFER_CREATED) + event_retrieve_mock.return_value = fake_event + + resp = Client().post( + reverse("djstripe:webhook"), + json.dumps(fake_event), + content_type="application/json" + ) + self.assertEquals(resp.status_code, 200) + event = Event.objects.get(type="transfer.created") + webhook_event_callback_mock.called_once_with(event) + @patch("stripe.Transfer.retrieve", return_value=deepcopy(FAKE_TRANSFER)) @patch("stripe.Event.retrieve") def test_webhook_with_transfer_event_duplicate(self, event_retrieve_mock, transfer_retrieve_mock): @@ -67,52 +85,71 @@ def test_webhook_with_transfer_event_duplicate(self, event_retrieve_mock, transf class TestWebhookHandlers(TestCase): def setUp(self): # Reset state of registrations per test - patcher = patch.object( - webhooks, 'registrations', new_callable=lambda: defaultdict(list)) + patcher = patch.object(webhooks, 'registrations', new_callable=(lambda: defaultdict(list))) self.addCleanup(patcher.stop) self.registrations = patcher.start() - patcher = patch.object( - webhooks, 'registrations_global', new_callable=list) + patcher = patch.object(webhooks, 'registrations_global', new_callable=list) self.addCleanup(patcher.stop) self.registrations_global = patcher.start() def test_global_handler_registration(self): func_mock = Mock() handler_all()(func_mock) - call_handlers(Mock(), {'data': 'foo'}, 'wib', 'ble') # handled + self._call_handlers("wib.ble", {"data": "foo"}) # handled self.assertEqual(1, func_mock.call_count) + func_mock.assert_called_with(ANY, ANY, "wib", "ble") def test_event_handler_registration(self): global_func_mock = Mock() handler_all()(global_func_mock) func_mock = Mock() - handler(['foo'])(func_mock) - call_handlers(Mock(), {'data': 'foo'}, 'foo', 'bar') # handled - call_handlers(Mock(), {'data': 'foo'}, 'bar', 'foo') # not handled + handler(["foo"])(func_mock) + self._call_handlers("foo.bar", {"data": "foo"}) # handled + self._call_handlers("bar.foo", {"data": "foo"}) # not handled self.assertEqual(2, global_func_mock.call_count) # called each time self.assertEqual(1, func_mock.call_count) - func_mock.assert_called_with(ANY, ANY, 'foo', 'bar') + func_mock.assert_called_with(ANY, ANY, "foo", "bar") def test_event_subtype_handler_registration(self): global_func_mock = Mock() handler_all()(global_func_mock) func_mock = Mock() - handler(['foo.bar'])(func_mock) - call_handlers(Mock(), {'data': 'foo'}, 'foo', 'bar') # handled - call_handlers(Mock(), {'data': 'foo'}, 'foo', 'baz') # not handled - self.assertEqual(2, global_func_mock.call_count) # called each time - self.assertEqual(1, func_mock.call_count) - func_mock.assert_called_with(ANY, ANY, 'foo', 'bar') + handler(["foo.bar"])(func_mock) + self._call_handlers("foo.bar", {"data": "foo"}) # handled + self._call_handlers("foo.bar.wib", {"data": "foo"}) # handled + self._call_handlers("foo.baz", {"data": "foo"}) # not handled + self.assertEqual(3, global_func_mock.call_count) # called each time + self.assertEqual(2, func_mock.call_count) + func_mock.assert_has_calls([ + call(ANY, ANY, "foo", "bar"), + call(ANY, ANY, "foo", "bar.wib")]) def test_global_handler_registration_with_function(self): func_mock = Mock() handler_all(func_mock) - call_handlers(Mock(), {'data': 'foo'}, 'wib', 'ble') # handled + self._call_handlers("wib.ble", {"data": "foo"}) # handled self.assertEqual(1, func_mock.call_count) + func_mock.assert_called_with(ANY, ANY, "wib", "ble") def test_event_handle_registation_with_string(self): func_mock = Mock() - handler('foo')(func_mock) - call_handlers(Mock(), {'data': 'foo'}, 'foo', 'bar') # handled + handler("foo")(func_mock) + self._call_handlers("foo.bar", {"data": "foo"}) # handled self.assertEqual(1, func_mock.call_count) + func_mock.assert_called_with(ANY, ANY, "foo", "bar") + + # + # Helpers + # + + @staticmethod + def _call_handlers(event_spec, data): + event = Mock(spec=Event) + event_parts = event_spec.split(".") + event_type = event_parts[0] + event_subtype = ".".join(event_parts[1:]) + type(event).parts = PropertyMock(return_value=event_parts) + type(event).event_type = PropertyMock(return_value=event_type) + type(event).event_subtype = PropertyMock(return_value=event_subtype) + return call_handlers(event, data, event_type, event_subtype)