diff --git a/examples/control-vector-generator/control-vector-generator.cpp b/examples/control-vector-generator/control-vector-generator.cpp index 8d4983c88..2541fcb27 100644 --- a/examples/control-vector-generator/control-vector-generator.cpp +++ b/examples/control-vector-generator/control-vector-generator.cpp @@ -10,6 +10,8 @@ #include #include +// TODO read everything over and make sure it makes sense because you're dropping logic errors left and right + struct diff_wrapper { float * diff; // matrix of size [n_rows, cb_data.n_embd] with zero rows stripped size_t n_rows; // number of rows in the matrix for size calculation @@ -23,14 +25,14 @@ struct callback_data { // each element of the vector correspond to one layer std::vector v_pos; // vector of matrices of size [n_embd, n_tokens] std::vector v_neg; // vector of matrices of size [n_embd, n_tokens] - std::vector v_diff; // vector of matrices of size [n_embd, n_tokens] std::vector v_final; // vector of finished vectors of size [n_embd] + std::vector v_diff; // vector of matrices of size [n_embd, m] where m is some some sum of concatenated matrices // each element of the outer vector correspond to one layer, each element of the inner vector correspond to one prompt pass std::vector> v_diffs_wrapped; // vector of compiled diff matrices to be concatenated ~callback_data() { for (auto ptr : v_pos) free(ptr); for (auto ptr : v_neg) free(ptr); - for (auto ptr : v_diff) free(ptr); + for (auto ptr : v_diff) free(ptr.diff); for (auto ptr : v_final) free(ptr); for (auto & vec : v_diffs_wrapped) for (auto ptr : vec) free(ptr.diff); } @@ -321,7 +323,10 @@ static void concatenate_diffs(callback_data & cb_data) { memcpy(diff + offset, origin, vec[j].n_rows * cb_data.n_embd * sizeof(float)); offset += vec[j].n_rows * cb_data.n_embd; } - cb_data.v_diff.push_back(diff); + struct diff_wrapper dw; + dw.n_rows = n_rows_total; + dw.diff = diff; + cb_data.v_diff.push_back(dw); } } @@ -335,8 +340,8 @@ static float* square_diff(callback_data & cb_data, size_t idx) { for (size_t i = 0; i < cb_data.n_embd; i++) { for (size_t j = 0; j < cb_data.n_embd; j++) { float sum = 0.0f; - for (size_t k = 0; k < cb_data.n_tokens; k++) { - sum += cb_data.v_diff[idx][i + cb_data.n_embd * k] * cb_data.v_diff[idx][j + cb_data.n_embd * k]; + for (size_t k = 0; k < cb_data.v_diff[idx].n_rows; k++) { + sum += cb_data.v_diff[idx].diff[i + cb_data.n_embd * k] * cb_data.v_diff[idx].diff[j + cb_data.n_embd * k]; } result[i * cb_data.n_embd + j] = sum; }