calc diff

This commit is contained in:
ngxson 2024-05-24 11:46:47 +02:00
parent 0a46d73056
commit c31c118d86

View file

@ -97,6 +97,20 @@ static void padding_seq(llama_context * ctx, std::vector<llama_token> & tokens,
}
}
static void calc_diff(callback_data & cb_data) {
// TODO: assert cb_data.v_pos.size() == cb_data.v_neg.size()
const size_t n_elems = cb_data.n_embd * cb_data.n_tokens;
for (size_t il = 0; il < cb_data.v_pos.size(); il++) {
auto & inp_pos = cb_data.v_pos[il];
auto & inp_neg = cb_data.v_neg[il];
float * dest = (float *) malloc(n_elems * sizeof(float *));
for (size_t i = 0; i < n_elems; i++) {
dest[i] = inp_pos[i] - inp_neg[i];
}
cb_data.v_diff.push_back(dest);
}
}
int main(int argc, char ** argv) {
callback_data cb_data;
std::string prompt_pos = "happy";
@ -149,6 +163,9 @@ int main(int argc, char ** argv) {
printf("%f %f \n", cb_data.v_pos[0][4096], cb_data.v_pos[0][4096]);
printf("%f %f \n", cb_data.v_neg[0][4096], cb_data.v_neg[0][4096]);
calc_diff(cb_data);
printf("%f %f \n", cb_data.v_diff[0][4096], cb_data.v_diff[0][4096]);
//llama_print_timings(ctx);
llama_free(ctx);