diff --git a/data/model/test/test_user.py b/data/model/test/test_user.py index 47dd7c0ec..2f4f82995 100644 --- a/data/model/test/test_user.py +++ b/data/model/test/test_user.py @@ -5,7 +5,7 @@ import pytest from mock import patch from data.database import EmailConfirmation -from data.model.user import create_user_noverify, validate_reset_code +from data.model.user import create_user_noverify, validate_reset_code, get_active_users from util.timedeltastring import convert_to_timedelta from test.fixtures import * @@ -28,3 +28,13 @@ def test_validation_code(token_lifetime, time_since, initialized_db): result = validate_reset_code(confirmation.code) expect_success = convert_to_timedelta(token_lifetime) >= convert_to_timedelta(time_since) assert expect_success == (result is not None) + +@pytest.mark.parametrize('disabled', [ + (True), + (False), +]) +def test_get_active_users(disabled, initialized_db): + users = get_active_users(disabled=disabled) + for user in users: + if not disabled: + assert user.enabled diff --git a/data/model/user.py b/data/model/user.py index fdd17e445..7895fe9ba 100644 --- a/data/model/user.py +++ b/data/model/user.py @@ -750,8 +750,11 @@ def get_private_repo_count(username): .count()) -def get_active_users(): - return User.select().where(User.organization == False, User.robot == False) +def get_active_users(disabled=True): + query = User.select().where(User.organization == False, User.robot == False) + if not disabled: + query = query.where(User.enabled == True) + return query def get_active_user_count(): diff --git a/endpoints/api/superuser.py b/endpoints/api/superuser.py index f173d8587..212d5e69a 100644 --- a/endpoints/api/superuser.py +++ b/endpoints/api/superuser.py @@ -1,36 +1,31 @@ """ Superuser API. """ -import json import logging import os import string import socket -import pathvalidate - from datetime import datetime, timedelta from random import SystemRandom -from dateutil.relativedelta import relativedelta +import pathvalidate + from flask import request, make_response, jsonify -from tzlocal import get_localzone import features -from app import (app, avatar, superusers, authentication, config_provider, license_validator, - all_queues, log_archive, build_logs) +from app import app, avatar, superusers, authentication, config_provider, license_validator from auth import scopes from auth.auth_context import get_authenticated_user from auth.permissions import SuperUserPermission -from data.buildlogs import BuildStatusRetrievalError +from data.database import ServiceKeyApprovalType from endpoints.api import (ApiResource, nickname, resource, validate_json_request, internal_only, require_scope, show_if, parse_args, query_param, abort, require_fresh_login, path_param, verify_not_prod, - page_support, log_action, InvalidRequest, format_date) -from endpoints.api.build import build_status_view, get_logs_or_log_url -from data import model, database -from data.database import ServiceKeyApprovalType -from endpoints.api.superuser_models_pre_oci import pre_oci_model, ServiceKeyDoesNotExist, ServiceKeyAlreadyApproved, \ - InvalidRepositoryBuildException + page_support, log_action, InvalidRequest, format_date, truthy_bool) +from endpoints.api.build import get_logs_or_log_url +from endpoints.api.superuser_models_pre_oci import (pre_oci_model, ServiceKeyDoesNotExist, + ServiceKeyAlreadyApproved, + InvalidRepositoryBuildException) from endpoints.exception import NotFound, InvalidResponse from util.useremails import send_confirmation_email, send_recovery_email from util.license import decode_license, LicenseDecodeError @@ -142,7 +137,8 @@ class SuperUserAggregateLogs(ApiResource): def get(self, parsed_args): """ Returns the aggregated logs for the current system. """ if SuperUserPermission().can(): - (start_time, end_time) = _validate_logs_arguments(parsed_args['starttime'], parsed_args['endtime']) + (start_time, end_time) = _validate_logs_arguments(parsed_args['starttime'], + parsed_args['endtime']) aggregated_logs = pre_oci_model.get_aggregated_logs(start_time, end_time) return { @@ -151,10 +147,8 @@ class SuperUserAggregateLogs(ApiResource): abort(403) - LOGS_PER_PAGE = 20 - @resource('/v1/superuser/logs') @internal_only @show_if(features.SUPER_USERS) @@ -179,10 +173,10 @@ class SuperUserLogs(ApiResource): log_page = pre_oci_model.get_logs_query(start_time, end_time, page_token=page_token) return { - 'start_time': format_date(start_time), - 'end_time': format_date(end_time), - 'logs': [log.to_dict() for log in log_page.logs], - }, log_page.next_page_token + 'start_time': format_date(start_time), + 'end_time': format_date(end_time), + 'logs': [log.to_dict() for log in log_page.logs], + }, log_page.next_page_token abort(403) @@ -281,11 +275,14 @@ class SuperUserList(ApiResource): @require_fresh_login @verify_not_prod @nickname('listAllUsers') + @parse_args() + @query_param('disabled', 'If false, only enabled users will be returned.', type=truthy_bool, + default=True) @require_scope(scopes.SUPERUSER) - def get(self): + def get(self, parsed_args): """ Returns a list of all users in the system. """ if SuperUserPermission().can(): - users = pre_oci_model.get_active_users() + users = pre_oci_model.get_active_users(disabled=parsed_args['disabled']) return { 'users': [user.to_dict() for user in users] } @@ -469,7 +466,8 @@ class SuperUserManagement(ApiResource): return_value = user.to_dict() if user_data.get('password') is not None: - return_value['encrypted_password'] = authentication.encrypt_user_password(user_data.get('password')) + password = user_data.get('password') + return_value['encrypted_password'] = authentication.encrypt_user_password(password) return return_value diff --git a/endpoints/api/superuser_models_pre_oci.py b/endpoints/api/superuser_models_pre_oci.py index f2d4dfda6..4e44aea78 100644 --- a/endpoints/api/superuser_models_pre_oci.py +++ b/endpoints/api/superuser_models_pre_oci.py @@ -198,8 +198,8 @@ class PreOCIModel(SuperuserDataInterface): return return_user, confirmation.code return return_user, '' - def get_active_users(self): - users = model.user.get_active_users() + def get_active_users(self, disabled=True): + users = model.user.get_active_users(disabled=disabled) return [_create_user(user) for user in users] def get_organizations(self): diff --git a/endpoints/api/test/test_superuser.py b/endpoints/api/test/test_superuser.py new file mode 100644 index 000000000..5b98066b5 --- /dev/null +++ b/endpoints/api/test/test_superuser.py @@ -0,0 +1,19 @@ +import pytest + +from endpoints.api.superuser import SuperUserList +from endpoints.api.test.shared import conduct_api_call +from endpoints.test.shared import client_with_identity +from test.fixtures import * + +@pytest.mark.parametrize('disabled', [ + (True), + (False), +]) +def test_list_all_users(disabled, client): + with client_with_identity('devtable', client) as cl: + params = {'disabled': disabled} + result = conduct_api_call(cl, SuperUserList, 'GET', params, None, 200).json + assert len(result['users']) + for user in result['users']: + if not disabled: + assert user['enabled'] diff --git a/test/test_api_usage.py b/test/test_api_usage.py index 434693bb7..b1f8b7476 100644 --- a/test/test_api_usage.py +++ b/test/test_api_usage.py @@ -68,7 +68,7 @@ from endpoints.api.repository import ( from endpoints.api.permission import (RepositoryUserPermission, RepositoryTeamPermission, RepositoryTeamPermissionList, RepositoryUserPermissionList) from endpoints.api.superuser import ( - SuperUserLogs, SuperUserList, SuperUserManagement, SuperUserServiceKeyManagement, + SuperUserLogs, SuperUserManagement, SuperUserServiceKeyManagement, SuperUserServiceKey, SuperUserServiceKeyApproval, SuperUserTakeOwnership, SuperUserCustomCertificates, SuperUserCustomCertificate) from endpoints.api.globalmessages import ( @@ -3989,16 +3989,6 @@ class TestSuperUserLogs(ApiTestCase): assert len(json['logs']) > 0 -class TestSuperUserList(ApiTestCase): - def test_get_users(self): - self.login(ADMIN_ACCESS_USER) - - json = self.getJsonResponse(SuperUserList) - - assert 'users' in json - assert len(json['users']) > 0 - - class TestSuperUserCreateInitialSuperUser(ApiTestCase): def test_create_superuser(self): data = {