diff --git a/examples/server/server.cpp b/examples/server/server.cpp index 1ab80412b..6ce64c94d 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -3176,10 +3176,10 @@ int main(int argc, char ** argv) { res.status = 200; // HTTP OK }; - const auto handle_get_control_vectors = [&ctx_server, ¶ms](const httplib::Request & req, httplib::Response & res) { + const auto handle_get_control_vectors = [&ctx_server](const httplib::Request & req, httplib::Response & res) { json vectors = json::array(); - for (const auto & vec : params.control_vectors) { + for (const auto & vec : ctx_server.params.control_vectors) { vectors.push_back(json { { "fname", vec.fname }, { "strength", vec.strength } @@ -3187,13 +3187,13 @@ int main(int argc, char ** argv) { } json data = { { "vectors", vectors }, - { "layer_start", params.control_vector_layer_start }, - { "layer_end", params.control_vector_layer_end } + { "layer_start", ctx_server.params.control_vector_layer_start }, + { "layer_end", ctx_server.params.control_vector_layer_end } }; res.set_content(data.dump(), "application/json; charset=utf-8"); }; - const auto handle_set_control_vectors = [&ctx_server, &res_error, ¶ms, &handle_get_control_vectors](const httplib::Request & req, httplib::Response & res) { + const auto handle_set_control_vectors = [&ctx_server, &res_error, &handle_get_control_vectors](const httplib::Request & req, httplib::Response & res) { res.set_header("Access-Control-Allow-Origin", req.get_header_value("Origin")); json data = json::parse(req.body); @@ -3202,7 +3202,7 @@ int main(int argc, char ** argv) { if (data.contains("vectors") && data["vectors"].is_array()) { for (const auto &item : data["vectors"]) { auto v = item.get(); - // std::cout << "Add vector: " << v.fname << " " << v.strength << "\n"; + std::cout << "Add vector: " << v.fname << " " << v.strength << "\n"; vec_params.push_back(v); } } else { @@ -3210,44 +3210,47 @@ int main(int argc, char ** argv) { res_error(res, format_error_response("No vectors passed", ERROR_TYPE_SERVER)); return; } - for (auto v : params.control_vectors) { - // std::cout << "Subtract vector:" << v.fname << " " << v.strength << "\n"; - vec_params.push_back({ -v.strength, v.fname }); - } const auto cvec = llama_control_vector_load(vec_params); if (cvec.n_embd == -1) { - // std::cerr << "Could not load control vector\n"; + std::cerr << "Could not load control vector\n"; res_error(res, format_error_response("Could not load control vector", ERROR_TYPE_SERVER)); return; } - if (params.control_vector_layer_start <= 0) { - params.control_vector_layer_start = 1; + if (ctx_server.params.control_vector_layer_start <= 0) { + ctx_server.params.control_vector_layer_start = 1; } - if (params.control_vector_layer_end <= 0){ - params.control_vector_layer_end = llama_n_layer(ctx_server.model); + if (ctx_server.params.control_vector_layer_end <= 0){ + ctx_server.params.control_vector_layer_end = llama_n_layer(ctx_server.model); } int err = llama_control_vector_apply(ctx_server.ctx, cvec.data.data(), cvec.data.size(), cvec.n_embd, - params.control_vector_layer_start, - params.control_vector_layer_end); + ctx_server.params.control_vector_layer_start, + ctx_server.params.control_vector_layer_end); if (err) { std::cerr << "Could not apply control vector\n"; res_error(res, format_error_response("Could not apply control vector", ERROR_TYPE_SERVER)); return; } - auto s = params.control_vectors.size(); - auto s2 = vec_params.size(); - params.control_vectors.clear(); - unsigned i = 0; + ctx_server.params.control_vectors.clear(); for (auto v : vec_params) { - if (i++ < s2 - s) { - //std::cout << "set vector param: " << v.fname << " " << v.strength << "\n"; - params.control_vectors.push_back(v); - } + //std::cout << "set vector param: " << v.fname << " " << v.strength << "\n"; + ctx_server.params.control_vectors.push_back(v); } + + /*std::cerr << "Maybe we need to do this initiation ritual before it werks?\n"; // No, it's still all garbled bullshit. + + std::vector tmp = { llama_token_bos(ctx_server.model), llama_token_eos(ctx_server.model), }; + std::cerr << "decode, bro\n"; + llama_decode(ctx_server.ctx, llama_batch_get_one(tmp.data(), std::min(tmp.size(), (size_t) ctx_server.params.n_batch), 0, 0)); + std::cerr << "clear that fucking cache\n"; + llama_kv_cache_clear(ctx_server.ctx); + std::cerr << "symcr0nice or what\n"; + llama_synchronize(ctx_server.ctx); + std::cerr << "time will tell\n"; + llama_reset_timings(ctx_server.ctx);*/ handle_get_control_vectors(req, res); };