duo: v5
This commit is contained in:
parent
7c8699add6
commit
de26d49fbe
1 changed files with 21 additions and 20 deletions
|
@ -82,7 +82,8 @@ static int decode(llama_context * ctx, iter_t from, iter_t to, int offset, bool
|
||||||
}
|
}
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
int res = 0;
|
int res = 0;
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode(ctx, batch) != 0)
|
||||||
|
{
|
||||||
fprintf(stderr, "llama_decode() failed\n");
|
fprintf(stderr, "llama_decode() failed\n");
|
||||||
res = 1;
|
res = 1;
|
||||||
}
|
}
|
||||||
|
@ -126,7 +127,8 @@ static int speculation(
|
||||||
}
|
}
|
||||||
|
|
||||||
auto next_tokens = greedy_tokens(model, ctx, logit_idx, logit_idx + 1);
|
auto next_tokens = greedy_tokens(model, ctx, logit_idx, logit_idx + 1);
|
||||||
if (next_tokens.size() != 1) {
|
if (next_tokens.size() != 1)
|
||||||
|
{
|
||||||
fprintf(stderr, "invalid next tokens\n");
|
fprintf(stderr, "invalid next tokens\n");
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
@ -157,9 +159,7 @@ static int speculation(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
decode(ctx, local.begin() + match_len, local.end(), match_len, false, batch);
|
decode(ctx, local.begin() + match_len, local.end(), match_len, false, batch);
|
||||||
|
|
||||||
logit_idx = local.size() - match_len - 1;
|
logit_idx = local.size() - match_len - 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -179,9 +179,8 @@ static int target(
|
||||||
llama_batch batch = llama_batch_init(512, 0, 1);
|
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||||
decode(ctx, input.begin(), input.end(), 0, false, batch);
|
decode(ctx, input.begin(), input.end(), 0, false, batch);
|
||||||
|
|
||||||
// TODO: rename to n_accepted
|
size_t n_accepted = input.size();
|
||||||
size_t n_cur = input.size();
|
size_t n_decoded = 0;
|
||||||
size_t n_decode = 0;
|
|
||||||
|
|
||||||
const auto t_main_start = ggml_time_us();
|
const auto t_main_start = ggml_time_us();
|
||||||
|
|
||||||
|
@ -192,7 +191,7 @@ static int target(
|
||||||
llama_tokens input_seq, next_tokens;
|
llama_tokens input_seq, next_tokens;
|
||||||
input_seq.push_back(input.back());
|
input_seq.push_back(input.back());
|
||||||
|
|
||||||
while (n_decode <= n_predict)
|
while (n_decoded < n_predict)
|
||||||
{
|
{
|
||||||
next_tokens = greedy_tokens(model, ctx, logits_from, logits_to);
|
next_tokens = greedy_tokens(model, ctx, logits_from, logits_to);
|
||||||
if (next_tokens.size() != input_seq.size())
|
if (next_tokens.size() != input_seq.size())
|
||||||
|
@ -201,16 +200,16 @@ static int target(
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t next_tokens_pos = n_cur;
|
size_t next_tokens_pos = n_accepted;
|
||||||
// we always accept at least one new token
|
// we always accept at least one new token
|
||||||
n_cur += 1;
|
n_accepted += 1;
|
||||||
n_decode += 1;
|
n_decoded += 1;
|
||||||
for (size_t i = 0; i + 1 < input_seq.size(); i++)
|
for (size_t i = 0; i + 1 < input_seq.size(); i++)
|
||||||
{
|
{
|
||||||
if (next_tokens[i] == input_seq[i + 1])
|
if (next_tokens[i] == input_seq[i + 1])
|
||||||
{
|
{
|
||||||
n_cur += 1;
|
n_accepted += 1;
|
||||||
n_decode += 1;
|
n_decoded += 1;
|
||||||
}
|
}
|
||||||
else
|
else
|
||||||
{
|
{
|
||||||
|
@ -222,7 +221,7 @@ static int target(
|
||||||
|
|
||||||
// empty the non-matching portion of kv cache.
|
// empty the non-matching portion of kv cache.
|
||||||
// n_cur is incremented at least once and will be > 0
|
// n_cur is incremented at least once and will be > 0
|
||||||
llama_kv_cache_seq_rm(ctx, 0, n_cur - 1, -1);
|
llama_kv_cache_seq_rm(ctx, 0, n_accepted - 1, -1);
|
||||||
|
|
||||||
bool done = false;
|
bool done = false;
|
||||||
for (size_t i = 0; i < next_tokens.size(); i++)
|
for (size_t i = 0; i < next_tokens.size(); i++)
|
||||||
|
@ -263,14 +262,14 @@ static int target(
|
||||||
spec.push_back(tok);
|
spec.push_back(tok);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
input_seq.assign(spec.begin() + n_cur - 1, spec.end());
|
input_seq.assign(spec.begin() + n_accepted - 1, spec.end());
|
||||||
}
|
}
|
||||||
if (n_decode >= n_predict || done)
|
if (n_decoded >= n_predict || done)
|
||||||
{
|
{
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
decode(ctx, input_seq.begin(), input_seq.end(), n_cur - 1, true, batch);
|
decode(ctx, input_seq.begin(), input_seq.end(), n_accepted - 1, true, batch);
|
||||||
|
|
||||||
logits_from = 0;
|
logits_from = 0;
|
||||||
logits_to = input_seq.size();
|
logits_to = input_seq.size();
|
||||||
|
@ -279,7 +278,7 @@ static int target(
|
||||||
const auto t_main_end = ggml_time_us();
|
const auto t_main_end = ggml_time_us();
|
||||||
|
|
||||||
fprintf(stderr, "decoded %zu tokens in %.2f s, speed: %.2f t/s\n",
|
fprintf(stderr, "decoded %zu tokens in %.2f s, speed: %.2f t/s\n",
|
||||||
n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
n_decoded, (t_main_end - t_main_start) / 1000000.0f, n_decoded / ((t_main_end - t_main_start) / 1000000.0f));
|
||||||
|
|
||||||
llama_print_timings(ctx);
|
llama_print_timings(ctx);
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
@ -295,11 +294,13 @@ static int target(
|
||||||
int main(int argc, char ** argv) {
|
int main(int argc, char ** argv) {
|
||||||
gpt_params params;
|
gpt_params params;
|
||||||
|
|
||||||
if (gpt_params_parse(argc, argv, params) == false) {
|
if (gpt_params_parse(argc, argv, params) == false)
|
||||||
|
{
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
if (params.seed == LLAMA_DEFAULT_SEED)
|
||||||
|
{
|
||||||
params.seed = time(NULL);
|
params.seed = time(NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue