This commit is contained in:
Oleksandr Kuvshynov 2024-05-25 14:23:57 -04:00
parent 78938bc0c9
commit 96811fdf63

View file

@ -88,6 +88,24 @@ static llama_tokens greedy_tokens(llama_model * model, llama_context * ctx, int3
return res;
}
template<typename iter_t>
static int decode(llama_context * ctx, iter_t from, iter_t to, int offset, bool all_logits, llama_batch & batch)
{
llama_batch_clear(batch);
size_t i = offset;
for (auto it = from; it != to; ++it)
{
llama_batch_add(batch, *it, i++, { 0 }, all_logits);
}
batch.logits[batch.n_tokens - 1] = true;
int res = 0;
if (llama_decode(ctx, batch) != 0) {
fprintf(stderr, "llama_decode() failed\n");
res = 1;
}
return res;
}
static int speculation(
llama_model * model,
speculation_context * spec_ctx,
@ -96,20 +114,9 @@ static int speculation(
// TODO: check that input is non-empty
llama_batch batch = llama_batch_init(512, 0, 1);
decode(ctx, input.begin(), input.end(), 0, false, batch);
for (size_t i = 0; i < input.size(); i++)
{
llama_batch_add(batch, input[i], i, { 0 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
return 1;
}
int logit_idx = batch.n_tokens - 1;
int logit_idx = input.size() - 1;
llama_tokens local = input;
size_t match_len;
@ -167,20 +174,10 @@ static int speculation(
}
}
llama_batch_clear(batch);
// TODO theoretically this can be empty?
for (size_t i = match_len; i < local.size(); i++)
{
llama_batch_add(batch, local[i], i, { 0 }, true);
}
logit_idx = batch.n_tokens - 1;
decode(ctx, local.begin() + match_len, local.end(), match_len, false, batch);
if (llama_decode(ctx, batch) != 0)
{
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
logit_idx = local.size() - match_len - 1;
}
llama_batch_free(batch);
@ -194,20 +191,11 @@ static int target(
size_t n_predict)
{
dbg_default(to_string(ctx, input.begin(), input.end()));
// TODO: create int decode()
llama_batch batch = llama_batch_init(512, 0, 1);
for (size_t i = 0; i < input.size(); i++)
{
llama_batch_add(batch, input[i], i, { 0 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
decode(ctx, input.begin(), input.end(), 0, false, batch);
if (llama_decode(ctx, batch) != 0) {
fprintf(stderr, "llama_decode() failed\n");
return 1;
}
// how many tokens are currently accepted
// TODO: rename to n_accepted
size_t n_cur = input.size();
size_t n_decode = 0;
@ -215,8 +203,8 @@ static int target(
const auto t_main_start = ggml_time_us();
// we'll use logits from this position to determine next token
int logits_from = batch.n_tokens - 1;
int logits_to = batch.n_tokens;
int logits_from = input.size() - 1;
int logits_to = input.size();
llama_tokens input_seq, next_tokens;
input_seq.push_back(input.back());
@ -299,15 +287,8 @@ static int target(
break;
}
llama_batch_clear(batch);
for (size_t i = 0; i < input_seq.size(); i++)
{
llama_batch_add(batch, input_seq[i], n_cur - 1 + i, { 0 }, true);
}
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
decode(ctx, input_seq.begin(), input_seq.end(), n_cur - 1, true, batch);
logits_from = 0;
logits_to = input_seq.size();
}