many fixes

This commit is contained in:
Xuan Son Nguyen 2024-12-04 18:58:16 +01:00
parent 0d6485f0f8
commit d2419b3255
4 changed files with 45 additions and 24 deletions

View file

@ -1172,6 +1172,8 @@ struct server_context {
res.n_decoded = slot.n_decoded;
res.n_prompt_tokens = slot.n_prompt_tokens;
res.content = tkn.text_to_send;
res.stop = slot.stop;
res.truncated = slot.truncated;
if (slot.params.sampling.n_probs > 0) {
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
@ -1186,7 +1188,8 @@ struct server_context {
}
}
if (slot.params.timings_per_token) {
// populate timings if this is final response or timings_per_token is enabled
if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) {
res.timings = slot.get_timings();
}
@ -1195,6 +1198,7 @@ struct server_context {
void send_final_response(server_slot & slot) {
if (slot.params.stream) {
// if in stream mode, send the last partial response
return send_partial_response(slot, {0, "", {}});
}
@ -1209,6 +1213,8 @@ struct server_context {
res.n_tokens_cached = slot.n_past;
res.content = slot.generated_text;
res.stop = slot.stop;
res.truncated = slot.truncated;
res.timings = slot.get_timings();
res.generation_params = slot.params; // copy the parameters
@ -1439,6 +1445,8 @@ struct server_context {
break;
}
SRV_ERR("received partial result, %s\n", result.to_json().dump().c_str());
if (result.stop != STOP_TYPE_NONE) {
if (++n_finished == id_tasks.size()) {
break;
@ -1533,7 +1541,7 @@ struct server_context {
res.id = task.id;
res.n_idle_slots = n_idle_slots;
res.n_processing_slots = n_processing_slots;
res.n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
res.n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
res.t_start = metrics.t_start;
res.kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
@ -1627,13 +1635,13 @@ struct server_context {
const double t_restore_ms = (t_end - t_start) / 1000.0;
server_task_result_slot_save_load result;
result.id = task.id;
result.id_slot = id_slot;
result.filename = filename;
result.is_save = false;
result.n_saved = token_count;
result.n_read = nread;
result.t_ms = t_restore_ms;
result.id = task.id;
result.id_slot = id_slot;
result.filename = filename;
result.is_save = false;
result.n_restored = token_count;
result.n_read = nread;
result.t_ms = t_restore_ms;
queue_results.send(result);
} break;
case SERVER_TASK_TYPE_SLOT_ERASE:

View file

@ -15,6 +15,7 @@
using json = nlohmann::ordered_json;
// cast a shared_ptr to a specific type using copy constructor
#define copy_cast_ptr(TYPEOUT, ptr) *(static_cast<TYPEOUT*>(ptr.get()))
enum stop_type {
@ -281,23 +282,34 @@ struct server_task_result_cmpl_partial : server_task_result {
server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {}
int index = 0;
std::string content;
bool truncated;
int32_t n_decoded;
int32_t n_prompt_tokens;
stop_type stop = STOP_TYPE_NONE;
std::vector<completion_token_output> probs_output;
result_timings timings;
json to_json() {
bool is_stop = stop != STOP_TYPE_NONE;
// non-OAI-compat JSON
json res = json {
{"index", index},
{"content", content},
{"stop", stop != STOP_TYPE_NONE},
{"id_slot", id_slot},
{"index", index},
{"content", content},
{"stop_type", stop_type_to_str(stop)},
{"stop", is_stop},
{"id_slot", id_slot},
{"tokens_predicted", n_decoded},
{"tokens_evaluated", n_prompt_tokens},
};
// populate the timings object when timings_per_token is set
// populate the timings object when needed (usually for the last response or with timings_per_token enabled)
if (timings.prompt_n > 0) {
res.push_back({"timings", timings.to_json()});
}
if (is_stop) {
res.push_back({"truncated", truncated});
}
return res;
}
@ -464,7 +476,7 @@ struct server_task_result_slot_erase : server_task_result {
{ "n_erased", n_erased },
};
}
static server_task_result_slot_erase from_ptr(std::unique_ptr<server_task_result> & result_ptr) {
return copy_cast_ptr(server_task_result_slot_erase, result_ptr);
}

View file

@ -1,5 +1,9 @@
#!/bin/bash
# make sure we are in the right directory
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
cd $SCRIPT_DIR
set -eu
if [ $# -lt 1 ]

View file

@ -12,13 +12,13 @@ def create_server():
@pytest.mark.parametrize(
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
[
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
]
)
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
global server
server.start()
res = server.make_request("POST", "/chat/completions", data={
@ -35,10 +35,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
choice = res.body["choices"][0]
assert "assistant" == choice["message"]["role"]
assert match_regex(re_content, choice["message"]["content"])
if truncated:
assert choice["finish_reason"] == "length"
else:
assert choice["finish_reason"] == "stop"
assert choice["finish_reason"] == finish_reason
@pytest.mark.parametrize(
@ -93,7 +90,7 @@ def test_chat_completion_with_openai_library():
temperature=0.8,
)
print(res)
assert res.choices[0].finish_reason == "stop"
assert res.choices[0].finish_reason == "length"
assert res.choices[0].message.content is not None
assert match_regex("(Suddenly)+", res.choices[0].message.content)