This repository has been archived on 2020-03-24. You can view files and clone it, but cannot push or open issues or pull requests.
quay/data/logs_model/test/fake_elasticsearch.py

391 lines
11 KiB
Python
Raw Normal View History

2019-11-12 16:09:47 +00:00
import json
import uuid
import fnmatch
from collections import defaultdict
from contextlib import contextmanager
from datetime import datetime
import dateutil.parser
from httmock import urlmatch, HTTMock
FAKE_ES_HOST = 'fakees'
EMPTY_RESULT = {
'hits': {'hits': [], 'total': 0},
'_shards': {'successful': 1, 'total': 1},
}
def parse_query(query):
if not query:
return {}
return {s.split('=')[0]: s.split('=')[1] for s in query.split("&")}
@contextmanager
def fake_elasticsearch(allow_wildcard=True):
templates = {}
docs = defaultdict(list)
scrolls = {}
id_counter = [1]
def transform(value, field_name):
# TODO: implement this using a real index template if we ever need more than a few
# fields here.
if field_name == 'datetime':
if isinstance(value, int):
return datetime.utcfromtimestamp(value / 1000)
parsed = dateutil.parser.parse(value)
return parsed
return value
@urlmatch(netloc=FAKE_ES_HOST, path=r'/_template/(.+)', method='GET')
def get_template(url, request):
template_name = url[len('/_template/'):]
if template_name in templates:
return {'status_code': 200}
return {'status_code': 404}
@urlmatch(netloc=FAKE_ES_HOST, path=r'/_template/(.+)', method='PUT')
def put_template(url, request):
template_name = url[len('/_template/'):]
templates[template_name] = True
return {'status_code': 201}
@urlmatch(netloc=FAKE_ES_HOST, path=r'/([^/]+)/_doc', method='POST')
def post_doc(url, request):
index_name, _ = url.path[1:].split('/')
item = json.loads(request.body)
item['_id'] = item['random_id']
id_counter[0] += 1
docs[index_name].append(item)
return {
'status_code': 204,
'headers': {
'Content-Type': 'application/json',
},
'content': json.dumps({
"result": "created",
}),
}
@urlmatch(netloc=FAKE_ES_HOST, path=r'/([^/]+)$', method='DELETE')
def index_delete(url, request):
index_name_or_pattern = url.path[1:]
to_delete = []
for index_name in docs.keys():
if not fnmatch.fnmatch(index_name, index_name_or_pattern):
continue
to_delete.append(index_name)
for index in to_delete:
docs.pop(index)
return {
'status_code': 200,
'headers': {
'Content-Type': 'application/json',
},
'content': {'acknowledged': True}
}
@urlmatch(netloc=FAKE_ES_HOST, path=r'/([^/]+)$', method='GET')
def index_lookup(url, request):
index_name_or_pattern = url.path[1:]
found = {}
for index_name in docs.keys():
if not fnmatch.fnmatch(index_name, index_name_or_pattern):
continue
found[index_name] = {}
if not found:
return {
'status_code': 404,
}
return {
'status_code': 200,
'headers': {
'Content-Type': 'application/json',
},
'content': json.dumps(found),
}
def _match_query(index_name_or_pattern, query):
found = []
found_index = False
for index_name in docs.keys():
if not allow_wildcard and index_name_or_pattern.find('*') >= 0:
break
if not fnmatch.fnmatch(index_name, index_name_or_pattern):
continue
found_index = True
def _is_match(doc, current_query):
if current_query is None:
return True
for filter_type, filter_params in current_query.iteritems():
for field_name, filter_props in filter_params.iteritems():
if filter_type == 'range':
lt = transform(filter_props['lt'], field_name)
gte = transform(filter_props['gte'], field_name)
doc_value = transform(doc[field_name], field_name)
if not (doc_value < lt and doc_value >= gte):
return False
elif filter_type == 'term':
doc_value = transform(doc[field_name], field_name)
return doc_value == filter_props
elif filter_type == 'terms':
doc_value = transform(doc[field_name], field_name)
return doc_value in filter_props
elif filter_type == 'bool':
assert not 'should' in filter_params, 'should is unsupported'
must = filter_params.get('must')
must_not = filter_params.get('must_not')
filter_bool = filter_params.get('filter')
if must:
for check in must:
if not _is_match(doc, check):
return False
if must_not:
for check in must_not:
if _is_match(doc, check):
return False
if filter_bool:
for check in filter_bool:
if not _is_match(doc, check):
return False
else:
raise Exception('Unimplemented query %s: %s' % (filter_type, query))
return True
for doc in docs[index_name]:
if not _is_match(doc, query):
continue
found.append({'_source': doc, '_index': index_name})
return found, found_index or (index_name_or_pattern.find('*') >= 0)
@urlmatch(netloc=FAKE_ES_HOST, path=r'/([^/]+)/_count$', method='GET')
def count_docs(url, request):
request = json.loads(request.body)
index_name_or_pattern, _ = url.path[1:].split('/')
found, found_index = _match_query(index_name_or_pattern, request['query'])
if not found_index:
return {
'status_code': 404,
}
return {
'status_code': 200,
'headers': {
'Content-Type': 'application/json',
},
'content': json.dumps({'count': len(found)}),
}
@urlmatch(netloc=FAKE_ES_HOST, path=r'/_search/scroll$', method='GET')
def lookup_scroll(url, request):
request_obj = json.loads(request.body)
scroll_id = request_obj['scroll_id']
if scroll_id in scrolls:
return {
'status_code': 200,
'headers': {
'Content-Type': 'application/json',
},
'content': json.dumps(scrolls[scroll_id]),
}
return {
'status_code': 404,
}
@urlmatch(netloc=FAKE_ES_HOST, path=r'/_search/scroll$', method='DELETE')
def delete_scroll(url, request):
request = json.loads(request.body)
for scroll_id in request['scroll_id']:
scrolls.pop(scroll_id, None)
return {
'status_code': 404,
}
@urlmatch(netloc=FAKE_ES_HOST, path=r'/([^/]+)/_search$', method='GET')
def lookup_docs(url, request):
query_params = parse_query(url.query)
request = json.loads(request.body)
index_name_or_pattern, _ = url.path[1:].split('/')
# Find matching docs.
query = request.get('query')
found, found_index = _match_query(index_name_or_pattern, query)
if not found_index:
return {
'status_code': 404,
}
# Sort.
sort = request.get('sort')
if sort:
if sort == ['_doc'] or sort == '_doc':
found.sort(key=lambda x: x['_source']['_id'])
else:
def get_sort_key(item):
source = item['_source']
key = ''
for sort_config in sort:
for sort_key, direction in sort_config.iteritems():
assert direction == 'desc'
sort_key = sort_key.replace('.keyword', '')
key += str(transform(source[sort_key], sort_key))
key += '|'
return key
found.sort(key=get_sort_key, reverse=True)
# Search after.
search_after = request.get('search_after')
if search_after:
sort_fields = []
for sort_config in sort:
if isinstance(sort_config, unicode):
sort_fields.append(sort_config)
continue
for sort_key, _ in sort_config.iteritems():
sort_key = sort_key.replace('.keyword', '')
sort_fields.append(sort_key)
for index, search_after_value in enumerate(search_after):
field_name = sort_fields[index]
value = transform(search_after_value, field_name)
if field_name == '_doc':
found = [f for f in found if transform(f['_source']['_id'], field_name) > value]
else:
found = [f for f in found if transform(f['_source'][field_name], field_name) < value]
if len(found) < 2:
break
if field_name == '_doc':
if found[0]['_source']['_id'] != found[1]['_source']:
break
else:
if found[0]['_source'][field_name] != found[1]['_source']:
break
# Size.
size = request.get('size')
if size:
found = found[0:size]
# Aggregation.
# {u'query':
# {u'range':
# {u'datetime': {u'lt': u'2019-06-27T15:45:09.768085',
# u'gte': u'2019-06-27T15:35:09.768085'}}},
# u'aggs': {
# u'by_id': {
# u'terms': {u'field': u'kind_id'},
# u'aggs': {
# u'by_date': {u'date_histogram': {u'field': u'datetime', u'interval': u'day'}}}}},
# u'size': 0}
def _by_field(agg_field_params, results):
aggregated_by_field = defaultdict(list)
for agg_means, agg_means_params in agg_field_params.iteritems():
if agg_means == 'terms':
field_name = agg_means_params['field']
for result in results:
value = result['_source'][field_name]
aggregated_by_field[value].append(result)
elif agg_means == 'date_histogram':
field_name = agg_means_params['field']
interval = agg_means_params['interval']
for result in results:
value = transform(result['_source'][field_name], field_name)
aggregated_by_field[getattr(value, interval)].append(result)
elif agg_means == 'aggs':
# Skip. Handled below.
continue
else:
raise Exception('Unsupported aggregation method: %s' % agg_means)
# Invoke the aggregation recursively.
buckets = []
for field_value, field_results in aggregated_by_field.iteritems():
aggregated = _aggregate(agg_field_params, field_results)
if isinstance(aggregated, list):
aggregated = {'doc_count': len(aggregated)}
aggregated['key'] = field_value
buckets.append(aggregated)
return {'buckets': buckets}
def _aggregate(query_config, results):
agg_params = query_config.get(u'aggs')
if not agg_params:
return results
by_field_name = {}
for agg_field_name, agg_field_params in agg_params.iteritems():
by_field_name[agg_field_name] = _by_field(agg_field_params, results)
return by_field_name
final_result = {
'hits': {
'hits': found,
'total': len(found),
},
'_shards': {
'successful': 1,
'total': 1,
},
'aggregations': _aggregate(request, found),
}
if query_params.get('scroll'):
scroll_id = str(uuid.uuid4())
scrolls[scroll_id] = EMPTY_RESULT
final_result['_scroll_id'] = scroll_id
return {
'status_code': 200,
'headers': {
'Content-Type': 'application/json',
},
'content': json.dumps(final_result),
}
@urlmatch(netloc=FAKE_ES_HOST)
def catchall_handler(url, request):
print "Unsupported URL: %s %s" % (request.method, url, )
return {'status_code': 501}
handlers = [get_template, put_template, index_delete, index_lookup, post_doc, count_docs,
lookup_docs, lookup_scroll, delete_scroll, catchall_handler]
with HTTMock(*handlers):
yield