diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 33ceda319..894afbd54 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -120,6 +120,8 @@ struct server_params { std::vector api_keys; + std::vector control_vector_load_options; + #ifdef CPPHTTPLIB_OPENSSL_SUPPORT std::string ssl_key_file = ""; std::string ssl_cert_file = ""; @@ -2735,6 +2737,25 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams, } params.control_vector_layer_end = std::stoi(argv[i]); break; + } else if (arg == "--control-vector-option") { + if (++i >= argc) { + invalid_param = true; + break; + } + char *name = argv[i]; + if (++i >= argc) { + invalid_param = true; + break; + } + size_t slen = strlen(argv[i]); + bool is_dir = slen < 5 || strncmp(argv[i] + slen - 5, ".gguf", 5) != 0; + + // Append path separator for dirs + std::string fname = argv[i]; + if (is_dir && argv[i][slen - 1] != '/') + fname += '/'; + sparams.control_vector_load_options.push_back({ argv[i-1], fname, is_dir }); + break; } else { fprintf(stderr, "error: unknown argument: %s\n", arg.c_str()); server_print_usage(argv[0], default_params, default_sparams); @@ -3183,6 +3204,16 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; + const auto handle_control_vector_options = [&sparams](const httplib::Request & req, httplib::Response & res) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + json options = json::array(); + + for (const auto & opt : sparams.control_vector_load_options) { + options.push_back(opt.name); + } + res.set_content(options.dump(), "application/json; charset=utf-8"); + }; + const auto handle_get_control_vectors = [&ctx_server](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json vectors = json::array(); @@ -3201,23 +3232,62 @@ int main(int argc, char ** argv) { res.set_content(data.dump(), "application/json; charset=utf-8"); }; - const auto handle_set_control_vectors = [&ctx_server, &res_error, &handle_get_control_vectors](const httplib::Request & req, httplib::Response & res) { + const auto handle_set_control_vectors = [&ctx_server, &sparams, &res_error, &handle_get_control_vectors](const httplib::Request & req, httplib::Response & res) { json data = json::parse(req.body); + + // vector parameters passed by user std::vector vec_params; + // names translated to real file names + std::vector real_vec_params; if (data.contains("vectors") && data["vectors"].is_array()) { for (const auto &item : data["vectors"]) { - auto v = item.get(); - std::cout << "Add vector: " << v.fname << " " << v.strength << "\n"; + llama_control_vector_load_info v = item.get(); + std::string real_fname = ""; + std::cout << "Check vec " << v.fname << "\n"; + // check for path traversal attempt + if (v.fname.length() > 0 && v.fname[0] != '/' && v.fname[0] != '\\') { + if (v.fname.find("../") == -1 && v.fname.find("..\\") == -1 && + v.fname.find("/..") == -1 && v.fname.find("\\..") == -1) { + + // check if vector name matches allowed names + for (auto opt : sparams.control_vector_load_options) { + std::cout << "check option " << opt.name << " : " << opt.fname << " : " << opt.is_dir << "\n"; + if (!opt.is_dir && opt.name == v.fname) { + std::cout << "file exact match\n"; + real_fname = opt.fname; + break; + } + if (opt.is_dir && v.fname.rfind(opt.name, 0) == 0) { + std::cout << "file exact match\n"; + // opt.fname already includes '/' (or '\') while opt.name doesn't + real_fname = opt.fname + v.fname.substr(opt.name.length() + 1); + break; + } + } + } + } + + if (real_fname.length() == 0) { + res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); + res_error(res, format_error_response("Control vector not allowed", ERROR_TYPE_SERVER)); + return; + } + + std::cout << "Add vector: " << v.fname << " -> " << real_fname << " " << v.strength << "\n"; + llama_control_vector_load_info real_info = { v.strength, real_fname }; vec_params.push_back(v); + real_vec_params.push_back(real_info); } } else { - std::cerr << "No vectors passed\n"; + std::cerr << "No vectors array passed\n"; res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); - res_error(res, format_error_response("No vectors passed", ERROR_TYPE_SERVER)); + res_error(res, format_error_response("No vectors array passed. If you want reset to 0, send an empty array.", ERROR_TYPE_SERVER)); return; } - const auto cvec = llama_control_vector_load(vec_params); + + const auto cvec = llama_control_vector_load(real_vec_params); + if (cvec.n_embd == -1) { std::cerr << "Could not load control vector\n"; res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); @@ -3231,6 +3301,7 @@ int main(int argc, char ** argv) { if (ctx_server.params.control_vector_layer_end <= 0){ ctx_server.params.control_vector_layer_end = llama_n_layer(ctx_server.model); } + int err = llama_control_vector_apply(ctx_server.ctx, cvec.data.data(), cvec.data.size(), @@ -3243,7 +3314,9 @@ int main(int argc, char ** argv) { res_error(res, format_error_response("Could not apply control vector", ERROR_TYPE_SERVER)); return; } + ctx_server.params.control_vectors.clear(); + for (auto v : vec_params) { std::cout << "set vector param: " << v.fname << " " << v.strength << "\n"; ctx_server.params.control_vectors.push_back(v); @@ -3599,24 +3672,25 @@ int main(int argc, char ** argv) { json_schema_to_grammar_mjs, json_schema_to_grammar_mjs_len, "text/javascript; charset=utf-8")); // register API routes - svr->Get ("/health", handle_health); - svr->Get ("/slots", handle_slots); - svr->Get ("/metrics", handle_metrics); - svr->Get ("/props", handle_props); - svr->Get ("/v1/models", handle_models); - svr->Get ("/control-vectors", handle_get_control_vectors); - svr->Post("/control-vectors", handle_set_control_vectors); - svr->Post("/completion", handle_completions); // legacy - svr->Post("/completions", handle_completions); - svr->Post("/v1/completions", handle_completions); - svr->Post("/chat/completions", handle_chat_completions); - svr->Post("/v1/chat/completions", handle_chat_completions); - svr->Post("/infill", handle_infill); - svr->Post("/embedding", handle_embeddings); // legacy - svr->Post("/embeddings", handle_embeddings); - svr->Post("/v1/embeddings", handle_embeddings); - svr->Post("/tokenize", handle_tokenize); - svr->Post("/detokenize", handle_detokenize); + svr->Get ("/health", handle_health); + svr->Get ("/slots", handle_slots); + svr->Get ("/metrics", handle_metrics); + svr->Get ("/props", handle_props); + svr->Get ("/v1/models", handle_models); + svr->Get ("/control-vectors", handle_get_control_vectors); + svr->Get ("/control-vector-options", handle_control_vector_options); + svr->Post("/control-vectors", handle_set_control_vectors); + svr->Post("/completion", handle_completions); // legacy + svr->Post("/completions", handle_completions); + svr->Post("/v1/completions", handle_completions); + svr->Post("/chat/completions", handle_chat_completions); + svr->Post("/v1/chat/completions", handle_chat_completions); + svr->Post("/infill", handle_infill); + svr->Post("/embedding", handle_embeddings); // legacy + svr->Post("/embeddings", handle_embeddings); + svr->Post("/v1/embeddings", handle_embeddings); + svr->Post("/tokenize", handle_tokenize); + svr->Post("/detokenize", handle_detokenize); // // Start the server diff --git a/examples/server/utils.hpp b/examples/server/utils.hpp index 79928264b..b6de21b82 100644 --- a/examples/server/utils.hpp +++ b/examples/server/utils.hpp @@ -617,6 +617,12 @@ static json format_error_response(const std::string & message, const enum error_ } static void from_json(const json& j, llama_control_vector_load_info& l) { - j.at("strength").get_to(l.strength); - j.at("fname").get_to(l.fname); + j.at("strength").get_to(l.strength); + j.at("fname").get_to(l.fname); } + +struct llama_control_vector_load_option { + std::string name; + std::string fname; + bool is_dir; +}; \ No newline at end of file