Fix the tests and implement a fake stripe.

This commit is contained in:
jakedt 2014-04-10 15:20:16 -04:00
parent 4f3fa34206
commit d39f3cc5d4
14 changed files with 262 additions and 136 deletions

View file

@ -52,13 +52,13 @@ restart daemons
running the tests:
```
STACK=test python -m unittest discover
TEST=true python -m unittest discover
```
running the tests with coverage (requires coverage module):
```
STACK=test coverage run -m unittest discover
TEST=true coverage run -m unittest discover
coverage html
```

5
app.py
View file

@ -1,6 +1,5 @@
import logging
import os
import stripe
from flask import Flask
from flask.ext.principal import Principal
@ -12,6 +11,7 @@ import features
from storage import Storage
from data.userfiles import Userfiles
from util.analytics import Analytics
from data.billing import Billing
OVERRIDE_CONFIG_FILENAME = 'conf/stack/config.py'
@ -43,5 +43,4 @@ mail = Mail(app)
storage = Storage(app)
userfiles = Userfiles(app)
analytics = Analytics(app)
stripe.api_key = app.config.get('STRIPE_SECRET_KEY', None)
billing = Billing(app)

View file

@ -93,8 +93,7 @@ class DefaultConfig(object):
USER_EVENTS = UserEventBuilder('logs.quay.io')
# Stripe config
STRIPE_SECRET_KEY = ''
STRIPE_PUBLISHABLE_KEY = ''
BILLING_TYPE = 'FakeStripe'
# Userfiles
USERFILES_TYPE = 'LocalUserfiles'

232
data/billing.py Normal file
View file

@ -0,0 +1,232 @@
import stripe
from datetime import datetime, timedelta
from calendar import timegm
PLANS = [
# Deprecated Plans
{
'title': 'Micro',
'price': 700,
'privateRepos': 5,
'stripeId': 'micro',
'audience': 'For smaller teams',
'bus_features': False,
'deprecated': True,
},
{
'title': 'Basic',
'price': 1200,
'privateRepos': 10,
'stripeId': 'small',
'audience': 'For your basic team',
'bus_features': False,
'deprecated': True,
},
{
'title': 'Medium',
'price': 2200,
'privateRepos': 20,
'stripeId': 'medium',
'audience': 'For medium teams',
'bus_features': False,
'deprecated': True,
},
{
'title': 'Large',
'price': 5000,
'privateRepos': 50,
'stripeId': 'large',
'audience': 'For larger teams',
'bus_features': False,
'deprecated': True,
},
# Active plans
{
'title': 'Open Source',
'price': 0,
'privateRepos': 0,
'stripeId': 'free',
'audience': 'Committment to FOSS',
'bus_features': False,
'deprecated': False,
},
{
'title': 'Personal',
'price': 1200,
'privateRepos': 5,
'stripeId': 'personal',
'audience': 'Individuals',
'bus_features': False,
'deprecated': False,
},
{
'title': 'Skiff',
'price': 2500,
'privateRepos': 10,
'stripeId': 'bus-micro',
'audience': 'For startups',
'bus_features': True,
'deprecated': False,
},
{
'title': 'Yacht',
'price': 5000,
'privateRepos': 20,
'stripeId': 'bus-small',
'audience': 'For small businesses',
'bus_features': True,
'deprecated': False,
},
{
'title': 'Freighter',
'price': 10000,
'privateRepos': 50,
'stripeId': 'bus-medium',
'audience': 'For normal businesses',
'bus_features': True,
'deprecated': False,
},
{
'title': 'Tanker',
'price': 20000,
'privateRepos': 125,
'stripeId': 'bus-large',
'audience': 'For large businesses',
'bus_features': True,
'deprecated': False,
},
]
def get_plan(plan_id):
""" Returns the plan with the given ID or None if none. """
for plan in PLANS:
if plan['stripeId'] == plan_id:
return plan
return None
class AttrDict(dict):
def __init__(self, *args, **kwargs):
super(AttrDict, self).__init__(*args, **kwargs)
self.__dict__ = self
@classmethod
def deep_copy(cls, attr_dict):
copy = AttrDict(attr_dict)
for key, value in copy.items():
if isinstance(value, AttrDict):
copy[key] = cls.deep_copy(value)
return copy
class FakeStripe(object):
class Customer(AttrDict):
FAKE_PLAN = AttrDict({
'id': 'bus-small',
})
FAKE_SUBSCRIPTION = AttrDict({
'plan': FAKE_PLAN,
'current_period_start': timegm(datetime.now().utctimetuple()),
'current_period_end': timegm((datetime.now() + timedelta(days=30)).utctimetuple()),
})
FAKE_CARD = AttrDict({
'id': 'card123',
'name': 'Joe User',
'type': 'Visa',
'last4': '4242',
})
FAKE_CARD_LIST = AttrDict({
'data': [FAKE_CARD],
})
ACTIVE_CUSTOMERS = {}
@property
def card(self):
return self.get('new_card', None)
@card.setter
def card(self, card_token):
self['new_card'] = card_token
@property
def plan(self):
return self.get('new_plan', None)
@plan.setter
def plan(self, plan_name):
self['new_plan'] = plan_name
def save(self):
if self.get('new_card', None) is not None:
raise stripe.CardError('Test raising exception on set card.', self.get('new_card'), 402)
if self.get('new_plan', None) is not None:
if self.subscription is None:
self.subscription = AttrDict.deep_copy(self.FAKE_SUBSCRIPTION)
self.subscription.plan.id = self.get('new_plan')
if self.get('cancel_subscription', None) is not None:
self.subscription = None
def cancel_subscription(self):
self['cancel_subscription'] = True
@classmethod
def retrieve(cls, stripe_customer_id):
if stripe_customer_id in cls.ACTIVE_CUSTOMERS:
cls.ACTIVE_CUSTOMERS[stripe_customer_id].pop('new_card', None)
cls.ACTIVE_CUSTOMERS[stripe_customer_id].pop('new_plan', None)
cls.ACTIVE_CUSTOMERS[stripe_customer_id].pop('cancel_subscription', None)
return cls.ACTIVE_CUSTOMERS[stripe_customer_id]
else:
new_customer = cls({
'default_card': 'card123',
'cards': AttrDict.deep_copy(cls.FAKE_CARD_LIST),
'subscription': AttrDict.deep_copy(cls.FAKE_SUBSCRIPTION),
'id': stripe_customer_id,
})
cls.ACTIVE_CUSTOMERS[stripe_customer_id] = new_customer
return new_customer
class Invoice(AttrDict):
@staticmethod
def all(customer, count):
return AttrDict({
'data': [],
})
class Billing(object):
def __init__(self, app=None):
self.app = app
if app is not None:
self.state = self.init_app(app)
else:
self.state = None
def init_app(self, app):
billing_type = app.config.get('BILLING_TYPE', 'FakeStripe')
if billing_type == 'Stripe':
billing = stripe
stripe.api_key = app.config.get('STRIPE_SECRET_KEY', None)
elif billing_type == 'FakeStripe':
billing = FakeStripe
else:
raise RuntimeError('Unknown billing type: %s' % billing_type)
# register extension with app
app.extensions = getattr(app, 'extensions', {})
app.extensions['billing'] = billing
return billing
def __getattr__(self, name):
return getattr(self.state, name, None)

View file

@ -1,104 +0,0 @@
PLANS = [
# Deprecated Plans
{
'title': 'Micro',
'price': 700,
'privateRepos': 5,
'stripeId': 'micro',
'audience': 'For smaller teams',
'bus_features': False,
'deprecated': True,
},
{
'title': 'Basic',
'price': 1200,
'privateRepos': 10,
'stripeId': 'small',
'audience': 'For your basic team',
'bus_features': False,
'deprecated': True,
},
{
'title': 'Medium',
'price': 2200,
'privateRepos': 20,
'stripeId': 'medium',
'audience': 'For medium teams',
'bus_features': False,
'deprecated': True,
},
{
'title': 'Large',
'price': 5000,
'privateRepos': 50,
'stripeId': 'large',
'audience': 'For larger teams',
'bus_features': False,
'deprecated': True,
},
# Active plans
{
'title': 'Open Source',
'price': 0,
'privateRepos': 0,
'stripeId': 'free',
'audience': 'Committment to FOSS',
'bus_features': False,
'deprecated': False,
},
{
'title': 'Personal',
'price': 1200,
'privateRepos': 5,
'stripeId': 'personal',
'audience': 'Individuals',
'bus_features': False,
'deprecated': False,
},
{
'title': 'Skiff',
'price': 2500,
'privateRepos': 10,
'stripeId': 'bus-micro',
'audience': 'For startups',
'bus_features': True,
'deprecated': False,
},
{
'title': 'Yacht',
'price': 5000,
'privateRepos': 20,
'stripeId': 'bus-small',
'audience': 'For small businesses',
'bus_features': True,
'deprecated': False,
},
{
'title': 'Freighter',
'price': 10000,
'privateRepos': 50,
'stripeId': 'bus-medium',
'audience': 'For normal businesses',
'bus_features': True,
'deprecated': False,
},
{
'title': 'Tanker',
'price': 20000,
'privateRepos': 125,
'stripeId': 'bus-large',
'audience': 'For large businesses',
'bus_features': True,
'deprecated': False,
},
]
def get_plan(plan_id):
""" Returns the plan with the given ID or None if none. """
for plan in PLANS:
if plan['stripeId'] == plan_id:
return plan
return None

View file

@ -154,15 +154,18 @@ class Userfiles(object):
download_userfile_endpoint, methods=['GET'])
userfiles = LocalUserfiles(path)
elif userfiles_type == 'S3Userfiles':
elif storage_type == 'S3Userfiles':
access_key = app.config.get('USERFILES_AWS_ACCESS_KEY', '')
secret_key = app.config.get('USERFILES_AWS_SECRET_KEY', '')
bucket = app.config.get('USERFILES_S3_BUCKET', '')
userfiles = S3Userfiles(path, access_key, secret_key, bucket)
else:
elif storage_type == 'FakeUserfiles':
userfiles = FakeUserfiles()
else:
raise RuntimeError('Unknown userfiles type: %s' % storage_type)
# register extension with app
app.extensions = getattr(app, 'extensions', {})
app.extensions['userfiles'] = userfiles

View file

@ -1,7 +1,7 @@
import stripe
from flask import request
from app import billing
from endpoints.api import (resource, nickname, ApiResource, validate_json_request, log_action,
related_user_resource, internal_only, Unauthorized, NotFound,
require_user_admin, show_if, hide_if)
@ -9,7 +9,7 @@ from endpoints.api.subscribe import subscribe, subscription_view
from auth.permissions import AdministerOrganizationPermission
from auth.auth_context import get_authenticated_user
from data import model
from data.plans import PLANS
from data.billing import PLANS
import features
@ -23,7 +23,7 @@ def get_card(user):
}
if user.stripe_id:
cus = stripe.Customer.retrieve(user.stripe_id)
cus = billing.Customer.retrieve(user.stripe_id)
if cus and cus.default_card:
# Find the default card.
default_card = None
@ -44,7 +44,7 @@ def get_card(user):
def set_card(user, token):
if user.stripe_id:
cus = stripe.Customer.retrieve(user.stripe_id)
cus = billing.Customer.retrieve(user.stripe_id)
if cus:
try:
cus.card = token
@ -73,7 +73,7 @@ def get_invoices(customer_id):
'plan': i.lines.data[0].plan.id if i.lines.data[0].plan else None
}
invoices = stripe.Invoice.all(customer=customer_id, count=12)
invoices = billing.Invoice.all(customer=customer_id, count=12)
return {
'invoices': [invoice_view(i) for i in invoices.data]
}
@ -225,7 +225,7 @@ class UserPlan(ApiResource):
private_repos = model.get_private_repo_count(user.username)
if user.stripe_id:
cus = stripe.Customer.retrieve(user.stripe_id)
cus = billing.Customer.retrieve(user.stripe_id)
if cus.subscription:
return subscription_view(cus.subscription, private_repos)
@ -285,7 +285,7 @@ class OrganizationPlan(ApiResource):
private_repos = model.get_private_repo_count(orgname)
organization = model.get_organization(orgname)
if organization.stripe_id:
cus = stripe.Customer.retrieve(organization.stripe_id)
cus = billing.Customer.retrieve(organization.stripe_id)
if cus.subscription:
return subscription_view(cus.subscription, private_repos)

View file

@ -1,8 +1,8 @@
import logging
import stripe
from flask import request
from app import billing as stripe
from endpoints.api import (resource, nickname, ApiResource, validate_json_request, request_error,
related_user_resource, internal_only, Unauthorized, NotFound,
require_user_admin, log_action, show_if)
@ -12,7 +12,7 @@ from auth.permissions import (AdministerOrganizationPermission, OrganizationMem
CreateRepositoryPermission)
from auth.auth_context import get_authenticated_user
from data import model
from data.plans import get_plan
from data.billing import get_plan
from util.gravatar import compute_hash
import features

View file

@ -1,10 +1,11 @@
import logging
import stripe
from app import billing
from endpoints.api import request_error, log_action, NotFound
from endpoints.common import check_repository_usage
from data import model
from data.plans import PLANS
from data.billing import PLANS
import features
@ -60,7 +61,7 @@ def subscribe(user, plan, token, require_business_plan):
card = token
try:
cus = stripe.Customer.create(email=user.email, plan=plan, card=card)
cus = billing.Customer.create(email=user.email, plan=plan, card=card)
user.stripe_id = cus.id
user.save()
check_repository_usage(user, plan_found)
@ -73,7 +74,7 @@ def subscribe(user, plan, token, require_business_plan):
else:
# Change the plan
cus = stripe.Customer.retrieve(user.stripe_id)
cus = billing.Customer.retrieve(user.stripe_id)
if plan_found['price'] == 0:
if cus.subscription is not None:

View file

@ -1,19 +1,18 @@
import logging
import stripe
import json
from flask import request
from flask.ext.login import logout_user
from flask.ext.principal import identity_changed, AnonymousIdentity
from app import app
from app import app, billing as stripe
from endpoints.api import (ApiResource, nickname, resource, validate_json_request, request_error,
log_action, internal_only, NotFound, require_user_admin,
InvalidToken, require_scope, format_date, hide_if, show_if)
from endpoints.api.subscribe import subscribe
from endpoints.common import common_login
from data import model
from data.plans import get_plan
from data.billing import get_plan
from auth.permissions import (AdministerOrganizationPermission, CreateRepositoryPermission,
UserAdminPermission, UserReadPermission)
from auth.auth_context import get_authenticated_user

View file

@ -1,5 +1,4 @@
import logging
import stripe
import os
from flask import (abort, redirect, request, url_for, make_response, Response,
@ -9,7 +8,7 @@ from urlparse import urlparse
from data import model
from data.model.oauth import DatabaseAuthorizationProvider
from app import app
from app import app, billing as stripe
from auth.permissions import AdministerOrganizationPermission
from util.invoice import renderInvoiceToPdf
from util.seo import render_snapshot

View file

@ -1,9 +1,9 @@
import logging
import stripe
import json
from flask import request, make_response, Blueprint
from app import billing as stripe
from data import model
from data.queue import dockerfile_build_queue
from auth.auth import process_auth

View file

@ -148,8 +148,7 @@ def setup_database_for_testing(testcase):
# Sanity check to make sure we're not killing our prod db
db = model.db
if (not isinstance(model.db, SqliteDatabase) or
app.config['DB_DRIVER'] is not SqliteDatabase):
if not isinstance(model.db, SqliteDatabase):
raise RuntimeError('Attempted to wipe production database!')
global db_initialized_for_testing
@ -240,8 +239,7 @@ def wipe_database():
# Sanity check to make sure we're not killing our prod db
db = model.db
if (not isinstance(model.db, SqliteDatabase) or
app.config['DB_DRIVER'] is not SqliteDatabase):
if not isinstance(model.db, SqliteDatabase):
raise RuntimeError('Attempted to wipe production database!')
drop_model_tables(all_models, fail_silently=True)

View file

@ -13,7 +13,7 @@ class FakeTransaction(object):
class TestConfig(DefaultConfig):
TESTING = True
DB_NAME = ':memory:'
DB_URL = 'sqlite:///:memory:'
DB_CONNECTION_ARGS = {}
@staticmethod