many fixes
This commit is contained in:
parent
0d6485f0f8
commit
d2419b3255
4 changed files with 45 additions and 24 deletions
|
@ -1172,6 +1172,8 @@ struct server_context {
|
|||
res.n_decoded = slot.n_decoded;
|
||||
res.n_prompt_tokens = slot.n_prompt_tokens;
|
||||
res.content = tkn.text_to_send;
|
||||
res.stop = slot.stop;
|
||||
res.truncated = slot.truncated;
|
||||
|
||||
if (slot.params.sampling.n_probs > 0) {
|
||||
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
||||
|
@ -1186,7 +1188,8 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
if (slot.params.timings_per_token) {
|
||||
// populate timings if this is final response or timings_per_token is enabled
|
||||
if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) {
|
||||
res.timings = slot.get_timings();
|
||||
}
|
||||
|
||||
|
@ -1195,6 +1198,7 @@ struct server_context {
|
|||
|
||||
void send_final_response(server_slot & slot) {
|
||||
if (slot.params.stream) {
|
||||
// if in stream mode, send the last partial response
|
||||
return send_partial_response(slot, {0, "", {}});
|
||||
}
|
||||
|
||||
|
@ -1209,6 +1213,8 @@ struct server_context {
|
|||
res.n_tokens_cached = slot.n_past;
|
||||
res.content = slot.generated_text;
|
||||
res.stop = slot.stop;
|
||||
res.truncated = slot.truncated;
|
||||
res.timings = slot.get_timings();
|
||||
|
||||
res.generation_params = slot.params; // copy the parameters
|
||||
|
||||
|
@ -1439,6 +1445,8 @@ struct server_context {
|
|||
break;
|
||||
}
|
||||
|
||||
SRV_ERR("received partial result, %s\n", result.to_json().dump().c_str());
|
||||
|
||||
if (result.stop != STOP_TYPE_NONE) {
|
||||
if (++n_finished == id_tasks.size()) {
|
||||
break;
|
||||
|
@ -1533,7 +1541,7 @@ struct server_context {
|
|||
res.id = task.id;
|
||||
res.n_idle_slots = n_idle_slots;
|
||||
res.n_processing_slots = n_processing_slots;
|
||||
res.n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
|
||||
res.n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
|
||||
res.t_start = metrics.t_start;
|
||||
|
||||
res.kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
|
||||
|
@ -1627,13 +1635,13 @@ struct server_context {
|
|||
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
||||
|
||||
server_task_result_slot_save_load result;
|
||||
result.id = task.id;
|
||||
result.id_slot = id_slot;
|
||||
result.filename = filename;
|
||||
result.is_save = false;
|
||||
result.n_saved = token_count;
|
||||
result.n_read = nread;
|
||||
result.t_ms = t_restore_ms;
|
||||
result.id = task.id;
|
||||
result.id_slot = id_slot;
|
||||
result.filename = filename;
|
||||
result.is_save = false;
|
||||
result.n_restored = token_count;
|
||||
result.n_read = nread;
|
||||
result.t_ms = t_restore_ms;
|
||||
queue_results.send(result);
|
||||
} break;
|
||||
case SERVER_TASK_TYPE_SLOT_ERASE:
|
||||
|
|
|
@ -15,6 +15,7 @@
|
|||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
// cast a shared_ptr to a specific type using copy constructor
|
||||
#define copy_cast_ptr(TYPEOUT, ptr) *(static_cast<TYPEOUT*>(ptr.get()))
|
||||
|
||||
enum stop_type {
|
||||
|
@ -281,23 +282,34 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {}
|
||||
int index = 0;
|
||||
std::string content;
|
||||
|
||||
bool truncated;
|
||||
int32_t n_decoded;
|
||||
int32_t n_prompt_tokens;
|
||||
|
||||
stop_type stop = STOP_TYPE_NONE;
|
||||
std::vector<completion_token_output> probs_output;
|
||||
result_timings timings;
|
||||
|
||||
json to_json() {
|
||||
bool is_stop = stop != STOP_TYPE_NONE;
|
||||
// non-OAI-compat JSON
|
||||
json res = json {
|
||||
{"index", index},
|
||||
{"content", content},
|
||||
{"stop", stop != STOP_TYPE_NONE},
|
||||
{"id_slot", id_slot},
|
||||
{"index", index},
|
||||
{"content", content},
|
||||
{"stop_type", stop_type_to_str(stop)},
|
||||
{"stop", is_stop},
|
||||
{"id_slot", id_slot},
|
||||
{"tokens_predicted", n_decoded},
|
||||
{"tokens_evaluated", n_prompt_tokens},
|
||||
};
|
||||
// populate the timings object when timings_per_token is set
|
||||
// populate the timings object when needed (usually for the last response or with timings_per_token enabled)
|
||||
if (timings.prompt_n > 0) {
|
||||
res.push_back({"timings", timings.to_json()});
|
||||
}
|
||||
if (is_stop) {
|
||||
res.push_back({"truncated", truncated});
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
|
@ -464,7 +476,7 @@ struct server_task_result_slot_erase : server_task_result {
|
|||
{ "n_erased", n_erased },
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
static server_task_result_slot_erase from_ptr(std::unique_ptr<server_task_result> & result_ptr) {
|
||||
return copy_cast_ptr(server_task_result_slot_erase, result_ptr);
|
||||
}
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
#!/bin/bash
|
||||
|
||||
# make sure we are in the right directory
|
||||
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||
cd $SCRIPT_DIR
|
||||
|
||||
set -eu
|
||||
|
||||
if [ $# -lt 1 ]
|
||||
|
|
|
@ -12,13 +12,13 @@ def create_server():
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
||||
[
|
||||
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
|
||||
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
|
||||
]
|
||||
)
|
||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
|
||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/chat/completions", data={
|
||||
|
@ -35,10 +35,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
|||
choice = res.body["choices"][0]
|
||||
assert "assistant" == choice["message"]["role"]
|
||||
assert match_regex(re_content, choice["message"]["content"])
|
||||
if truncated:
|
||||
assert choice["finish_reason"] == "length"
|
||||
else:
|
||||
assert choice["finish_reason"] == "stop"
|
||||
assert choice["finish_reason"] == finish_reason
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
|
@ -93,7 +90,7 @@ def test_chat_completion_with_openai_library():
|
|||
temperature=0.8,
|
||||
)
|
||||
print(res)
|
||||
assert res.choices[0].finish_reason == "stop"
|
||||
assert res.choices[0].finish_reason == "length"
|
||||
assert res.choices[0].message.content is not None
|
||||
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue