From fc7714a3bc3a3f1bee4a7e0807e281c4dd730aa6 Mon Sep 17 00:00:00 2001 From: Diego Escobedo Date: Thu, 2 Feb 2023 12:11:20 -0800 Subject: [PATCH] update test account flow for stripe integration (#521) --- .github/workflows/django-postgres.yml | 3 +- .github/workflows/postman_workflow.yml | 3 +- backend/lotus/settings.py | 5 +- backend/metering_billing/models.py | 15 ++- backend/metering_billing/payment_providers.py | 107 +++++++++++++----- backend/metering_billing/tasks.py | 4 +- backend/metering_billing/tests/conftest.py | 9 +- .../tests/test_integrations.py | 8 +- .../metering_billing/views/webhook_views.py | 25 +++- docs/contributing.mdx | 3 +- docs/overview/self-hosting.mdx | 3 +- env/.env.dev.example | 3 +- env/.env.prod.example | 3 +- 13 files changed, 130 insertions(+), 61 deletions(-) diff --git a/.github/workflows/django-postgres.yml b/.github/workflows/django-postgres.yml index e2a8e7934..c9eea7d1d 100644 --- a/.github/workflows/django-postgres.yml +++ b/.github/workflows/django-postgres.yml @@ -49,7 +49,8 @@ jobs: DJANGO_SETTINGS_MODULE: "lotus.settings" PYTHONPATH: "." SECRET_KEY: ${{ secrets.SECRET_KEY }} - STRIPE_SECRET_KEY: ${{ secrets.STRIPE_SECRET_KEY }} + STRIPE_LIVE_SECRET_KEY: ${{ secrets.STRIPE_LIVE_SECRET_KEY }} + STRIPE_TEST_SECRET_KEY: ${{ secrets.STRIPE_TEST_SECRET_KEY }} DEBUG: False KAFKA_URL: "localhost:9092" PYTHONDONTWRITEBYTECODE: 1 diff --git a/.github/workflows/postman_workflow.yml b/.github/workflows/postman_workflow.yml index c9e6b04b0..8a5dd1682 100644 --- a/.github/workflows/postman_workflow.yml +++ b/.github/workflows/postman_workflow.yml @@ -44,7 +44,8 @@ jobs: DJANGO_SETTINGS_MODULE: "lotus.settings" PYTHONPATH: "." SECRET_KEY: ${{ secrets.SECRET_KEY }} - STRIPE_SECRET_KEY: ${{ secrets.STRIPE_SECRET_KEY }} + STRIPE_LIVE_SECRET_KEY: ${{ secrets.STRIPE_LIVE_SECRET_KEY }} + STRIPE_TEST_SECRET_KEY: ${{ secrets.STRIPE_TEST_SECRET_KEY }} DEBUG: False KAFKA_URL: "localhost:9092" PYTHONDONTWRITEBYTECODE: 1 diff --git a/backend/lotus/settings.py b/backend/lotus/settings.py index 9f3e81317..90f26adea 100644 --- a/backend/lotus/settings.py +++ b/backend/lotus/settings.py @@ -63,7 +63,10 @@ PRODUCT_ANALYTICS_OPT_IN = config("PRODUCT_ANALYTICS_OPT_IN", default=True, cast=bool) PRODUCT_ANALYTICS_OPT_IN = True if not SELF_HOSTED else PRODUCT_ANALYTICS_OPT_IN # Stripe required -STRIPE_SECRET_KEY = config("STRIPE_SECRET_KEY", default="") +STRIPE_LIVE_SECRET_KEY = config("STRIPE_LIVE_SECRET_KEY", default=None) +if STRIPE_LIVE_SECRET_KEY is None: + STRIPE_LIVE_SECRET_KEY = config("STRIPE_SECRET_KEY", default=None) +STRIPE_TEST_SECRET_KEY = config("STRIPE_TEST_SECRET_KEY", default=None) STRIPE_WEBHOOK_SECRET = config("STRIPE_WEBHOOK_SECRET", default="whsec_") # Webhooks for Svix SVIX_API_KEY = config("SVIX_API_KEY", default="") diff --git a/backend/metering_billing/models.py b/backend/metering_billing/models.py index 5aaea9f81..c88fb19af 100644 --- a/backend/metering_billing/models.py +++ b/backend/metering_billing/models.py @@ -17,14 +17,6 @@ from django.db.models.constraints import CheckConstraint, UniqueConstraint from django.db.models.functions import Cast, Coalesce from django.utils.translation import gettext_lazy as _ -from rest_framework_api_key.models import AbstractAPIKey -from simple_history.models import HistoricalRecords -from svix.api import ApplicationIn, EndpointIn, EndpointSecretRotateIn, EndpointUpdate -from svix.internal.openapi_client.models.http_error import HttpError -from svix.internal.openapi_client.models.http_validation_error import ( - HTTPValidationError, -) - from metering_billing.exceptions.exceptions import ( ExternalConnectionFailure, ExternalConnectionInvalid, @@ -77,6 +69,13 @@ WEBHOOK_TRIGGER_EVENTS, ) from metering_billing.webhooks import invoice_paid_webhook, usage_alert_webhook +from rest_framework_api_key.models import AbstractAPIKey +from simple_history.models import HistoricalRecords +from svix.api import ApplicationIn, EndpointIn, EndpointSecretRotateIn, EndpointUpdate +from svix.internal.openapi_client.models.http_error import HttpError +from svix.internal.openapi_client.models.http_validation_error import ( + HTTPValidationError, +) logger = logging.getLogger("django.server") META = settings.META diff --git a/backend/metering_billing/payment_providers.py b/backend/metering_billing/payment_providers.py index 4a9b93210..37a0517cc 100644 --- a/backend/metering_billing/payment_providers.py +++ b/backend/metering_billing/payment_providers.py @@ -26,7 +26,8 @@ logger = logging.getLogger("django.server") SELF_HOSTED = settings.SELF_HOSTED -STRIPE_SECRET_KEY = settings.STRIPE_SECRET_KEY +STRIPE_LIVE_SECRET_KEY = settings.STRIPE_LIVE_SECRET_KEY +STRIPE_TEST_SECRET_KEY = settings.STRIPE_TEST_SECRET_KEY VITE_STRIPE_CLIENT = settings.VITE_STRIPE_CLIENT VITE_API_URL = settings.VITE_API_URL @@ -53,7 +54,7 @@ def working(self) -> bool: pass @abc.abstractmethod - def update_payment_object_status(self, payment_object_id: str): + def update_payment_object_status(self, organization, payment_object_id: str): """This method will be called periodically when the status of a payment object needs to be updated. It should return the status of the payment object, which should be either paid or unpaid.""" pass @@ -108,7 +109,8 @@ def initialize_settings(self, organization) -> None: class StripeConnector(PaymentProvider): def __init__(self): - self.secret_key = STRIPE_SECRET_KEY + self.live_secret_key = STRIPE_LIVE_SECRET_KEY + self.test_secret_key = STRIPE_TEST_SECRET_KEY self.self_hosted = SELF_HOSTED redirect_dict = { "response_type": "code", @@ -123,7 +125,7 @@ def __init__(self): self.redirect_url = "" def working(self) -> bool: - return self.secret_key != "" and self.secret_key is not None + return self.live_secret_key is not None or self.test_secret_key is not None def customer_connected(self, customer) -> bool: pp_ids = customer.integrations @@ -133,18 +135,26 @@ def customer_connected(self, customer) -> bool: def organization_connected(self, organization) -> bool: if self.self_hosted: - return self.secret_key != "" and self.secret_key is not None + return self.live_secret_key is not None or self.test_secret_key is not None else: return ( - organization.payment_provider_ids.get(PAYMENT_PROVIDERS.STRIPE, "") - != "" + organization.payment_provider_ids.get(PAYMENT_PROVIDERS.STRIPE, None) + is not None ) - def update_payment_object_status(self, payment_object_id): - from metering_billing.models import Invoice + def update_payment_object_status(self, organization, payment_object_id): + from metering_billing.models import Invoice, Organization - stripe.api_key = self.secret_key - invoice = stripe.Invoice.retrieve(payment_object_id) + invoice_payload = {} + if not self.self_hosted: + invoice_payload["stripe_account"] = organization.payment_provider_ids.get( + PAYMENT_PROVIDERS.STRIPE + ) + if organization.organization_type == Organization.OrganizationType.PRODUCTION: + stripe.api_key = self.live_secret_key + else: + stripe.api_key = self.test_secret_key + invoice = stripe.Invoice.retrieve(payment_object_id, **invoice_payload) if invoice.status == "paid": return Invoice.PaymentStatus.PAID else: @@ -154,15 +164,18 @@ def import_customers(self, organization): """ Imports customers from Stripe. If they already exist (by checking that either they already have their Stripe ID in our system, or seeing that they have the same email address), then we update the Stripe section of payment_providers dict to reflect new information. If they don't exist, we create them (not as a Lotus customer yet, just as a Stripe customer). """ - from metering_billing.models import Customer + from metering_billing.models import Customer, Organization - stripe.api_key = self.secret_key + if organization.organization_type == Organization.OrganizationType.PRODUCTION: + stripe.api_key = self.live_secret_key + else: + stripe.api_key = self.test_secret_key num_cust_added = 0 org_ppis = organization.payment_provider_ids stripe_cust_kwargs = {} - if org_ppis.get(PAYMENT_PROVIDERS.STRIPE) not in ["", None]: + if not self.self_hosted: # this is to get "on behalf" of someone stripe_cust_kwargs["stripe_account"] = org_ppis.get( PAYMENT_PROVIDERS.STRIPE @@ -233,7 +246,12 @@ def import_customers(self, organization): return num_cust_added def import_payment_objects(self, organization): - stripe.api_key = self.secret_key + from metering_billing.models import Organization + + if organization.organization_type == Organization.OrganizationType.PRODUCTION: + stripe.api_key = self.live_secret_key + else: + stripe.api_key = self.test_secret_key imported_invoices = {} for customer in organization.customers.all(): if PAYMENT_PROVIDERS.STRIPE in customer.integrations: @@ -244,9 +262,13 @@ def import_payment_objects(self, organization): def _import_payment_objects_for_customer(self, customer): from metering_billing.models import Invoice - stripe.api_key = self.secret_key + payload = {} + if not self.self_hosted: + payload["stripe_account"] = customer.organization.payment_provider_ids.get( + PAYMENT_PROVIDERS.STRIPE + ) invoices = stripe.Invoice.list( - customer=customer.integrations[PAYMENT_PROVIDERS.STRIPE]["id"] + customer=customer.integrations[PAYMENT_PROVIDERS.STRIPE]["id"], **payload ) lotus_invoices = [] for stripe_invoice in invoices.auto_paging_iter(): @@ -273,8 +295,15 @@ def _import_payment_objects_for_customer(self, customer): return lotus_invoices def create_customer(self, customer): - stripe.api_key = self.secret_key - from metering_billing.models import OrganizationSetting + from metering_billing.models import Organization, OrganizationSetting + + if ( + customer.organization.organization_type + == Organization.OrganizationType.PRODUCTION + ): + stripe.api_key = self.live_secret_key + else: + stripe.api_key = self.test_secret_key setting = OrganizationSetting.objects.get( setting_name=ORGANIZATION_SETTING_NAMES.GENERATE_CUSTOMER_IN_STRIPE_AFTER_LOTUS, @@ -293,10 +322,10 @@ def create_customer(self, customer): } if not self.self_hosted: org_stripe_acct = customer.organization.payment_provider_ids.get( - PAYMENT_PROVIDERS.STRIPE, "" + PAYMENT_PROVIDERS.STRIPE, None ) assert ( - org_stripe_acct != "" + org_stripe_acct is not None ), "Organization does not have a Stripe account ID" customer_kwargs["stripe_account"] = org_stripe_acct try: @@ -318,7 +347,15 @@ def create_customer(self, customer): ) def create_payment_object(self, invoice) -> str: - stripe.api_key = self.secret_key + from metering_billing.models import Organization + + if ( + invoice.organization.organization_type + == Organization.OrganizationType.PRODUCTION + ): + stripe.api_key = self.live_secret_key + else: + stripe.api_key = self.test_secret_key # check everything works as expected + build invoice item assert invoice.external_payment_obj_id is None customer = invoice.customer @@ -329,9 +366,6 @@ def create_payment_object(self, invoice) -> str: invoice_kwargs = { "auto_advance": True, "customer": stripe_customer_id, - # "automatic_tax": { - # "enabled": True, - # }, "description": "Invoice from {}".format( customer.organization.organization_name ), @@ -339,10 +373,10 @@ def create_payment_object(self, invoice) -> str: } if not self.self_hosted: org_stripe_acct = customer.organization.payment_provider_ids.get( - PAYMENT_PROVIDERS.STRIPE, "" + PAYMENT_PROVIDERS.STRIPE, None ) assert ( - org_stripe_acct != "" + org_stripe_acct is not None ), "Organization does not have a Stripe account ID" invoice_kwargs["stripe_account"] = org_stripe_acct @@ -374,6 +408,8 @@ def create_payment_object(self, invoice) -> str: "tax_behavior": tax_behavior, "metadata": metadata, } + if not self.self_hosted: + inv_dict["stripe_account"] = org_stripe_acct stripe.InvoiceItem.create(**inv_dict) stripe_invoice = stripe.Invoice.create(**invoice_kwargs) return stripe_invoice.id @@ -385,7 +421,12 @@ class StripePostRequestDataSerializer(serializers.Serializer): return StripePostRequestDataSerializer def handle_post(self, data, organization) -> PaymentProviderPostResponseSerializer: - stripe.api_key = self.secret_key + from metering_billing.models import Organization + + if organization.organization_type == Organization.OrganizationType.PRODUCTION: + stripe.api_key = self.live_secret_key + else: + stripe.api_key = self.test_secret_key response = stripe.OAuth.token( grant_type="authorization_code", code=data["authorization_code"], @@ -426,11 +467,15 @@ def transfer_subscriptions( from metering_billing.models import ( Customer, ExternalPlanLink, + Organization, Plan, SubscriptionRecord, ) - stripe.api_key = self.secret_key + if organization.organization_type == Organization.OrganizationType.PRODUCTION: + stripe.api_key = self.live_secret_key + else: + stripe.api_key = self.test_secret_key org_ppis = organization.payment_provider_ids stripe_cust_kwargs = {} @@ -445,7 +490,7 @@ def transfer_subscriptions( ) stripe_subscriptions = stripe.Subscription.search( - query="status:'active'", + query="status:'active'", **stripe_cust_kwargs ) plans_with_links = ( Plan.objects.filter(organization=organization, status=PLAN_STATUS.ACTIVE) @@ -507,6 +552,7 @@ def transfer_subscriptions( subscription.id, prorate=True, invoice_now=True, + **stripe_cust_kwargs, ) else: validated_data["start_date"] = datetime.datetime.utcfromtimestamp( @@ -515,6 +561,7 @@ def transfer_subscriptions( sub = stripe.Subscription.modify( subscription.id, cancel_at_period_end=True, + **stripe_cust_kwargs, ) ret_subs.append(sub) SubscriptionRecord.objects.create(**validated_data) diff --git a/backend/metering_billing/tasks.py b/backend/metering_billing/tasks.py index 876dd91b6..aab245df2 100644 --- a/backend/metering_billing/tasks.py +++ b/backend/metering_billing/tasks.py @@ -6,7 +6,6 @@ from dateutil.relativedelta import relativedelta from django.conf import settings from django.db.models import Q - from metering_billing.payment_providers import PAYMENT_PROVIDER_MAP from metering_billing.serializers.backtest_serializers import ( AllSubstitutionResultsSerializer, @@ -143,8 +142,9 @@ def update_invoice_status(): for incomplete_invoice in incomplete_invoices: pp = incomplete_invoice.external_payment_obj_type if pp in PAYMENT_PROVIDER_MAP and PAYMENT_PROVIDER_MAP[pp].working(): + organization = incomplete_invoice.organization status = PAYMENT_PROVIDER_MAP[pp].update_payment_object_status( - incomplete_invoice.external_payment_obj_id + organization, incomplete_invoice.external_payment_obj_id ) if status == Invoice.PaymentStatus.PAID: incomplete_invoice.payment_status = Invoice.PaymentStatus.PAID diff --git a/backend/metering_billing/tests/conftest.py b/backend/metering_billing/tests/conftest.py index b982c2a8b..a6282fb5c 100644 --- a/backend/metering_billing/tests/conftest.py +++ b/backend/metering_billing/tests/conftest.py @@ -2,8 +2,6 @@ import posthog import pytest -from model_bakery import baker - from metering_billing.utils import now_utc from metering_billing.utils.enums import ( FLAT_FEE_BILLING_TYPE, @@ -13,6 +11,7 @@ PRODUCT_STATUS, USAGE_BILLING_FREQUENCY, ) +from model_bakery import baker @pytest.fixture(autouse=True) @@ -39,12 +38,12 @@ def use_dummy_cache_backend(settings): def turn_off_stripe_connection(): from metering_billing.payment_providers import PAYMENT_PROVIDER_MAP - sk = PAYMENT_PROVIDER_MAP["stripe"].secret_key - PAYMENT_PROVIDER_MAP["stripe"].secret_key = None + sk = PAYMENT_PROVIDER_MAP["stripe"].test_secret_key + PAYMENT_PROVIDER_MAP["stripe"].test_secret_key = None yield - PAYMENT_PROVIDER_MAP["stripe"].secret_key = sk + PAYMENT_PROVIDER_MAP["stripe"].test_secret_key = sk @pytest.fixture diff --git a/backend/metering_billing/tests/test_integrations.py b/backend/metering_billing/tests/test_integrations.py index 6519cdccf..513e7b5db 100644 --- a/backend/metering_billing/tests/test_integrations.py +++ b/backend/metering_billing/tests/test_integrations.py @@ -19,8 +19,8 @@ from metering_billing.utils import now_utc from metering_billing.utils.enums import PAYMENT_PROVIDERS -STRIPE_SECRET_KEY = settings.STRIPE_SECRET_KEY -stripe.api_key = STRIPE_SECRET_KEY +STRIPE_TEST_SECRET_KEY = settings.STRIPE_TEST_SECRET_KEY +stripe.api_key = STRIPE_TEST_SECRET_KEY @pytest.fixture @@ -197,7 +197,7 @@ def test_update_invoice_status(self, integration_test_common_setup): # update the status of the invoice new_status = stripe_connector.update_payment_object_status( - invoice.external_payment_obj_id + setup_dict["org"], invoice.external_payment_obj_id ) assert new_status == Invoice.PaymentStatus.UNPAID # now add payment method @@ -206,7 +206,7 @@ def test_update_invoice_status(self, integration_test_common_setup): paid_out_of_band=True, ) new_status = stripe_connector.update_payment_object_status( - invoice.external_payment_obj_id + setup_dict["org"], invoice.external_payment_obj_id ) assert new_status == Invoice.PaymentStatus.PAID diff --git a/backend/metering_billing/views/webhook_views.py b/backend/metering_billing/views/webhook_views.py index 78bc4124b..09659a2f6 100644 --- a/backend/metering_billing/views/webhook_views.py +++ b/backend/metering_billing/views/webhook_views.py @@ -1,8 +1,6 @@ import stripe from django.conf import settings from django.views.decorators.csrf import csrf_exempt -from metering_billing.models import Customer, Invoice -from metering_billing.utils.enums import PAYMENT_PROVIDERS from rest_framework import status from rest_framework.decorators import ( api_view, @@ -11,9 +9,13 @@ ) from rest_framework.response import Response +from metering_billing.models import Customer, Invoice +from metering_billing.payment_providers import PAYMENT_PROVIDER_MAP +from metering_billing.utils.enums import PAYMENT_PROVIDERS + STRIPE_WEBHOOK_SECRET = settings.STRIPE_WEBHOOK_SECRET -STRIPE_SECRET_KEY = settings.STRIPE_SECRET_KEY -stripe.api_key = STRIPE_SECRET_KEY +STRIPE_TEST_SECRET_KEY = settings.STRIPE_TEST_SECRET_KEY +STRIPE_LIVE_SECRET_KEY = settings.STRIPE_LIVE_SECRET_KEY def _invoice_paid_handler(event): @@ -32,9 +34,22 @@ def _payment_method_refresh_handler(stripe_customer_id): integrations__stripe__id=stripe_customer_id ).first() if matching_customer: + organization = matching_customer.organization + if organization.OrganizationType == organization.OrganizationType.PRODUCTION: + stripe.api_key = STRIPE_LIVE_SECRET_KEY + else: + stripe.api_key = STRIPE_TEST_SECRET_KEY integrations_dict = matching_customer.integrations stripe_payment_methods = [] - payment_methods = stripe.Customer.list_payment_methods(stripe_customer_id) + if PAYMENT_PROVIDER_MAP[PAYMENT_PROVIDERS.STRIPE].self_hosted: + payment_methods = stripe.Customer.list_payment_methods(stripe_customer_id) + else: + payment_methods = stripe.Customer.list_payment_methods( + customer=stripe_customer_id, + stripe_account=organization.payment_provider_ids.get( + PAYMENT_PROVIDERS.STRIPE + ), + ) for payment_method in payment_methods.auto_paging_iter(): pm_dict = { "id": payment_method.id, diff --git a/docs/contributing.mdx b/docs/contributing.mdx index f4dbbff83..174b108e9 100644 --- a/docs/contributing.mdx +++ b/docs/contributing.mdx @@ -82,6 +82,7 @@ chmod +x ./scripts/run-codestyle-docker.sh && ./scripts/run-codestyle-docker.sh | NODE_ENV | development | | | VITE_API_URL | "http://localhost:8000/" | | | VITE_STRIPE_CLIENT | ca\_ | ✔ | -| STRIPE_SECRET_KEY | "" | ✔ | +| STRIPE_LIVE_SECRET_KEY | sk_live\_ | ✔ | +| STRIPE_TEST_SECRET_KEY | sk_test\_ | ✔ | | STRIPE_WEBHOOK_SECRET | whsec\_ | ✔ | | SVIX_JWT_SECRET | change_me | ✔ | diff --git a/docs/overview/self-hosting.mdx b/docs/overview/self-hosting.mdx index 2adfde86a..2d8b38d53 100644 --- a/docs/overview/self-hosting.mdx +++ b/docs/overview/self-hosting.mdx @@ -44,6 +44,7 @@ Optionally: | NODE_ENV | production | | | VITE_API_URL | "http://localhost/" | | | VITE_STRIPE_CLIENT | ca\_ | ✔ | -| STRIPE_SECRET_KEY | sk_live\_ | ✔ | +| STRIPE_LIVE_SECRET_KEY | sk_live\_ | ✔ | +| STRIPE_TEST_SECRET_KEY | sk_test\_ | ✔ | | STRIPE_WEBHOOK_SECRET | whsec\_ | ✔ | | SVIX_JWT_SECRET | change_me | ✔ | diff --git a/env/.env.dev.example b/env/.env.dev.example index deb34749b..2e72312ed 100644 --- a/env/.env.dev.example +++ b/env/.env.dev.example @@ -15,7 +15,8 @@ NODE_ENV=development VITE_API_URL="http://localhost:8000/" VITE_STRIPE_CLIENT=ca_ -STRIPE_SECRET_KEY=sk_test_ +STRIPE_LIVE_SECRET_KEY=sk_live_ +STRIPE_TEST_SECRET_KEY=sk_test_ STRIPE_WEBHOOK_SECRET=whsec_ KAFKA_URL="redpanda:29092" diff --git a/env/.env.prod.example b/env/.env.prod.example index dc5484dad..7106d81be 100644 --- a/env/.env.prod.example +++ b/env/.env.prod.example @@ -15,7 +15,8 @@ NODE_ENV=production VITE_API_URL="http://localhost/" VITE_STRIPE_CLIENT=ca_ -STRIPE_SECRET_KEY=sk_live_ +STRIPE_LIVE_SECRET_KEY=sk_live_ +STRIPE_TEST_SECRET_KEY=sk_test_ STRIPE_WEBHOOK_SECRET=whsec_ KAFKA_URL="redpanda:29092"