fix matrix transpose multiplication

you have got to be kidding me
This commit is contained in:
Christian Zhou-Zheng 2024-05-30 21:36:17 -04:00
parent d446c6d887
commit 31f153fe9c

View file

@ -276,7 +276,7 @@ static float* square_diff(callback_data & cb_data, size_t idx) {
for (size_t j = 0; j < cb_data.n_embd; j++) {
float sum = 0.0f;
for (size_t k = 0; k < cb_data.n_tokens; k++) {
sum += cb_data.v_diff[idx][i * cb_data.n_tokens + k] * cb_data.v_diff[idx][j * cb_data.n_tokens + k];
sum += cb_data.v_diff[idx][i + cb_data.n_embd * k] * cb_data.v_diff[idx][j + cb_data.n_embd * k];
}
result[i * cb_data.n_embd + j] = sum;
}