From c31c118d86d0725448933b37349db8304867fc59 Mon Sep 17 00:00:00 2001 From: ngxson Date: Fri, 24 May 2024 11:46:47 +0200 Subject: [PATCH] calc diff --- .../control-vector-generator.cpp | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/examples/control-vector-generator/control-vector-generator.cpp b/examples/control-vector-generator/control-vector-generator.cpp index 5c64c3b74..2195e28fa 100644 --- a/examples/control-vector-generator/control-vector-generator.cpp +++ b/examples/control-vector-generator/control-vector-generator.cpp @@ -97,6 +97,20 @@ static void padding_seq(llama_context * ctx, std::vector & tokens, } } +static void calc_diff(callback_data & cb_data) { + // TODO: assert cb_data.v_pos.size() == cb_data.v_neg.size() + const size_t n_elems = cb_data.n_embd * cb_data.n_tokens; + for (size_t il = 0; il < cb_data.v_pos.size(); il++) { + auto & inp_pos = cb_data.v_pos[il]; + auto & inp_neg = cb_data.v_neg[il]; + float * dest = (float *) malloc(n_elems * sizeof(float *)); + for (size_t i = 0; i < n_elems; i++) { + dest[i] = inp_pos[i] - inp_neg[i]; + } + cb_data.v_diff.push_back(dest); + } +} + int main(int argc, char ** argv) { callback_data cb_data; std::string prompt_pos = "happy"; @@ -149,6 +163,9 @@ int main(int argc, char ** argv) { printf("%f %f \n", cb_data.v_pos[0][4096], cb_data.v_pos[0][4096]); printf("%f %f \n", cb_data.v_neg[0][4096], cb_data.v_neg[0][4096]); + calc_diff(cb_data); + printf("%f %f \n", cb_data.v_diff[0][4096], cb_data.v_diff[0][4096]); + //llama_print_timings(ctx); llama_free(ctx);