import json import uuid import fnmatch from collections import defaultdict from contextlib import contextmanager from datetime import datetime import dateutil.parser from httmock import urlmatch, HTTMock FAKE_ES_HOST = 'fakees' EMPTY_RESULT = { 'hits': {'hits': [], 'total': 0}, '_shards': {'successful': 1, 'total': 1}, } def parse_query(query): if not query: return {} return {s.split('=')[0]: s.split('=')[1] for s in query.split("&")} @contextmanager def fake_elasticsearch(allow_wildcard=True): templates = {} docs = defaultdict(list) scrolls = {} id_counter = [1] def transform(value, field_name): # TODO: implement this using a real index template if we ever need more than a few # fields here. if field_name == 'datetime': if isinstance(value, int): return datetime.utcfromtimestamp(value / 1000) parsed = dateutil.parser.parse(value) return parsed return value @urlmatch(netloc=FAKE_ES_HOST, path=r'/_template/(.+)', method='GET') def get_template(url, request): template_name = url[len('/_template/'):] if template_name in templates: return {'status_code': 200} return {'status_code': 404} @urlmatch(netloc=FAKE_ES_HOST, path=r'/_template/(.+)', method='PUT') def put_template(url, request): template_name = url[len('/_template/'):] templates[template_name] = True return {'status_code': 201} @urlmatch(netloc=FAKE_ES_HOST, path=r'/([^/]+)/_doc', method='POST') def post_doc(url, request): index_name, _ = url.path[1:].split('/') item = json.loads(request.body) item['_id'] = item['random_id'] id_counter[0] += 1 docs[index_name].append(item) return { 'status_code': 204, 'headers': { 'Content-Type': 'application/json', }, 'content': json.dumps({ "result": "created", }), } @urlmatch(netloc=FAKE_ES_HOST, path=r'/([^/]+)$', method='DELETE') def index_delete(url, request): index_name_or_pattern = url.path[1:] to_delete = [] for index_name in docs.keys(): if not fnmatch.fnmatch(index_name, index_name_or_pattern): continue to_delete.append(index_name) for index in to_delete: docs.pop(index) return { 'status_code': 200, 'headers': { 'Content-Type': 'application/json', }, 'content': {'acknowledged': True} } @urlmatch(netloc=FAKE_ES_HOST, path=r'/([^/]+)$', method='GET') def index_lookup(url, request): index_name_or_pattern = url.path[1:] found = {} for index_name in docs.keys(): if not fnmatch.fnmatch(index_name, index_name_or_pattern): continue found[index_name] = {} if not found: return { 'status_code': 404, } return { 'status_code': 200, 'headers': { 'Content-Type': 'application/json', }, 'content': json.dumps(found), } def _match_query(index_name_or_pattern, query): found = [] found_index = False for index_name in docs.keys(): if not allow_wildcard and index_name_or_pattern.find('*') >= 0: break if not fnmatch.fnmatch(index_name, index_name_or_pattern): continue found_index = True def _is_match(doc, current_query): if current_query is None: return True for filter_type, filter_params in current_query.iteritems(): for field_name, filter_props in filter_params.iteritems(): if filter_type == 'range': lt = transform(filter_props['lt'], field_name) gte = transform(filter_props['gte'], field_name) doc_value = transform(doc[field_name], field_name) if not (doc_value < lt and doc_value >= gte): return False elif filter_type == 'term': doc_value = transform(doc[field_name], field_name) return doc_value == filter_props elif filter_type == 'terms': doc_value = transform(doc[field_name], field_name) return doc_value in filter_props elif filter_type == 'bool': assert not 'should' in filter_params, 'should is unsupported' must = filter_params.get('must') must_not = filter_params.get('must_not') filter_bool = filter_params.get('filter') if must: for check in must: if not _is_match(doc, check): return False if must_not: for check in must_not: if _is_match(doc, check): return False if filter_bool: for check in filter_bool: if not _is_match(doc, check): return False else: raise Exception('Unimplemented query %s: %s' % (filter_type, query)) return True for doc in docs[index_name]: if not _is_match(doc, query): continue found.append({'_source': doc, '_index': index_name}) return found, found_index or (index_name_or_pattern.find('*') >= 0) @urlmatch(netloc=FAKE_ES_HOST, path=r'/([^/]+)/_count$', method='GET') def count_docs(url, request): request = json.loads(request.body) index_name_or_pattern, _ = url.path[1:].split('/') found, found_index = _match_query(index_name_or_pattern, request['query']) if not found_index: return { 'status_code': 404, } return { 'status_code': 200, 'headers': { 'Content-Type': 'application/json', }, 'content': json.dumps({'count': len(found)}), } @urlmatch(netloc=FAKE_ES_HOST, path=r'/_search/scroll$', method='GET') def lookup_scroll(url, request): request_obj = json.loads(request.body) scroll_id = request_obj['scroll_id'] if scroll_id in scrolls: return { 'status_code': 200, 'headers': { 'Content-Type': 'application/json', }, 'content': json.dumps(scrolls[scroll_id]), } return { 'status_code': 404, } @urlmatch(netloc=FAKE_ES_HOST, path=r'/_search/scroll$', method='DELETE') def delete_scroll(url, request): request = json.loads(request.body) for scroll_id in request['scroll_id']: scrolls.pop(scroll_id, None) return { 'status_code': 404, } @urlmatch(netloc=FAKE_ES_HOST, path=r'/([^/]+)/_search$', method='GET') def lookup_docs(url, request): query_params = parse_query(url.query) request = json.loads(request.body) index_name_or_pattern, _ = url.path[1:].split('/') # Find matching docs. query = request.get('query') found, found_index = _match_query(index_name_or_pattern, query) if not found_index: return { 'status_code': 404, } # Sort. sort = request.get('sort') if sort: if sort == ['_doc'] or sort == '_doc': found.sort(key=lambda x: x['_source']['_id']) else: def get_sort_key(item): source = item['_source'] key = '' for sort_config in sort: for sort_key, direction in sort_config.iteritems(): assert direction == 'desc' sort_key = sort_key.replace('.keyword', '') key += str(transform(source[sort_key], sort_key)) key += '|' return key found.sort(key=get_sort_key, reverse=True) # Search after. search_after = request.get('search_after') if search_after: sort_fields = [] for sort_config in sort: if isinstance(sort_config, unicode): sort_fields.append(sort_config) continue for sort_key, _ in sort_config.iteritems(): sort_key = sort_key.replace('.keyword', '') sort_fields.append(sort_key) for index, search_after_value in enumerate(search_after): field_name = sort_fields[index] value = transform(search_after_value, field_name) if field_name == '_doc': found = [f for f in found if transform(f['_source']['_id'], field_name) > value] else: found = [f for f in found if transform(f['_source'][field_name], field_name) < value] if len(found) < 2: break if field_name == '_doc': if found[0]['_source']['_id'] != found[1]['_source']: break else: if found[0]['_source'][field_name] != found[1]['_source']: break # Size. size = request.get('size') if size: found = found[0:size] # Aggregation. # {u'query': # {u'range': # {u'datetime': {u'lt': u'2019-06-27T15:45:09.768085', # u'gte': u'2019-06-27T15:35:09.768085'}}}, # u'aggs': { # u'by_id': { # u'terms': {u'field': u'kind_id'}, # u'aggs': { # u'by_date': {u'date_histogram': {u'field': u'datetime', u'interval': u'day'}}}}}, # u'size': 0} def _by_field(agg_field_params, results): aggregated_by_field = defaultdict(list) for agg_means, agg_means_params in agg_field_params.iteritems(): if agg_means == 'terms': field_name = agg_means_params['field'] for result in results: value = result['_source'][field_name] aggregated_by_field[value].append(result) elif agg_means == 'date_histogram': field_name = agg_means_params['field'] interval = agg_means_params['interval'] for result in results: value = transform(result['_source'][field_name], field_name) aggregated_by_field[getattr(value, interval)].append(result) elif agg_means == 'aggs': # Skip. Handled below. continue else: raise Exception('Unsupported aggregation method: %s' % agg_means) # Invoke the aggregation recursively. buckets = [] for field_value, field_results in aggregated_by_field.iteritems(): aggregated = _aggregate(agg_field_params, field_results) if isinstance(aggregated, list): aggregated = {'doc_count': len(aggregated)} aggregated['key'] = field_value buckets.append(aggregated) return {'buckets': buckets} def _aggregate(query_config, results): agg_params = query_config.get(u'aggs') if not agg_params: return results by_field_name = {} for agg_field_name, agg_field_params in agg_params.iteritems(): by_field_name[agg_field_name] = _by_field(agg_field_params, results) return by_field_name final_result = { 'hits': { 'hits': found, 'total': len(found), }, '_shards': { 'successful': 1, 'total': 1, }, 'aggregations': _aggregate(request, found), } if query_params.get('scroll'): scroll_id = str(uuid.uuid4()) scrolls[scroll_id] = EMPTY_RESULT final_result['_scroll_id'] = scroll_id return { 'status_code': 200, 'headers': { 'Content-Type': 'application/json', }, 'content': json.dumps(final_result), } @urlmatch(netloc=FAKE_ES_HOST) def catchall_handler(url, request): print "Unsupported URL: %s %s" % (request.method, url, ) return {'status_code': 501} handlers = [get_template, put_template, index_delete, index_lookup, post_doc, count_docs, lookup_docs, lookup_scroll, delete_scroll, catchall_handler] with HTTMock(*handlers): yield