This commit is contained in:
Oleksandr Kuvshynov 2024-05-25 22:19:23 -04:00
parent 7c8699add6
commit de26d49fbe

View file

@ -82,7 +82,8 @@ static int decode(llama_context * ctx, iter_t from, iter_t to, int offset, bool
} }
batch.logits[batch.n_tokens - 1] = true; batch.logits[batch.n_tokens - 1] = true;
int res = 0; int res = 0;
if (llama_decode(ctx, batch) != 0) { if (llama_decode(ctx, batch) != 0)
{
fprintf(stderr, "llama_decode() failed\n"); fprintf(stderr, "llama_decode() failed\n");
res = 1; res = 1;
} }
@ -126,7 +127,8 @@ static int speculation(
} }
auto next_tokens = greedy_tokens(model, ctx, logit_idx, logit_idx + 1); auto next_tokens = greedy_tokens(model, ctx, logit_idx, logit_idx + 1);
if (next_tokens.size() != 1) { if (next_tokens.size() != 1)
{
fprintf(stderr, "invalid next tokens\n"); fprintf(stderr, "invalid next tokens\n");
return 1; return 1;
} }
@ -157,9 +159,7 @@ static int speculation(
} }
} }
decode(ctx, local.begin() + match_len, local.end(), match_len, false, batch); decode(ctx, local.begin() + match_len, local.end(), match_len, false, batch);
logit_idx = local.size() - match_len - 1; logit_idx = local.size() - match_len - 1;
} }
@ -179,9 +179,8 @@ static int target(
llama_batch batch = llama_batch_init(512, 0, 1); llama_batch batch = llama_batch_init(512, 0, 1);
decode(ctx, input.begin(), input.end(), 0, false, batch); decode(ctx, input.begin(), input.end(), 0, false, batch);
// TODO: rename to n_accepted size_t n_accepted = input.size();
size_t n_cur = input.size(); size_t n_decoded = 0;
size_t n_decode = 0;
const auto t_main_start = ggml_time_us(); const auto t_main_start = ggml_time_us();
@ -192,7 +191,7 @@ static int target(
llama_tokens input_seq, next_tokens; llama_tokens input_seq, next_tokens;
input_seq.push_back(input.back()); input_seq.push_back(input.back());
while (n_decode <= n_predict) while (n_decoded < n_predict)
{ {
next_tokens = greedy_tokens(model, ctx, logits_from, logits_to); next_tokens = greedy_tokens(model, ctx, logits_from, logits_to);
if (next_tokens.size() != input_seq.size()) if (next_tokens.size() != input_seq.size())
@ -201,16 +200,16 @@ static int target(
return 1; return 1;
} }
size_t next_tokens_pos = n_cur; size_t next_tokens_pos = n_accepted;
// we always accept at least one new token // we always accept at least one new token
n_cur += 1; n_accepted += 1;
n_decode += 1; n_decoded += 1;
for (size_t i = 0; i + 1 < input_seq.size(); i++) for (size_t i = 0; i + 1 < input_seq.size(); i++)
{ {
if (next_tokens[i] == input_seq[i + 1]) if (next_tokens[i] == input_seq[i + 1])
{ {
n_cur += 1; n_accepted += 1;
n_decode += 1; n_decoded += 1;
} }
else else
{ {
@ -222,7 +221,7 @@ static int target(
// empty the non-matching portion of kv cache. // empty the non-matching portion of kv cache.
// n_cur is incremented at least once and will be > 0 // n_cur is incremented at least once and will be > 0
llama_kv_cache_seq_rm(ctx, 0, n_cur - 1, -1); llama_kv_cache_seq_rm(ctx, 0, n_accepted - 1, -1);
bool done = false; bool done = false;
for (size_t i = 0; i < next_tokens.size(); i++) for (size_t i = 0; i < next_tokens.size(); i++)
@ -263,14 +262,14 @@ static int target(
spec.push_back(tok); spec.push_back(tok);
} }
} }
input_seq.assign(spec.begin() + n_cur - 1, spec.end()); input_seq.assign(spec.begin() + n_accepted - 1, spec.end());
} }
if (n_decode >= n_predict || done) if (n_decoded >= n_predict || done)
{ {
break; break;
} }
decode(ctx, input_seq.begin(), input_seq.end(), n_cur - 1, true, batch); decode(ctx, input_seq.begin(), input_seq.end(), n_accepted - 1, true, batch);
logits_from = 0; logits_from = 0;
logits_to = input_seq.size(); logits_to = input_seq.size();
@ -279,7 +278,7 @@ static int target(
const auto t_main_end = ggml_time_us(); const auto t_main_end = ggml_time_us();
fprintf(stderr, "decoded %zu tokens in %.2f s, speed: %.2f t/s\n", fprintf(stderr, "decoded %zu tokens in %.2f s, speed: %.2f t/s\n",
n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); n_decoded, (t_main_end - t_main_start) / 1000000.0f, n_decoded / ((t_main_end - t_main_start) / 1000000.0f));
llama_print_timings(ctx); llama_print_timings(ctx);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
@ -295,11 +294,13 @@ static int target(
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
if (gpt_params_parse(argc, argv, params) == false) { if (gpt_params_parse(argc, argv, params) == false)
{
return 1; return 1;
} }
if (params.seed == LLAMA_DEFAULT_SEED) { if (params.seed == LLAMA_DEFAULT_SEED)
{
params.seed = time(NULL); params.seed = time(NULL);
} }