add test
This commit is contained in:
parent
1b301dbec3
commit
28d8c91741
3 changed files with 22 additions and 5 deletions
|
@ -416,7 +416,7 @@ node index.js
|
|||
|
||||
`samplers`: The order the samplers should be applied in. An array of strings representing sampler type names. If a sampler is not set, it will not be used. If a sampler is specified more than once, it will be applied multiple times. Default: `["dry", "top_k", "typ_p", "top_p", "min_p", "xtc", "temperature"]` - these are all the available values.
|
||||
|
||||
`timing_per_token`: Include prompt processing and text generation speed information in each response. Default: `false`
|
||||
`timings_per_token`: Include prompt processing and text generation speed information in each response. Default: `false`
|
||||
|
||||
**Response format**
|
||||
|
||||
|
|
|
@ -177,7 +177,7 @@ struct server_slot {
|
|||
bool stopped_word = false;
|
||||
bool stopped_limit = false;
|
||||
|
||||
bool timing_per_token = false;
|
||||
bool timings_per_token = false;
|
||||
|
||||
bool oaicompat = false;
|
||||
|
||||
|
@ -884,7 +884,7 @@ struct server_context {
|
|||
slot.oaicompat_model = "";
|
||||
}
|
||||
|
||||
slot.timing_per_token = json_value(data, "timing_per_token", false);
|
||||
slot.timings_per_token = json_value(data, "timings_per_token", false);
|
||||
|
||||
slot.params.stream = json_value(data, "stream", false);
|
||||
slot.params.cache_prompt = json_value(data, "cache_prompt", true);
|
||||
|
@ -1283,7 +1283,7 @@ struct server_context {
|
|||
{"speculative.n_max", slot.params.speculative.n_max},
|
||||
{"speculative.n_min", slot.params.speculative.n_min},
|
||||
{"speculative.p_min", slot.params.speculative.p_min},
|
||||
{"timing_per_token", slot.timing_per_token},
|
||||
{"timings_per_token", slot.timings_per_token},
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -1341,7 +1341,7 @@ struct server_context {
|
|||
res.data["model"] = slot.oaicompat_model;
|
||||
}
|
||||
|
||||
if (slot.timing_per_token) {
|
||||
if (slot.timings_per_token) {
|
||||
res.data["timings"] = slot.get_formated_timings();
|
||||
}
|
||||
|
||||
|
|
|
@ -146,3 +146,20 @@ def test_invalid_chat_completion_req(messages):
|
|||
})
|
||||
assert res.status_code == 400 or res.status_code == 500
|
||||
assert "error" in res.body
|
||||
|
||||
|
||||
def test_chat_completion_with_timings_per_token():
|
||||
global server
|
||||
server.start()
|
||||
res = server.make_stream_request("POST", "/chat/completions", data={
|
||||
"max_tokens": 10,
|
||||
"messages": [{"role": "user", "content": "test"}],
|
||||
"stream": True,
|
||||
"timings_per_token": True,
|
||||
})
|
||||
for data in res:
|
||||
assert "timings" in data
|
||||
assert "prompt_per_second" in data["timings"]
|
||||
assert "predicted_per_second" in data["timings"]
|
||||
assert "predicted_n" in data["timings"]
|
||||
assert data["timings"]["predicted_n"] <= 10
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue