Control vectors in server
This commit is contained in:
parent
a0e584defd
commit
0274e6b364
2 changed files with 111 additions and 1 deletions
|
@ -624,7 +624,6 @@ struct server_response {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
struct server_context {
|
struct server_context {
|
||||||
llama_model * model = nullptr;
|
llama_model * model = nullptr;
|
||||||
llama_context * ctx = nullptr;
|
llama_context * ctx = nullptr;
|
||||||
|
@ -2700,6 +2699,35 @@ static void server_params_parse(int argc, char ** argv, server_params & sparams,
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
params.kv_overrides.push_back(kvo);
|
params.kv_overrides.push_back(kvo);
|
||||||
|
} else if (arg == "--control-vector") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.control_vectors.push_back({ 1.0f, argv[i], });
|
||||||
|
} else if (arg == "--control-vector-scaled") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
const char* fname = argv[i];
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.control_vectors.push_back({ std::stof(argv[i]), fname, });
|
||||||
|
} else if (arg == "--control-vector-layer-range") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.control_vector_layer_start = std::stoi(argv[i]);
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
params.control_vector_layer_end = std::stoi(argv[i]);
|
||||||
|
break;
|
||||||
} else {
|
} else {
|
||||||
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
fprintf(stderr, "error: unknown argument: %s\n", arg.c_str());
|
||||||
server_print_usage(argv[0], default_params, default_sparams);
|
server_print_usage(argv[0], default_params, default_sparams);
|
||||||
|
@ -3148,6 +3176,81 @@ int main(int argc, char ** argv) {
|
||||||
res.status = 200; // HTTP OK
|
res.status = 200; // HTTP OK
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const auto handle_get_control_vectors = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
json vectors = json::array();
|
||||||
|
|
||||||
|
for (const auto & vec : params.control_vectors) {
|
||||||
|
vectors.push_back(json {
|
||||||
|
{ "fname", vec.fname },
|
||||||
|
{ "strength", vec.strength }
|
||||||
|
});
|
||||||
|
}
|
||||||
|
json data = {
|
||||||
|
{ "vectors", vectors },
|
||||||
|
{ "layer_start", params.control_vector_layer_start },
|
||||||
|
{ "layer_end", params.control_vector_layer_end }
|
||||||
|
};
|
||||||
|
res.set_content(data.dump(), "application/json; charset=utf-8");
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto handle_set_control_vectors = [&ctx_server, &res_error, ¶ms, &handle_get_control_vectors](const httplib::Request & req, httplib::Response & res) {
|
||||||
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
|
|
||||||
|
json data = json::parse(req.body);
|
||||||
|
std::vector<llama_control_vector_load_info> vec_params;
|
||||||
|
|
||||||
|
if (data.contains("vectors") && data["vectors"].is_array()) {
|
||||||
|
for (const auto &item : data["vectors"]) {
|
||||||
|
auto v = item.get<llama_control_vector_load_info>();
|
||||||
|
// std::cout << "Add vector: " << v.fname << " " << v.strength << "\n";
|
||||||
|
vec_params.push_back(v);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
std::cerr << "No vectors passed\n";
|
||||||
|
res_error(res, format_error_response("No vectors passed", ERROR_TYPE_SERVER));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
for (auto v : params.control_vectors) {
|
||||||
|
// std::cout << "Subtract vector:" << v.fname << " " << v.strength << "\n";
|
||||||
|
vec_params.push_back({ -v.strength, v.fname });
|
||||||
|
}
|
||||||
|
const auto cvec = llama_control_vector_load(vec_params);
|
||||||
|
if (cvec.n_embd == -1) {
|
||||||
|
// std::cerr << "Could not load control vector\n";
|
||||||
|
res_error(res, format_error_response("Could not load control vector", ERROR_TYPE_SERVER));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.control_vector_layer_start <= 0) {
|
||||||
|
params.control_vector_layer_start = 1;
|
||||||
|
}
|
||||||
|
if (params.control_vector_layer_end <= 0){
|
||||||
|
params.control_vector_layer_end = llama_n_layer(ctx_server.model);
|
||||||
|
}
|
||||||
|
int err = llama_control_vector_apply(ctx_server.ctx,
|
||||||
|
cvec.data.data(),
|
||||||
|
cvec.data.size(),
|
||||||
|
cvec.n_embd,
|
||||||
|
params.control_vector_layer_start,
|
||||||
|
params.control_vector_layer_end);
|
||||||
|
if (err) {
|
||||||
|
std::cerr << "Could not apply control vector\n";
|
||||||
|
res_error(res, format_error_response("Could not apply control vector", ERROR_TYPE_SERVER));
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
auto s = params.control_vectors.size();
|
||||||
|
auto s2 = vec_params.size();
|
||||||
|
params.control_vectors.clear();
|
||||||
|
unsigned i = 0;
|
||||||
|
for (auto v : vec_params) {
|
||||||
|
if (i++ < s2 - s) {
|
||||||
|
//std::cout << "set vector param: " << v.fname << " " << v.strength << "\n";
|
||||||
|
params.control_vectors.push_back(v);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
handle_get_control_vectors(req, res);
|
||||||
|
};
|
||||||
|
|
||||||
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
const auto handle_props = [&ctx_server](const httplib::Request & req, httplib::Response & res) {
|
||||||
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin"));
|
||||||
json data = {
|
json data = {
|
||||||
|
@ -3497,8 +3600,10 @@ int main(int argc, char ** argv) {
|
||||||
svr->Get ("/health", handle_health);
|
svr->Get ("/health", handle_health);
|
||||||
svr->Get ("/slots", handle_slots);
|
svr->Get ("/slots", handle_slots);
|
||||||
svr->Get ("/metrics", handle_metrics);
|
svr->Get ("/metrics", handle_metrics);
|
||||||
|
svr->Get ("/control-vectors", handle_get_control_vectors);
|
||||||
svr->Get ("/props", handle_props);
|
svr->Get ("/props", handle_props);
|
||||||
svr->Get ("/v1/models", handle_models);
|
svr->Get ("/v1/models", handle_models);
|
||||||
|
svr->Post("/control-vectors", handle_set_control_vectors);
|
||||||
svr->Post("/completion", handle_completions); // legacy
|
svr->Post("/completion", handle_completions); // legacy
|
||||||
svr->Post("/completions", handle_completions);
|
svr->Post("/completions", handle_completions);
|
||||||
svr->Post("/v1/completions", handle_completions);
|
svr->Post("/v1/completions", handle_completions);
|
||||||
|
|
|
@ -615,3 +615,8 @@ static json format_error_response(const std::string & message, const enum error_
|
||||||
{"type", type_str},
|
{"type", type_str},
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void from_json(const json& j, llama_control_vector_load_info& l) {
|
||||||
|
j.at("strength").get_to(l.strength);
|
||||||
|
j.at("fname").get_to(l.fname);
|
||||||
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue