Add backwards-compatibility for older model format

This commit is contained in:
Jakub Horak 2023-03-26 16:23:11 +02:00
parent b391579db9
commit 417bd2d677

View file

@ -314,27 +314,26 @@ static bool llama_model_load(
fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str()); fprintf(stderr, "%s: failed to open '%s'\n", __func__, fname.c_str());
return false; return false;
} }
uint32_t format_version = 0;
// verify magic // verify magic
{ {
uint32_t magic; uint32_t magic;
fin.read((char *) &magic, sizeof(magic)); fin.read((char *) &magic, sizeof(magic));
if (magic == LLAMA_FILE_MAGIC_UNVERSIONED) { if (magic == LLAMA_FILE_MAGIC_UNVERSIONED) {
fprintf(stderr, "%s: invalid model file '%s' (too old, regenerate your model files!)\n", fprintf(stderr, "%s: model '%s' is too old, continuing in compatibility mode with degraded performance\n",
__func__, fname.c_str()); __func__, fname.c_str());
return false; format_version = 0;
} } else if (magic == LLAMA_FILE_MAGIC) {
if (magic != LLAMA_FILE_MAGIC) { fin.read((char *) &format_version, sizeof(format_version));
fprintf(stderr, "%s: invalid model file '%s' (bad magic)\n", __func__, fname.c_str());
return false;
}
uint32_t format_version; if (format_version > LLAMA_FILE_VERSION) {
fin.read((char *) &format_version, sizeof(format_version)); fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ", expected %d)\n",
__func__, fname.c_str(), format_version, LLAMA_FILE_VERSION);
if (format_version != LLAMA_FILE_VERSION) { return false;
fprintf(stderr, "%s: invalid model file '%s' (unsupported format version %" PRIu32 ", expected %d)\n", }
__func__, fname.c_str(), format_version, LLAMA_FILE_VERSION); } else {
fprintf(stderr, "%s: invalid model file '%s' (bad magic %08x)\n", __func__, fname.c_str(), magic);
return false; return false;
} }
} }
@ -417,7 +416,14 @@ static bool llama_model_load(
} }
float score; float score;
fin.read((char *) &score, sizeof(score)); if (format_version == 0) {
// Older version doesn't have embedded token score, use approximation: length^2
// TODO: Maybe read it from tokenizer.model as a fallback?
score = word.length();
score *= score;
} else {
fin.read((char *) &score, sizeof(score));
}
vocab.token_to_id[word] = i; vocab.token_to_id[word] = i;