get vocabulary for exporting training checkpoint to llama compatible model file

This commit is contained in:
xaedes 2023-05-29 02:25:18 +02:00
parent 4b81c32d5b
commit 56895e28f6
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1943,6 +1943,25 @@ int main(int argc, char ** argv) {
struct llama_context * lctx = llama_init_from_file(fn_model, llama_params); struct llama_context * lctx = llama_init_from_file(fn_model, llama_params);
struct llama_vocab vocab;
{
std::vector<const char *> strings;
std::vector<float> scores;
int n_vocab = llama_n_vocab(lctx);
strings.resize(n_vocab, NULL);
scores.resize(n_vocab, 0);
n_vocab = llama_get_vocab(lctx, strings.data(), scores.data(), n_vocab);
GGML_ASSERT(n_vocab == llama_n_vocab(lctx));
vocab.id_to_token.resize(n_vocab);
for (int i=0; i<n_vocab; ++i) {
std::string tok = std::string(strings[i]);
float score = scores[i];
vocab.id_to_token[i].tok = tok;
vocab.id_to_token[i].score = score;
vocab.token_to_id.emplace(tok, i);
}
}
printf("%s: tokenize training data\n", __func__); printf("%s: tokenize training data\n", __func__);
std::vector<llama_token> train_tokens; std::vector<llama_token> train_tokens;
if (tokenize_file(lctx, fn_train, train_tokens) < 0) { if (tokenize_file(lctx, fn_train, train_tokens) < 0) {