diff --git a/data/database.py b/data/database.py index e12c1e147..8bc0488a7 100644 --- a/data/database.py +++ b/data/database.py @@ -167,6 +167,17 @@ class BaseModel(ReadSlaveModel): database = db read_slaves = (read_slave,) + def __getattribute__(self, name): + """ Adds _id accessors so that foreign key field IDs can be looked up without making + a database roundtrip. + """ + if name.endswith('_id'): + field_name = name[0:len(name) - 3] + if field_name in self._meta.fields: + return self._data.get(field_name) + + return super(BaseModel, self).__getattribute__(name) + class User(BaseModel): uuid = CharField(default=uuid_generator, max_length=36, null=True) diff --git a/data/model/legacy.py b/data/model/legacy.py index ad80ddaeb..6f8ded0a7 100644 --- a/data/model/legacy.py +++ b/data/model/legacy.py @@ -991,8 +991,69 @@ def _get_public_repo_visibility(): return _public_repo_visibility_cache -def get_matching_repositories(repo_term, username=None, limit=10, include_public=True, - pull_count_sort=False): +def get_sorted_matching_repositories(prefix, only_public, checker, limit=10): + """ Returns repositories matching the given prefix string and passing the given checker + function. + """ + + last_week = datetime.now() - timedelta(weeks=1) + results = [] + existing_ids = [] + + def get_search_results(search_clause, with_count): + if len(results) >= limit: + return + + selected = [Repository, Namespace] + if with_count: + selected.append(fn.Count(LogEntry.id).alias('count')) + + query = (Repository.select(*selected) + .join(Namespace, JOIN_LEFT_OUTER, on=(Namespace.id == Repository.namespace_user)) + .switch(Repository) + .where(search_clause) + .group_by(Repository, Namespace)) + + if only_public: + query = query.where(Repository.visibility == _get_public_repo_visibility()) + + if existing_ids: + query = query.where(~(Repository.id << existing_ids)) + + if with_count: + query = (query.join(LogEntry, JOIN_LEFT_OUTER) + .where(LogEntry.datetime >= last_week) + .order_by(fn.Count(LogEntry.id).desc())) + + for result in query: + if len(results) >= limit: + return results + + # Note: We compare IDs here, instead of objects, because calling .visibility on the + # Repository will kick off a new SQL query to retrieve that visibility enum value. We don't + # join the visibility table in SQL, as well, because it is ungodly slow in MySQL :-/ + result.is_public = result.visibility_id == _get_public_repo_visibility().id + result.count = result.count if with_count else 0 + + if not checker(result): + continue + + results.append(result) + existing_ids.append(result.id) + + # For performance reasons, we conduct the repo name and repo namespace searches on their + # own, and with and without counts on their own. This also affords us the ability to give + # higher precedence to repository names matching over namespaces, which is semantically correct. + get_search_results((Repository.name ** (prefix + '%')), with_count=True) + get_search_results((Repository.name ** (prefix + '%')), with_count=False) + + get_search_results((Namespace.username ** (prefix + '%')), with_count=True) + get_search_results((Namespace.username ** (prefix + '%')), with_count=False) + + return results + + +def get_matching_repositories(repo_term, username=None, limit=10, include_public=True): namespace_term = repo_term name_term = repo_term @@ -1010,22 +1071,7 @@ def get_matching_repositories(repo_term, username=None, limit=10, include_public search_clauses = (Repository.name ** ('%' + name_term + '%') & Namespace.username ** ('%' + namespace_term + '%')) - query = visible.where(search_clauses).limit(limit) - - if pull_count_sort: - repo_pull = LogEntryKind.get(name = 'pull_repo') - last_month = datetime.now() - timedelta(weeks=4) - - query = (query.switch(Repository) - .join(LogEntry, JOIN_LEFT_OUTER) - .where(((LogEntry.kind == repo_pull) & (LogEntry.datetime >= last_month)) | - (LogEntry.id >> None)) - .group_by(Repository, Namespace, Visibility) - .order_by(fn.Count(LogEntry.id).desc()) - .select(Repository, Namespace, Visibility, - fn.Count(LogEntry.id).alias('count'))) - - return query + return visible.where(search_clauses).limit(limit) def change_password(user, new_password): diff --git a/endpoints/api/search.py b/endpoints/api/search.py index 9619a5021..20a34b495 100644 --- a/endpoints/api/search.py +++ b/endpoints/api/search.py @@ -3,7 +3,7 @@ from endpoints.api import (ApiResource, parse_args, query_param, truthy_bool, ni from data import model from auth.permissions import (OrganizationMemberPermission, ViewTeamPermission, ReadRepositoryPermission, UserAdminPermission, - AdministerOrganizationPermission) + AdministerOrganizationPermission, ReadRepositoryPermission) from auth.auth_context import get_authenticated_user from auth import scopes from app import avatar, get_app_url @@ -205,22 +205,29 @@ def conduct_admined_team_search(username, query, encountered_teams, results): def conduct_repo_search(username, query, results): """ Finds matching repositories. """ - matching_repos = model.get_matching_repositories(query, username, limit=5, pull_count_sort=True) + def can_read(repository): + if repository.is_public: + return True + + return ReadRepositoryPermission(repository.namespace_user.username, repository.name).can() + + only_public = username is None + matching_repos = model.get_sorted_matching_repositories(query, only_public, can_read, limit=5) for repo in matching_repos: repo_score = math.log(repo.count or 1, 10) or 1 - # If the repository is under the user's namespace, give it 50% more weight. + # If the repository is under the user's namespace, give it 20% more weight. namespace = repo.namespace_user.username if OrganizationMemberPermission(namespace).can() or namespace == username: - repo_score = repo_score * 1.5 + repo_score = repo_score * 1.2 results.append({ 'kind': 'repository', 'namespace': search_entity_view(username, repo.namespace_user), 'name': repo.name, 'description': repo.description, - 'is_public': repo.visibility.name == 'public', + 'is_public': repo.is_public, 'score': repo_score, 'href': '/repository/' + repo.namespace_user.username + '/' + repo.name })