calc diff
This commit is contained in:
parent
0a46d73056
commit
c31c118d86
1 changed files with 17 additions and 0 deletions
|
@ -97,6 +97,20 @@ static void padding_seq(llama_context * ctx, std::vector<llama_token> & tokens,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void calc_diff(callback_data & cb_data) {
|
||||||
|
// TODO: assert cb_data.v_pos.size() == cb_data.v_neg.size()
|
||||||
|
const size_t n_elems = cb_data.n_embd * cb_data.n_tokens;
|
||||||
|
for (size_t il = 0; il < cb_data.v_pos.size(); il++) {
|
||||||
|
auto & inp_pos = cb_data.v_pos[il];
|
||||||
|
auto & inp_neg = cb_data.v_neg[il];
|
||||||
|
float * dest = (float *) malloc(n_elems * sizeof(float *));
|
||||||
|
for (size_t i = 0; i < n_elems; i++) {
|
||||||
|
dest[i] = inp_pos[i] - inp_neg[i];
|
||||||
|
}
|
||||||
|
cb_data.v_diff.push_back(dest);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
callback_data cb_data;
|
callback_data cb_data;
|
||||||
std::string prompt_pos = "happy";
|
std::string prompt_pos = "happy";
|
||||||
|
@ -149,6 +163,9 @@ int main(int argc, char ** argv) {
|
||||||
printf("%f %f \n", cb_data.v_pos[0][4096], cb_data.v_pos[0][4096]);
|
printf("%f %f \n", cb_data.v_pos[0][4096], cb_data.v_pos[0][4096]);
|
||||||
printf("%f %f \n", cb_data.v_neg[0][4096], cb_data.v_neg[0][4096]);
|
printf("%f %f \n", cb_data.v_neg[0][4096], cb_data.v_neg[0][4096]);
|
||||||
|
|
||||||
|
calc_diff(cb_data);
|
||||||
|
printf("%f %f \n", cb_data.v_diff[0][4096], cb_data.v_diff[0][4096]);
|
||||||
|
|
||||||
//llama_print_timings(ctx);
|
//llama_print_timings(ctx);
|
||||||
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue