duo: v2
This commit is contained in:
parent
78938bc0c9
commit
96811fdf63
1 changed files with 29 additions and 48 deletions
|
@ -88,6 +88,24 @@ static llama_tokens greedy_tokens(llama_model * model, llama_context * ctx, int3
|
|||
return res;
|
||||
}
|
||||
|
||||
template<typename iter_t>
|
||||
static int decode(llama_context * ctx, iter_t from, iter_t to, int offset, bool all_logits, llama_batch & batch)
|
||||
{
|
||||
llama_batch_clear(batch);
|
||||
size_t i = offset;
|
||||
for (auto it = from; it != to; ++it)
|
||||
{
|
||||
llama_batch_add(batch, *it, i++, { 0 }, all_logits);
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
int res = 0;
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
fprintf(stderr, "llama_decode() failed\n");
|
||||
res = 1;
|
||||
}
|
||||
return res;
|
||||
}
|
||||
|
||||
static int speculation(
|
||||
llama_model * model,
|
||||
speculation_context * spec_ctx,
|
||||
|
@ -96,20 +114,9 @@ static int speculation(
|
|||
|
||||
// TODO: check that input is non-empty
|
||||
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||
decode(ctx, input.begin(), input.end(), 0, false, batch);
|
||||
|
||||
for (size_t i = 0; i < input.size(); i++)
|
||||
{
|
||||
llama_batch_add(batch, input[i], i, { 0 }, false);
|
||||
}
|
||||
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
int logit_idx = batch.n_tokens - 1;
|
||||
int logit_idx = input.size() - 1;
|
||||
llama_tokens local = input;
|
||||
size_t match_len;
|
||||
|
||||
|
@ -167,20 +174,10 @@ static int speculation(
|
|||
}
|
||||
}
|
||||
|
||||
llama_batch_clear(batch);
|
||||
// TODO theoretically this can be empty?
|
||||
for (size_t i = match_len; i < local.size(); i++)
|
||||
{
|
||||
llama_batch_add(batch, local[i], i, { 0 }, true);
|
||||
}
|
||||
|
||||
logit_idx = batch.n_tokens - 1;
|
||||
decode(ctx, local.begin() + match_len, local.end(), match_len, false, batch);
|
||||
|
||||
if (llama_decode(ctx, batch) != 0)
|
||||
{
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
logit_idx = local.size() - match_len - 1;
|
||||
}
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
@ -194,20 +191,11 @@ static int target(
|
|||
size_t n_predict)
|
||||
{
|
||||
dbg_default(to_string(ctx, input.begin(), input.end()));
|
||||
// TODO: create int decode()
|
||||
|
||||
|
||||
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||
for (size_t i = 0; i < input.size(); i++)
|
||||
{
|
||||
llama_batch_add(batch, input[i], i, { 0 }, false);
|
||||
}
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
decode(ctx, input.begin(), input.end(), 0, false, batch);
|
||||
|
||||
if (llama_decode(ctx, batch) != 0) {
|
||||
fprintf(stderr, "llama_decode() failed\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
// how many tokens are currently accepted
|
||||
// TODO: rename to n_accepted
|
||||
size_t n_cur = input.size();
|
||||
size_t n_decode = 0;
|
||||
|
@ -215,8 +203,8 @@ static int target(
|
|||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
// we'll use logits from this position to determine next token
|
||||
int logits_from = batch.n_tokens - 1;
|
||||
int logits_to = batch.n_tokens;
|
||||
int logits_from = input.size() - 1;
|
||||
int logits_to = input.size();
|
||||
|
||||
llama_tokens input_seq, next_tokens;
|
||||
input_seq.push_back(input.back());
|
||||
|
@ -299,15 +287,8 @@ static int target(
|
|||
break;
|
||||
}
|
||||
|
||||
llama_batch_clear(batch);
|
||||
for (size_t i = 0; i < input_seq.size(); i++)
|
||||
{
|
||||
llama_batch_add(batch, input_seq[i], n_cur - 1 + i, { 0 }, true);
|
||||
}
|
||||
if (llama_decode(ctx, batch)) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
decode(ctx, input_seq.begin(), input_seq.end(), n_cur - 1, true, batch);
|
||||
|
||||
logits_from = 0;
|
||||
logits_to = input_seq.size();
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue