Files
membership-be/payment_service.py

306 lines
9.9 KiB
Python

"""
Payment service for Stripe integration.
Handles subscription creation, checkout sessions, and webhook processing.
"""
import stripe
import os
from dotenv import load_dotenv
from datetime import datetime, timezone, timedelta
# Load environment variables
load_dotenv()
# Initialize Stripe with secret key
stripe.api_key = os.getenv("STRIPE_SECRET_KEY")
# Stripe webhook secret for signature verification
STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET")
def create_checkout_session(
user_id: str,
user_email: str,
plan_id: str,
stripe_price_id: str,
success_url: str,
cancel_url: str
):
"""
Create a Stripe Checkout session for subscription payment.
Args:
user_id: User's UUID
user_email: User's email address
plan_id: SubscriptionPlan UUID
stripe_price_id: Stripe Price ID for the plan
success_url: URL to redirect after successful payment
cancel_url: URL to redirect if user cancels
Returns:
dict: Checkout session object with session ID and URL
"""
try:
# Create Checkout Session
checkout_session = stripe.checkout.Session.create(
customer_email=user_email,
payment_method_types=["card"],
line_items=[
{
"price": stripe_price_id,
"quantity": 1,
}
],
mode="subscription",
success_url=success_url,
cancel_url=cancel_url,
metadata={
"user_id": str(user_id),
"plan_id": str(plan_id),
},
subscription_data={
"metadata": {
"user_id": str(user_id),
"plan_id": str(plan_id),
}
}
)
return {
"session_id": checkout_session.id,
"url": checkout_session.url
}
except stripe.error.StripeError as e:
raise Exception(f"Stripe error: {str(e)}")
def verify_webhook_signature(payload: bytes, sig_header: str) -> dict:
"""
Verify Stripe webhook signature and construct event.
Args:
payload: Raw webhook payload bytes
sig_header: Stripe signature header
Returns:
dict: Verified webhook event
Raises:
ValueError: If signature verification fails
"""
try:
event = stripe.Webhook.construct_event(
payload, sig_header, STRIPE_WEBHOOK_SECRET
)
return event
except ValueError as e:
raise ValueError(f"Invalid payload: {str(e)}")
except stripe.error.SignatureVerificationError as e:
raise ValueError(f"Invalid signature: {str(e)}")
def get_subscription_end_date(billing_cycle: str = "yearly") -> datetime:
"""
Calculate subscription end date based on billing cycle.
Args:
billing_cycle: "yearly" or "monthly"
Returns:
datetime: End date for the subscription
"""
now = datetime.now(timezone.utc)
if billing_cycle == "yearly":
# Add 1 year
return now + timedelta(days=365)
elif billing_cycle == "monthly":
# Add 1 month (approximation)
return now + timedelta(days=30)
else:
# Default to yearly
return now + timedelta(days=365)
def calculate_subscription_period(plan, start_date=None, admin_override_dates=None):
"""
Calculate subscription start and end dates based on plan's custom cycle or billing_cycle.
Supports three scenarios:
1. Plan with custom billing cycle (e.g., Jan 1 - Dec 31 recurring annually)
2. Admin-overridden custom dates for manual activation
3. Standard relative billing cycle (30/90/365 days from start_date)
Args:
plan: SubscriptionPlan object with custom_cycle fields
start_date: Optional custom start date (defaults to now)
admin_override_dates: Optional dict with {'start_date': datetime, 'end_date': datetime}
Returns:
tuple: (start_date, end_date) as datetime objects
Examples:
# Plan with Jan 1 - Dec 31 custom cycle, subscribing on May 15, 2025
>>> calculate_subscription_period(plan)
(datetime(2025, 5, 15), datetime(2025, 12, 31))
# Plan with Jul 1 - Jun 30 fiscal year cycle, subscribing on Aug 20, 2025
>>> calculate_subscription_period(plan)
(datetime(2025, 8, 20), datetime(2026, 6, 30))
# Admin override for custom dates
>>> calculate_subscription_period(plan, admin_override_dates={'start_date': ..., 'end_date': ...})
(custom_start, custom_end)
"""
# Admin override takes precedence
if admin_override_dates:
return (admin_override_dates['start_date'], admin_override_dates['end_date'])
# Default start date to now if not provided
if start_date is None:
start_date = datetime.now(timezone.utc)
# Check if plan uses custom billing cycle
if plan.custom_cycle_enabled and plan.custom_cycle_start_month and plan.custom_cycle_start_day:
# Calculate end date based on recurring date range
current_year = start_date.year
# Create end date for current cycle
try:
# Check if this is a year-spanning cycle (e.g., Jul 1 - Jun 30)
year_spanning = plan.custom_cycle_end_month < plan.custom_cycle_start_month
if year_spanning:
# Fiscal year scenario: determine if we're in current or next fiscal year
cycle_start_this_year = datetime(current_year, plan.custom_cycle_start_month,
plan.custom_cycle_start_day, tzinfo=timezone.utc)
if start_date >= cycle_start_this_year:
# We're after the start of the current fiscal year
end_date = datetime(current_year + 1, plan.custom_cycle_end_month,
plan.custom_cycle_end_day, 23, 59, 59, tzinfo=timezone.utc)
else:
# We're before the start, so we're in the previous fiscal year
end_date = datetime(current_year, plan.custom_cycle_end_month,
plan.custom_cycle_end_day, 23, 59, 59, tzinfo=timezone.utc)
else:
# Calendar-aligned cycle (e.g., Jan 1 - Dec 31)
end_date = datetime(current_year, plan.custom_cycle_end_month,
plan.custom_cycle_end_day, 23, 59, 59, tzinfo=timezone.utc)
# If end date has already passed this year, use next year's end date
if end_date < start_date:
end_date = datetime(current_year + 1, plan.custom_cycle_end_month,
plan.custom_cycle_end_day, 23, 59, 59, tzinfo=timezone.utc)
return (start_date, end_date)
except ValueError:
# Invalid date (e.g., Feb 30) - fall back to relative billing
pass
# Fall back to relative billing cycle
if plan.billing_cycle == "yearly":
end_date = start_date + timedelta(days=365)
elif plan.billing_cycle == "quarterly":
end_date = start_date + timedelta(days=90)
elif plan.billing_cycle == "monthly":
end_date = start_date + timedelta(days=30)
elif plan.billing_cycle == "lifetime":
# Lifetime membership: set end date 100 years in the future
end_date = start_date + timedelta(days=365 * 100)
else:
# Default to yearly
end_date = start_date + timedelta(days=365)
return (start_date, end_date)
def get_stripe_interval(billing_cycle: str) -> str:
"""
Map billing_cycle to Stripe recurring interval.
Args:
billing_cycle: Plan billing cycle (yearly, monthly, quarterly, lifetime, custom)
Returns:
str: Stripe interval ("year", "month", or None for one-time)
Examples:
>>> get_stripe_interval("yearly")
"year"
>>> get_stripe_interval("monthly")
"month"
>>> get_stripe_interval("quarterly")
"month" # Will use interval_count=3
>>> get_stripe_interval("lifetime")
None # One-time payment
"""
if billing_cycle in ["yearly", "custom"]:
return "year"
elif billing_cycle in ["monthly", "quarterly"]:
return "month"
elif billing_cycle == "lifetime":
return None # One-time payment, not recurring
else:
# Default to year
return "year"
def create_stripe_price(
product_name: str,
price_cents: int,
billing_cycle: str = "yearly"
) -> str:
"""
Create a Stripe Price object for a subscription plan.
Args:
product_name: Name of the product/plan
price_cents: Price in cents
billing_cycle: "yearly" or "monthly"
Returns:
str: Stripe Price ID
"""
try:
# Create a product first
product = stripe.Product.create(name=product_name)
# Determine recurring interval
interval = "year" if billing_cycle == "yearly" else "month"
# Create price
price = stripe.Price.create(
product=product.id,
unit_amount=price_cents,
currency="usd",
recurring={"interval": interval},
)
return price.id
except stripe.error.StripeError as e:
raise Exception(f"Stripe error creating price: {str(e)}")
def get_customer_portal_url(stripe_customer_id: str, return_url: str) -> str:
"""
Create a Stripe Customer Portal session for subscription management.
Args:
stripe_customer_id: Stripe Customer ID
return_url: URL to return to after portal session
Returns:
str: Customer portal URL
"""
try:
session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=return_url,
)
return session.url
except stripe.error.StripeError as e:
raise Exception(f"Stripe error creating portal session: {str(e)}")