duo: v2
This commit is contained in:
parent
78938bc0c9
commit
96811fdf63
1 changed files with 29 additions and 48 deletions
|
@ -88,6 +88,24 @@ static llama_tokens greedy_tokens(llama_model * model, llama_context * ctx, int3
|
||||||
return res;
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename iter_t>
|
||||||
|
static int decode(llama_context * ctx, iter_t from, iter_t to, int offset, bool all_logits, llama_batch & batch)
|
||||||
|
{
|
||||||
|
llama_batch_clear(batch);
|
||||||
|
size_t i = offset;
|
||||||
|
for (auto it = from; it != to; ++it)
|
||||||
|
{
|
||||||
|
llama_batch_add(batch, *it, i++, { 0 }, all_logits);
|
||||||
|
}
|
||||||
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
int res = 0;
|
||||||
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
|
fprintf(stderr, "llama_decode() failed\n");
|
||||||
|
res = 1;
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
static int speculation(
|
static int speculation(
|
||||||
llama_model * model,
|
llama_model * model,
|
||||||
speculation_context * spec_ctx,
|
speculation_context * spec_ctx,
|
||||||
|
@ -96,20 +114,9 @@ static int speculation(
|
||||||
|
|
||||||
// TODO: check that input is non-empty
|
// TODO: check that input is non-empty
|
||||||
llama_batch batch = llama_batch_init(512, 0, 1);
|
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||||
|
decode(ctx, input.begin(), input.end(), 0, false, batch);
|
||||||
|
|
||||||
for (size_t i = 0; i < input.size(); i++)
|
int logit_idx = input.size() - 1;
|
||||||
{
|
|
||||||
llama_batch_add(batch, input[i], i, { 0 }, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
|
||||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
int logit_idx = batch.n_tokens - 1;
|
|
||||||
llama_tokens local = input;
|
llama_tokens local = input;
|
||||||
size_t match_len;
|
size_t match_len;
|
||||||
|
|
||||||
|
@ -167,20 +174,10 @@ static int speculation(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
|
||||||
// TODO theoretically this can be empty?
|
|
||||||
for (size_t i = match_len; i < local.size(); i++)
|
|
||||||
{
|
|
||||||
llama_batch_add(batch, local[i], i, { 0 }, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
logit_idx = batch.n_tokens - 1;
|
decode(ctx, local.begin() + match_len, local.end(), match_len, false, batch);
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0)
|
logit_idx = local.size() - match_len - 1;
|
||||||
{
|
|
||||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
|
@ -194,20 +191,11 @@ static int target(
|
||||||
size_t n_predict)
|
size_t n_predict)
|
||||||
{
|
{
|
||||||
dbg_default(to_string(ctx, input.begin(), input.end()));
|
dbg_default(to_string(ctx, input.begin(), input.end()));
|
||||||
// TODO: create int decode()
|
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(512, 0, 1);
|
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||||
for (size_t i = 0; i < input.size(); i++)
|
decode(ctx, input.begin(), input.end(), 0, false, batch);
|
||||||
{
|
|
||||||
llama_batch_add(batch, input[i], i, { 0 }, false);
|
|
||||||
}
|
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
|
||||||
fprintf(stderr, "llama_decode() failed\n");
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
// how many tokens are currently accepted
|
|
||||||
// TODO: rename to n_accepted
|
// TODO: rename to n_accepted
|
||||||
size_t n_cur = input.size();
|
size_t n_cur = input.size();
|
||||||
size_t n_decode = 0;
|
size_t n_decode = 0;
|
||||||
|
@ -215,8 +203,8 @@ static int target(
|
||||||
const auto t_main_start = ggml_time_us();
|
const auto t_main_start = ggml_time_us();
|
||||||
|
|
||||||
// we'll use logits from this position to determine next token
|
// we'll use logits from this position to determine next token
|
||||||
int logits_from = batch.n_tokens - 1;
|
int logits_from = input.size() - 1;
|
||||||
int logits_to = batch.n_tokens;
|
int logits_to = input.size();
|
||||||
|
|
||||||
llama_tokens input_seq, next_tokens;
|
llama_tokens input_seq, next_tokens;
|
||||||
input_seq.push_back(input.back());
|
input_seq.push_back(input.back());
|
||||||
|
@ -299,15 +287,8 @@ static int target(
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
decode(ctx, input_seq.begin(), input_seq.end(), n_cur - 1, true, batch);
|
||||||
for (size_t i = 0; i < input_seq.size(); i++)
|
|
||||||
{
|
|
||||||
llama_batch_add(batch, input_seq[i], n_cur - 1 + i, { 0 }, true);
|
|
||||||
}
|
|
||||||
if (llama_decode(ctx, batch)) {
|
|
||||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
logits_from = 0;
|
logits_from = 0;
|
||||||
logits_to = input_seq.size();
|
logits_to = input_seq.size();
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue