Skip to content

Commit

Permalink
update test account flow for stripe integration (#521)
Browse files Browse the repository at this point in the history
  • Loading branch information
diego-escobedo authored Feb 2, 2023
1 parent a56dc94 commit fc7714a
Show file tree
Hide file tree
Showing 13 changed files with 130 additions and 61 deletions.
3 changes: 2 additions & 1 deletion .github/workflows/django-postgres.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,8 @@ jobs:
DJANGO_SETTINGS_MODULE: "lotus.settings"
PYTHONPATH: "."
SECRET_KEY: ${{ secrets.SECRET_KEY }}
STRIPE_SECRET_KEY: ${{ secrets.STRIPE_SECRET_KEY }}
STRIPE_LIVE_SECRET_KEY: ${{ secrets.STRIPE_LIVE_SECRET_KEY }}
STRIPE_TEST_SECRET_KEY: ${{ secrets.STRIPE_TEST_SECRET_KEY }}
DEBUG: False
KAFKA_URL: "localhost:9092"
PYTHONDONTWRITEBYTECODE: 1
Expand Down
3 changes: 2 additions & 1 deletion .github/workflows/postman_workflow.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ jobs:
DJANGO_SETTINGS_MODULE: "lotus.settings"
PYTHONPATH: "."
SECRET_KEY: ${{ secrets.SECRET_KEY }}
STRIPE_SECRET_KEY: ${{ secrets.STRIPE_SECRET_KEY }}
STRIPE_LIVE_SECRET_KEY: ${{ secrets.STRIPE_LIVE_SECRET_KEY }}
STRIPE_TEST_SECRET_KEY: ${{ secrets.STRIPE_TEST_SECRET_KEY }}
DEBUG: False
KAFKA_URL: "localhost:9092"
PYTHONDONTWRITEBYTECODE: 1
Expand Down
5 changes: 4 additions & 1 deletion backend/lotus/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,10 @@
PRODUCT_ANALYTICS_OPT_IN = config("PRODUCT_ANALYTICS_OPT_IN", default=True, cast=bool)
PRODUCT_ANALYTICS_OPT_IN = True if not SELF_HOSTED else PRODUCT_ANALYTICS_OPT_IN
# Stripe required
STRIPE_SECRET_KEY = config("STRIPE_SECRET_KEY", default="")
STRIPE_LIVE_SECRET_KEY = config("STRIPE_LIVE_SECRET_KEY", default=None)
if STRIPE_LIVE_SECRET_KEY is None:
STRIPE_LIVE_SECRET_KEY = config("STRIPE_SECRET_KEY", default=None)
STRIPE_TEST_SECRET_KEY = config("STRIPE_TEST_SECRET_KEY", default=None)
STRIPE_WEBHOOK_SECRET = config("STRIPE_WEBHOOK_SECRET", default="whsec_")
# Webhooks for Svix
SVIX_API_KEY = config("SVIX_API_KEY", default="")
Expand Down
15 changes: 7 additions & 8 deletions backend/metering_billing/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@
from django.db.models.constraints import CheckConstraint, UniqueConstraint
from django.db.models.functions import Cast, Coalesce
from django.utils.translation import gettext_lazy as _
from rest_framework_api_key.models import AbstractAPIKey
from simple_history.models import HistoricalRecords
from svix.api import ApplicationIn, EndpointIn, EndpointSecretRotateIn, EndpointUpdate
from svix.internal.openapi_client.models.http_error import HttpError
from svix.internal.openapi_client.models.http_validation_error import (
HTTPValidationError,
)

from metering_billing.exceptions.exceptions import (
ExternalConnectionFailure,
ExternalConnectionInvalid,
Expand Down Expand Up @@ -77,6 +69,13 @@
WEBHOOK_TRIGGER_EVENTS,
)
from metering_billing.webhooks import invoice_paid_webhook, usage_alert_webhook
from rest_framework_api_key.models import AbstractAPIKey
from simple_history.models import HistoricalRecords
from svix.api import ApplicationIn, EndpointIn, EndpointSecretRotateIn, EndpointUpdate
from svix.internal.openapi_client.models.http_error import HttpError
from svix.internal.openapi_client.models.http_validation_error import (
HTTPValidationError,
)

logger = logging.getLogger("django.server")
META = settings.META
Expand Down
107 changes: 77 additions & 30 deletions backend/metering_billing/payment_providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
logger = logging.getLogger("django.server")

SELF_HOSTED = settings.SELF_HOSTED
STRIPE_SECRET_KEY = settings.STRIPE_SECRET_KEY
STRIPE_LIVE_SECRET_KEY = settings.STRIPE_LIVE_SECRET_KEY
STRIPE_TEST_SECRET_KEY = settings.STRIPE_TEST_SECRET_KEY
VITE_STRIPE_CLIENT = settings.VITE_STRIPE_CLIENT
VITE_API_URL = settings.VITE_API_URL

Expand All @@ -53,7 +54,7 @@ def working(self) -> bool:
pass

@abc.abstractmethod
def update_payment_object_status(self, payment_object_id: str):
def update_payment_object_status(self, organization, payment_object_id: str):
"""This method will be called periodically when the status of a payment object needs to be updated. It should return the status of the payment object, which should be either paid or unpaid."""
pass

Expand Down Expand Up @@ -108,7 +109,8 @@ def initialize_settings(self, organization) -> None:

class StripeConnector(PaymentProvider):
def __init__(self):
self.secret_key = STRIPE_SECRET_KEY
self.live_secret_key = STRIPE_LIVE_SECRET_KEY
self.test_secret_key = STRIPE_TEST_SECRET_KEY
self.self_hosted = SELF_HOSTED
redirect_dict = {
"response_type": "code",
Expand All @@ -123,7 +125,7 @@ def __init__(self):
self.redirect_url = ""

def working(self) -> bool:
return self.secret_key != "" and self.secret_key is not None
return self.live_secret_key is not None or self.test_secret_key is not None

def customer_connected(self, customer) -> bool:
pp_ids = customer.integrations
Expand All @@ -133,18 +135,26 @@ def customer_connected(self, customer) -> bool:

def organization_connected(self, organization) -> bool:
if self.self_hosted:
return self.secret_key != "" and self.secret_key is not None
return self.live_secret_key is not None or self.test_secret_key is not None
else:
return (
organization.payment_provider_ids.get(PAYMENT_PROVIDERS.STRIPE, "")
!= ""
organization.payment_provider_ids.get(PAYMENT_PROVIDERS.STRIPE, None)
is not None
)

def update_payment_object_status(self, payment_object_id):
from metering_billing.models import Invoice
def update_payment_object_status(self, organization, payment_object_id):
from metering_billing.models import Invoice, Organization

stripe.api_key = self.secret_key
invoice = stripe.Invoice.retrieve(payment_object_id)
invoice_payload = {}
if not self.self_hosted:
invoice_payload["stripe_account"] = organization.payment_provider_ids.get(
PAYMENT_PROVIDERS.STRIPE
)
if organization.organization_type == Organization.OrganizationType.PRODUCTION:
stripe.api_key = self.live_secret_key
else:
stripe.api_key = self.test_secret_key
invoice = stripe.Invoice.retrieve(payment_object_id, **invoice_payload)
if invoice.status == "paid":
return Invoice.PaymentStatus.PAID
else:
Expand All @@ -154,15 +164,18 @@ def import_customers(self, organization):
"""
Imports customers from Stripe. If they already exist (by checking that either they already have their Stripe ID in our system, or seeing that they have the same email address), then we update the Stripe section of payment_providers dict to reflect new information. If they don't exist, we create them (not as a Lotus customer yet, just as a Stripe customer).
"""
from metering_billing.models import Customer
from metering_billing.models import Customer, Organization

stripe.api_key = self.secret_key
if organization.organization_type == Organization.OrganizationType.PRODUCTION:
stripe.api_key = self.live_secret_key
else:
stripe.api_key = self.test_secret_key

num_cust_added = 0
org_ppis = organization.payment_provider_ids

stripe_cust_kwargs = {}
if org_ppis.get(PAYMENT_PROVIDERS.STRIPE) not in ["", None]:
if not self.self_hosted:
# this is to get "on behalf" of someone
stripe_cust_kwargs["stripe_account"] = org_ppis.get(
PAYMENT_PROVIDERS.STRIPE
Expand Down Expand Up @@ -233,7 +246,12 @@ def import_customers(self, organization):
return num_cust_added

def import_payment_objects(self, organization):
stripe.api_key = self.secret_key
from metering_billing.models import Organization

if organization.organization_type == Organization.OrganizationType.PRODUCTION:
stripe.api_key = self.live_secret_key
else:
stripe.api_key = self.test_secret_key
imported_invoices = {}
for customer in organization.customers.all():
if PAYMENT_PROVIDERS.STRIPE in customer.integrations:
Expand All @@ -244,9 +262,13 @@ def import_payment_objects(self, organization):
def _import_payment_objects_for_customer(self, customer):
from metering_billing.models import Invoice

stripe.api_key = self.secret_key
payload = {}
if not self.self_hosted:
payload["stripe_account"] = customer.organization.payment_provider_ids.get(
PAYMENT_PROVIDERS.STRIPE
)
invoices = stripe.Invoice.list(
customer=customer.integrations[PAYMENT_PROVIDERS.STRIPE]["id"]
customer=customer.integrations[PAYMENT_PROVIDERS.STRIPE]["id"], **payload
)
lotus_invoices = []
for stripe_invoice in invoices.auto_paging_iter():
Expand All @@ -273,8 +295,15 @@ def _import_payment_objects_for_customer(self, customer):
return lotus_invoices

def create_customer(self, customer):
stripe.api_key = self.secret_key
from metering_billing.models import OrganizationSetting
from metering_billing.models import Organization, OrganizationSetting

if (
customer.organization.organization_type
== Organization.OrganizationType.PRODUCTION
):
stripe.api_key = self.live_secret_key
else:
stripe.api_key = self.test_secret_key

setting = OrganizationSetting.objects.get(
setting_name=ORGANIZATION_SETTING_NAMES.GENERATE_CUSTOMER_IN_STRIPE_AFTER_LOTUS,
Expand All @@ -293,10 +322,10 @@ def create_customer(self, customer):
}
if not self.self_hosted:
org_stripe_acct = customer.organization.payment_provider_ids.get(
PAYMENT_PROVIDERS.STRIPE, ""
PAYMENT_PROVIDERS.STRIPE, None
)
assert (
org_stripe_acct != ""
org_stripe_acct is not None
), "Organization does not have a Stripe account ID"
customer_kwargs["stripe_account"] = org_stripe_acct
try:
Expand All @@ -318,7 +347,15 @@ def create_customer(self, customer):
)

def create_payment_object(self, invoice) -> str:
stripe.api_key = self.secret_key
from metering_billing.models import Organization

if (
invoice.organization.organization_type
== Organization.OrganizationType.PRODUCTION
):
stripe.api_key = self.live_secret_key
else:
stripe.api_key = self.test_secret_key
# check everything works as expected + build invoice item
assert invoice.external_payment_obj_id is None
customer = invoice.customer
Expand All @@ -329,20 +366,17 @@ def create_payment_object(self, invoice) -> str:
invoice_kwargs = {
"auto_advance": True,
"customer": stripe_customer_id,
# "automatic_tax": {
# "enabled": True,
# },
"description": "Invoice from {}".format(
customer.organization.organization_name
),
"currency": invoice.currency.code.lower(),
}
if not self.self_hosted:
org_stripe_acct = customer.organization.payment_provider_ids.get(
PAYMENT_PROVIDERS.STRIPE, ""
PAYMENT_PROVIDERS.STRIPE, None
)
assert (
org_stripe_acct != ""
org_stripe_acct is not None
), "Organization does not have a Stripe account ID"
invoice_kwargs["stripe_account"] = org_stripe_acct

Expand Down Expand Up @@ -374,6 +408,8 @@ def create_payment_object(self, invoice) -> str:
"tax_behavior": tax_behavior,
"metadata": metadata,
}
if not self.self_hosted:
inv_dict["stripe_account"] = org_stripe_acct
stripe.InvoiceItem.create(**inv_dict)
stripe_invoice = stripe.Invoice.create(**invoice_kwargs)
return stripe_invoice.id
Expand All @@ -385,7 +421,12 @@ class StripePostRequestDataSerializer(serializers.Serializer):
return StripePostRequestDataSerializer

def handle_post(self, data, organization) -> PaymentProviderPostResponseSerializer:
stripe.api_key = self.secret_key
from metering_billing.models import Organization

if organization.organization_type == Organization.OrganizationType.PRODUCTION:
stripe.api_key = self.live_secret_key
else:
stripe.api_key = self.test_secret_key
response = stripe.OAuth.token(
grant_type="authorization_code",
code=data["authorization_code"],
Expand Down Expand Up @@ -426,11 +467,15 @@ def transfer_subscriptions(
from metering_billing.models import (
Customer,
ExternalPlanLink,
Organization,
Plan,
SubscriptionRecord,
)

stripe.api_key = self.secret_key
if organization.organization_type == Organization.OrganizationType.PRODUCTION:
stripe.api_key = self.live_secret_key
else:
stripe.api_key = self.test_secret_key

org_ppis = organization.payment_provider_ids
stripe_cust_kwargs = {}
Expand All @@ -445,7 +490,7 @@ def transfer_subscriptions(
)

stripe_subscriptions = stripe.Subscription.search(
query="status:'active'",
query="status:'active'", **stripe_cust_kwargs
)
plans_with_links = (
Plan.objects.filter(organization=organization, status=PLAN_STATUS.ACTIVE)
Expand Down Expand Up @@ -507,6 +552,7 @@ def transfer_subscriptions(
subscription.id,
prorate=True,
invoice_now=True,
**stripe_cust_kwargs,
)
else:
validated_data["start_date"] = datetime.datetime.utcfromtimestamp(
Expand All @@ -515,6 +561,7 @@ def transfer_subscriptions(
sub = stripe.Subscription.modify(
subscription.id,
cancel_at_period_end=True,
**stripe_cust_kwargs,
)
ret_subs.append(sub)
SubscriptionRecord.objects.create(**validated_data)
Expand Down
4 changes: 2 additions & 2 deletions backend/metering_billing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from dateutil.relativedelta import relativedelta
from django.conf import settings
from django.db.models import Q

from metering_billing.payment_providers import PAYMENT_PROVIDER_MAP
from metering_billing.serializers.backtest_serializers import (
AllSubstitutionResultsSerializer,
Expand Down Expand Up @@ -143,8 +142,9 @@ def update_invoice_status():
for incomplete_invoice in incomplete_invoices:
pp = incomplete_invoice.external_payment_obj_type
if pp in PAYMENT_PROVIDER_MAP and PAYMENT_PROVIDER_MAP[pp].working():
organization = incomplete_invoice.organization
status = PAYMENT_PROVIDER_MAP[pp].update_payment_object_status(
incomplete_invoice.external_payment_obj_id
organization, incomplete_invoice.external_payment_obj_id
)
if status == Invoice.PaymentStatus.PAID:
incomplete_invoice.payment_status = Invoice.PaymentStatus.PAID
Expand Down
9 changes: 4 additions & 5 deletions backend/metering_billing/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import posthog
import pytest
from model_bakery import baker

from metering_billing.utils import now_utc
from metering_billing.utils.enums import (
FLAT_FEE_BILLING_TYPE,
Expand All @@ -13,6 +11,7 @@
PRODUCT_STATUS,
USAGE_BILLING_FREQUENCY,
)
from model_bakery import baker


@pytest.fixture(autouse=True)
Expand All @@ -39,12 +38,12 @@ def use_dummy_cache_backend(settings):
def turn_off_stripe_connection():
from metering_billing.payment_providers import PAYMENT_PROVIDER_MAP

sk = PAYMENT_PROVIDER_MAP["stripe"].secret_key
PAYMENT_PROVIDER_MAP["stripe"].secret_key = None
sk = PAYMENT_PROVIDER_MAP["stripe"].test_secret_key
PAYMENT_PROVIDER_MAP["stripe"].test_secret_key = None

yield

PAYMENT_PROVIDER_MAP["stripe"].secret_key = sk
PAYMENT_PROVIDER_MAP["stripe"].test_secret_key = sk


@pytest.fixture
Expand Down
Loading

0 comments on commit fc7714a

Please sign in to comment.