diff --git a/data/database.py b/data/database.py index 271a25c32..21692fb92 100644 --- a/data/database.py +++ b/data/database.py @@ -22,6 +22,7 @@ from sqlalchemy.engine.url import make_url import resumablehashlib from data.fields import ResumableSHA256Field, ResumableSHA1Field, JSONField, Base64BinaryField +from data.text import match_mysql, match_like from data.read_slave import ReadSlaveModel from util.names import urn_generator @@ -42,6 +43,14 @@ _SCHEME_DRIVERS = { 'postgresql+psycopg2': PostgresqlDatabase, } +SCHEME_MATCH_FUNCTION = { + 'mysql': match_mysql, + 'mysql+pymysql': match_mysql, + 'sqlite': match_like, + 'postgresql': match_like, + 'postgresql+psycopg2': match_like, +} + SCHEME_RANDOM_FUNCTION = { 'mysql': fn.Rand, 'mysql+pymysql': fn.Rand, @@ -211,6 +220,7 @@ class TupleSelector(object): db = Proxy() read_slave = Proxy() db_random_func = CallableProxy() +db_match_func = CallableProxy() db_for_update = CallableProxy() db_transaction = CallableProxy() db_concat_func = CallableProxy() @@ -257,6 +267,7 @@ def configure(config_object): parsed_write_uri = make_url(write_db_uri) db_random_func.initialize(SCHEME_RANDOM_FUNCTION[parsed_write_uri.drivername]) + db_match_func.initialize(SCHEME_MATCH_FUNCTION[parsed_write_uri.drivername]) db_for_update.initialize(SCHEME_SPECIALIZED_FOR_UPDATE.get(parsed_write_uri.drivername, real_for_update)) db_concat_func.initialize(SCHEME_SPECIALIZED_CONCAT.get(parsed_write_uri.drivername, diff --git a/data/fields.py b/data/fields.py index f73bc0cf1..8228dd099 100644 --- a/data/fields.py +++ b/data/fields.py @@ -2,7 +2,8 @@ import base64 import resumablehashlib import json -from peewee import TextField +from peewee import TextField, CharField, Clause +from data.text import prefix_search class _ResumableSHAField(TextField): @@ -64,3 +65,44 @@ class Base64BinaryField(TextField): if value is None: return None return base64.b64decode(value) + + +def _add_fulltext(field_class): + """ Adds support for full text indexing and lookup to the given field class. """ + class indexed_class(field_class): + # Marker used by SQLAlchemy translation layer to add the proper index for full text searching. + __fulltext__ = True + + def __init__(self, match_function, *args, **kwargs): + field_class.__init__(self, *args, **kwargs) + self.match_function = match_function + + def match(self, query): + return self.match_function(self, query) + + def match_prefix(self, query): + return prefix_search(self, query) + + def __mod__(self, _): + raise Exception('Unsafe operation: Use `match` or `match_prefix`') + + def __pow__(self, _): + raise Exception('Unsafe operation: Use `match` or `match_prefix`') + + def __contains__(self, _): + raise Exception('Unsafe operation: Use `match` or `match_prefix`') + + def contains(self, _): + raise Exception('Unsafe operation: Use `match` or `match_prefix`') + + def startswith(self, _): + raise Exception('Unsafe operation: Use `match` or `match_prefix`') + + def endswith(self, _): + raise Exception('Unsafe operation: Use `match` or `match_prefix`') + + return indexed_class + + +FullIndexedCharField = _add_fulltext(CharField) +FullIndexedTextField = _add_fulltext(TextField) diff --git a/data/model/sqlalchemybridge.py b/data/model/sqlalchemybridge.py index df622e381..9add46ba9 100644 --- a/data/model/sqlalchemybridge.py +++ b/data/model/sqlalchemybridge.py @@ -1,5 +1,5 @@ from sqlalchemy import (Table, MetaData, Column, ForeignKey, Integer, String, Boolean, Text, - DateTime, Date, BigInteger, Index) + DateTime, Date, BigInteger, Index, text) from peewee import (PrimaryKeyField, CharField, BooleanField, DateTimeField, TextField, ForeignKeyField, BigIntegerField, IntegerField, DateField) @@ -28,6 +28,7 @@ def gen_sqlalchemy_metadata(peewee_model_list): meta = model._meta all_indexes = set(meta.indexes) + fulltext_indexes = [] columns = [] for field in meta.sorted_fields: @@ -60,6 +61,10 @@ def gen_sqlalchemy_metadata(peewee_model_list): else: raise RuntimeError('Unknown column type: %s' % field) + if hasattr(field, '__fulltext__'): + # Add the fulltext index for the field, based on whether we are under MySQL or Postgres. + fulltext_indexes.append(field.name) + for option_name in OPTIONS_TO_COPY: alchemy_option_name = (OPTION_TRANSLATIONS[option_name] if option_name in OPTION_TRANSLATIONS else option_name) @@ -81,4 +86,11 @@ def gen_sqlalchemy_metadata(peewee_model_list): col_refs = [getattr(new_table.c, col_name) for col_name in col_names] Index(index_name, *col_refs, unique=unique) + for col_field_name in fulltext_indexes: + index_name = '%s_%s__fulltext' % (meta.db_table, col_field_name) + col_ref = getattr(new_table.c, col_field_name) + Index(index_name, col_ref, postgresql_ops={col_field_name: 'gin_trgm_ops'}, + postgresql_using='gin', + mysql_prefix='FULLTEXT') + return metadata diff --git a/data/text.py b/data/text.py new file mode 100644 index 000000000..870dd935a --- /dev/null +++ b/data/text.py @@ -0,0 +1,38 @@ +from peewee import Clause, SQL, fn, TextField, Field + +def _escape_wildcard(search_query): + """ Escapes the wildcards found in the given search query so that they are treated as *characters* + rather than wildcards when passed to a LIKE or ILIKE clause with an ESCAPE '!'. + """ + search_query = (search_query + .replace('!', '!!') + .replace('%', '!%') + .replace('_', '!_') + .replace('[', '![')) + return search_query + + +def prefix_search(field, prefix_query): + """ Returns the wildcard match for searching for the given prefix query. """ + # Escape the known wildcard characters. + prefix_query = _escape_wildcard(prefix_query) + return Field.__pow__(field, Clause(prefix_query + '%', SQL("ESCAPE '!'"))) + + +def match_mysql(field, search_query): + """ Generates a full-text match query using a Match operation, which is needed for MySQL. + """ + if field.name.find('`') >= 0: # Just to be safe. + raise Exception("How did field name '%s' end up containing a backtick?" % field.name) + + return Clause(fn.MATCH(SQL("`%s`" % field.name)), fn.AGAINST(SQL('%s', search_query)), + parens=True) + + +def match_like(field, search_query): + """ Generates a full-text match query using an ILIKE operation, which is needed for SQLite and + Postgres. + """ + escaped_query = _escape_wildcard(search_query) + clause = Clause('%' + escaped_query + '%', SQL("ESCAPE '!'")) + return Field.__pow__(field, clause)