diff --git a/endpoints/api/billing.py b/endpoints/api/billing.py index d889c7c82..aae577908 100644 --- a/endpoints/api/billing.py +++ b/endpoints/api/billing.py @@ -12,12 +12,42 @@ from auth.permissions import AdministerOrganizationPermission from auth.auth_context import get_authenticated_user from auth import scopes from data import model -from data.billing import PLANS +from data.billing import PLANS, get_plan import features import uuid import json +def lookup_allowed_private_repos(namespace): + """ Returns false if the given namespace has used its allotment of private repositories. """ + # Lookup the namespace and verify it has a subscription. + namespace_user = model.user.get_namespace_user(namespace) + if namespace_user is None: + return False + + if not namespace_user.stripe_id: + return False + + # Ask Stripe for the subscribed plan. + # TODO: Can we cache this or make it faster somehow? + try: + cus = billing.Customer.retrieve(namespace_user.stripe_id) + except stripe.APIConnectionError: + abort(503, message='Cannot contact Stripe') + + if not cus.subscription: + return False + + # Find the number of private repositories used by the namespace and compare it to the + # plan subscribed. + private_repos = model.user.get_private_repo_count(namespace) + current_plan = get_plan(cus.subscription.plan.id) + if current_plan is None: + return False + + return private_repos < current_plan['privateRepos'] + + def carderror_response(e): return {'carderror': e.message}, 402 diff --git a/endpoints/api/repository.py b/endpoints/api/repository.py index 215931785..b241a70a0 100644 --- a/endpoints/api/repository.py +++ b/endpoints/api/repository.py @@ -2,6 +2,7 @@ import logging import datetime +import features from datetime import timedelta @@ -15,7 +16,8 @@ from endpoints.api import (truthy_bool, format_date, nickname, log_action, valid require_repo_read, require_repo_write, require_repo_admin, RepositoryParamResource, resource, query_param, parse_args, ApiResource, request_error, require_scope, Unauthorized, NotFound, InvalidRequest, - path_param) + path_param, ExceedsLicenseException) +from endpoints.api.billing import lookup_allowed_private_repos from auth.permissions import (ModifyRepositoryPermission, AdministerRepositoryPermission, CreateRepositoryPermission) @@ -26,6 +28,18 @@ from auth import scopes logger = logging.getLogger(__name__) +def check_allowed_private_repos(namespace): + """ Checks to see if the given namespace has reached its private repository limit. If so, + raises a ExceedsLicenseException. + """ + # Not enabled if billing is disabled. + if not features.BILLING: + return + + if not lookup_allowed_private_repos(namespace): + raise ExceedsLicenseException() + + @resource('/v1/repository') class RepositoryList(ApiResource): """Operations for creating and listing repositories.""" @@ -87,6 +101,8 @@ class RepositoryList(ApiResource): raise request_error(message='Repository already exists') visibility = req['visibility'] + if visibility == 'private': + check_allowed_private_repos(namespace_name) repo = model.repository.create_repository(namespace_name, repository_name, owner, visibility) repo.description = req['description'] @@ -339,7 +355,11 @@ class RepositoryVisibility(RepositoryParamResource): repo = model.repository.get_repository(namespace, repository) if repo: values = request.get_json() - model.repository.set_repository_visibility(repo, values['visibility']) + visibility = values['visibility'] + if visibility == 'private': + check_allowed_private_repos(namespace) + + model.repository.set_repository_visibility(repo, visibility) log_action('change_repo_visibility', namespace, {'repo': repository, 'visibility': values['visibility']}, repo=repo) diff --git a/test/test_api_usage.py b/test/test_api_usage.py index 163a10977..20a8a8fec 100644 --- a/test/test_api_usage.py +++ b/test/test_api_usage.py @@ -315,8 +315,18 @@ class TestGetUserPrivateAllowed(ApiTestCase): def test_allowed(self): self.login(ADMIN_ACCESS_USER) + + # Change the subscription of the namespace. + self.putJsonResponse(UserPlan, data=dict(plan='personal-30')) + json = self.getJsonResponse(PrivateRepositories) assert json['privateCount'] >= 6 + assert not json['privateAllowed'] + + # Change the subscription of the namespace. + self.putJsonResponse(UserPlan, data=dict(plan='bus-large-30')) + + json = self.getJsonResponse(PrivateRepositories) assert json['privateAllowed'] @@ -1435,6 +1445,36 @@ class TestUpdateRepo(ApiTestCase): class TestChangeRepoVisibility(ApiTestCase): SIMPLE_REPO = ADMIN_ACCESS_USER + '/simple' + + def test_trychangevisibility(self): + self.login(ADMIN_ACCESS_USER) + + # Make public. + self.postJsonResponse(RepositoryVisibility, + params=dict(repository=self.SIMPLE_REPO), + data=dict(visibility='public')) + + # Verify the visibility. + json = self.getJsonResponse(Repository, + params=dict(repository=self.SIMPLE_REPO)) + + self.assertEquals(True, json['is_public']) + + # Change the subscription of the namespace. + self.putJsonResponse(UserPlan, data=dict(plan='personal-30')) + + # Try to make private. + self.postJsonResponse(RepositoryVisibility, + params=dict(repository=self.SIMPLE_REPO), + data=dict(visibility='private'), + expected_code=402) + + # Verify the visibility. + json = self.getJsonResponse(Repository, + params=dict(repository=self.SIMPLE_REPO)) + + self.assertEquals(True, json['is_public']) + def test_changevisibility(self): self.login(ADMIN_ACCESS_USER) diff --git a/test/testconfig.py b/test/testconfig.py index 2ee3e89bb..34ae8da48 100644 --- a/test/testconfig.py +++ b/test/testconfig.py @@ -20,6 +20,7 @@ TEST_DB_FILE = NamedTemporaryFile(delete=True) class TestConfig(DefaultConfig): TESTING = True SECRET_KEY = 'a36c9d7d-25a9-4d3f-a586-3d2f8dc40a83' + BILLING_TYPE = 'FakeStripe' TEST_DB_FILE = TEST_DB_FILE DB_URI = os.environ.get('TEST_DATABASE_URI', 'sqlite:///{0}'.format(TEST_DB_FILE.name))