forked from andika/membership-be
306 lines
9.9 KiB
Python
306 lines
9.9 KiB
Python
"""
|
|
Payment service for Stripe integration.
|
|
Handles subscription creation, checkout sessions, and webhook processing.
|
|
"""
|
|
|
|
import stripe
|
|
import os
|
|
from dotenv import load_dotenv
|
|
from datetime import datetime, timezone, timedelta
|
|
|
|
# Load environment variables
|
|
load_dotenv()
|
|
|
|
# Initialize Stripe with secret key
|
|
stripe.api_key = os.getenv("STRIPE_SECRET_KEY")
|
|
|
|
# Stripe webhook secret for signature verification
|
|
STRIPE_WEBHOOK_SECRET = os.getenv("STRIPE_WEBHOOK_SECRET")
|
|
|
|
def create_checkout_session(
|
|
user_id: str,
|
|
user_email: str,
|
|
plan_id: str,
|
|
stripe_price_id: str,
|
|
success_url: str,
|
|
cancel_url: str
|
|
):
|
|
"""
|
|
Create a Stripe Checkout session for subscription payment.
|
|
|
|
Args:
|
|
user_id: User's UUID
|
|
user_email: User's email address
|
|
plan_id: SubscriptionPlan UUID
|
|
stripe_price_id: Stripe Price ID for the plan
|
|
success_url: URL to redirect after successful payment
|
|
cancel_url: URL to redirect if user cancels
|
|
|
|
Returns:
|
|
dict: Checkout session object with session ID and URL
|
|
"""
|
|
try:
|
|
# Create Checkout Session
|
|
checkout_session = stripe.checkout.Session.create(
|
|
customer_email=user_email,
|
|
payment_method_types=["card"],
|
|
line_items=[
|
|
{
|
|
"price": stripe_price_id,
|
|
"quantity": 1,
|
|
}
|
|
],
|
|
mode="subscription",
|
|
success_url=success_url,
|
|
cancel_url=cancel_url,
|
|
metadata={
|
|
"user_id": str(user_id),
|
|
"plan_id": str(plan_id),
|
|
},
|
|
subscription_data={
|
|
"metadata": {
|
|
"user_id": str(user_id),
|
|
"plan_id": str(plan_id),
|
|
}
|
|
}
|
|
)
|
|
|
|
return {
|
|
"session_id": checkout_session.id,
|
|
"url": checkout_session.url
|
|
}
|
|
|
|
except stripe.error.StripeError as e:
|
|
raise Exception(f"Stripe error: {str(e)}")
|
|
|
|
|
|
def verify_webhook_signature(payload: bytes, sig_header: str) -> dict:
|
|
"""
|
|
Verify Stripe webhook signature and construct event.
|
|
|
|
Args:
|
|
payload: Raw webhook payload bytes
|
|
sig_header: Stripe signature header
|
|
|
|
Returns:
|
|
dict: Verified webhook event
|
|
|
|
Raises:
|
|
ValueError: If signature verification fails
|
|
"""
|
|
try:
|
|
event = stripe.Webhook.construct_event(
|
|
payload, sig_header, STRIPE_WEBHOOK_SECRET
|
|
)
|
|
return event
|
|
except ValueError as e:
|
|
raise ValueError(f"Invalid payload: {str(e)}")
|
|
except stripe.error.SignatureVerificationError as e:
|
|
raise ValueError(f"Invalid signature: {str(e)}")
|
|
|
|
|
|
def get_subscription_end_date(billing_cycle: str = "yearly") -> datetime:
|
|
"""
|
|
Calculate subscription end date based on billing cycle.
|
|
|
|
Args:
|
|
billing_cycle: "yearly" or "monthly"
|
|
|
|
Returns:
|
|
datetime: End date for the subscription
|
|
"""
|
|
now = datetime.now(timezone.utc)
|
|
|
|
if billing_cycle == "yearly":
|
|
# Add 1 year
|
|
return now + timedelta(days=365)
|
|
elif billing_cycle == "monthly":
|
|
# Add 1 month (approximation)
|
|
return now + timedelta(days=30)
|
|
else:
|
|
# Default to yearly
|
|
return now + timedelta(days=365)
|
|
|
|
|
|
def calculate_subscription_period(plan, start_date=None, admin_override_dates=None):
|
|
"""
|
|
Calculate subscription start and end dates based on plan's custom cycle or billing_cycle.
|
|
|
|
Supports three scenarios:
|
|
1. Plan with custom billing cycle (e.g., Jan 1 - Dec 31 recurring annually)
|
|
2. Admin-overridden custom dates for manual activation
|
|
3. Standard relative billing cycle (30/90/365 days from start_date)
|
|
|
|
Args:
|
|
plan: SubscriptionPlan object with custom_cycle fields
|
|
start_date: Optional custom start date (defaults to now)
|
|
admin_override_dates: Optional dict with {'start_date': datetime, 'end_date': datetime}
|
|
|
|
Returns:
|
|
tuple: (start_date, end_date) as datetime objects
|
|
|
|
Examples:
|
|
# Plan with Jan 1 - Dec 31 custom cycle, subscribing on May 15, 2025
|
|
>>> calculate_subscription_period(plan)
|
|
(datetime(2025, 5, 15), datetime(2025, 12, 31))
|
|
|
|
# Plan with Jul 1 - Jun 30 fiscal year cycle, subscribing on Aug 20, 2025
|
|
>>> calculate_subscription_period(plan)
|
|
(datetime(2025, 8, 20), datetime(2026, 6, 30))
|
|
|
|
# Admin override for custom dates
|
|
>>> calculate_subscription_period(plan, admin_override_dates={'start_date': ..., 'end_date': ...})
|
|
(custom_start, custom_end)
|
|
"""
|
|
# Admin override takes precedence
|
|
if admin_override_dates:
|
|
return (admin_override_dates['start_date'], admin_override_dates['end_date'])
|
|
|
|
# Default start date to now if not provided
|
|
if start_date is None:
|
|
start_date = datetime.now(timezone.utc)
|
|
|
|
# Check if plan uses custom billing cycle
|
|
if plan.custom_cycle_enabled and plan.custom_cycle_start_month and plan.custom_cycle_start_day:
|
|
# Calculate end date based on recurring date range
|
|
current_year = start_date.year
|
|
|
|
# Create end date for current cycle
|
|
try:
|
|
# Check if this is a year-spanning cycle (e.g., Jul 1 - Jun 30)
|
|
year_spanning = plan.custom_cycle_end_month < plan.custom_cycle_start_month
|
|
|
|
if year_spanning:
|
|
# Fiscal year scenario: determine if we're in current or next fiscal year
|
|
cycle_start_this_year = datetime(current_year, plan.custom_cycle_start_month,
|
|
plan.custom_cycle_start_day, tzinfo=timezone.utc)
|
|
|
|
if start_date >= cycle_start_this_year:
|
|
# We're after the start of the current fiscal year
|
|
end_date = datetime(current_year + 1, plan.custom_cycle_end_month,
|
|
plan.custom_cycle_end_day, 23, 59, 59, tzinfo=timezone.utc)
|
|
else:
|
|
# We're before the start, so we're in the previous fiscal year
|
|
end_date = datetime(current_year, plan.custom_cycle_end_month,
|
|
plan.custom_cycle_end_day, 23, 59, 59, tzinfo=timezone.utc)
|
|
else:
|
|
# Calendar-aligned cycle (e.g., Jan 1 - Dec 31)
|
|
end_date = datetime(current_year, plan.custom_cycle_end_month,
|
|
plan.custom_cycle_end_day, 23, 59, 59, tzinfo=timezone.utc)
|
|
|
|
# If end date has already passed this year, use next year's end date
|
|
if end_date < start_date:
|
|
end_date = datetime(current_year + 1, plan.custom_cycle_end_month,
|
|
plan.custom_cycle_end_day, 23, 59, 59, tzinfo=timezone.utc)
|
|
|
|
return (start_date, end_date)
|
|
|
|
except ValueError:
|
|
# Invalid date (e.g., Feb 30) - fall back to relative billing
|
|
pass
|
|
|
|
# Fall back to relative billing cycle
|
|
if plan.billing_cycle == "yearly":
|
|
end_date = start_date + timedelta(days=365)
|
|
elif plan.billing_cycle == "quarterly":
|
|
end_date = start_date + timedelta(days=90)
|
|
elif plan.billing_cycle == "monthly":
|
|
end_date = start_date + timedelta(days=30)
|
|
elif plan.billing_cycle == "lifetime":
|
|
# Lifetime membership: set end date 100 years in the future
|
|
end_date = start_date + timedelta(days=365 * 100)
|
|
else:
|
|
# Default to yearly
|
|
end_date = start_date + timedelta(days=365)
|
|
|
|
return (start_date, end_date)
|
|
|
|
|
|
def get_stripe_interval(billing_cycle: str) -> str:
|
|
"""
|
|
Map billing_cycle to Stripe recurring interval.
|
|
|
|
Args:
|
|
billing_cycle: Plan billing cycle (yearly, monthly, quarterly, lifetime, custom)
|
|
|
|
Returns:
|
|
str: Stripe interval ("year", "month", or None for one-time)
|
|
|
|
Examples:
|
|
>>> get_stripe_interval("yearly")
|
|
"year"
|
|
>>> get_stripe_interval("monthly")
|
|
"month"
|
|
>>> get_stripe_interval("quarterly")
|
|
"month" # Will use interval_count=3
|
|
>>> get_stripe_interval("lifetime")
|
|
None # One-time payment
|
|
"""
|
|
if billing_cycle in ["yearly", "custom"]:
|
|
return "year"
|
|
elif billing_cycle in ["monthly", "quarterly"]:
|
|
return "month"
|
|
elif billing_cycle == "lifetime":
|
|
return None # One-time payment, not recurring
|
|
else:
|
|
# Default to year
|
|
return "year"
|
|
|
|
|
|
def create_stripe_price(
|
|
product_name: str,
|
|
price_cents: int,
|
|
billing_cycle: str = "yearly"
|
|
) -> str:
|
|
"""
|
|
Create a Stripe Price object for a subscription plan.
|
|
|
|
Args:
|
|
product_name: Name of the product/plan
|
|
price_cents: Price in cents
|
|
billing_cycle: "yearly" or "monthly"
|
|
|
|
Returns:
|
|
str: Stripe Price ID
|
|
"""
|
|
try:
|
|
# Create a product first
|
|
product = stripe.Product.create(name=product_name)
|
|
|
|
# Determine recurring interval
|
|
interval = "year" if billing_cycle == "yearly" else "month"
|
|
|
|
# Create price
|
|
price = stripe.Price.create(
|
|
product=product.id,
|
|
unit_amount=price_cents,
|
|
currency="usd",
|
|
recurring={"interval": interval},
|
|
)
|
|
|
|
return price.id
|
|
|
|
except stripe.error.StripeError as e:
|
|
raise Exception(f"Stripe error creating price: {str(e)}")
|
|
|
|
|
|
def get_customer_portal_url(stripe_customer_id: str, return_url: str) -> str:
|
|
"""
|
|
Create a Stripe Customer Portal session for subscription management.
|
|
|
|
Args:
|
|
stripe_customer_id: Stripe Customer ID
|
|
return_url: URL to return to after portal session
|
|
|
|
Returns:
|
|
str: Customer portal URL
|
|
"""
|
|
try:
|
|
session = stripe.billing_portal.Session.create(
|
|
customer=stripe_customer_id,
|
|
return_url=return_url,
|
|
)
|
|
return session.url
|
|
except stripe.error.StripeError as e:
|
|
raise Exception(f"Stripe error creating portal session: {str(e)}")
|