server : implement cancellable request (#11285)
* server : implement cancellable request * fix typo * httplib 0.18.5 * fix i underflow
This commit is contained in:
parent
f26c874179
commit
f30f099228
4 changed files with 1396 additions and 431 deletions
File diff suppressed because it is too large
Load diff
|
@ -19,6 +19,7 @@
|
||||||
#include "loading.html.hpp"
|
#include "loading.html.hpp"
|
||||||
|
|
||||||
#include <atomic>
|
#include <atomic>
|
||||||
|
#include <chrono>
|
||||||
#include <condition_variable>
|
#include <condition_variable>
|
||||||
#include <cstddef>
|
#include <cstddef>
|
||||||
#include <cinttypes>
|
#include <cinttypes>
|
||||||
|
@ -32,6 +33,8 @@
|
||||||
|
|
||||||
using json = nlohmann::ordered_json;
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
constexpr int HTTP_POLLING_SECONDS = 1;
|
||||||
|
|
||||||
enum stop_type {
|
enum stop_type {
|
||||||
STOP_TYPE_NONE,
|
STOP_TYPE_NONE,
|
||||||
STOP_TYPE_EOS,
|
STOP_TYPE_EOS,
|
||||||
|
@ -1602,6 +1605,30 @@ struct server_response {
|
||||||
// should never reach here
|
// should never reach here
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// same as recv(), but have timeout in seconds
|
||||||
|
// if timeout is reached, nullptr is returned
|
||||||
|
server_task_result_ptr recv_with_timeout(const std::unordered_set<int> & id_tasks, int timeout) {
|
||||||
|
while (true) {
|
||||||
|
std::unique_lock<std::mutex> lock(mutex_results);
|
||||||
|
bool cr_res = condition_results.wait_for(lock, std::chrono::seconds(timeout), [&]{
|
||||||
|
return !queue_results.empty();
|
||||||
|
});
|
||||||
|
if (!cr_res) {
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < (int) queue_results.size(); i++) {
|
||||||
|
if (id_tasks.find(queue_results[i]->id) != id_tasks.end()) {
|
||||||
|
server_task_result_ptr res = std::move(queue_results[i]);
|
||||||
|
queue_results.erase(queue_results.begin() + i);
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// should never reach here
|
||||||
|
}
|
||||||
|
|
||||||
// single-task version of recv()
|
// single-task version of recv()
|
||||||
server_task_result_ptr recv(int id_task) {
|
server_task_result_ptr recv(int id_task) {
|
||||||
std::unordered_set<int> id_tasks = {id_task};
|
std::unordered_set<int> id_tasks = {id_task};
|
||||||
|
@ -2322,10 +2349,21 @@ struct server_context {
|
||||||
void receive_multi_results(
|
void receive_multi_results(
|
||||||
const std::unordered_set<int> & id_tasks,
|
const std::unordered_set<int> & id_tasks,
|
||||||
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
|
const std::function<void(std::vector<server_task_result_ptr>&)> & result_handler,
|
||||||
const std::function<void(json)> & error_handler) {
|
const std::function<void(json)> & error_handler,
|
||||||
|
const std::function<bool()> & is_connection_closed) {
|
||||||
std::vector<server_task_result_ptr> results(id_tasks.size());
|
std::vector<server_task_result_ptr> results(id_tasks.size());
|
||||||
for (size_t i = 0; i < id_tasks.size(); i++) {
|
for (int i = 0; i < (int)id_tasks.size(); i++) {
|
||||||
server_task_result_ptr result = queue_results.recv(id_tasks);
|
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
|
||||||
|
|
||||||
|
if (is_connection_closed()) {
|
||||||
|
cancel_tasks(id_tasks);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result == nullptr) {
|
||||||
|
i--; // retry
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
error_handler(result->to_json());
|
error_handler(result->to_json());
|
||||||
|
@ -2349,10 +2387,20 @@ struct server_context {
|
||||||
void receive_cmpl_results_stream(
|
void receive_cmpl_results_stream(
|
||||||
const std::unordered_set<int> & id_tasks,
|
const std::unordered_set<int> & id_tasks,
|
||||||
const std::function<bool(server_task_result_ptr&)> & result_handler,
|
const std::function<bool(server_task_result_ptr&)> & result_handler,
|
||||||
const std::function<void(json)> & error_handler) {
|
const std::function<void(json)> & error_handler,
|
||||||
|
const std::function<bool()> & is_connection_closed) {
|
||||||
size_t n_finished = 0;
|
size_t n_finished = 0;
|
||||||
while (true) {
|
while (true) {
|
||||||
server_task_result_ptr result = queue_results.recv(id_tasks);
|
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, HTTP_POLLING_SECONDS);
|
||||||
|
|
||||||
|
if (is_connection_closed()) {
|
||||||
|
cancel_tasks(id_tasks);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (result == nullptr) {
|
||||||
|
continue; // retry
|
||||||
|
}
|
||||||
|
|
||||||
if (result->is_error()) {
|
if (result->is_error()) {
|
||||||
error_handler(result->to_json());
|
error_handler(result->to_json());
|
||||||
|
@ -3633,6 +3681,7 @@ int main(int argc, char ** argv) {
|
||||||
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
|
const auto handle_completions_impl = [&ctx_server, &res_error, &res_ok](
|
||||||
server_task_type type,
|
server_task_type type,
|
||||||
json & data,
|
json & data,
|
||||||
|
std::function<bool()> is_connection_closed,
|
||||||
httplib::Response & res,
|
httplib::Response & res,
|
||||||
oaicompat_type oaicompat) {
|
oaicompat_type oaicompat) {
|
||||||
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
|
||||||
|
@ -3694,7 +3743,7 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}, [&](const json & error_data) {
|
}, [&](const json & error_data) {
|
||||||
res_error(res, error_data);
|
res_error(res, error_data);
|
||||||
});
|
}, is_connection_closed);
|
||||||
|
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||||
} else {
|
} else {
|
||||||
|
@ -3704,6 +3753,7 @@ int main(int argc, char ** argv) {
|
||||||
if (res_json.is_array()) {
|
if (res_json.is_array()) {
|
||||||
for (const auto & res : res_json) {
|
for (const auto & res : res_json) {
|
||||||
if (!server_sent_event(sink, "data", res)) {
|
if (!server_sent_event(sink, "data", res)) {
|
||||||
|
// sending failed (HTTP connection closed), cancel the generation
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -3713,6 +3763,9 @@ int main(int argc, char ** argv) {
|
||||||
}
|
}
|
||||||
}, [&](const json & error_data) {
|
}, [&](const json & error_data) {
|
||||||
server_sent_event(sink, "error", error_data);
|
server_sent_event(sink, "error", error_data);
|
||||||
|
}, [&sink]() {
|
||||||
|
// note: do not use req.is_connection_closed here because req is already destroyed
|
||||||
|
return !sink.is_writable();
|
||||||
});
|
});
|
||||||
if (oaicompat != OAICOMPAT_TYPE_NONE) {
|
if (oaicompat != OAICOMPAT_TYPE_NONE) {
|
||||||
static const std::string ev_done = "data: [DONE]\n\n";
|
static const std::string ev_done = "data: [DONE]\n\n";
|
||||||
|
@ -3735,6 +3788,7 @@ int main(int argc, char ** argv) {
|
||||||
return handle_completions_impl(
|
return handle_completions_impl(
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
data,
|
data,
|
||||||
|
req.is_connection_closed,
|
||||||
res,
|
res,
|
||||||
OAICOMPAT_TYPE_NONE);
|
OAICOMPAT_TYPE_NONE);
|
||||||
};
|
};
|
||||||
|
@ -3744,6 +3798,7 @@ int main(int argc, char ** argv) {
|
||||||
return handle_completions_impl(
|
return handle_completions_impl(
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
data,
|
data,
|
||||||
|
req.is_connection_closed,
|
||||||
res,
|
res,
|
||||||
OAICOMPAT_TYPE_COMPLETION);
|
OAICOMPAT_TYPE_COMPLETION);
|
||||||
};
|
};
|
||||||
|
@ -3820,6 +3875,7 @@ int main(int argc, char ** argv) {
|
||||||
return handle_completions_impl(
|
return handle_completions_impl(
|
||||||
SERVER_TASK_TYPE_INFILL,
|
SERVER_TASK_TYPE_INFILL,
|
||||||
data,
|
data,
|
||||||
|
req.is_connection_closed,
|
||||||
res,
|
res,
|
||||||
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
|
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
|
||||||
};
|
};
|
||||||
|
@ -3834,6 +3890,7 @@ int main(int argc, char ** argv) {
|
||||||
return handle_completions_impl(
|
return handle_completions_impl(
|
||||||
SERVER_TASK_TYPE_COMPLETION,
|
SERVER_TASK_TYPE_COMPLETION,
|
||||||
data,
|
data,
|
||||||
|
req.is_connection_closed,
|
||||||
res,
|
res,
|
||||||
OAICOMPAT_TYPE_CHAT);
|
OAICOMPAT_TYPE_CHAT);
|
||||||
};
|
};
|
||||||
|
@ -3980,7 +4037,7 @@ int main(int argc, char ** argv) {
|
||||||
}, [&](const json & error_data) {
|
}, [&](const json & error_data) {
|
||||||
res_error(res, error_data);
|
res_error(res, error_data);
|
||||||
error = true;
|
error = true;
|
||||||
});
|
}, req.is_connection_closed);
|
||||||
|
|
||||||
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
ctx_server.queue_results.remove_waiting_task_ids(task_ids);
|
||||||
}
|
}
|
||||||
|
@ -4070,7 +4127,7 @@ int main(int argc, char ** argv) {
|
||||||
}, [&](const json & error_data) {
|
}, [&](const json & error_data) {
|
||||||
res_error(res, error_data);
|
res_error(res, error_data);
|
||||||
error = true;
|
error = true;
|
||||||
});
|
}, req.is_connection_closed);
|
||||||
}
|
}
|
||||||
|
|
||||||
if (error) {
|
if (error) {
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
import requests
|
||||||
import time
|
import time
|
||||||
from openai import OpenAI
|
from openai import OpenAI
|
||||||
from utils import *
|
from utils import *
|
||||||
|
@ -405,3 +406,23 @@ def test_n_probs_post_sampling():
|
||||||
assert "bytes" in prob and type(prob["bytes"]) == list
|
assert "bytes" in prob and type(prob["bytes"]) == list
|
||||||
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
|
# because the test model usually output token with either 100% or 0% probability, we need to check all the top_probs
|
||||||
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
|
assert any(prob["prob"] == 1.0 for prob in tok["top_probs"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_request():
|
||||||
|
global server
|
||||||
|
server.n_ctx = 4096
|
||||||
|
server.n_predict = -1
|
||||||
|
server.n_slots = 1
|
||||||
|
server.server_slots = True
|
||||||
|
server.start()
|
||||||
|
# send a request that will take a long time, but cancel it before it finishes
|
||||||
|
try:
|
||||||
|
server.make_request("POST", "/completion", data={
|
||||||
|
"prompt": "I believe the meaning of life is",
|
||||||
|
}, timeout=0.1)
|
||||||
|
except requests.exceptions.ReadTimeout:
|
||||||
|
pass # expected
|
||||||
|
# make sure the slot is free
|
||||||
|
time.sleep(1) # wait for HTTP_POLLING_SECONDS
|
||||||
|
res = server.make_request("GET", "/slots")
|
||||||
|
assert res.body[0]["is_processing"] == False
|
||||||
|
|
|
@ -219,17 +219,18 @@ class ServerProcess:
|
||||||
path: str,
|
path: str,
|
||||||
data: dict | Any | None = None,
|
data: dict | Any | None = None,
|
||||||
headers: dict | None = None,
|
headers: dict | None = None,
|
||||||
|
timeout: float | None = None,
|
||||||
) -> ServerResponse:
|
) -> ServerResponse:
|
||||||
url = f"http://{self.server_host}:{self.server_port}{path}"
|
url = f"http://{self.server_host}:{self.server_port}{path}"
|
||||||
parse_body = False
|
parse_body = False
|
||||||
if method == "GET":
|
if method == "GET":
|
||||||
response = requests.get(url, headers=headers)
|
response = requests.get(url, headers=headers, timeout=timeout)
|
||||||
parse_body = True
|
parse_body = True
|
||||||
elif method == "POST":
|
elif method == "POST":
|
||||||
response = requests.post(url, headers=headers, json=data)
|
response = requests.post(url, headers=headers, json=data, timeout=timeout)
|
||||||
parse_body = True
|
parse_body = True
|
||||||
elif method == "OPTIONS":
|
elif method == "OPTIONS":
|
||||||
response = requests.options(url, headers=headers)
|
response = requests.options(url, headers=headers, timeout=timeout)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unimplemented method: {method}")
|
raise ValueError(f"Unimplemented method: {method}")
|
||||||
result = ServerResponse()
|
result = ServerResponse()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue