diff --git a/data/database.py b/data/database.py index 271a25c32..83dbe74de 100644 --- a/data/database.py +++ b/data/database.py @@ -21,7 +21,9 @@ from sqlalchemy.engine.url import make_url import resumablehashlib -from data.fields import ResumableSHA256Field, ResumableSHA1Field, JSONField, Base64BinaryField +from data.fields import (ResumableSHA256Field, ResumableSHA1Field, JSONField, Base64BinaryField, + FullIndexedTextField, FullIndexedCharField) +from data.text import match_mysql, match_like from data.read_slave import ReadSlaveModel from util.names import urn_generator @@ -30,10 +32,12 @@ logger = logging.getLogger(__name__) DEFAULT_DB_CONNECT_TIMEOUT = 10 # seconds + # IMAGE_NOT_SCANNED_ENGINE_VERSION is the version found in security_indexed_engine when the # image has not yet been scanned. IMAGE_NOT_SCANNED_ENGINE_VERSION = -1 + _SCHEME_DRIVERS = { 'mysql': MySQLDatabase, 'mysql+pymysql': MySQLDatabase, @@ -42,6 +46,16 @@ _SCHEME_DRIVERS = { 'postgresql+psycopg2': PostgresqlDatabase, } + +SCHEME_MATCH_FUNCTION = { + 'mysql': match_mysql, + 'mysql+pymysql': match_mysql, + 'sqlite': match_like, + 'postgresql': match_like, + 'postgresql+psycopg2': match_like, +} + + SCHEME_RANDOM_FUNCTION = { 'mysql': fn.Rand, 'mysql+pymysql': fn.Rand, @@ -50,6 +64,7 @@ SCHEME_RANDOM_FUNCTION = { 'postgresql+psycopg2': fn.Random, } + def pipes_concat(arg1, arg2, *extra_args): """ Concat function for sqlite, since it doesn't support fn.Concat. Concatenates clauses with || characters. @@ -211,6 +226,7 @@ class TupleSelector(object): db = Proxy() read_slave = Proxy() db_random_func = CallableProxy() +db_match_func = CallableProxy() db_for_update = CallableProxy() db_transaction = CallableProxy() db_concat_func = CallableProxy() @@ -257,6 +273,7 @@ def configure(config_object): parsed_write_uri = make_url(write_db_uri) db_random_func.initialize(SCHEME_RANDOM_FUNCTION[parsed_write_uri.drivername]) + db_match_func.initialize(SCHEME_MATCH_FUNCTION[parsed_write_uri.drivername]) db_for_update.initialize(SCHEME_SPECIALIZED_FOR_UPDATE.get(parsed_write_uri.drivername, real_for_update)) db_concat_func.initialize(SCHEME_SPECIALIZED_CONCAT.get(parsed_write_uri.drivername, @@ -471,9 +488,9 @@ class Visibility(BaseModel): class Repository(BaseModel): namespace_user = QuayUserField(null=True) - name = CharField() + name = FullIndexedCharField(match_function=db_match_func) visibility = ForeignKeyField(Visibility) - description = TextField(null=True) + description = FullIndexedTextField(match_function=db_match_func, null=True) badge_token = CharField(default=uuid_generator) class Meta: diff --git a/data/fields.py b/data/fields.py index f73bc0cf1..8228dd099 100644 --- a/data/fields.py +++ b/data/fields.py @@ -2,7 +2,8 @@ import base64 import resumablehashlib import json -from peewee import TextField +from peewee import TextField, CharField, Clause +from data.text import prefix_search class _ResumableSHAField(TextField): @@ -64,3 +65,44 @@ class Base64BinaryField(TextField): if value is None: return None return base64.b64decode(value) + + +def _add_fulltext(field_class): + """ Adds support for full text indexing and lookup to the given field class. """ + class indexed_class(field_class): + # Marker used by SQLAlchemy translation layer to add the proper index for full text searching. + __fulltext__ = True + + def __init__(self, match_function, *args, **kwargs): + field_class.__init__(self, *args, **kwargs) + self.match_function = match_function + + def match(self, query): + return self.match_function(self, query) + + def match_prefix(self, query): + return prefix_search(self, query) + + def __mod__(self, _): + raise Exception('Unsafe operation: Use `match` or `match_prefix`') + + def __pow__(self, _): + raise Exception('Unsafe operation: Use `match` or `match_prefix`') + + def __contains__(self, _): + raise Exception('Unsafe operation: Use `match` or `match_prefix`') + + def contains(self, _): + raise Exception('Unsafe operation: Use `match` or `match_prefix`') + + def startswith(self, _): + raise Exception('Unsafe operation: Use `match` or `match_prefix`') + + def endswith(self, _): + raise Exception('Unsafe operation: Use `match` or `match_prefix`') + + return indexed_class + + +FullIndexedCharField = _add_fulltext(CharField) +FullIndexedTextField = _add_fulltext(TextField) diff --git a/data/migrations/versions/e2894a3a3c19_add_full_text_search_indexing_for_repo_.py b/data/migrations/versions/e2894a3a3c19_add_full_text_search_indexing_for_repo_.py new file mode 100644 index 000000000..9395fd926 --- /dev/null +++ b/data/migrations/versions/e2894a3a3c19_add_full_text_search_indexing_for_repo_.py @@ -0,0 +1,31 @@ +"""Add full text search indexing for repo name and description + +Revision ID: e2894a3a3c19 +Revises: d42c175b439a +Create Date: 2017-01-11 13:55:54.890774 + +""" + +# revision identifiers, used by Alembic. +revision = 'e2894a3a3c19' +down_revision = 'd42c175b439a' + +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import mysql + +def upgrade(tables): + if op.get_bind().engine.name == 'postgresql': + op.execute('CREATE EXTENSION IF NOT EXISTS pg_trgm') + + # ### commands auto generated by Alembic - please adjust! ### + op.create_index('repository_description__fulltext', 'repository', ['description'], unique=False, postgresql_using='gin', postgresql_ops={'description': 'gin_trgm_ops'}, mysql_prefix='FULLTEXT') + op.create_index('repository_name__fulltext', 'repository', ['name'], unique=False, postgresql_using='gin', postgresql_ops={'name': 'gin_trgm_ops'}, mysql_prefix='FULLTEXT') + # ### end Alembic commands ### + + +def downgrade(tables): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_index('repository_name__fulltext', table_name='repository') + op.drop_index('repository_description__fulltext', table_name='repository') + # ### end Alembic commands ### diff --git a/data/model/_basequery.py b/data/model/_basequery.py index bc5ab9748..59de01a3b 100644 --- a/data/model/_basequery.py +++ b/data/model/_basequery.py @@ -1,4 +1,4 @@ -from peewee import Clause, SQL, fn +from peewee import fn from cachetools import lru_cache from data.model import DataModelException @@ -6,18 +6,6 @@ from data.database import (Repository, User, Team, TeamMember, RepositoryPermiss Namespace, Visibility, ImageStorage, Image, db_for_update) -def prefix_search(field, prefix_query): - """ Returns the wildcard match for searching for the given prefix query. """ - # Escape the known wildcard characters. - prefix_query = (prefix_query - .replace('!', '!!') - .replace('%', '!%') - .replace('_', '!_') - .replace('[', '![')) - - return field ** Clause(prefix_query + '%', SQL("ESCAPE '!'")) - - def get_existing_repository(namespace_name, repository_name, for_update=False): query = (Repository .select(Repository, Namespace) diff --git a/data/model/label.py b/data/model/label.py index 166acedec..467eca86f 100644 --- a/data/model/label.py +++ b/data/model/label.py @@ -4,7 +4,7 @@ from cachetools import lru_cache from data.database import Label, TagManifestLabel, MediaType, LabelSourceType, db_transaction from data.model import InvalidLabelKeyException, InvalidMediaTypeException, DataModelException -from data.model._basequery import prefix_search +from data.text import prefix_search from util.validation import validate_label_key from util.validation import is_json diff --git a/data/model/repository.py b/data/model/repository.py index 11bcafcf1..910ae358a 100644 --- a/data/model/repository.py +++ b/data/model/repository.py @@ -12,6 +12,7 @@ from data.database import (Repository, Namespace, RepositoryTag, Star, Image, Im Role, RepositoryAuthorizedEmail, TagManifest, DerivedStorageForImage, Label, TagManifestLabel, db_for_update, get_epoch_timestamp, db_random_func, db_concat_func) +from data.text import prefix_search logger = logging.getLogger(__name__) @@ -318,8 +319,8 @@ def get_visible_repositories(username, namespace=None, include_public=False, sta return query -def get_sorted_matching_repositories(prefix, only_public, checker, limit=10): - """ Returns repositories matching the given prefix string and passing the given checker +def get_sorted_matching_repositories(lookup_value, only_public, checker, limit=10): + """ Returns repositories matching the given lookup string and passing the given checker function. """ last_week = datetime.now() - timedelta(weeks=1) @@ -370,14 +371,16 @@ def get_sorted_matching_repositories(prefix, only_public, checker, limit=10): results.append(result) existing_ids.append(result.id) - # For performance reasons, we conduct the repo name and repo namespace searches on their - # own. This also affords us the ability to give higher precedence to repository names matching - # over namespaces, which is semantically correct. - get_search_results(_basequery.prefix_search(Repository.name, prefix), with_count=True) - get_search_results(_basequery.prefix_search(Repository.name, prefix), with_count=False) + # For performance reasons, we conduct each set of searches on their own. This also affords us the + # ability to easily define an order precedence. + get_search_results(Repository.name.match(lookup_value), with_count=True) + get_search_results(Repository.name.match(lookup_value), with_count=False) - get_search_results(_basequery.prefix_search(Namespace.username, prefix), with_count=True) - get_search_results(_basequery.prefix_search(Namespace.username, prefix), with_count=False) + get_search_results(Repository.description.match(lookup_value), with_count=True) + get_search_results(Repository.description.match(lookup_value), with_count=False) + + get_search_results(prefix_search(Namespace.username, lookup_value), with_count=True) + get_search_results(prefix_search(Namespace.username, lookup_value), with_count=False) return results diff --git a/data/model/sqlalchemybridge.py b/data/model/sqlalchemybridge.py index df622e381..9add46ba9 100644 --- a/data/model/sqlalchemybridge.py +++ b/data/model/sqlalchemybridge.py @@ -1,5 +1,5 @@ from sqlalchemy import (Table, MetaData, Column, ForeignKey, Integer, String, Boolean, Text, - DateTime, Date, BigInteger, Index) + DateTime, Date, BigInteger, Index, text) from peewee import (PrimaryKeyField, CharField, BooleanField, DateTimeField, TextField, ForeignKeyField, BigIntegerField, IntegerField, DateField) @@ -28,6 +28,7 @@ def gen_sqlalchemy_metadata(peewee_model_list): meta = model._meta all_indexes = set(meta.indexes) + fulltext_indexes = [] columns = [] for field in meta.sorted_fields: @@ -60,6 +61,10 @@ def gen_sqlalchemy_metadata(peewee_model_list): else: raise RuntimeError('Unknown column type: %s' % field) + if hasattr(field, '__fulltext__'): + # Add the fulltext index for the field, based on whether we are under MySQL or Postgres. + fulltext_indexes.append(field.name) + for option_name in OPTIONS_TO_COPY: alchemy_option_name = (OPTION_TRANSLATIONS[option_name] if option_name in OPTION_TRANSLATIONS else option_name) @@ -81,4 +86,11 @@ def gen_sqlalchemy_metadata(peewee_model_list): col_refs = [getattr(new_table.c, col_name) for col_name in col_names] Index(index_name, *col_refs, unique=unique) + for col_field_name in fulltext_indexes: + index_name = '%s_%s__fulltext' % (meta.db_table, col_field_name) + col_ref = getattr(new_table.c, col_field_name) + Index(index_name, col_ref, postgresql_ops={col_field_name: 'gin_trgm_ops'}, + postgresql_using='gin', + mysql_prefix='FULLTEXT') + return metadata diff --git a/data/model/team.py b/data/model/team.py index 64d877c5b..738a27407 100644 --- a/data/model/team.py +++ b/data/model/team.py @@ -1,6 +1,7 @@ from data.database import Team, TeamMember, TeamRole, User, TeamMemberInvite, RepositoryPermission from data.model import (DataModelException, InvalidTeamException, UserAlreadyInTeam, InvalidTeamMemberException, user, _basequery) +from data.text import prefix_search from util.validation import validate_username from peewee import fn, JOIN_LEFT_OUTER from util.morecollections import AttrDict @@ -137,7 +138,7 @@ def add_or_invite_to_team(inviter, team, user_obj=None, email=None, requires_inv def get_matching_user_teams(team_prefix, user_obj, limit=10): - team_prefix_search = _basequery.prefix_search(Team.name, team_prefix) + team_prefix_search = prefix_search(Team.name, team_prefix) query = (Team .select() .join(User) @@ -163,7 +164,7 @@ def get_organization_team(orgname, teamname): def get_matching_admined_teams(team_prefix, user_obj, limit=10): - team_prefix_search = _basequery.prefix_search(Team.name, team_prefix) + team_prefix_search = prefix_search(Team.name, team_prefix) admined_orgs = (_basequery.get_user_organizations(user_obj.username) .switch(Team) .join(TeamRole) @@ -182,7 +183,7 @@ def get_matching_admined_teams(team_prefix, user_obj, limit=10): def get_matching_teams(team_prefix, organization): - team_prefix_search = _basequery.prefix_search(Team.name, team_prefix) + team_prefix_search = prefix_search(Team.name, team_prefix) query = Team.select().where(team_prefix_search, Team.organization == organization) return query.limit(10) diff --git a/data/model/user.py b/data/model/user.py index ddb27400d..5672dc6ac 100644 --- a/data/model/user.py +++ b/data/model/user.py @@ -18,6 +18,7 @@ from data.model import (DataModelException, InvalidPasswordException, InvalidRob InvalidUsernameException, InvalidEmailAddressException, TooManyLoginAttemptsException, db_transaction, notification, config, repository, _basequery) +from data.text import prefix_search from util.names import format_robot_username, parse_robot_username from util.validation import (validate_username, validate_email, validate_password, INVALID_PASSWORD_MESSAGE) @@ -259,10 +260,10 @@ def get_matching_robots(name_prefix, username, limit=10): prefix_checks = False for org in admined_orgs: - org_search = _basequery.prefix_search(User.username, org.username + '+' + name_prefix) + org_search = prefix_search(User.username, org.username + '+' + name_prefix) prefix_checks = prefix_checks | org_search - user_search = _basequery.prefix_search(User.username, username + '+' + name_prefix) + user_search = prefix_search(User.username, username + '+' + name_prefix) prefix_checks = prefix_checks | user_search return User.select().where(prefix_checks).limit(limit) @@ -562,7 +563,7 @@ def get_user_or_org_by_customer_id(customer_id): def get_matching_user_namespaces(namespace_prefix, username, limit=10): - namespace_search = _basequery.prefix_search(Namespace.username, namespace_prefix) + namespace_search = prefix_search(Namespace.username, namespace_prefix) base_query = (Namespace .select() .distinct() @@ -573,12 +574,12 @@ def get_matching_user_namespaces(namespace_prefix, username, limit=10): return _basequery.filter_to_repos_for_user(base_query, username).limit(limit) def get_matching_users(username_prefix, robot_namespace=None, organization=None, limit=20): - user_search = _basequery.prefix_search(User.username, username_prefix) + user_search = prefix_search(User.username, username_prefix) direct_user_query = (user_search & (User.organization == False) & (User.robot == False)) if robot_namespace: robot_prefix = format_robot_username(robot_namespace, username_prefix) - robot_search = _basequery.prefix_search(User.username, robot_prefix) + robot_search = prefix_search(User.username, robot_prefix) direct_user_query = ((robot_search & (User.robot == True)) | direct_user_query) query = (User diff --git a/data/text.py b/data/text.py new file mode 100644 index 000000000..870dd935a --- /dev/null +++ b/data/text.py @@ -0,0 +1,38 @@ +from peewee import Clause, SQL, fn, TextField, Field + +def _escape_wildcard(search_query): + """ Escapes the wildcards found in the given search query so that they are treated as *characters* + rather than wildcards when passed to a LIKE or ILIKE clause with an ESCAPE '!'. + """ + search_query = (search_query + .replace('!', '!!') + .replace('%', '!%') + .replace('_', '!_') + .replace('[', '![')) + return search_query + + +def prefix_search(field, prefix_query): + """ Returns the wildcard match for searching for the given prefix query. """ + # Escape the known wildcard characters. + prefix_query = _escape_wildcard(prefix_query) + return Field.__pow__(field, Clause(prefix_query + '%', SQL("ESCAPE '!'"))) + + +def match_mysql(field, search_query): + """ Generates a full-text match query using a Match operation, which is needed for MySQL. + """ + if field.name.find('`') >= 0: # Just to be safe. + raise Exception("How did field name '%s' end up containing a backtick?" % field.name) + + return Clause(fn.MATCH(SQL("`%s`" % field.name)), fn.AGAINST(SQL('%s', search_query)), + parens=True) + + +def match_like(field, search_query): + """ Generates a full-text match query using an ILIKE operation, which is needed for SQLite and + Postgres. + """ + escaped_query = _escape_wildcard(search_query) + clause = Clause('%' + escaped_query + '%', SQL("ESCAPE '!'")) + return Field.__pow__(field, clause) diff --git a/initdb.py b/initdb.py index 0e92aec80..2c1615cf7 100644 --- a/initdb.py +++ b/initdb.py @@ -568,6 +568,11 @@ def populate_database(minimal=False, with_storage=False): [(new_user_2, 'write'), (reader, 'read')], (5, [], 'latest')) + __generate_repository(with_storage, new_user_1, 'text-full-repo', + 'This is a repository for testing text search', False, + [(new_user_2, 'write'), (reader, 'read')], + (5, [], 'latest')) + building = __generate_repository(with_storage, new_user_1, 'building', 'Empty repository which is building.', False, [], (0, [], None)) diff --git a/requirements-nover.txt b/requirements-nover.txt index 80a9efcb0..1d28e929a 100644 --- a/requirements-nover.txt +++ b/requirements-nover.txt @@ -60,7 +60,7 @@ redis redlock reportlab==2.7 semantic-version -sqlalchemy +sqlalchemy==1.1.5 stringscore stripe toposort diff --git a/requirements.txt b/requirements.txt index 3ff9ac3fc..0076672fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -109,7 +109,7 @@ requests-oauthlib==0.7.0 rfc3986==0.4.1 semantic-version==2.6.0 six==1.10.0 -SQLAlchemy==1.1.2 +SQLAlchemy==1.1.5 stevedore==1.17.1 stringscore==0.1.0 stripe==1.41.0 diff --git a/test/test_api_usage.py b/test/test_api_usage.py index 35c322ee3..a3cdf7eb0 100644 --- a/test/test_api_usage.py +++ b/test/test_api_usage.py @@ -1001,6 +1001,23 @@ class TestConductSearch(ApiTestCase): self.assertEquals(json['results'][0]['name'], 'shared') + def test_full_text(self): + self.login(ADMIN_ACCESS_USER) + + # Make sure the repository is found via `full` and `text search`. + json = self.getJsonResponse(ConductSearch, + params=dict(query='full')) + self.assertEquals(1, len(json['results'])) + self.assertEquals(json['results'][0]['kind'], 'repository') + self.assertEquals(json['results'][0]['name'], 'text-full-repo') + + json = self.getJsonResponse(ConductSearch, + params=dict(query='text search')) + self.assertEquals(1, len(json['results'])) + self.assertEquals(json['results'][0]['kind'], 'repository') + self.assertEquals(json['results'][0]['name'], 'text-full-repo') + + class TestGetMatchingEntities(ApiTestCase): def test_simple_lookup(self): self.login(ADMIN_ACCESS_USER)