lora per request
This commit is contained in:
parent
2ba6efc561
commit
9d84127fa6
3 changed files with 125 additions and 54 deletions
|
@ -64,7 +64,6 @@ enum server_task_type {
|
||||||
SERVER_TASK_TYPE_SLOT_SAVE,
|
SERVER_TASK_TYPE_SLOT_SAVE,
|
||||||
SERVER_TASK_TYPE_SLOT_RESTORE,
|
SERVER_TASK_TYPE_SLOT_RESTORE,
|
||||||
SERVER_TASK_TYPE_SLOT_ERASE,
|
SERVER_TASK_TYPE_SLOT_ERASE,
|
||||||
SERVER_TASK_TYPE_SET_LORA,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||||
|
@ -91,6 +90,8 @@ struct slot_params {
|
||||||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||||
|
|
||||||
|
std::vector<common_lora_adapter_container> lora;
|
||||||
|
|
||||||
std::vector<std::string> antiprompt;
|
std::vector<std::string> antiprompt;
|
||||||
std::vector<std::string> response_fields;
|
std::vector<std::string> response_fields;
|
||||||
bool timings_per_token = false;
|
bool timings_per_token = false;
|
||||||
|
@ -114,6 +115,11 @@ struct slot_params {
|
||||||
samplers.emplace_back(common_sampler_type_to_str(sampler));
|
samplers.emplace_back(common_sampler_type_to_str(sampler));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
json lora = json::array();
|
||||||
|
for (size_t i = 0; i < this->lora.size(); ++i) {
|
||||||
|
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
|
||||||
|
}
|
||||||
|
|
||||||
return json {
|
return json {
|
||||||
{"n_predict", n_predict}, // Server configured n_predict
|
{"n_predict", n_predict}, // Server configured n_predict
|
||||||
{"seed", sampling.seed},
|
{"seed", sampling.seed},
|
||||||
|
@ -154,6 +160,7 @@ struct slot_params {
|
||||||
{"speculative.p_min", speculative.p_min},
|
{"speculative.p_min", speculative.p_min},
|
||||||
{"timings_per_token", timings_per_token},
|
{"timings_per_token", timings_per_token},
|
||||||
{"post_sampling_probs", post_sampling_probs},
|
{"post_sampling_probs", post_sampling_probs},
|
||||||
|
{"lora", lora},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -189,6 +196,7 @@ struct server_task {
|
||||||
const llama_model * model,
|
const llama_model * model,
|
||||||
const llama_context * ctx,
|
const llama_context * ctx,
|
||||||
const common_params & params_base,
|
const common_params & params_base,
|
||||||
|
const std::vector<common_lora_adapter_container> & base_lora,
|
||||||
const json & data) {
|
const json & data) {
|
||||||
slot_params params;
|
slot_params params;
|
||||||
|
|
||||||
|
@ -245,6 +253,16 @@ struct server_task {
|
||||||
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
||||||
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
||||||
|
|
||||||
|
if (data.contains("lora")) {
|
||||||
|
if (data.at("lora").is_array()) {
|
||||||
|
params.lora = parse_lora_request(base_lora, data.at("lora"));
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
params.lora = base_lora;
|
||||||
|
}
|
||||||
|
|
||||||
// TODO: add more sanity checks for the input parameters
|
// TODO: add more sanity checks for the input parameters
|
||||||
|
|
||||||
if (params.sampling.penalty_last_n < -1) {
|
if (params.sampling.penalty_last_n < -1) {
|
||||||
|
@ -989,12 +1007,6 @@ struct server_task_result_slot_erase : server_task_result {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_task_result_apply_lora : server_task_result {
|
|
||||||
virtual json to_json() override {
|
|
||||||
return json {{ "success", true }};
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
struct server_slot {
|
struct server_slot {
|
||||||
int id;
|
int id;
|
||||||
int id_task = -1;
|
int id_task = -1;
|
||||||
|
@ -1009,6 +1021,8 @@ struct server_slot {
|
||||||
|
|
||||||
common_speculative * spec = nullptr;
|
common_speculative * spec = nullptr;
|
||||||
|
|
||||||
|
std::vector<common_lora_adapter_container> lora;
|
||||||
|
|
||||||
// the index relative to completion multi-task request
|
// the index relative to completion multi-task request
|
||||||
size_t index = 0;
|
size_t index = 0;
|
||||||
|
|
||||||
|
@ -1091,7 +1105,8 @@ struct server_slot {
|
||||||
}
|
}
|
||||||
|
|
||||||
bool can_batch_with(server_slot & other_slot) {
|
bool can_batch_with(server_slot & other_slot) {
|
||||||
return is_non_causal() == other_slot.is_non_causal();
|
return is_non_causal() == other_slot.is_non_causal()
|
||||||
|
&& are_lora_equal(lora, other_slot.lora);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool has_budget(const common_params & global_params) {
|
bool has_budget(const common_params & global_params) {
|
||||||
|
@ -1503,7 +1518,7 @@ struct server_context {
|
||||||
|
|
||||||
llama_model * model = nullptr;
|
llama_model * model = nullptr;
|
||||||
llama_context * ctx = nullptr;
|
llama_context * ctx = nullptr;
|
||||||
std::vector<common_lora_adapter_container> loras;
|
std::vector<common_lora_adapter_container> lora;
|
||||||
|
|
||||||
llama_model * model_dft = nullptr;
|
llama_model * model_dft = nullptr;
|
||||||
llama_context_params cparams_dft;
|
llama_context_params cparams_dft;
|
||||||
|
@ -1570,7 +1585,7 @@ struct server_context {
|
||||||
|
|
||||||
model = llama_init.model;
|
model = llama_init.model;
|
||||||
ctx = llama_init.context;
|
ctx = llama_init.context;
|
||||||
loras = llama_init.lora_adapters;
|
lora = llama_init.lora_adapters;
|
||||||
|
|
||||||
if (model == nullptr) {
|
if (model == nullptr) {
|
||||||
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
|
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
|
||||||
|
@ -1776,6 +1791,12 @@ struct server_context {
|
||||||
slot.params = std::move(task.params);
|
slot.params = std::move(task.params);
|
||||||
slot.prompt_tokens = std::move(task.prompt_tokens);
|
slot.prompt_tokens = std::move(task.prompt_tokens);
|
||||||
|
|
||||||
|
if (!are_lora_equal(task.params.lora, slot.lora)) {
|
||||||
|
// if lora is changed, we cannot reuse cached tokens
|
||||||
|
slot.cache_tokens.clear();
|
||||||
|
slot.lora = std::move(task.params.lora);
|
||||||
|
}
|
||||||
|
|
||||||
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
||||||
|
|
||||||
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
||||||
|
@ -2465,13 +2486,6 @@ struct server_context {
|
||||||
res->n_erased = n_erased;
|
res->n_erased = n_erased;
|
||||||
queue_results.send(std::move(res));
|
queue_results.send(std::move(res));
|
||||||
} break;
|
} break;
|
||||||
case SERVER_TASK_TYPE_SET_LORA:
|
|
||||||
{
|
|
||||||
common_lora_adapters_apply(ctx, loras);
|
|
||||||
auto res = std::make_unique<server_task_result_apply_lora>();
|
|
||||||
res->id = task.id;
|
|
||||||
queue_results.send(std::move(res));
|
|
||||||
} break;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2808,8 +2822,12 @@ struct server_context {
|
||||||
|
|
||||||
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
||||||
|
|
||||||
|
if (slot_batched) {
|
||||||
// make sure we're in the right embedding mode
|
// make sure we're in the right embedding mode
|
||||||
llama_set_embeddings(ctx, slot_batched && slot_batched->is_non_causal());
|
llama_set_embeddings(ctx, slot_batched->is_non_causal());
|
||||||
|
// apply lora, only need to do it once per batch
|
||||||
|
common_lora_adapters_apply(ctx, slot_batched->lora);
|
||||||
|
}
|
||||||
|
|
||||||
// process the created batch of tokens
|
// process the created batch of tokens
|
||||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||||
|
@ -3530,7 +3548,12 @@ int main(int argc, char ** argv) {
|
||||||
task.index = i;
|
task.index = i;
|
||||||
|
|
||||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||||
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
|
task.params = server_task::params_from_json_cmpl(
|
||||||
|
ctx_server.model,
|
||||||
|
ctx_server.ctx,
|
||||||
|
ctx_server.params_base,
|
||||||
|
ctx_server.lora,
|
||||||
|
data);
|
||||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||||
|
|
||||||
// OAI-compat
|
// OAI-compat
|
||||||
|
@ -3944,8 +3967,8 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
||||||
json result = json::array();
|
json result = json::array();
|
||||||
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
|
for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
|
||||||
auto & lora = ctx_server.loras[i];
|
auto & lora = ctx_server.lora[i];
|
||||||
result.push_back({
|
result.push_back({
|
||||||
{"id", i},
|
{"id", i},
|
||||||
{"path", lora.path},
|
{"path", lora.path},
|
||||||
|
@ -3957,40 +3980,13 @@ int main(int argc, char ** argv) {
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
||||||
const std::vector<json> body = json::parse(req.body);
|
const json body = json::parse(req.body);
|
||||||
int max_idx = ctx_server.loras.size();
|
if (!body.is_array()) {
|
||||||
|
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
|
||||||
// clear existing value
|
|
||||||
for (auto & lora : ctx_server.loras) {
|
|
||||||
lora.scale = 0.0f;
|
|
||||||
}
|
|
||||||
|
|
||||||
// set value
|
|
||||||
for (auto entry : body) {
|
|
||||||
int id = entry.at("id");
|
|
||||||
float scale = entry.at("scale");
|
|
||||||
if (0 <= id && id < max_idx) {
|
|
||||||
ctx_server.loras[id].scale = scale;
|
|
||||||
} else {
|
|
||||||
throw std::runtime_error("invalid adapter id");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
server_task task(SERVER_TASK_TYPE_SET_LORA);
|
|
||||||
task.id = ctx_server.queue_tasks.get_new_id();
|
|
||||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
|
||||||
ctx_server.queue_tasks.post(task);
|
|
||||||
|
|
||||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
|
||||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
|
||||||
|
|
||||||
if (result->is_error()) {
|
|
||||||
res_error(res, result->to_json());
|
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
ctx_server.lora = parse_lora_request(ctx_server.lora, body);
|
||||||
GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
|
res_ok(res, json{{"success", true}});
|
||||||
res_ok(res, result->to_json());
|
|
||||||
};
|
};
|
||||||
|
|
||||||
//
|
//
|
||||||
|
|
|
@ -40,3 +40,37 @@ def test_lora(scale: float, re_content: str):
|
||||||
assert res.status_code == 200
|
assert res.status_code == 200
|
||||||
assert match_regex(re_content, res.body["content"])
|
assert match_regex(re_content, res.body["content"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_per_request():
|
||||||
|
global server
|
||||||
|
server.n_slots = 4
|
||||||
|
server.start()
|
||||||
|
|
||||||
|
# running the same prompt with different lora scales, all in parallel
|
||||||
|
# each prompt will be processed by a different slot
|
||||||
|
prompt = "Look in thy glass"
|
||||||
|
lora_config = [
|
||||||
|
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
|
||||||
|
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
|
||||||
|
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
|
||||||
|
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
|
||||||
|
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
|
||||||
|
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
|
||||||
|
]
|
||||||
|
# FIXME: tesing with scale between 0.0 and 1.0 (i.e. 0.2, 0.5, 0.7) produces unreliable results
|
||||||
|
|
||||||
|
tasks = [(
|
||||||
|
server.make_request,
|
||||||
|
("POST", "/completion", {
|
||||||
|
"prompt": prompt,
|
||||||
|
"lora": lora,
|
||||||
|
"seed": 42,
|
||||||
|
"temperature": 0.0,
|
||||||
|
})
|
||||||
|
) for lora, re_test in lora_config]
|
||||||
|
results = parallel_function_calls(tasks)
|
||||||
|
|
||||||
|
print(results)
|
||||||
|
assert all([res.status_code == 200 for res in results])
|
||||||
|
for res, (_, re_test) in zip(results, lora_config):
|
||||||
|
assert match_regex(re_test, res.body["content"])
|
||||||
|
|
|
@ -771,3 +771,44 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
|
||||||
|
|
||||||
return cur;
|
return cur;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static bool are_lora_equal(
|
||||||
|
const std::vector<common_lora_adapter_container> & l1,
|
||||||
|
const std::vector<common_lora_adapter_container> & l2) {
|
||||||
|
if (l1.size() != l2.size()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
for (size_t i = 0; i < l1.size(); ++i) {
|
||||||
|
// we don't check lora.path to reduce the time complexity
|
||||||
|
if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse lora config from JSON request, returned a copy of base_lora with updated scale
|
||||||
|
static std::vector<common_lora_adapter_container> parse_lora_request(
|
||||||
|
const std::vector<common_lora_adapter_container> & base_lora,
|
||||||
|
const json & data) {
|
||||||
|
std::vector<common_lora_adapter_container> lora(base_lora);
|
||||||
|
int max_idx = lora.size();
|
||||||
|
|
||||||
|
// clear existing value
|
||||||
|
for (auto & entry : lora) {
|
||||||
|
entry.scale = 0.0f;
|
||||||
|
}
|
||||||
|
|
||||||
|
// set value
|
||||||
|
for (auto entry : data) {
|
||||||
|
int id = json_value(entry, "id", -1);
|
||||||
|
float scale = json_value(entry, "scale", 0.0f);
|
||||||
|
if (0 <= id && id < max_idx) {
|
||||||
|
lora[id].scale = scale;
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("invalid adapter id");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return lora;
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue