imatrix : allow loading mis-ordered tensors

Sums and counts tensors no longer need to be consecutive.

* imatrix : more sanity checks when loading multiple imatrix files

* imatrix : use ggml_format_name instead of std::string concatenation

Co-authored-by: Xuan Son Nguyen <son@huggingface.co>
This commit is contained in:
Francis Couture-Harpin 2024-09-10 11:31:49 -04:00
parent 2217247051
commit 8c13e16bb0
2 changed files with 87 additions and 37 deletions

View file

@ -6,6 +6,7 @@
#include <vector>
#include <string>
#include <unordered_map>
#include <map>
struct quant_option {
std::string name;
@ -125,6 +126,15 @@ static void usage(const char * executable) {
exit(1);
}
// TODO: share with implementation in imatrix.cpp
static bool str_remove_suffix(std::string & str, const std::string & suffix) {
bool has_suffix = str.size() >= suffix.size() && str.compare(str.size() - suffix.size(), str.size(), suffix) == 0;
if (has_suffix) {
str = str.substr(0, str.size() - suffix.size());
}
return has_suffix;
}
static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_dataset, std::unordered_map<std::string, std::vector<float>> & imatrix_data) {
struct ggml_context * ctx = nullptr;
@ -138,7 +148,7 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_
exit(1);
}
const int32_t n_entries = gguf_get_n_tensors(ctx_gguf);
if (n_entries < 2) {
if (n_entries < 1) {
fprintf(stderr, "%s: no data in file %s\n", __func__, imatrix_file.c_str());
gguf_free(ctx_gguf);
ggml_free(ctx);
@ -160,26 +170,35 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_
const std::string sums_suffix{".sums"};
const std::string counts_suffix{".counts"};
// TODO: allow loading from mis-ordered imatrix files
for (int32_t i = 0; i < n_entries - 1; i += 2) {
std::string sums_name{gguf_get_tensor_name(ctx_gguf, i + 0)};
std::string counts_name{gguf_get_tensor_name(ctx_gguf, i + 1)};
// Using an ordered map to get a deterministic iteration order.
std::map<std::string, std::pair<struct ggml_tensor *, struct ggml_tensor *>> sums_counts_for;
if (sums_name.size() < sums_suffix.size() ||
counts_name.size() < counts_suffix.size() ||
!std::equal(sums_name.begin(), sums_name.end() - sums_suffix.size(), counts_name.begin()) ||
!std::equal(sums_suffix.rbegin(), sums_suffix.rend(), sums_name.rbegin()) ||
!std::equal(counts_suffix.rbegin(), counts_suffix.rend(), counts_name.rbegin())) {
fprintf(stderr, "%s: mismatched sums and counts for entry %d\n", __func__, i / 2);
for (struct ggml_tensor * cur = ggml_get_first_tensor(ctx); cur; cur = ggml_get_next_tensor(ctx, cur)) {
std::string name = cur->name;
if (name.empty()) { continue; }
if (str_remove_suffix(name, sums_suffix)) {
// sums
sums_counts_for[name].first = cur;
} else if (str_remove_suffix(name, counts_suffix)) {
// counts
sums_counts_for[name].second = cur;
} else {
fprintf(stderr, "%s: invalid imatrix tensor name: %s\n", __func__, name.c_str());
gguf_free(ctx_gguf);
ggml_free(ctx);
exit(1);
}
}
for (const auto & sc : sums_counts_for) {
const std::string & name = sc.first;
const struct ggml_tensor * sums = sc.second.first;
const struct ggml_tensor * counts = sc.second.second;
struct ggml_tensor * sums = ggml_get_tensor(ctx, sums_name.c_str());
struct ggml_tensor * counts = ggml_get_tensor(ctx, counts_name.c_str());
if (!sums || !counts) {
fprintf(stderr, "%s: failed reading data for entry %d\n", __func__, i / 2);
fprintf(stderr, "%s: mismatched sums and counts for %s\n", __func__, name.c_str());
gguf_free(ctx_gguf);
ggml_free(ctx);
exit(1);
@ -187,7 +206,7 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_
const int64_t ne0 = sums->ne[0];
const int64_t ne1 = sums->ne[1];
std::string name = sums_name.substr(0, sums_name.size() - sums_suffix.size());
auto & e = imatrix_data[name];
e.resize(ggml_nelements(sums));
float max_count = 0.0f;
@ -201,7 +220,7 @@ static int load_imatrix(const std::string & imatrix_file, std::string & imatrix_
}
}
if (getenv("LLAMA_TRACE")) {
printf("%s: loaded data (size = %6d, ncall = %6d) for '%s'\n", __func__, int(e.size()), int(max_count / chunk_size), name.c_str());
printf("%s: loaded data (size = %6d, n_tokens = %6d) for '%s'\n", __func__, int(e.size()), int(max_count), name.c_str());
}
}