server: tests: CORS and api key checks scenario

This commit is contained in:
Pierrick HYMBERT 2024-02-21 01:49:39 +01:00
parent 6dcbcfe6ba
commit 672d98f6f0
3 changed files with 125 additions and 51 deletions

View file

@ -1,7 +1,7 @@
Feature: llama.cpp server Feature: llama.cpp server
Background: Server startup Background: Server startup
Given a server listening on localhost:8080 with 2 slots and 42 as seed Given a server listening on localhost:8080 with 2 slots, 42 as seed and llama.cpp as api key
Then the server is starting Then the server is starting
Then the server is healthy Then the server is healthy
@ -13,13 +13,17 @@ Feature: llama.cpp server
@llama.cpp @llama.cpp
Scenario Outline: Completion Scenario Outline: Completion
Given a <prompt> completion request with maximum <n_predict> tokens Given a prompt <prompt>
And a user api key <api_key>
And <n_predict> max tokens to predict
And a completion request
Then <n_predict> tokens are predicted Then <n_predict> tokens are predicted
Examples: Prompts Examples: Prompts
| prompt | n_predict | | prompt | n_predict | api_key |
| I believe the meaning of life is | 128 | | I believe the meaning of life is | 128 | llama.cpp |
| Write a joke about AI | 512 | | Write a joke about AI | 512 | llama.cpp |
| say goodbye | 0 | |
@llama.cpp @llama.cpp
Scenario Outline: OAI Compatibility Scenario Outline: OAI Compatibility
@ -28,13 +32,15 @@ Feature: llama.cpp server
And a model <model> And a model <model>
And <max_tokens> max tokens to predict And <max_tokens> max tokens to predict
And streaming is <enable_streaming> And streaming is <enable_streaming>
Given an OAI compatible chat completions request And a user api key <api_key>
Given an OAI compatible chat completions request with an api error <api_error>
Then <max_tokens> tokens are predicted Then <max_tokens> tokens are predicted
Examples: Prompts Examples: Prompts
| model | system_prompt | user_prompt | max_tokens | enable_streaming | | model | system_prompt | user_prompt | max_tokens | enable_streaming | api_key | api_error |
| llama-2 | You are ChatGPT. | Say hello. | 64 | false | | llama-2 | You are ChatGPT. | Say hello. | 64 | false | llama.cpp | none |
| codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 512 | true | | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 512 | true | llama.cpp | none |
| John-Doe | You are an hacker. | Write segfault code in rust. | 0 | true | hackme | raised |
@llama.cpp @llama.cpp
Scenario: Multi users Scenario: Multi users
@ -47,6 +53,7 @@ Feature: llama.cpp server
Write another very long music lyrics. Write another very long music lyrics.
""" """
And 32 max tokens to predict And 32 max tokens to predict
And a user api key llama.cpp
Given concurrent completion requests Given concurrent completion requests
Then the server is busy Then the server is busy
And all slots are busy And all slots are busy
@ -57,7 +64,7 @@ Feature: llama.cpp server
@llama.cpp @llama.cpp
Scenario: Multi users OAI Compatibility Scenario: Multi users OAI Compatibility
Given a system prompt "You are an AI assistant." Given a system prompt "You are an AI assistant."
And a model tinyllama-2 And a model tinyllama-2
Given a prompt: Given a prompt:
""" """
Write a very long story about AI. Write a very long story about AI.
@ -68,6 +75,7 @@ Feature: llama.cpp server
""" """
And 32 max tokens to predict And 32 max tokens to predict
And streaming is enabled And streaming is enabled
And a user api key llama.cpp
Given concurrent OAI completions requests Given concurrent OAI completions requests
Then the server is busy Then the server is busy
And all slots are busy And all slots are busy
@ -126,3 +134,15 @@ Feature: llama.cpp server
""" """
Then tokens can be detokenize Then tokens can be detokenize
@llama.cpp
Scenario Outline: CORS Options
When an OPTIONS request is sent from <origin>
Then CORS header <cors_header> is set to <cors_header_value>
Examples: Headers
| origin | cors_header | cors_header_value |
| localhost | Access-Control-Allow-Origin | localhost |
| web.mydomain.fr | Access-Control-Allow-Origin | web.mydomain.fr |
| origin | Access-Control-Allow-Credentials | true |
| web.mydomain.fr | Access-Control-Allow-Methods | POST |
| web.mydomain.fr | Access-Control-Allow-Headers | * |

View file

@ -7,8 +7,9 @@ import requests
from behave import step from behave import step
@step(u"a server listening on {server_fqdn}:{server_port} with {n_slots} slots and {seed} as seed") @step(
def step_server_config(context, server_fqdn, server_port, n_slots, seed): u"a server listening on {server_fqdn}:{server_port} with {n_slots} slots, {seed} as seed and {api_key} as api key")
def step_server_config(context, server_fqdn, server_port, n_slots, seed, api_key):
context.server_fqdn = server_fqdn context.server_fqdn = server_fqdn
context.server_port = int(server_port) context.server_port = int(server_port)
context.n_slots = int(n_slots) context.n_slots = int(n_slots)
@ -19,7 +20,8 @@ def step_server_config(context, server_fqdn, server_port, n_slots, seed):
context.completion_threads = [] context.completion_threads = []
context.prompts = [] context.prompts = []
openai.api_key = 'llama.cpp' context.api_key = api_key
openai.api_key = context.api_key
@step(u"the server is {expecting_status}") @step(u"the server is {expecting_status}")
@ -77,14 +79,16 @@ def step_all_slots_status(context, expected_slot_status_string):
request_slots_status(context, expected_slots) request_slots_status(context, expected_slots)
@step(u'a {prompt} completion request with maximum {n_predict} tokens') @step(u'a completion request')
def step_request_completion(context, prompt, n_predict): def step_request_completion(context):
request_completion(context, prompt, n_predict) request_completion(context, context.prompts.pop(), context.n_predict, context.user_api_key)
context.user_api_key = None
@step(u'{predicted_n} tokens are predicted') @step(u'{predicted_n} tokens are predicted')
def step_n_tokens_predicted(context, predicted_n): def step_n_tokens_predicted(context, predicted_n):
assert_n_tokens_predicted(context.completions[0], int(predicted_n)) if int(predicted_n) > 0:
assert_n_tokens_predicted(context.completions[0], int(predicted_n))
@step(u'a user prompt {user_prompt}') @step(u'a user prompt {user_prompt}')
@ -112,9 +116,20 @@ def step_streaming(context, enable_streaming):
context.enable_streaming = enable_streaming == 'enabled' or bool(enable_streaming) context.enable_streaming = enable_streaming == 'enabled' or bool(enable_streaming)
@step(u'an OAI compatible chat completions request') @step(u'a user api key {user_api_key}')
def step_oai_chat_completions(context): def step_user_api_key(context, user_api_key):
oai_chat_completions(context, context.user_prompt) context.user_api_key = user_api_key
@step(u'a user api key ')
def step_user_api_key(context):
context.user_api_key = None
@step(u'an OAI compatible chat completions request with an api error {api_error}')
def step_oai_chat_completions(context, api_error):
oai_chat_completions(context, context.user_prompt, api_error=api_error == 'raised')
context.user_api_key = None
@step(u'a prompt') @step(u'a prompt')
@ -122,14 +137,19 @@ def step_a_prompt(context):
context.prompts.append(context.text) context.prompts.append(context.text)
@step(u'a prompt {prompt}')
def step_a_prompt_prompt(context, prompt):
context.prompts.append(prompt)
@step(u'concurrent completion requests') @step(u'concurrent completion requests')
def step_concurrent_completion_requests(context): def step_concurrent_completion_requests(context):
concurrent_requests(context, request_completion) concurrent_requests(context, request_completion, context.n_predict, context.user_api_key)
@step(u'concurrent OAI completions requests') @step(u'concurrent OAI completions requests')
def step_oai_chat_completions(context): def step_oai_chat_completions(context):
concurrent_requests(context, oai_chat_completions) concurrent_requests(context, oai_chat_completions, context.user_api_key)
@step(u'all prompts are predicted') @step(u'all prompts are predicted')
@ -168,7 +188,7 @@ def step_oai_compute_embedding(context):
def step_tokenize(context): def step_tokenize(context):
context.tokenized_text = context.text context.tokenized_text = context.text
response = requests.post(f'{context.base_url}/tokenize', json={ response = requests.post(f'{context.base_url}/tokenize', json={
"content":context.tokenized_text, "content": context.tokenized_text,
}) })
assert response.status_code == 200 assert response.status_code == 200
context.tokens = response.json()['tokens'] context.tokens = response.json()['tokens']
@ -181,49 +201,82 @@ def step_detokenize(context):
"tokens": context.tokens, "tokens": context.tokens,
}) })
assert response.status_code == 200 assert response.status_code == 200
print(response.json())
# FIXME the detokenize answer contains a space prefix ? see #3287 # FIXME the detokenize answer contains a space prefix ? see #3287
assert context.tokenized_text == response.json()['content'].strip() assert context.tokenized_text == response.json()['content'].strip()
def concurrent_requests(context, f_completion): @step(u'an OPTIONS request is sent from {origin}')
def step_options_request(context, origin):
options_response = requests.options(f'{context.base_url}/v1/chat/completions',
headers={"Origin": origin})
assert options_response.status_code == 200
context.options_response = options_response
@step(u'CORS header {cors_header} is set to {cors_header_value}')
def step_check_options_header_value(context, cors_header, cors_header_value):
assert context.options_response.headers[cors_header] == cors_header_value
def concurrent_requests(context, f_completion, *argv):
context.completions.clear() context.completions.clear()
context.completion_threads.clear() context.completion_threads.clear()
for prompt in context.prompts: for prompt in context.prompts:
completion_thread = threading.Thread(target=f_completion, args=(context, prompt)) completion_thread = threading.Thread(target=f_completion, args=(context, prompt, *argv))
completion_thread.start() completion_thread.start()
context.completion_threads.append(completion_thread) context.completion_threads.append(completion_thread)
context.prompts.clear() context.prompts.clear()
def request_completion(context, prompt, n_predict=None): def request_completion(context, prompt, n_predict=None, user_api_key=None):
response = requests.post(f'{context.base_url}/completion', json={ origin = "my.super.domain"
"prompt": prompt, headers = {
"n_predict": int(n_predict) if n_predict is not None else context.n_predict, 'Origin': origin
"seed": context.seed }
}) if 'user_api_key' in context:
assert response.status_code == 200 headers['Authorization'] = f'Bearer {user_api_key}'
context.completions.append(response.json())
response = requests.post(f'{context.base_url}/completion',
json={
"prompt": prompt,
"n_predict": int(n_predict) if n_predict is not None else context.n_predict,
"seed": context.seed
},
headers=headers)
if n_predict is not None and n_predict > 0:
assert response.status_code == 200
assert response.headers['Access-Control-Allow-Origin'] == origin
context.completions.append(response.json())
else:
assert response.status_code == 401
def oai_chat_completions(context, user_prompt):
def oai_chat_completions(context, user_prompt, api_error=None):
openai.api_key = context.user_api_key
openai.api_base = f'{context.base_url}/v1/chat' openai.api_base = f'{context.base_url}/v1/chat'
chat_completion = openai.Completion.create( try:
messages=[ chat_completion = openai.Completion.create(
{ messages=[
"role": "system", {
"content": context.system_prompt, "role": "system",
}, "content": context.system_prompt,
{ },
"role": "user", {
"content": user_prompt, "role": "user",
} "content": user_prompt,
], }
model=context.model, ],
max_tokens=context.n_predict, model=context.model,
stream=context.enable_streaming, max_tokens=context.n_predict,
seed=context.seed stream=context.enable_streaming,
) seed=context.seed
)
except openai.error.APIError:
if api_error:
openai.api_key = context.api_key
return
openai.api_key = context.api_key
if context.enable_streaming: if context.enable_streaming:
completion_response = { completion_response = {
'content': '', 'content': '',

View file

@ -29,6 +29,7 @@ set -eu
--threads-batch 4 \ --threads-batch 4 \
--embedding \ --embedding \
--cont-batching \ --cont-batching \
--api-key llama.cpp \
"$@" & "$@" &
# Start tests # Start tests