From edbd2e98624cd725b4e0b8654dfa558c35a1dce0 Mon Sep 17 00:00:00 2001 From: ochafik Date: Sun, 17 Mar 2024 22:51:29 +0000 Subject: [PATCH] json: add server tests for OAI JSON response_format --- examples/server/tests/features/security.feature | 16 ++++++++++++++++ examples/server/tests/features/server.feature | 16 ++++++++++++++++ examples/server/tests/features/steps/steps.py | 17 +++++++++++++++++ examples/server/utils.hpp | 7 ++++--- 4 files changed, 53 insertions(+), 3 deletions(-) diff --git a/examples/server/tests/features/security.feature b/examples/server/tests/features/security.feature index 1d6aa40ea..eb82e7aca 100644 --- a/examples/server/tests/features/security.feature +++ b/examples/server/tests/features/security.feature @@ -37,6 +37,22 @@ Feature: Security | llama.cpp | no | | hackme | raised | + Scenario Outline: OAI Compatibility (invalid response formats) + Given a system prompt test + And a user prompt test + And a response format + And a model test + And 2 max tokens to predict + And streaming is disabled + Given an OAI compatible chat completions request with raised api error + + Examples: Prompts + | response_format | + | {"type": "sound"} | + | {"type": "json_object", "schema": 123} | + | {"type": "json_object", "schema": {"type": 123}} | + | {"type": "json_object", "schema": {"type": "hiccup"}} | + Scenario Outline: CORS Options Given a user api key llama.cpp diff --git a/examples/server/tests/features/server.feature b/examples/server/tests/features/server.feature index 5014f326d..4ecc41d52 100644 --- a/examples/server/tests/features/server.feature +++ b/examples/server/tests/features/server.feature @@ -69,6 +69,22 @@ Feature: llama.cpp server | codellama70b | You are a coding assistant. | Write the fibonacci function in c++. | 128 | (thanks\|happy\|bird)+ | -1 | 64 | enabled | | + Scenario Outline: OAI Compatibility w/ response format + Given a model test + And a system prompt test + And a user prompt test + And a response format + And 10 max tokens to predict + Given an OAI compatible chat completions request with no api error + Then tokens are predicted matching + + Examples: Prompts + | response_format | n_predicted | re_content | + | {"type": "json_object", "schema": {"const": "42"}} | 5 | "42" | + | {"type": "json_object", "schema": {"items": [{"type": "integer"}]}} | 10 | \[ -300 \] | + | {"type": "json_object"} | 10 | \{ " said the other an | + + Scenario: Tokenize / Detokenize When tokenizing: """ diff --git a/examples/server/tests/features/steps/steps.py b/examples/server/tests/features/steps/steps.py index a59a52d21..2bebcb71b 100644 --- a/examples/server/tests/features/steps/steps.py +++ b/examples/server/tests/features/steps/steps.py @@ -52,6 +52,7 @@ def step_server_config(context, server_fqdn, server_port): context.seed = None context.server_seed = None context.user_api_key = None + context.response_format = None context.tasks_result = [] context.concurrent_tasks = [] @@ -248,6 +249,11 @@ def step_max_tokens(context, max_tokens): context.n_predict = max_tokens +@step('a response format {response_format}') +def step_response_format(context, response_format): + context.response_format = json.loads(response_format) + + @step('streaming is {enable_streaming}') def step_streaming(context, enable_streaming): context.enable_streaming = enable_streaming == 'enabled' @@ -363,6 +369,9 @@ async def step_oai_chat_completions(context, api_error): enable_streaming=context.enable_streaming if hasattr(context, 'enable_streaming') else None, + response_format=context.response_format + if hasattr(context, 'response_format') else None, + seed=await completions_seed(context), user_api_key=context.user_api_key @@ -422,6 +431,8 @@ async def step_oai_chat_completions(context): if hasattr(context, 'n_predict') else None, enable_streaming=context.enable_streaming if hasattr(context, 'enable_streaming') else None, + response_format=context.response_format + if hasattr(context, 'response_format') else None, seed=await completions_seed(context), user_api_key=context.user_api_key if hasattr(context, 'user_api_key') else None) @@ -442,6 +453,8 @@ async def step_oai_chat_completions(context): if hasattr(context, 'n_predict') else None, enable_streaming=context.enable_streaming if hasattr(context, 'enable_streaming') else None, + response_format=context.response_format + if hasattr(context, 'response_format') else None, seed=context.seed if hasattr(context, 'seed') else context.server_seed @@ -724,6 +737,7 @@ async def oai_chat_completions(user_prompt, model=None, n_predict=None, enable_streaming=None, + response_format=None, seed=None, user_api_key=None, expect_api_error=None): @@ -749,6 +763,8 @@ async def oai_chat_completions(user_prompt, "stream": enable_streaming, "seed": seed } + if response_format is not None: + payload['response_format'] = response_format completion_response = { 'content': '', 'timings': { @@ -809,6 +825,7 @@ async def oai_chat_completions(user_prompt, model=model, max_tokens=n_predict, stream=enable_streaming, + response_format=payload.get('response_format'), seed=seed ) except openai.error.AuthenticationError as e: diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 6591e84a1..e976ebda6 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -376,11 +376,12 @@ static json oaicompat_completion_params_parse( llama_params["grammar"] = json_value(body, "grammar", json::object()); } - if (body.contains("response_format")) { + if (!body["response_format"].is_null()) { auto response_format = json_value(body, "response_format", json::object()); - if (response_format.contains("schema") && response_format["type"] == "json_object") { + if (response_format["type"] == "json_object") { llama_params["json_schema"] = json_value(response_format, "schema", json::object()); - std::cerr << "GOT " << llama_params["json_schema"] << std::endl; + } else { + throw std::runtime_error("response_format type not supported: " + response_format["type"].dump()); } }