server: tests: fix re content

This commit is contained in:
Pierrick HYMBERT 2024-03-02 19:30:19 +01:00
parent 830d0efbd2
commit 1aa5ad9150

View file

@ -812,13 +812,21 @@ def assert_n_tokens_predicted(completion_response, expected_predicted_n=None, re
n_predicted = completion_response['timings']['predicted_n']
assert len(content) > 0, "no token predicted"
if re_content is not None:
re_content = f'^(.*)({re_content})(.*)$'
p = re.compile(re_content, flags=RegexFlag.IGNORECASE | RegexFlag.MULTILINE | RegexFlag.DOTALL)
match = p.match(content)
assert match and len(match.groups()) >= 3, f'/{re_content}/g must match ```{content}```'
matches = p.finditer(content)
last_match = 0
highlighted = ''
for match in matches:
start, end = match.span()
highlighted += content[last_match: start]
highlighted += '\x1b[33m'
highlighted += content[start: end]
highlighted += '\x1b[0m'
last_match = end
highlighted += content[last_match:]
if 'DEBUG' in os.environ and os.environ['DEBUG'] == 'ON':
highlighted = p.sub(r"\1<hi>\2</hi>\3", content).replace('<hi>', '\x1b[33m').replace('</hi>', '\x1b[0m')
print(f"Checking completion response: {highlighted}\n")
assert last_match > 0, f'/{re_content}/ must match ```{highlighted}```'
if expected_predicted_n and expected_predicted_n > 0:
assert n_predicted == expected_predicted_n, (f'invalid number of tokens predicted:'
f' {n_predicted} <> {expected_predicted_n}')