duo: v5
This commit is contained in:
parent
7c8699add6
commit
de26d49fbe
1 changed files with 21 additions and 20 deletions
|
@ -82,7 +82,8 @@ static int decode(llama_context * ctx, iter_t from, iter_t to, int offset, bool
|
|||
}
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
int res = 0;
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
if (llama_decode(ctx, batch) != 0)
|
||||
{
|
||||
fprintf(stderr, "llama_decode() failed\n");
|
||||
res = 1;
|
||||
}
|
||||
|
@ -126,7 +127,8 @@ static int speculation(
|
|||
}
|
||||
|
||||
auto next_tokens = greedy_tokens(model, ctx, logit_idx, logit_idx + 1);
|
||||
if (next_tokens.size() != 1) {
|
||||
if (next_tokens.size() != 1)
|
||||
{
|
||||
fprintf(stderr, "invalid next tokens\n");
|
||||
return 1;
|
||||
}
|
||||
|
@ -157,9 +159,7 @@ static int speculation(
|
|||
}
|
||||
}
|
||||
|
||||
|
||||
decode(ctx, local.begin() + match_len, local.end(), match_len, false, batch);
|
||||
|
||||
logit_idx = local.size() - match_len - 1;
|
||||
}
|
||||
|
||||
|
@ -179,9 +179,8 @@ static int target(
|
|||
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||
decode(ctx, input.begin(), input.end(), 0, false, batch);
|
||||
|
||||
// TODO: rename to n_accepted
|
||||
size_t n_cur = input.size();
|
||||
size_t n_decode = 0;
|
||||
size_t n_accepted = input.size();
|
||||
size_t n_decoded = 0;
|
||||
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
|
@ -192,7 +191,7 @@ static int target(
|
|||
llama_tokens input_seq, next_tokens;
|
||||
input_seq.push_back(input.back());
|
||||
|
||||
while (n_decode <= n_predict)
|
||||
while (n_decoded < n_predict)
|
||||
{
|
||||
next_tokens = greedy_tokens(model, ctx, logits_from, logits_to);
|
||||
if (next_tokens.size() != input_seq.size())
|
||||
|
@ -201,16 +200,16 @@ static int target(
|
|||
return 1;
|
||||
}
|
||||
|
||||
size_t next_tokens_pos = n_cur;
|
||||
size_t next_tokens_pos = n_accepted;
|
||||
// we always accept at least one new token
|
||||
n_cur += 1;
|
||||
n_decode += 1;
|
||||
n_accepted += 1;
|
||||
n_decoded += 1;
|
||||
for (size_t i = 0; i + 1 < input_seq.size(); i++)
|
||||
{
|
||||
if (next_tokens[i] == input_seq[i + 1])
|
||||
{
|
||||
n_cur += 1;
|
||||
n_decode += 1;
|
||||
n_accepted += 1;
|
||||
n_decoded += 1;
|
||||
}
|
||||
else
|
||||
{
|
||||
|
@ -222,7 +221,7 @@ static int target(
|
|||
|
||||
// empty the non-matching portion of kv cache.
|
||||
// n_cur is incremented at least once and will be > 0
|
||||
llama_kv_cache_seq_rm(ctx, 0, n_cur - 1, -1);
|
||||
llama_kv_cache_seq_rm(ctx, 0, n_accepted - 1, -1);
|
||||
|
||||
bool done = false;
|
||||
for (size_t i = 0; i < next_tokens.size(); i++)
|
||||
|
@ -263,14 +262,14 @@ static int target(
|
|||
spec.push_back(tok);
|
||||
}
|
||||
}
|
||||
input_seq.assign(spec.begin() + n_cur - 1, spec.end());
|
||||
input_seq.assign(spec.begin() + n_accepted - 1, spec.end());
|
||||
}
|
||||
if (n_decode >= n_predict || done)
|
||||
if (n_decoded >= n_predict || done)
|
||||
{
|
||||
break;
|
||||
}
|
||||
|
||||
decode(ctx, input_seq.begin(), input_seq.end(), n_cur - 1, true, batch);
|
||||
decode(ctx, input_seq.begin(), input_seq.end(), n_accepted - 1, true, batch);
|
||||
|
||||
logits_from = 0;
|
||||
logits_to = input_seq.size();
|
||||
|
@ -279,7 +278,7 @@ static int target(
|
|||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
fprintf(stderr, "decoded %zu tokens in %.2f s, speed: %.2f t/s\n",
|
||||
n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
n_decoded, (t_main_end - t_main_start) / 1000000.0f, n_decoded / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
llama_print_timings(ctx);
|
||||
fprintf(stderr, "\n");
|
||||
|
@ -295,11 +294,13 @@ static int target(
|
|||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
if (gpt_params_parse(argc, argv, params) == false) {
|
||||
if (gpt_params_parse(argc, argv, params) == false)
|
||||
{
|
||||
return 1;
|
||||
}
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
if (params.seed == LLAMA_DEFAULT_SEED)
|
||||
{
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue