diff --git a/README.md b/README.md index c2d045c08..958070640 100644 --- a/README.md +++ b/README.md @@ -52,13 +52,13 @@ restart daemons running the tests: ``` -STACK=test python -m unittest discover +TEST=true python -m unittest discover ``` running the tests with coverage (requires coverage module): ``` -STACK=test coverage run -m unittest discover +TEST=true coverage run -m unittest discover coverage html ``` diff --git a/app.py b/app.py index c4747cc68..2a99a2446 100644 --- a/app.py +++ b/app.py @@ -1,6 +1,5 @@ import logging import os -import stripe from flask import Flask from flask.ext.principal import Principal @@ -12,6 +11,7 @@ import features from storage import Storage from data.userfiles import Userfiles from util.analytics import Analytics +from data.billing import Billing OVERRIDE_CONFIG_FILENAME = 'conf/stack/config.py' @@ -43,5 +43,4 @@ mail = Mail(app) storage = Storage(app) userfiles = Userfiles(app) analytics = Analytics(app) - -stripe.api_key = app.config.get('STRIPE_SECRET_KEY', None) +billing = Billing(app) diff --git a/config.py b/config.py index 09f883cb6..63fad0dbe 100644 --- a/config.py +++ b/config.py @@ -93,8 +93,7 @@ class DefaultConfig(object): USER_EVENTS = UserEventBuilder('logs.quay.io') # Stripe config - STRIPE_SECRET_KEY = '' - STRIPE_PUBLISHABLE_KEY = '' + BILLING_TYPE = 'FakeStripe' # Userfiles USERFILES_TYPE = 'LocalUserfiles' diff --git a/data/billing.py b/data/billing.py new file mode 100644 index 000000000..69f1d7c04 --- /dev/null +++ b/data/billing.py @@ -0,0 +1,232 @@ +import stripe + +from datetime import datetime, timedelta +from calendar import timegm + +PLANS = [ + # Deprecated Plans + { + 'title': 'Micro', + 'price': 700, + 'privateRepos': 5, + 'stripeId': 'micro', + 'audience': 'For smaller teams', + 'bus_features': False, + 'deprecated': True, + }, + { + 'title': 'Basic', + 'price': 1200, + 'privateRepos': 10, + 'stripeId': 'small', + 'audience': 'For your basic team', + 'bus_features': False, + 'deprecated': True, + }, + { + 'title': 'Medium', + 'price': 2200, + 'privateRepos': 20, + 'stripeId': 'medium', + 'audience': 'For medium teams', + 'bus_features': False, + 'deprecated': True, + }, + { + 'title': 'Large', + 'price': 5000, + 'privateRepos': 50, + 'stripeId': 'large', + 'audience': 'For larger teams', + 'bus_features': False, + 'deprecated': True, + }, + + # Active plans + { + 'title': 'Open Source', + 'price': 0, + 'privateRepos': 0, + 'stripeId': 'free', + 'audience': 'Committment to FOSS', + 'bus_features': False, + 'deprecated': False, + }, + { + 'title': 'Personal', + 'price': 1200, + 'privateRepos': 5, + 'stripeId': 'personal', + 'audience': 'Individuals', + 'bus_features': False, + 'deprecated': False, + }, + { + 'title': 'Skiff', + 'price': 2500, + 'privateRepos': 10, + 'stripeId': 'bus-micro', + 'audience': 'For startups', + 'bus_features': True, + 'deprecated': False, + }, + { + 'title': 'Yacht', + 'price': 5000, + 'privateRepos': 20, + 'stripeId': 'bus-small', + 'audience': 'For small businesses', + 'bus_features': True, + 'deprecated': False, + }, + { + 'title': 'Freighter', + 'price': 10000, + 'privateRepos': 50, + 'stripeId': 'bus-medium', + 'audience': 'For normal businesses', + 'bus_features': True, + 'deprecated': False, + }, + { + 'title': 'Tanker', + 'price': 20000, + 'privateRepos': 125, + 'stripeId': 'bus-large', + 'audience': 'For large businesses', + 'bus_features': True, + 'deprecated': False, + }, +] + + +def get_plan(plan_id): + """ Returns the plan with the given ID or None if none. """ + for plan in PLANS: + if plan['stripeId'] == plan_id: + return plan + + return None + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self + + @classmethod + def deep_copy(cls, attr_dict): + copy = AttrDict(attr_dict) + for key, value in copy.items(): + if isinstance(value, AttrDict): + copy[key] = cls.deep_copy(value) + return copy + + +class FakeStripe(object): + class Customer(AttrDict): + FAKE_PLAN = AttrDict({ + 'id': 'bus-small', + }) + + FAKE_SUBSCRIPTION = AttrDict({ + 'plan': FAKE_PLAN, + 'current_period_start': timegm(datetime.now().utctimetuple()), + 'current_period_end': timegm((datetime.now() + timedelta(days=30)).utctimetuple()), + }) + + FAKE_CARD = AttrDict({ + 'id': 'card123', + 'name': 'Joe User', + 'type': 'Visa', + 'last4': '4242', + }) + + FAKE_CARD_LIST = AttrDict({ + 'data': [FAKE_CARD], + }) + + ACTIVE_CUSTOMERS = {} + + @property + def card(self): + return self.get('new_card', None) + + @card.setter + def card(self, card_token): + self['new_card'] = card_token + + @property + def plan(self): + return self.get('new_plan', None) + + @plan.setter + def plan(self, plan_name): + self['new_plan'] = plan_name + + def save(self): + if self.get('new_card', None) is not None: + raise stripe.CardError('Test raising exception on set card.', self.get('new_card'), 402) + if self.get('new_plan', None) is not None: + if self.subscription is None: + self.subscription = AttrDict.deep_copy(self.FAKE_SUBSCRIPTION) + self.subscription.plan.id = self.get('new_plan') + if self.get('cancel_subscription', None) is not None: + self.subscription = None + + def cancel_subscription(self): + self['cancel_subscription'] = True + + @classmethod + def retrieve(cls, stripe_customer_id): + if stripe_customer_id in cls.ACTIVE_CUSTOMERS: + cls.ACTIVE_CUSTOMERS[stripe_customer_id].pop('new_card', None) + cls.ACTIVE_CUSTOMERS[stripe_customer_id].pop('new_plan', None) + cls.ACTIVE_CUSTOMERS[stripe_customer_id].pop('cancel_subscription', None) + return cls.ACTIVE_CUSTOMERS[stripe_customer_id] + else: + new_customer = cls({ + 'default_card': 'card123', + 'cards': AttrDict.deep_copy(cls.FAKE_CARD_LIST), + 'subscription': AttrDict.deep_copy(cls.FAKE_SUBSCRIPTION), + 'id': stripe_customer_id, + }) + cls.ACTIVE_CUSTOMERS[stripe_customer_id] = new_customer + return new_customer + + class Invoice(AttrDict): + @staticmethod + def all(customer, count): + return AttrDict({ + 'data': [], + }) + + +class Billing(object): + def __init__(self, app=None): + self.app = app + if app is not None: + self.state = self.init_app(app) + else: + self.state = None + + def init_app(self, app): + billing_type = app.config.get('BILLING_TYPE', 'FakeStripe') + + if billing_type == 'Stripe': + billing = stripe + stripe.api_key = app.config.get('STRIPE_SECRET_KEY', None) + + elif billing_type == 'FakeStripe': + billing = FakeStripe + + else: + raise RuntimeError('Unknown billing type: %s' % billing_type) + + # register extension with app + app.extensions = getattr(app, 'extensions', {}) + app.extensions['billing'] = billing + return billing + + def __getattr__(self, name): + return getattr(self.state, name, None) diff --git a/data/plans.py b/data/plans.py deleted file mode 100644 index 2b8b6af2b..000000000 --- a/data/plans.py +++ /dev/null @@ -1,104 +0,0 @@ -PLANS = [ - # Deprecated Plans - { - 'title': 'Micro', - 'price': 700, - 'privateRepos': 5, - 'stripeId': 'micro', - 'audience': 'For smaller teams', - 'bus_features': False, - 'deprecated': True, - }, - { - 'title': 'Basic', - 'price': 1200, - 'privateRepos': 10, - 'stripeId': 'small', - 'audience': 'For your basic team', - 'bus_features': False, - 'deprecated': True, - }, - { - 'title': 'Medium', - 'price': 2200, - 'privateRepos': 20, - 'stripeId': 'medium', - 'audience': 'For medium teams', - 'bus_features': False, - 'deprecated': True, - }, - { - 'title': 'Large', - 'price': 5000, - 'privateRepos': 50, - 'stripeId': 'large', - 'audience': 'For larger teams', - 'bus_features': False, - 'deprecated': True, - }, - - # Active plans - { - 'title': 'Open Source', - 'price': 0, - 'privateRepos': 0, - 'stripeId': 'free', - 'audience': 'Committment to FOSS', - 'bus_features': False, - 'deprecated': False, - }, - { - 'title': 'Personal', - 'price': 1200, - 'privateRepos': 5, - 'stripeId': 'personal', - 'audience': 'Individuals', - 'bus_features': False, - 'deprecated': False, - }, - { - 'title': 'Skiff', - 'price': 2500, - 'privateRepos': 10, - 'stripeId': 'bus-micro', - 'audience': 'For startups', - 'bus_features': True, - 'deprecated': False, - }, - { - 'title': 'Yacht', - 'price': 5000, - 'privateRepos': 20, - 'stripeId': 'bus-small', - 'audience': 'For small businesses', - 'bus_features': True, - 'deprecated': False, - }, - { - 'title': 'Freighter', - 'price': 10000, - 'privateRepos': 50, - 'stripeId': 'bus-medium', - 'audience': 'For normal businesses', - 'bus_features': True, - 'deprecated': False, - }, - { - 'title': 'Tanker', - 'price': 20000, - 'privateRepos': 125, - 'stripeId': 'bus-large', - 'audience': 'For large businesses', - 'bus_features': True, - 'deprecated': False, - }, -] - - -def get_plan(plan_id): - """ Returns the plan with the given ID or None if none. """ - for plan in PLANS: - if plan['stripeId'] == plan_id: - return plan - - return None diff --git a/data/userfiles.py b/data/userfiles.py index 9617811e8..7d1f8b69b 100644 --- a/data/userfiles.py +++ b/data/userfiles.py @@ -154,15 +154,18 @@ class Userfiles(object): download_userfile_endpoint, methods=['GET']) userfiles = LocalUserfiles(path) - elif userfiles_type == 'S3Userfiles': + elif storage_type == 'S3Userfiles': access_key = app.config.get('USERFILES_AWS_ACCESS_KEY', '') secret_key = app.config.get('USERFILES_AWS_SECRET_KEY', '') bucket = app.config.get('USERFILES_S3_BUCKET', '') userfiles = S3Userfiles(path, access_key, secret_key, bucket) - else: + elif storage_type == 'FakeUserfiles': userfiles = FakeUserfiles() + else: + raise RuntimeError('Unknown userfiles type: %s' % storage_type) + # register extension with app app.extensions = getattr(app, 'extensions', {}) app.extensions['userfiles'] = userfiles diff --git a/endpoints/api/billing.py b/endpoints/api/billing.py index 89dda31f0..a00a28a72 100644 --- a/endpoints/api/billing.py +++ b/endpoints/api/billing.py @@ -1,7 +1,7 @@ import stripe from flask import request - +from app import billing from endpoints.api import (resource, nickname, ApiResource, validate_json_request, log_action, related_user_resource, internal_only, Unauthorized, NotFound, require_user_admin, show_if, hide_if) @@ -9,7 +9,7 @@ from endpoints.api.subscribe import subscribe, subscription_view from auth.permissions import AdministerOrganizationPermission from auth.auth_context import get_authenticated_user from data import model -from data.plans import PLANS +from data.billing import PLANS import features @@ -23,7 +23,7 @@ def get_card(user): } if user.stripe_id: - cus = stripe.Customer.retrieve(user.stripe_id) + cus = billing.Customer.retrieve(user.stripe_id) if cus and cus.default_card: # Find the default card. default_card = None @@ -44,7 +44,7 @@ def get_card(user): def set_card(user, token): if user.stripe_id: - cus = stripe.Customer.retrieve(user.stripe_id) + cus = billing.Customer.retrieve(user.stripe_id) if cus: try: cus.card = token @@ -73,7 +73,7 @@ def get_invoices(customer_id): 'plan': i.lines.data[0].plan.id if i.lines.data[0].plan else None } - invoices = stripe.Invoice.all(customer=customer_id, count=12) + invoices = billing.Invoice.all(customer=customer_id, count=12) return { 'invoices': [invoice_view(i) for i in invoices.data] } @@ -225,7 +225,7 @@ class UserPlan(ApiResource): private_repos = model.get_private_repo_count(user.username) if user.stripe_id: - cus = stripe.Customer.retrieve(user.stripe_id) + cus = billing.Customer.retrieve(user.stripe_id) if cus.subscription: return subscription_view(cus.subscription, private_repos) @@ -285,7 +285,7 @@ class OrganizationPlan(ApiResource): private_repos = model.get_private_repo_count(orgname) organization = model.get_organization(orgname) if organization.stripe_id: - cus = stripe.Customer.retrieve(organization.stripe_id) + cus = billing.Customer.retrieve(organization.stripe_id) if cus.subscription: return subscription_view(cus.subscription, private_repos) diff --git a/endpoints/api/organization.py b/endpoints/api/organization.py index f89ddc5d5..f6a381ace 100644 --- a/endpoints/api/organization.py +++ b/endpoints/api/organization.py @@ -1,8 +1,8 @@ import logging -import stripe from flask import request +from app import billing as stripe from endpoints.api import (resource, nickname, ApiResource, validate_json_request, request_error, related_user_resource, internal_only, Unauthorized, NotFound, require_user_admin, log_action, show_if) @@ -12,7 +12,7 @@ from auth.permissions import (AdministerOrganizationPermission, OrganizationMem CreateRepositoryPermission) from auth.auth_context import get_authenticated_user from data import model -from data.plans import get_plan +from data.billing import get_plan from util.gravatar import compute_hash import features diff --git a/endpoints/api/subscribe.py b/endpoints/api/subscribe.py index efc2dfea7..03d8a0b4c 100644 --- a/endpoints/api/subscribe.py +++ b/endpoints/api/subscribe.py @@ -1,10 +1,11 @@ import logging import stripe +from app import billing from endpoints.api import request_error, log_action, NotFound from endpoints.common import check_repository_usage from data import model -from data.plans import PLANS +from data.billing import PLANS import features @@ -60,7 +61,7 @@ def subscribe(user, plan, token, require_business_plan): card = token try: - cus = stripe.Customer.create(email=user.email, plan=plan, card=card) + cus = billing.Customer.create(email=user.email, plan=plan, card=card) user.stripe_id = cus.id user.save() check_repository_usage(user, plan_found) @@ -73,7 +74,7 @@ def subscribe(user, plan, token, require_business_plan): else: # Change the plan - cus = stripe.Customer.retrieve(user.stripe_id) + cus = billing.Customer.retrieve(user.stripe_id) if plan_found['price'] == 0: if cus.subscription is not None: diff --git a/endpoints/api/user.py b/endpoints/api/user.py index 40194a436..edd5c1f52 100644 --- a/endpoints/api/user.py +++ b/endpoints/api/user.py @@ -1,19 +1,18 @@ import logging -import stripe import json from flask import request from flask.ext.login import logout_user from flask.ext.principal import identity_changed, AnonymousIdentity -from app import app +from app import app, billing as stripe from endpoints.api import (ApiResource, nickname, resource, validate_json_request, request_error, log_action, internal_only, NotFound, require_user_admin, InvalidToken, require_scope, format_date, hide_if, show_if) from endpoints.api.subscribe import subscribe from endpoints.common import common_login from data import model -from data.plans import get_plan +from data.billing import get_plan from auth.permissions import (AdministerOrganizationPermission, CreateRepositoryPermission, UserAdminPermission, UserReadPermission) from auth.auth_context import get_authenticated_user diff --git a/endpoints/web.py b/endpoints/web.py index e14c70e79..f2acf1c70 100644 --- a/endpoints/web.py +++ b/endpoints/web.py @@ -1,5 +1,4 @@ import logging -import stripe import os from flask import (abort, redirect, request, url_for, make_response, Response, @@ -9,7 +8,7 @@ from urlparse import urlparse from data import model from data.model.oauth import DatabaseAuthorizationProvider -from app import app +from app import app, billing as stripe from auth.permissions import AdministerOrganizationPermission from util.invoice import renderInvoiceToPdf from util.seo import render_snapshot diff --git a/endpoints/webhooks.py b/endpoints/webhooks.py index 93d5e413c..df5988750 100644 --- a/endpoints/webhooks.py +++ b/endpoints/webhooks.py @@ -1,9 +1,9 @@ import logging -import stripe import json from flask import request, make_response, Blueprint +from app import billing as stripe from data import model from data.queue import dockerfile_build_queue from auth.auth import process_auth diff --git a/initdb.py b/initdb.py index c45f5be7b..9381ea282 100644 --- a/initdb.py +++ b/initdb.py @@ -148,8 +148,7 @@ def setup_database_for_testing(testcase): # Sanity check to make sure we're not killing our prod db db = model.db - if (not isinstance(model.db, SqliteDatabase) or - app.config['DB_DRIVER'] is not SqliteDatabase): + if not isinstance(model.db, SqliteDatabase): raise RuntimeError('Attempted to wipe production database!') global db_initialized_for_testing @@ -240,8 +239,7 @@ def wipe_database(): # Sanity check to make sure we're not killing our prod db db = model.db - if (not isinstance(model.db, SqliteDatabase) or - app.config['DB_DRIVER'] is not SqliteDatabase): + if not isinstance(model.db, SqliteDatabase): raise RuntimeError('Attempted to wipe production database!') drop_model_tables(all_models, fail_silently=True) diff --git a/test/testconfig.py b/test/testconfig.py index 11eb2f605..d012af469 100644 --- a/test/testconfig.py +++ b/test/testconfig.py @@ -13,7 +13,7 @@ class FakeTransaction(object): class TestConfig(DefaultConfig): TESTING = True - DB_NAME = ':memory:' + DB_URL = 'sqlite:///:memory:' DB_CONNECTION_ARGS = {} @staticmethod