fix save-load-state example
This commit is contained in:
parent
7264596a5c
commit
6395174a54
1 changed files with 25 additions and 5 deletions
|
@ -48,9 +48,16 @@ int main(int argc, char ** argv) {
|
||||||
// tokenize prompt
|
// tokenize prompt
|
||||||
auto tokens = common_tokenize(ctx, params.prompt, true);
|
auto tokens = common_tokenize(ctx, params.prompt, true);
|
||||||
|
|
||||||
|
// prepare the batch
|
||||||
|
llama_batch batch = llama_batch_init(tokens.size(), 0, 1);
|
||||||
|
for (size_t i = 0; i < tokens.size(); i++) {
|
||||||
|
common_batch_add(batch, tokens[i], i, {0}, false);
|
||||||
|
}
|
||||||
|
batch.logits[batch.n_tokens - 1] = true; // generate next token
|
||||||
|
|
||||||
// evaluate prompt
|
// evaluate prompt
|
||||||
llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()));
|
llama_decode(ctx, batch);
|
||||||
n_past += tokens.size();
|
n_past += batch.n_tokens;
|
||||||
|
|
||||||
// save state (rng, logits, embedding and kv_cache) to file
|
// save state (rng, logits, embedding and kv_cache) to file
|
||||||
{
|
{
|
||||||
|
@ -77,8 +84,12 @@ int main(int argc, char ** argv) {
|
||||||
printf("%s", next_token_str.c_str());
|
printf("%s", next_token_str.c_str());
|
||||||
result0 += next_token_str;
|
result0 += next_token_str;
|
||||||
|
|
||||||
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1))) {
|
common_batch_clear(batch);
|
||||||
|
common_batch_add(batch, next_token, n_past, {0}, true);
|
||||||
|
|
||||||
|
if (llama_decode(ctx, batch)) {
|
||||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||||
|
llama_batch_free(batch);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -133,8 +144,12 @@ int main(int argc, char ** argv) {
|
||||||
printf("%s", next_token_str.c_str());
|
printf("%s", next_token_str.c_str());
|
||||||
result1 += next_token_str;
|
result1 += next_token_str;
|
||||||
|
|
||||||
if (llama_decode(ctx2, llama_batch_get_one(&next_token, 1))) {
|
common_batch_clear(batch);
|
||||||
|
common_batch_add(batch, next_token, n_past, {0}, true);
|
||||||
|
|
||||||
|
if (llama_decode(ctx2, batch)) {
|
||||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||||
|
llama_batch_free(batch);
|
||||||
llama_free(ctx2);
|
llama_free(ctx2);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -221,8 +236,12 @@ int main(int argc, char ** argv) {
|
||||||
printf("%s", next_token_str.c_str());
|
printf("%s", next_token_str.c_str());
|
||||||
result2 += next_token_str;
|
result2 += next_token_str;
|
||||||
|
|
||||||
if (llama_decode(ctx3, llama_batch_get_one(&next_token, 1))) {
|
common_batch_clear(batch);
|
||||||
|
common_batch_add(batch, next_token, n_past, {1}, true);
|
||||||
|
|
||||||
|
if (llama_decode(ctx3, batch)) {
|
||||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||||
|
llama_batch_free(batch);
|
||||||
llama_free(ctx3);
|
llama_free(ctx3);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
return 1;
|
return 1;
|
||||||
|
@ -236,6 +255,7 @@ int main(int argc, char ** argv) {
|
||||||
llama_sampler_free(smpl2);
|
llama_sampler_free(smpl2);
|
||||||
llama_sampler_free(smpl3);
|
llama_sampler_free(smpl3);
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
llama_free(ctx3);
|
llama_free(ctx3);
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue