llama : make model stateless and context stateful (llama_state) (#1797)

* llama : make model stateless and context stateful

* llama : minor cleanup

* llama : update internal API declaration

* Apply suggestions from code review

fix style

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>

* Missing model memory release

* Fix style

* Add deprecated warning for public API function llama_init_from_file

* Update public API use cases: move away from deprecated llama_init_from_file

* Deprecate public API function llama_apply_lora_from_file

---------

Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
Didzis Gosko 2023-06-24 11:47:58 +03:00 committed by GitHub
parent d7b7484f74
commit 527b6fba1d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 244 additions and 92 deletions

View file

@ -35,12 +35,22 @@ int main(int argc, char ** argv) {
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
// init
auto ctx = llama_init_from_file(params.model.c_str(), lparams);
auto model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == nullptr) {
return 1;
}
auto ctx = llama_new_context_with_model(model, lparams);
if (ctx == nullptr) {
llama_free_model(model);
return 1;
}
auto tokens = std::vector<llama_token>(params.n_ctx);
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true);
if (n_prompt_tokens < 1) {
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);
llama_free(ctx);
llama_free_model(model);
return 1;
}
@ -84,6 +94,8 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str);
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx);
llama_free_model(model);
return 1;
}
n_past += 1;
@ -91,23 +103,27 @@ int main(int argc, char ** argv) {
printf("\n\n");
// free old model
// free old context
llama_free(ctx);
// load new model
auto ctx2 = llama_init_from_file(params.model.c_str(), lparams);
// make new context
auto ctx2 = llama_new_context_with_model(model, lparams);
// Load state (rng, logits, embedding and kv_cache) from file
{
FILE *fp_read = fopen("dump_state.bin", "rb");
if (state_size != llama_get_state_size(ctx2)) {
fprintf(stderr, "\n%s : failed to validate state size\n", __func__);
llama_free(ctx2);
llama_free_model(model);
return 1;
}
const size_t ret = fread(state_mem, 1, state_size, fp_read);
if (ret != state_size) {
fprintf(stderr, "\n%s : failed to read state\n", __func__);
llama_free(ctx2);
llama_free_model(model);
return 1;
}
@ -138,6 +154,8 @@ int main(int argc, char ** argv) {
printf("%s", next_token_str);
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx2);
llama_free_model(model);
return 1;
}
n_past += 1;
@ -145,5 +163,8 @@ int main(int argc, char ** argv) {
printf("\n\n");
llama_free(ctx2);
llama_free_model(model);
return 0;
}