calc diff
This commit is contained in:
parent
0a46d73056
commit
c31c118d86
1 changed files with 17 additions and 0 deletions
|
@ -97,6 +97,20 @@ static void padding_seq(llama_context * ctx, std::vector<llama_token> & tokens,
|
|||
}
|
||||
}
|
||||
|
||||
static void calc_diff(callback_data & cb_data) {
|
||||
// TODO: assert cb_data.v_pos.size() == cb_data.v_neg.size()
|
||||
const size_t n_elems = cb_data.n_embd * cb_data.n_tokens;
|
||||
for (size_t il = 0; il < cb_data.v_pos.size(); il++) {
|
||||
auto & inp_pos = cb_data.v_pos[il];
|
||||
auto & inp_neg = cb_data.v_neg[il];
|
||||
float * dest = (float *) malloc(n_elems * sizeof(float *));
|
||||
for (size_t i = 0; i < n_elems; i++) {
|
||||
dest[i] = inp_pos[i] - inp_neg[i];
|
||||
}
|
||||
cb_data.v_diff.push_back(dest);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
callback_data cb_data;
|
||||
std::string prompt_pos = "happy";
|
||||
|
@ -149,6 +163,9 @@ int main(int argc, char ** argv) {
|
|||
printf("%f %f \n", cb_data.v_pos[0][4096], cb_data.v_pos[0][4096]);
|
||||
printf("%f %f \n", cb_data.v_neg[0][4096], cb_data.v_neg[0][4096]);
|
||||
|
||||
calc_diff(cb_data);
|
||||
printf("%f %f \n", cb_data.v_diff[0][4096], cb_data.v_diff[0][4096]);
|
||||
|
||||
//llama_print_timings(ctx);
|
||||
|
||||
llama_free(ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue