diff --git a/examples/server/README.md b/examples/server/README.md index 91b5c9424..3ce16945a 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -947,6 +947,8 @@ This endpoint returns the loaded LoRA adapters. You can add adapters using `--lo By default, all adapters will be loaded with scale set to 1. To initialize all adapters scale to 0, add `--lora-init-without-apply` +Please note that this value will be overwritten by the `lora` field for each request. + If an adapter is disabled, the scale will be set to 0. **Response format** @@ -968,6 +970,8 @@ If an adapter is disabled, the scale will be set to 0. ### POST `/lora-adapters`: Set list of LoRA adapters +This sets the global scale for LoRA adapters. Please note that this value will be overwritten by the `lora` field for each request. + To disable an adapter, either remove it from the list below, or set scale to 0. **Request format** diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 315eaf94b..8b02c1195 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -64,6 +64,7 @@ enum server_task_type { SERVER_TASK_TYPE_SLOT_SAVE, SERVER_TASK_TYPE_SLOT_RESTORE, SERVER_TASK_TYPE_SLOT_ERASE, + SERVER_TASK_TYPE_SET_LORA, }; enum oaicompat_type { @@ -196,6 +197,9 @@ struct server_task { // used by SERVER_TASK_TYPE_METRICS bool metrics_reset_bucket = false; + // used by SERVER_TASK_TYPE_SET_LORA + std::vector set_lora; + server_task(server_task_type type) : type(type) {} static slot_params params_from_json_cmpl( @@ -1108,6 +1112,12 @@ struct server_task_result_slot_erase : server_task_result { } }; +struct server_task_result_apply_lora : server_task_result { + virtual json to_json() override { + return json {{ "success", true }}; + } +}; + struct server_slot { int id; int id_task = -1; @@ -2580,6 +2590,13 @@ struct server_context { res->n_erased = n_erased; queue_results.send(std::move(res)); } break; + case SERVER_TASK_TYPE_SET_LORA: + { + lora = std::move(task.set_lora); + auto res = std::make_unique(); + res->id = task.id; + queue_results.send(std::move(res)); + } break; } } @@ -4099,8 +4116,22 @@ int main(int argc, char ** argv) { res_error(res, format_error_response("Request body must be an array", ERROR_TYPE_INVALID_REQUEST)); return; } - ctx_server.lora = parse_lora_request(ctx_server.lora, body); - res_ok(res, json{{"success", true}}); + server_task task(SERVER_TASK_TYPE_SET_LORA); + task.id = ctx_server.queue_tasks.get_new_id(); + task.set_lora = parse_lora_request(ctx_server.lora, body); + ctx_server.queue_results.add_waiting_task_id(task.id); + ctx_server.queue_tasks.post(task); + + server_task_result_ptr result = ctx_server.queue_results.recv(task.id); + ctx_server.queue_results.remove_waiting_task_id(task.id); + + if (result->is_error()) { + res_error(res, result->to_json()); + return; + } + + GGML_ASSERT(dynamic_cast(result.get()) != nullptr); + res_ok(res, result->to_json()); }; // diff --git a/examples/server/tests/unit/test_lora.py b/examples/server/tests/unit/test_lora.py index 0481e62c0..c1aa8be70 100644 --- a/examples/server/tests/unit/test_lora.py +++ b/examples/server/tests/unit/test_lora.py @@ -1,5 +1,4 @@ import pytest -import os from utils import * server = ServerPreset.stories15m_moe()