fix square_diff matmul index range and CRLF->LF line endings

fixed a logic error where square_diff would not multiply all rows

fixed a formatting error where the provided completions.txt had CRLF line endings
This commit is contained in:
Christian Zhou-Zheng 2024-05-31 21:08:25 -04:00
parent 4d88cd1af1
commit 4d7d71bc43

View file

@ -10,6 +10,8 @@
#include <iostream>
#include <fstream>
// TODO read everything over and make sure it makes sense because you're dropping logic errors left and right
struct diff_wrapper {
float * diff; // matrix of size [n_rows, cb_data.n_embd] with zero rows stripped
size_t n_rows; // number of rows in the matrix for size calculation
@ -23,14 +25,14 @@ struct callback_data {
// each element of the vector correspond to one layer
std::vector<float *> v_pos; // vector of matrices of size [n_embd, n_tokens]
std::vector<float *> v_neg; // vector of matrices of size [n_embd, n_tokens]
std::vector<float *> v_diff; // vector of matrices of size [n_embd, n_tokens]
std::vector<float *> v_final; // vector of finished vectors of size [n_embd]
std::vector<diff_wrapper> v_diff; // vector of matrices of size [n_embd, m] where m is some some sum of concatenated matrices
// each element of the outer vector correspond to one layer, each element of the inner vector correspond to one prompt pass
std::vector<std::vector<diff_wrapper>> v_diffs_wrapped; // vector of compiled diff matrices to be concatenated
~callback_data() {
for (auto ptr : v_pos) free(ptr);
for (auto ptr : v_neg) free(ptr);
for (auto ptr : v_diff) free(ptr);
for (auto ptr : v_diff) free(ptr.diff);
for (auto ptr : v_final) free(ptr);
for (auto & vec : v_diffs_wrapped) for (auto ptr : vec) free(ptr.diff);
}
@ -321,7 +323,10 @@ static void concatenate_diffs(callback_data & cb_data) {
memcpy(diff + offset, origin, vec[j].n_rows * cb_data.n_embd * sizeof(float));
offset += vec[j].n_rows * cb_data.n_embd;
}
cb_data.v_diff.push_back(diff);
struct diff_wrapper dw;
dw.n_rows = n_rows_total;
dw.diff = diff;
cb_data.v_diff.push_back(dw);
}
}
@ -335,8 +340,8 @@ static float* square_diff(callback_data & cb_data, size_t idx) {
for (size_t i = 0; i < cb_data.n_embd; i++) {
for (size_t j = 0; j < cb_data.n_embd; j++) {
float sum = 0.0f;
for (size_t k = 0; k < cb_data.n_tokens; k++) {
sum += cb_data.v_diff[idx][i + cb_data.n_embd * k] * cb_data.v_diff[idx][j + cb_data.n_embd * k];
for (size_t k = 0; k < cb_data.v_diff[idx].n_rows; k++) {
sum += cb_data.v_diff[idx].diff[i + cb_data.n_embd * k] * cb_data.v_diff[idx].diff[j + cb_data.n_embd * k];
}
result[i * cb_data.n_embd + j] = sum;
}