lora per request
This commit is contained in:
parent
2ba6efc561
commit
9d84127fa6
3 changed files with 125 additions and 54 deletions
|
@ -64,7 +64,6 @@ enum server_task_type {
|
|||
SERVER_TASK_TYPE_SLOT_SAVE,
|
||||
SERVER_TASK_TYPE_SLOT_RESTORE,
|
||||
SERVER_TASK_TYPE_SLOT_ERASE,
|
||||
SERVER_TASK_TYPE_SET_LORA,
|
||||
};
|
||||
|
||||
// https://community.openai.com/t/openai-chat-list-of-error-codes-and-types/357791/11
|
||||
|
@ -91,6 +90,8 @@ struct slot_params {
|
|||
int64_t t_max_prompt_ms = -1; // TODO: implement
|
||||
int64_t t_max_predict_ms = -1; // if positive, limit the generation phase to this time limit
|
||||
|
||||
std::vector<common_lora_adapter_container> lora;
|
||||
|
||||
std::vector<std::string> antiprompt;
|
||||
std::vector<std::string> response_fields;
|
||||
bool timings_per_token = false;
|
||||
|
@ -114,6 +115,11 @@ struct slot_params {
|
|||
samplers.emplace_back(common_sampler_type_to_str(sampler));
|
||||
}
|
||||
|
||||
json lora = json::array();
|
||||
for (size_t i = 0; i < this->lora.size(); ++i) {
|
||||
lora.push_back({{"id", i}, {"scale", this->lora[i].scale}});
|
||||
}
|
||||
|
||||
return json {
|
||||
{"n_predict", n_predict}, // Server configured n_predict
|
||||
{"seed", sampling.seed},
|
||||
|
@ -154,6 +160,7 @@ struct slot_params {
|
|||
{"speculative.p_min", speculative.p_min},
|
||||
{"timings_per_token", timings_per_token},
|
||||
{"post_sampling_probs", post_sampling_probs},
|
||||
{"lora", lora},
|
||||
};
|
||||
}
|
||||
};
|
||||
|
@ -189,6 +196,7 @@ struct server_task {
|
|||
const llama_model * model,
|
||||
const llama_context * ctx,
|
||||
const common_params & params_base,
|
||||
const std::vector<common_lora_adapter_container> & base_lora,
|
||||
const json & data) {
|
||||
slot_params params;
|
||||
|
||||
|
@ -245,6 +253,16 @@ struct server_task {
|
|||
params.speculative.n_min = std::max(params.speculative.n_min, 2);
|
||||
params.speculative.n_max = std::max(params.speculative.n_max, 0);
|
||||
|
||||
if (data.contains("lora")) {
|
||||
if (data.at("lora").is_array()) {
|
||||
params.lora = parse_lora_request(base_lora, data.at("lora"));
|
||||
} else {
|
||||
throw std::runtime_error("Error: 'lora' must be an array of objects with 'id' and 'scale' fields");
|
||||
}
|
||||
} else {
|
||||
params.lora = base_lora;
|
||||
}
|
||||
|
||||
// TODO: add more sanity checks for the input parameters
|
||||
|
||||
if (params.sampling.penalty_last_n < -1) {
|
||||
|
@ -989,12 +1007,6 @@ struct server_task_result_slot_erase : server_task_result {
|
|||
}
|
||||
};
|
||||
|
||||
struct server_task_result_apply_lora : server_task_result {
|
||||
virtual json to_json() override {
|
||||
return json {{ "success", true }};
|
||||
}
|
||||
};
|
||||
|
||||
struct server_slot {
|
||||
int id;
|
||||
int id_task = -1;
|
||||
|
@ -1009,6 +1021,8 @@ struct server_slot {
|
|||
|
||||
common_speculative * spec = nullptr;
|
||||
|
||||
std::vector<common_lora_adapter_container> lora;
|
||||
|
||||
// the index relative to completion multi-task request
|
||||
size_t index = 0;
|
||||
|
||||
|
@ -1091,7 +1105,8 @@ struct server_slot {
|
|||
}
|
||||
|
||||
bool can_batch_with(server_slot & other_slot) {
|
||||
return is_non_causal() == other_slot.is_non_causal();
|
||||
return is_non_causal() == other_slot.is_non_causal()
|
||||
&& are_lora_equal(lora, other_slot.lora);
|
||||
}
|
||||
|
||||
bool has_budget(const common_params & global_params) {
|
||||
|
@ -1503,7 +1518,7 @@ struct server_context {
|
|||
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
std::vector<common_lora_adapter_container> loras;
|
||||
std::vector<common_lora_adapter_container> lora;
|
||||
|
||||
llama_model * model_dft = nullptr;
|
||||
llama_context_params cparams_dft;
|
||||
|
@ -1570,7 +1585,7 @@ struct server_context {
|
|||
|
||||
model = llama_init.model;
|
||||
ctx = llama_init.context;
|
||||
loras = llama_init.lora_adapters;
|
||||
lora = llama_init.lora_adapters;
|
||||
|
||||
if (model == nullptr) {
|
||||
SRV_ERR("failed to load model, '%s'\n", params_base.model.c_str());
|
||||
|
@ -1776,6 +1791,12 @@ struct server_context {
|
|||
slot.params = std::move(task.params);
|
||||
slot.prompt_tokens = std::move(task.prompt_tokens);
|
||||
|
||||
if (!are_lora_equal(task.params.lora, slot.lora)) {
|
||||
// if lora is changed, we cannot reuse cached tokens
|
||||
slot.cache_tokens.clear();
|
||||
slot.lora = std::move(task.params.lora);
|
||||
}
|
||||
|
||||
SLT_DBG(slot, "launching slot : %s\n", safe_json_to_str(slot.to_json()).c_str());
|
||||
|
||||
if (slot.n_predict > 0 && slot.params.n_predict > slot.n_predict) {
|
||||
|
@ -2465,13 +2486,6 @@ struct server_context {
|
|||
res->n_erased = n_erased;
|
||||
queue_results.send(std::move(res));
|
||||
} break;
|
||||
case SERVER_TASK_TYPE_SET_LORA:
|
||||
{
|
||||
common_lora_adapters_apply(ctx, loras);
|
||||
auto res = std::make_unique<server_task_result_apply_lora>();
|
||||
res->id = task.id;
|
||||
queue_results.send(std::move(res));
|
||||
} break;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -2808,8 +2822,12 @@ struct server_context {
|
|||
|
||||
SRV_DBG("decoding batch, n_tokens = %d\n", batch.n_tokens);
|
||||
|
||||
// make sure we're in the right embedding mode
|
||||
llama_set_embeddings(ctx, slot_batched && slot_batched->is_non_causal());
|
||||
if (slot_batched) {
|
||||
// make sure we're in the right embedding mode
|
||||
llama_set_embeddings(ctx, slot_batched->is_non_causal());
|
||||
// apply lora, only need to do it once per batch
|
||||
common_lora_adapters_apply(ctx, slot_batched->lora);
|
||||
}
|
||||
|
||||
// process the created batch of tokens
|
||||
for (int32_t i = 0; i < batch.n_tokens; i += n_batch) {
|
||||
|
@ -3530,7 +3548,12 @@ int main(int argc, char ** argv) {
|
|||
task.index = i;
|
||||
|
||||
task.prompt_tokens = std::move(tokenized_prompts[i]);
|
||||
task.params = server_task::params_from_json_cmpl(ctx_server.model, ctx_server.ctx, ctx_server.params_base, data);
|
||||
task.params = server_task::params_from_json_cmpl(
|
||||
ctx_server.model,
|
||||
ctx_server.ctx,
|
||||
ctx_server.params_base,
|
||||
ctx_server.lora,
|
||||
data);
|
||||
task.id_selected_slot = json_value(data, "id_slot", -1);
|
||||
|
||||
// OAI-compat
|
||||
|
@ -3944,8 +3967,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
const auto handle_lora_adapters_list = [&](const httplib::Request &, httplib::Response & res) {
|
||||
json result = json::array();
|
||||
for (size_t i = 0; i < ctx_server.loras.size(); ++i) {
|
||||
auto & lora = ctx_server.loras[i];
|
||||
for (size_t i = 0; i < ctx_server.lora.size(); ++i) {
|
||||
auto & lora = ctx_server.lora[i];
|
||||
result.push_back({
|
||||
{"id", i},
|
||||
{"path", lora.path},
|
||||
|
@ -3957,40 +3980,13 @@ int main(int argc, char ** argv) {
|
|||
};
|
||||
|
||||
const auto handle_lora_adapters_apply = [&](const httplib::Request & req, httplib::Response & res) {
|
||||
const std::vector<json> body = json::parse(req.body);
|
||||
int max_idx = ctx_server.loras.size();
|
||||
|
||||
// clear existing value
|
||||
for (auto & lora : ctx_server.loras) {
|
||||
lora.scale = 0.0f;
|
||||
}
|
||||
|
||||
// set value
|
||||
for (auto entry : body) {
|
||||
int id = entry.at("id");
|
||||
float scale = entry.at("scale");
|
||||
if (0 <= id && id < max_idx) {
|
||||
ctx_server.loras[id].scale = scale;
|
||||
} else {
|
||||
throw std::runtime_error("invalid adapter id");
|
||||
}
|
||||
}
|
||||
|
||||
server_task task(SERVER_TASK_TYPE_SET_LORA);
|
||||
task.id = ctx_server.queue_tasks.get_new_id();
|
||||
ctx_server.queue_results.add_waiting_task_id(task.id);
|
||||
ctx_server.queue_tasks.post(task);
|
||||
|
||||
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
|
||||
ctx_server.queue_results.remove_waiting_task_id(task.id);
|
||||
|
||||
if (result->is_error()) {
|
||||
res_error(res, result->to_json());
|
||||
const json body = json::parse(req.body);
|
||||
if (!body.is_array()) {
|
||||
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
|
||||
return;
|
||||
}
|
||||
|
||||
GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
|
||||
res_ok(res, result->to_json());
|
||||
ctx_server.lora = parse_lora_request(ctx_server.lora, body);
|
||||
res_ok(res, json{{"success", true}});
|
||||
};
|
||||
|
||||
//
|
||||
|
|
|
@ -40,3 +40,37 @@ def test_lora(scale: float, re_content: str):
|
|||
assert res.status_code == 200
|
||||
assert match_regex(re_content, res.body["content"])
|
||||
|
||||
|
||||
def test_lora_per_request():
|
||||
global server
|
||||
server.n_slots = 4
|
||||
server.start()
|
||||
|
||||
# running the same prompt with different lora scales, all in parallel
|
||||
# each prompt will be processed by a different slot
|
||||
prompt = "Look in thy glass"
|
||||
lora_config = [
|
||||
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
|
||||
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
|
||||
( [{"id": 0, "scale": 0.0}], "(bright|day|many|happy)+" ),
|
||||
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
|
||||
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
|
||||
( [{"id": 0, "scale": 1.0}], "(eye|love|glass|sun)+" ),
|
||||
]
|
||||
# FIXME: tesing with scale between 0.0 and 1.0 (i.e. 0.2, 0.5, 0.7) produces unreliable results
|
||||
|
||||
tasks = [(
|
||||
server.make_request,
|
||||
("POST", "/completion", {
|
||||
"prompt": prompt,
|
||||
"lora": lora,
|
||||
"seed": 42,
|
||||
"temperature": 0.0,
|
||||
})
|
||||
) for lora, re_test in lora_config]
|
||||
results = parallel_function_calls(tasks)
|
||||
|
||||
print(results)
|
||||
assert all([res.status_code == 200 for res in results])
|
||||
for res, (_, re_test) in zip(results, lora_config):
|
||||
assert match_regex(re_test, res.body["content"])
|
||||
|
|
|
@ -771,3 +771,44 @@ static std::vector<llama_token_data> get_token_probabilities(llama_context * ctx
|
|||
|
||||
return cur;
|
||||
}
|
||||
|
||||
static bool are_lora_equal(
|
||||
const std::vector<common_lora_adapter_container> & l1,
|
||||
const std::vector<common_lora_adapter_container> & l2) {
|
||||
if (l1.size() != l2.size()) {
|
||||
return false;
|
||||
}
|
||||
for (size_t i = 0; i < l1.size(); ++i) {
|
||||
// we don't check lora.path to reduce the time complexity
|
||||
if (l1[i].scale != l2[i].scale || l1[i].adapter != l2[i].adapter) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// parse lora config from JSON request, returned a copy of base_lora with updated scale
|
||||
static std::vector<common_lora_adapter_container> parse_lora_request(
|
||||
const std::vector<common_lora_adapter_container> & base_lora,
|
||||
const json & data) {
|
||||
std::vector<common_lora_adapter_container> lora(base_lora);
|
||||
int max_idx = lora.size();
|
||||
|
||||
// clear existing value
|
||||
for (auto & entry : lora) {
|
||||
entry.scale = 0.0f;
|
||||
}
|
||||
|
||||
// set value
|
||||
for (auto entry : data) {
|
||||
int id = json_value(entry, "id", -1);
|
||||
float scale = json_value(entry, "scale", 0.0f);
|
||||
if (0 <= id && id < max_idx) {
|
||||
lora[id].scale = scale;
|
||||
} else {
|
||||
throw std::runtime_error("invalid adapter id");
|
||||
}
|
||||
}
|
||||
|
||||
return lora;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue