move lora change task to queue

This commit is contained in:
Xuan Son Nguyen 2025-01-01 19:58:30 +01:00
parent bf7df95798
commit 1dbd16abb9
3 changed files with 37 additions and 3 deletions

View file

@ -947,6 +947,8 @@ This endpoint returns the loaded LoRA adapters. You can add adapters using `--lo
By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply`
Please note that this value will be overwritten by the `lora` field for each request.
If an adapter is disabled, the scale will be set to 0.
**Response format**
@ -968,6 +970,8 @@ If an adapter is disabled, the scale will be set to 0.
### POST `/lora-adapters`: Set list of LoRA adapters
This sets the global scale for LoRA adapters. Please note that this value will be overwritten by the `lora` field for each request.
To disable an adapter, either remove it from the list below, or set scale to 0.
**Request format**

View file

@ -64,6 +64,7 @@ enum server_task_type {
SERVER_TASK_TYPE_SLOT_SAVE,
SERVER_TASK_TYPE_SLOT_RESTORE,
SERVER_TASK_TYPE_SLOT_ERASE,
SERVER_TASK_TYPE_SET_LORA,
};
enum oaicompat_type {
@ -196,6 +197,9 @@ struct server_task {
// used by SERVER_TASK_TYPE_METRICS
bool metrics_reset_bucket = false;
// used by SERVER_TASK_TYPE_SET_LORA
std::vector<common_lora_adapter_container> set_lora;
server_task(server_task_type type) : type(type) {}
static slot_params params_from_json_cmpl(
@ -1108,6 +1112,12 @@ struct server_task_result_slot_erase : server_task_result {
}
};
struct server_task_result_apply_lora : server_task_result {
virtual json to_json() override {
return json {{ "success", true }};
}
};
struct server_slot {
int id;
int id_task = -1;
@ -2580,6 +2590,13 @@ struct server_context {
res->n_erased = n_erased;
queue_results.send(std::move(res));
} break;
case SERVER_TASK_TYPE_SET_LORA:
{
lora = std::move(task.set_lora);
auto res = std::make_unique<server_task_result_apply_lora>();
res->id = task.id;
queue_results.send(std::move(res));
} break;
}
}
@ -4099,8 +4116,22 @@ int main(int argc, char ** argv) {
res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST));
return;
}
ctx_server.lora = parse_lora_request(ctx_server.lora, body);
res_ok(res, json{{"success", true}});
server_task task(SERVER_TASK_TYPE_SET_LORA);
task.id = ctx_server.queue_tasks.get_new_id();
task.set_lora = parse_lora_request(ctx_server.lora, body);
ctx_server.queue_results.add_waiting_task_id(task.id);
ctx_server.queue_tasks.post(task);
server_task_result_ptr result = ctx_server.queue_results.recv(task.id);
ctx_server.queue_results.remove_waiting_task_id(task.id);
if (result->is_error()) {
res_error(res, result->to_json());
return;
}
GGML_ASSERT(dynamic_cast<server_task_result_apply_lora*>(result.get()) != nullptr);
res_ok(res, result->to_json());
};
//

View file

@ -1,5 +1,4 @@
import pytest
import os
from utils import *
server = ServerPreset.stories15m_moe()