This commit is contained in:
Paulo 2024-05-02 21:30:53 -03:00
parent 4a471b12d6
commit a772cde9dc
3 changed files with 91 additions and 2 deletions

View file

@ -47,6 +47,8 @@ def step_server_config(context, server_fqdn, server_port):
context.n_ga = None
context.n_ga_w = None
context.n_predict = None
context.n_keep = 0
context.n_truncate = 0
context.n_prompts = 0
context.n_server_predict = None
context.slot_save_path = None
@ -66,6 +68,7 @@ def step_server_config(context, server_fqdn, server_port):
context.user_api_key = None
context.response_format = None
context.stop_string = []
context.tasks_result = []
context.concurrent_tasks = []
context.prompts = []
@ -168,6 +171,7 @@ def step_start_server(context):
addrs = socket.getaddrinfo(context.server_fqdn, context.server_port, type=socket.SOCK_STREAM)
family, typ, proto, _, sockaddr = addrs[0]
print(sockaddr)
while True:
with closing(socket.socket(family, typ, proto)) as sock:
@ -231,17 +235,33 @@ async def step_all_slots_status(context, expected_slot_status_string):
@step('a completion request with {api_error} api error')
@async_run_until_complete
async def step_request_completion(context, api_error):
await make_completion_request(context, api_error)
@step('an ongoing completion request')
@async_run_until_complete
async def step_request_ongoing_completion(context):
await make_completion_request(context, '', True)
async def make_completion_request(context, api_error, ongoing=False):
expect_api_error = api_error == 'raised'
completion = await request_completion(context.prompts.pop(),
prompt = context.prompts[-1] if ongoing else context.prompts.pop()
completion = await request_completion(prompt,
context.base_url,
debug=context.debug,
n_predict=context.n_predict,
n_keep=context.n_keep,
n_truncate=context.n_truncate,
cache_prompt=context.cache_prompt,
stop_string=context.stop_string,
id_slot=context.id_slot,
seed=await completions_seed(context),
expect_api_error=expect_api_error,
user_api_key=context.user_api_key)
context.tasks_result.append(completion)
if ongoing and not expect_api_error:
context.prompts[-1] += completion['content']
if context.debug:
print(f"Completion response: {completion}")
if expect_api_error:
@ -306,6 +326,16 @@ def step_max_tokens(context, max_tokens):
context.n_predict = max_tokens
@step('{n_keep:d} tokens to keep')
def step_keep_tokens(context, n_keep):
context.n_keep = n_keep
@step('{n_truncate:d} tokens to truncate')
def step_truncate_tokens(context, n_truncate):
context.n_truncate = n_truncate
@step('a response format {response_format}')
def step_response_format(context, response_format):
context.response_format = json.loads(response_format)
@ -355,6 +385,10 @@ def step_n_ubatch(context, n_ubatch):
def step_seed(context, seed):
context.seed = seed
@step('a list of stop strings {stop_list}')
def step_stop_string(context, stop_list):
context.stop_string = json.loads(stop_list)
@step('a prefix prompt')
def step_prompt_prefix(context):
@ -457,6 +491,11 @@ def step_a_prompt_prompt(context, prompt):
context.n_prompts = len(context.prompts)
@step('an ongoing prompt')
def step_a_ongoing_prompt(context):
context.prompts[-1] += context_text(context)
@step('concurrent completion requests')
@async_run_until_complete()
async def step_concurrent_completion_requests(context):
@ -786,7 +825,10 @@ async def request_completion(prompt,
prompt_prefix=None,
prompt_suffix=None,
n_predict=None,
n_keep=0,
n_truncate=0,
cache_prompt=False,
stop_string=None,
id_slot=None,
seed=None,
expect_api_error=None,
@ -809,7 +851,10 @@ async def request_completion(prompt,
"prompt": prompt,
"input_suffix": prompt_suffix,
"n_predict": n_predict if n_predict is not None else -1,
"n_keep": n_keep,
"n_truncate": n_truncate,
"cache_prompt": cache_prompt,
"stop": stop_string if stop_string is not None else [],
"id_slot": id_slot,
"seed": seed if seed is not None else 42
},

View file

@ -0,0 +1,44 @@
# run with: ./tests.sh --no-skipped --tags truncation
@trucation
@slow
Feature: Chat truncation
Background: Server startup
Given a server listening on localhost:8080
And a model file mistral-7b-v0.2-iq3_s-imat.gguf from HF repo ggml-org/models
And prompt caching is enabled
And a list of stop strings ["\n"]
And 82 tokens to keep
And 256 KV cache size
And 32 server max tokens to predict
Then the server is starting
Then the server is healthy
Scenario: Correctly truncate the prompt when the prompt exceeds the context size
Given a prompt:
"""
Continue the chat below.
Me: Hey there, how's it going?
You: I'm doing well, thanks for asking! How are you?
Me: I'm doing good, just trying to get some work done. How's your day?
You: My day has been pretty productive so far. I've been working on a new project.
Me: That's great to hear! What's the new project you're working on?
You: It's a web application that's designed to help people manage their personal finances. I'm really excited about it.
Me: That sounds really useful, I'd be interested to hear more about it. Do you have a timeframe for when you expect to have it ready to launch?
You: I'm aiming to have the initial version ready within the next few months. I want to ensure it's robust before launching it.
Me: That's really nice, are you happy with the progress so far?
"""
And an ongoing completion request
Then -1 tokens are predicted matching You:
Given an ongoing prompt:
"""
Me: I have one more question for you my friend. What's the most value thing you learned during your development journey?
"""
And 52 tokens to truncate
And a completion request with no api error
Then -1 tokens are predicted matching You:
# 28 because '\n' stop string is not pushed to the context
And 28 prompt tokens are processed

View file

@ -5,7 +5,7 @@ set -eu
if [ $# -lt 1 ]
then
# Start @llama.cpp scenario
behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey' --tags llama.cpp
behave --summary --stop --no-capture --exclude 'issues|wrong_usages|passkey|truncation' --tags llama.cpp
else
behave "$@"
fi