fix matrix transpose multiplication
you have got to be kidding me
This commit is contained in:
parent
d446c6d887
commit
31f153fe9c
1 changed files with 1 additions and 1 deletions
|
@ -276,7 +276,7 @@ static float* square_diff(callback_data & cb_data, size_t idx) {
|
|||
for (size_t j = 0; j < cb_data.n_embd; j++) {
|
||||
float sum = 0.0f;
|
||||
for (size_t k = 0; k < cb_data.n_tokens; k++) {
|
||||
sum += cb_data.v_diff[idx][i * cb_data.n_tokens + k] * cb_data.v_diff[idx][j * cb_data.n_tokens + k];
|
||||
sum += cb_data.v_diff[idx][i + cb_data.n_embd * k] * cb_data.v_diff[idx][j + cb_data.n_embd * k];
|
||||
}
|
||||
result[i * cb_data.n_embd + j] = sum;
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue