many fixes
This commit is contained in:
parent
0d6485f0f8
commit
d2419b3255
4 changed files with 45 additions and 24 deletions
|
@ -1172,6 +1172,8 @@ struct server_context {
|
||||||
res.n_decoded = slot.n_decoded;
|
res.n_decoded = slot.n_decoded;
|
||||||
res.n_prompt_tokens = slot.n_prompt_tokens;
|
res.n_prompt_tokens = slot.n_prompt_tokens;
|
||||||
res.content = tkn.text_to_send;
|
res.content = tkn.text_to_send;
|
||||||
|
res.stop = slot.stop;
|
||||||
|
res.truncated = slot.truncated;
|
||||||
|
|
||||||
if (slot.params.sampling.n_probs > 0) {
|
if (slot.params.sampling.n_probs > 0) {
|
||||||
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
const llama_tokens to_send_toks = common_tokenize(ctx, tkn.text_to_send, false);
|
||||||
|
@ -1186,7 +1188,8 @@ struct server_context {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (slot.params.timings_per_token) {
|
// populate timings if this is final response or timings_per_token is enabled
|
||||||
|
if (slot.stop != STOP_TYPE_NONE || slot.params.timings_per_token) {
|
||||||
res.timings = slot.get_timings();
|
res.timings = slot.get_timings();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1195,6 +1198,7 @@ struct server_context {
|
||||||
|
|
||||||
void send_final_response(server_slot & slot) {
|
void send_final_response(server_slot & slot) {
|
||||||
if (slot.params.stream) {
|
if (slot.params.stream) {
|
||||||
|
// if in stream mode, send the last partial response
|
||||||
return send_partial_response(slot, {0, "", {}});
|
return send_partial_response(slot, {0, "", {}});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1209,6 +1213,8 @@ struct server_context {
|
||||||
res.n_tokens_cached = slot.n_past;
|
res.n_tokens_cached = slot.n_past;
|
||||||
res.content = slot.generated_text;
|
res.content = slot.generated_text;
|
||||||
res.stop = slot.stop;
|
res.stop = slot.stop;
|
||||||
|
res.truncated = slot.truncated;
|
||||||
|
res.timings = slot.get_timings();
|
||||||
|
|
||||||
res.generation_params = slot.params; // copy the parameters
|
res.generation_params = slot.params; // copy the parameters
|
||||||
|
|
||||||
|
@ -1439,6 +1445,8 @@ struct server_context {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
SRV_ERR("received partial result, %s\n", result.to_json().dump().c_str());
|
||||||
|
|
||||||
if (result.stop != STOP_TYPE_NONE) {
|
if (result.stop != STOP_TYPE_NONE) {
|
||||||
if (++n_finished == id_tasks.size()) {
|
if (++n_finished == id_tasks.size()) {
|
||||||
break;
|
break;
|
||||||
|
@ -1533,7 +1541,7 @@ struct server_context {
|
||||||
res.id = task.id;
|
res.id = task.id;
|
||||||
res.n_idle_slots = n_idle_slots;
|
res.n_idle_slots = n_idle_slots;
|
||||||
res.n_processing_slots = n_processing_slots;
|
res.n_processing_slots = n_processing_slots;
|
||||||
res.n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
|
res.n_tasks_deferred = queue_tasks.queue_tasks_deferred.size();
|
||||||
res.t_start = metrics.t_start;
|
res.t_start = metrics.t_start;
|
||||||
|
|
||||||
res.kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
|
res.kv_cache_tokens_count = llama_get_kv_cache_token_count(ctx);
|
||||||
|
@ -1627,13 +1635,13 @@ struct server_context {
|
||||||
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
const double t_restore_ms = (t_end - t_start) / 1000.0;
|
||||||
|
|
||||||
server_task_result_slot_save_load result;
|
server_task_result_slot_save_load result;
|
||||||
result.id = task.id;
|
result.id = task.id;
|
||||||
result.id_slot = id_slot;
|
result.id_slot = id_slot;
|
||||||
result.filename = filename;
|
result.filename = filename;
|
||||||
result.is_save = false;
|
result.is_save = false;
|
||||||
result.n_saved = token_count;
|
result.n_restored = token_count;
|
||||||
result.n_read = nread;
|
result.n_read = nread;
|
||||||
result.t_ms = t_restore_ms;
|
result.t_ms = t_restore_ms;
|
||||||
queue_results.send(result);
|
queue_results.send(result);
|
||||||
} break;
|
} break;
|
||||||
case SERVER_TASK_TYPE_SLOT_ERASE:
|
case SERVER_TASK_TYPE_SLOT_ERASE:
|
||||||
|
|
|
@ -15,6 +15,7 @@
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
// cast a shared_ptr to a specific type using copy constructor
|
||||||
#define copy_cast_ptr(TYPEOUT, ptr) *(static_cast<TYPEOUT*>(ptr.get()))
|
#define copy_cast_ptr(TYPEOUT, ptr) *(static_cast<TYPEOUT*>(ptr.get()))
|
||||||
|
|
||||||
enum stop_type {
|
enum stop_type {
|
||||||
|
@ -281,23 +282,34 @@ struct server_task_result_cmpl_partial : server_task_result {
|
||||||
server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {}
|
server_task_result_cmpl_partial() : server_task_result(RESULT_TYPE_CMPL_PARTIAL) {}
|
||||||
int index = 0;
|
int index = 0;
|
||||||
std::string content;
|
std::string content;
|
||||||
|
|
||||||
|
bool truncated;
|
||||||
int32_t n_decoded;
|
int32_t n_decoded;
|
||||||
int32_t n_prompt_tokens;
|
int32_t n_prompt_tokens;
|
||||||
|
|
||||||
stop_type stop = STOP_TYPE_NONE;
|
stop_type stop = STOP_TYPE_NONE;
|
||||||
std::vector<completion_token_output> probs_output;
|
std::vector<completion_token_output> probs_output;
|
||||||
result_timings timings;
|
result_timings timings;
|
||||||
|
|
||||||
json to_json() {
|
json to_json() {
|
||||||
|
bool is_stop = stop != STOP_TYPE_NONE;
|
||||||
|
// non-OAI-compat JSON
|
||||||
json res = json {
|
json res = json {
|
||||||
{"index", index},
|
{"index", index},
|
||||||
{"content", content},
|
{"content", content},
|
||||||
{"stop", stop != STOP_TYPE_NONE},
|
{"stop_type", stop_type_to_str(stop)},
|
||||||
{"id_slot", id_slot},
|
{"stop", is_stop},
|
||||||
|
{"id_slot", id_slot},
|
||||||
|
{"tokens_predicted", n_decoded},
|
||||||
|
{"tokens_evaluated", n_prompt_tokens},
|
||||||
};
|
};
|
||||||
// populate the timings object when timings_per_token is set
|
// populate the timings object when needed (usually for the last response or with timings_per_token enabled)
|
||||||
if (timings.prompt_n > 0) {
|
if (timings.prompt_n > 0) {
|
||||||
res.push_back({"timings", timings.to_json()});
|
res.push_back({"timings", timings.to_json()});
|
||||||
}
|
}
|
||||||
|
if (is_stop) {
|
||||||
|
res.push_back({"truncated", truncated});
|
||||||
|
}
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -464,7 +476,7 @@ struct server_task_result_slot_erase : server_task_result {
|
||||||
{ "n_erased", n_erased },
|
{ "n_erased", n_erased },
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
static server_task_result_slot_erase from_ptr(std::unique_ptr<server_task_result> & result_ptr) {
|
static server_task_result_slot_erase from_ptr(std::unique_ptr<server_task_result> & result_ptr) {
|
||||||
return copy_cast_ptr(server_task_result_slot_erase, result_ptr);
|
return copy_cast_ptr(server_task_result_slot_erase, result_ptr);
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,5 +1,9 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
|
|
||||||
|
# make sure we are in the right directory
|
||||||
|
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
|
||||||
|
cd $SCRIPT_DIR
|
||||||
|
|
||||||
set -eu
|
set -eu
|
||||||
|
|
||||||
if [ $# -lt 1 ]
|
if [ $# -lt 1 ]
|
||||||
|
|
|
@ -12,13 +12,13 @@ def create_server():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
|
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
||||||
[
|
[
|
||||||
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
|
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
||||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
|
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
|
def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
|
||||||
global server
|
global server
|
||||||
server.start()
|
server.start()
|
||||||
res = server.make_request("POST", "/chat/completions", data={
|
res = server.make_request("POST", "/chat/completions", data={
|
||||||
|
@ -35,10 +35,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
||||||
choice = res.body["choices"][0]
|
choice = res.body["choices"][0]
|
||||||
assert "assistant" == choice["message"]["role"]
|
assert "assistant" == choice["message"]["role"]
|
||||||
assert match_regex(re_content, choice["message"]["content"])
|
assert match_regex(re_content, choice["message"]["content"])
|
||||||
if truncated:
|
assert choice["finish_reason"] == finish_reason
|
||||||
assert choice["finish_reason"] == "length"
|
|
||||||
else:
|
|
||||||
assert choice["finish_reason"] == "stop"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -93,7 +90,7 @@ def test_chat_completion_with_openai_library():
|
||||||
temperature=0.8,
|
temperature=0.8,
|
||||||
)
|
)
|
||||||
print(res)
|
print(res)
|
||||||
assert res.choices[0].finish_reason == "stop"
|
assert res.choices[0].finish_reason == "length"
|
||||||
assert res.choices[0].message.content is not None
|
assert res.choices[0].message.content is not None
|
||||||
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
assert match_regex("(Suddenly)+", res.choices[0].message.content)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue