fix model_alias and completion_probabilities
This commit is contained in:
parent
a43e1dc66c
commit
fb4b9be602
5 changed files with 73 additions and 31 deletions
|
@ -215,7 +215,7 @@ struct common_params {
|
|||
struct common_params_speculative speculative;
|
||||
|
||||
std::string model = ""; // model path // NOLINT
|
||||
std::string model_alias = "unknown"; // model alias // NOLINT
|
||||
std::string model_alias = ""; // model alias // NOLINT
|
||||
std::string model_url = ""; // model url to download // NOLINT
|
||||
std::string hf_token = ""; // HF token // NOLINT
|
||||
std::string hf_repo = ""; // HF repo // NOLINT
|
||||
|
|
|
@ -250,29 +250,29 @@ struct completion_token_output {
|
|||
std::string text_to_send;
|
||||
struct token_prob {
|
||||
llama_token tok;
|
||||
std::string tok_str;
|
||||
float prob;
|
||||
};
|
||||
std::vector<token_prob> probs;
|
||||
|
||||
json to_json(const llama_context * ctx) const {
|
||||
json to_json() const {
|
||||
json probs_for_token = json::array();
|
||||
for (const auto & p : probs) {
|
||||
const std::string tok_str = tokens_to_output_formatted_string(ctx, p.tok);
|
||||
probs_for_token.push_back(json {
|
||||
{"tok_str", tok_str},
|
||||
{"tok_str", p.tok_str},
|
||||
{"prob", p.prob},
|
||||
});
|
||||
}
|
||||
return probs_for_token;
|
||||
}
|
||||
|
||||
static json probs_vector_to_json(const llama_context * ctx, const std::vector<completion_token_output> & probs) {
|
||||
static json probs_vector_to_json(const std::vector<completion_token_output> & probs) {
|
||||
json out = json::array();
|
||||
for (const auto & prob : probs) {
|
||||
const std::string tok_str = tokens_to_output_formatted_string(ctx, prob.tok);
|
||||
const std::string tok_str = prob.text_to_send;
|
||||
out.push_back(json {
|
||||
{"content", tok_str},
|
||||
{"probs", prob.to_json(ctx)},
|
||||
{"probs", prob.to_json()},
|
||||
});
|
||||
}
|
||||
return out;
|
||||
|
@ -309,7 +309,7 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
|
||||
virtual json to_json() override {
|
||||
// non-OAI-compat JSON
|
||||
return json {
|
||||
json res = json {
|
||||
{"index", index},
|
||||
{"content", content},
|
||||
{"id_slot", id_slot},
|
||||
|
@ -326,6 +326,10 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
{"tokens_cached", n_tokens_cached},
|
||||
{"timings", timings.to_json()},
|
||||
};
|
||||
if (!probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
virtual json to_json_oai_compat() override {
|
||||
|
@ -362,12 +366,6 @@ struct server_task_result_cmpl_final : server_task_result {
|
|||
if (verbose) {
|
||||
res["__verbose"] = to_json();
|
||||
}
|
||||
|
||||
// TODO: fix this
|
||||
// if (result.contains("completion_probabilities")) {
|
||||
// res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
|
||||
// }
|
||||
|
||||
if (timings.prompt_n >= 0) {
|
||||
res.push_back({"timings", timings.to_json()});
|
||||
}
|
||||
|
@ -418,6 +416,9 @@ struct server_task_result_cmpl_partial : server_task_result {
|
|||
if (timings.prompt_n > 0) {
|
||||
res.push_back({"timings", timings.to_json()});
|
||||
}
|
||||
if (!probs_output.empty()) {
|
||||
res["completion_probabilities"] = completion_token_output::probs_vector_to_json(probs_output);
|
||||
}
|
||||
if (is_stop) {
|
||||
res.push_back({"truncated", truncated});
|
||||
}
|
||||
|
@ -2786,9 +2787,11 @@ struct server_context {
|
|||
const auto * cur_p = common_sampler_get_candidates(slot.smpl);
|
||||
|
||||
for (size_t i = 0; i < (size_t) slot.params.sampling.n_probs; ++i) {
|
||||
auto tok_id = cur_p->data[i].id;
|
||||
result.probs.push_back({
|
||||
cur_p->data[i].id,
|
||||
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
||||
tok_id,
|
||||
tokens_to_output_formatted_string(ctx, tok_id),
|
||||
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -2920,10 +2923,6 @@ int main(int argc, char ** argv) {
|
|||
// struct that contains llama context and inference
|
||||
server_context ctx_server;
|
||||
|
||||
if (params.model_alias == "unknown") {
|
||||
params.model_alias = params.model;
|
||||
}
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
|
|
|
@ -44,4 +44,10 @@ To run with stdout/stderr display in real time (verbose output, but useful for d
|
|||
DEBUG=1 ./tests.sh -s -v -x
|
||||
```
|
||||
|
||||
Hint: You can compile and run test in single command, useful for local developement:
|
||||
|
||||
```shell
|
||||
cmake --build build -j --target llama-server && ./examples/server/tests/tests.sh
|
||||
```
|
||||
|
||||
To see all available arguments, please refer to [pytest documentation](https://docs.pytest.org/en/stable/how-to/usage.html)
|
||||
|
|
|
@ -14,7 +14,7 @@ def create_server():
|
|||
@pytest.mark.parametrize(
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
||||
[
|
||||
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
||||
(None, "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
|
||||
]
|
||||
)
|
||||
|
@ -30,6 +30,7 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
|||
],
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert res.body["model"] == model if model is not None else server.model_alias
|
||||
assert res.body["usage"]["prompt_tokens"] == n_prompt
|
||||
assert res.body["usage"]["completion_tokens"] == n_predicted
|
||||
choice = res.body["choices"][0]
|
||||
|
@ -39,17 +40,17 @@ def test_chat_completion(model, system_prompt, user_prompt, max_tokens, re_conte
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,truncated",
|
||||
"model,system_prompt,user_prompt,max_tokens,re_content,n_prompt,n_predicted,finish_reason",
|
||||
[
|
||||
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, False),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, False),
|
||||
("llama-2", "Book", "What is the best book", 8, "(Suddenly)+", 77, 8, "length"),
|
||||
("codellama70b", "You are a coding assistant.", "Write the fibonacci function in c++.", 128, "(Aside|she|felter|alonger)+", 104, 64, "length"),
|
||||
]
|
||||
)
|
||||
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, truncated):
|
||||
def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, re_content, n_prompt, n_predicted, finish_reason):
|
||||
global server
|
||||
server.model_alias = None
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/chat/completions", data={
|
||||
"model": model,
|
||||
"max_tokens": max_tokens,
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
|
@ -60,16 +61,13 @@ def test_chat_completion_stream(model, system_prompt, user_prompt, max_tokens, r
|
|||
content = ""
|
||||
for data in res:
|
||||
choice = data["choices"][0]
|
||||
assert "gpt-3.5" in data["model"] # DEFAULT_OAICOMPAT_MODEL, maybe changed in the future
|
||||
if choice["finish_reason"] in ["stop", "length"]:
|
||||
assert data["usage"]["prompt_tokens"] == n_prompt
|
||||
assert data["usage"]["completion_tokens"] == n_predicted
|
||||
assert "content" not in choice["delta"]
|
||||
assert match_regex(re_content, content)
|
||||
# FIXME: not sure why this is incorrect in stream mode
|
||||
# if truncated:
|
||||
# assert choice["finish_reason"] == "length"
|
||||
# else:
|
||||
# assert choice["finish_reason"] == "stop"
|
||||
assert choice["finish_reason"] == finish_reason
|
||||
else:
|
||||
assert choice["finish_reason"] is None
|
||||
content += choice["delta"]["content"]
|
||||
|
|
|
@ -51,6 +51,24 @@ def test_completion_stream(prompt: str, n_predict: int, re_content: str, n_promp
|
|||
content += data["content"]
|
||||
|
||||
|
||||
def test_completion_stream_vs_non_stream():
|
||||
global server
|
||||
server.start()
|
||||
res_stream = server.make_stream_request("POST", "/completion", data={
|
||||
"n_predict": 8,
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"stream": True,
|
||||
})
|
||||
res_non_stream = server.make_request("POST", "/completion", data={
|
||||
"n_predict": 8,
|
||||
"prompt": "I believe the meaning of life is",
|
||||
})
|
||||
content_stream = ""
|
||||
for data in res_stream:
|
||||
content_stream += data["content"]
|
||||
assert content_stream == res_non_stream.body["content"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("n_slots", [1, 2])
|
||||
def test_consistent_result_same_seed(n_slots: int):
|
||||
global server
|
||||
|
@ -221,3 +239,24 @@ def test_completion_parallel_slots(n_slots: int, n_requests: int):
|
|||
assert len(res.body["content"]) > 10
|
||||
# FIXME: the result is not deterministic when using other slot than slot 0
|
||||
# assert match_regex(re_content, res.body["content"])
|
||||
|
||||
|
||||
def test_n_probs():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_request("POST", "/completion", data={
|
||||
"prompt": "I believe the meaning of life is",
|
||||
"n_probs": 10,
|
||||
"temperature": 0.0,
|
||||
"n_predict": 5,
|
||||
})
|
||||
assert res.status_code == 200
|
||||
assert "completion_probabilities" in res.body
|
||||
assert len(res.body["completion_probabilities"]) == 5
|
||||
for tok in res.body["completion_probabilities"]:
|
||||
assert "probs" in tok
|
||||
assert len(tok["probs"]) == 10
|
||||
for prob in tok["probs"]:
|
||||
assert "prob" in prob
|
||||
assert "tok_str" in prob
|
||||
assert 0.0 <= prob["prob"] <= 1.0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue