Restrict control vectors to predefined options
This commit is contained in:
parent
6eae8bf5c3
commit
d4897432a1
2 changed files with 106 additions and 26 deletions
|
@ -120,6 +120,8 @@ struct server_params {
|
||||||
|
|
||||||
std::vector<std::string> api_keys;
|
std::vector<std::string> api_keys;
|
||||||
|
|
||||||
|
std::vector<llama_control_vector_load_option> control_vector_load_options;
|
||||||
|
|
||||||
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
#ifdef CPPHTTPLIB_OPENSSL_SUPPORT
|
||||||
std::string ssl_key_file = "";
|
std::string ssl_key_file = "";
|
||||||
std::string ssl_cert_file = "";
|
std::string ssl_cert_file = "";
|
||||||
|
@ -2735,6 +2737,25 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
||||||
}
|
}
|
||||||
params.control_vector_layer_end = std::stoi(argv[i]);
|
params.control_vector_layer_end = std::stoi(argv[i]);
|
||||||
break;
|
break;
|
||||||
|
} else if (arg == "--control-vector-option") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
char *name = argv[i];
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
size_t slen = strlen(argv[i]);
|
||||||
|
bool is_dir = slen < 5 || strncmp(argv[i] + slen - 5, ".gguf", 5) != 0;
|
||||||
|
|
||||||
|
// Append path separator for dirs
|
||||||
|
std::string fname = argv[i];
|
||||||
|
if (is_dir && argv[i][slen - 1] != '/')
|
||||||
|
fname += '/';
|
||||||
|
sparams.control_vector_load_options.push_back({ argv[i-1], fname, is_dir });
|
||||||
|
break;
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
server_print_usage(argv[0], default_params, default_sparams);
|
server_print_usage(argv[0], default_params, default_sparams);
|
||||||
|
@ -3183,6 +3204,16 @@ int main(int argc, char ** argv) {
|
||||||
res.status = 200; // HTTP OK
|
res.status = 200; // HTTP OK
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const auto handle_control_vector_options = [&sparams](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
json options = json::array();
|
||||||
|
|
||||||
|
for (const auto & opt : sparams.control_vector_load_options) {
|
||||||
|
options.push_back(opt.name);
|
||||||
|
}
|
||||||
|
res.set_content(options.dump(), "application/json; charset=utf-8");
|
||||||
|
};
|
||||||
|
|
||||||
const auto handle_get_control_vectors = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_get_control_vectors = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
json vectors = json::array();
|
json vectors = json::array();
|
||||||
|
@ -3201,23 +3232,62 @@ int main(int argc, char ** argv) {
|
||||||
res.set_content(data.dump(), "application/json; charset=utf-8");
|
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||||
};
|
};
|
||||||
|
|
||||||
const auto handle_set_control_vectors = [&ctx_server, &res_error, &handle_get_control_vectors](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_set_control_vectors = [&ctx_server, &sparams, &res_error, &handle_get_control_vectors](const httplib::Request & req, httplib::Response & res) {
|
||||||
json data = json::parse(req.body);
|
json data = json::parse(req.body);
|
||||||
|
|
||||||
|
// vector parameters passed by user
|
||||||
std::vector<llama_control_vector_load_info> vec_params;
|
std::vector<llama_control_vector_load_info> vec_params;
|
||||||
|
// names translated to real file names
|
||||||
|
std::vector<llama_control_vector_load_info> real_vec_params;
|
||||||
|
|
||||||
if (data.contains("vectors") && data["vectors"].is_array()) {
|
if (data.contains("vectors") && data["vectors"].is_array()) {
|
||||||
for (const auto &item : data["vectors"]) {
|
for (const auto &item : data["vectors"]) {
|
||||||
auto v = item.get<llama_control_vector_load_info>();
|
llama_control_vector_load_info v = item.get<llama_control_vector_load_info>();
|
||||||
std::cout << "Add vector: " << v.fname << " " << v.strength << "\n";
|
std::string real_fname = "";
|
||||||
vec_params.push_back(v);
|
std::cout << "Check vec " << v.fname << "\n";
|
||||||
|
// check for path traversal attempt
|
||||||
|
if (v.fname.length() > 0 && v.fname[0] != '/' && v.fname[0] != '\\') {
|
||||||
|
if (v.fname.find("../") == -1 && v.fname.find("..\\") == -1 &&
|
||||||
|
v.fname.find("/..") == -1 && v.fname.find("\\..") == -1) {
|
||||||
|
|
||||||
|
// check if vector name matches allowed names
|
||||||
|
for (auto opt : sparams.control_vector_load_options) {
|
||||||
|
std::cout << "check option " << opt.name << " : " << opt.fname << " : " << opt.is_dir << "\n";
|
||||||
|
if (!opt.is_dir && opt.name == v.fname) {
|
||||||
|
std::cout << "file exact match\n";
|
||||||
|
real_fname = opt.fname;
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
} else {
|
if (opt.is_dir && v.fname.rfind(opt.name, 0) == 0) {
|
||||||
std::cerr << "No vectors passed\n";
|
std::cout << "file exact match\n";
|
||||||
|
// opt.fname already includes '/' (or '\') while opt.name doesn't
|
||||||
|
real_fname = opt.fname + v.fname.substr(opt.name.length() + 1);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (real_fname.length() == 0) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
res_error(res, format_error_response("No vectors passed", ERROR_TYPE_SERVER));
|
res_error(res, format_error_response("Control vector not allowed", ERROR_TYPE_SERVER));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const auto cvec = llama_control_vector_load(vec_params);
|
|
||||||
|
std::cout << "Add vector: " << v.fname << " -> " << real_fname << " " << v.strength << "\n";
|
||||||
|
llama_control_vector_load_info real_info = { v.strength, real_fname };
|
||||||
|
vec_params.push_back(v);
|
||||||
|
real_vec_params.push_back(real_info);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::cerr << "No vectors array passed\n";
|
||||||
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
res_error(res, format_error_response("No vectors array passed. If you want reset to 0, send an empty array.", ERROR_TYPE_SERVER));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto cvec = llama_control_vector_load(real_vec_params);
|
||||||
|
|
||||||
if (cvec.n_embd == -1) {
|
if (cvec.n_embd == -1) {
|
||||||
std::cerr << "Could not load control vector\n";
|
std::cerr << "Could not load control vector\n";
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
@ -3231,6 +3301,7 @@ int main(int argc, char ** argv) {
|
||||||
if (ctx_server.params.control_vector_layer_end <= 0){
|
if (ctx_server.params.control_vector_layer_end <= 0){
|
||||||
ctx_server.params.control_vector_layer_end = llama_n_layer(ctx_server.model);
|
ctx_server.params.control_vector_layer_end = llama_n_layer(ctx_server.model);
|
||||||
}
|
}
|
||||||
|
|
||||||
int err = llama_control_vector_apply(ctx_server.ctx,
|
int err = llama_control_vector_apply(ctx_server.ctx,
|
||||||
cvec.data.data(),
|
cvec.data.data(),
|
||||||
cvec.data.size(),
|
cvec.data.size(),
|
||||||
|
@ -3243,7 +3314,9 @@ int main(int argc, char ** argv) {
|
||||||
res_error(res, format_error_response("Could not apply control vector", ERROR_TYPE_SERVER));
|
res_error(res, format_error_response("Could not apply control vector", ERROR_TYPE_SERVER));
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx_server.params.control_vectors.clear();
|
ctx_server.params.control_vectors.clear();
|
||||||
|
|
||||||
for (auto v : vec_params) {
|
for (auto v : vec_params) {
|
||||||
std::cout << "set vector param: " << v.fname << " " << v.strength << "\n";
|
std::cout << "set vector param: " << v.fname << " " << v.strength << "\n";
|
||||||
ctx_server.params.control_vectors.push_back(v);
|
ctx_server.params.control_vectors.push_back(v);
|
||||||
|
@ -3605,6 +3678,7 @@ int main(int argc, char ** argv) {
|
||||||
svr->Get ("/props", handle_props);
|
svr->Get ("/props", handle_props);
|
||||||
svr->Get ("/v1/models", handle_models);
|
svr->Get ("/v1/models", handle_models);
|
||||||
svr->Get ("/control-vectors", handle_get_control_vectors);
|
svr->Get ("/control-vectors", handle_get_control_vectors);
|
||||||
|
svr->Get ("/control-vector-options", handle_control_vector_options);
|
||||||
svr->Post("/control-vectors", handle_set_control_vectors);
|
svr->Post("/control-vectors", handle_set_control_vectors);
|
||||||
svr->Post("/completion", handle_completions); // legacy
|
svr->Post("/completion", handle_completions); // legacy
|
||||||
svr->Post("/completions", handle_completions);
|
svr->Post("/completions", handle_completions);
|
||||||
|
|
|
@ -620,3 +620,9 @@ static void from_json(const json& j, llama_control_vector_load_info& l) {
|
||||||
j.at("strength").get_to(l.strength);
|
j.at("strength").get_to(l.strength);
|
||||||
j.at("fname").get_to(l.fname);
|
j.at("fname").get_to(l.fname);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct llama_control_vector_load_option {
|
||||||
|
std::string name;
|
||||||
|
std::string fname;
|
||||||
|
bool is_dir;
|
||||||
|
};
|
Loading…
Add table
Add a link
Reference in a new issue