duo: v0
This commit is contained in:
parent
83aabb3fb7
commit
78938bc0c9
1 changed files with 10 additions and 18 deletions
|
@ -32,8 +32,8 @@ static void dbg_rejected(const std::string & rejected)
|
||||||
dbg_color(rejected, /* red */ "\033[31m");
|
dbg_color(rejected, /* red */ "\033[31m");
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename Iterator>
|
template<typename iter_t>
|
||||||
static std::string to_string(llama_context * ctx, Iterator from, Iterator to)
|
static std::string to_string(llama_context * ctx, iter_t from, iter_t to)
|
||||||
{
|
{
|
||||||
std::string res = "";
|
std::string res = "";
|
||||||
for (auto it = from; it != to; ++it)
|
for (auto it = from; it != to; ++it)
|
||||||
|
@ -55,6 +55,7 @@ struct speculation_context
|
||||||
|
|
||||||
speculation_context spec_ctx;
|
speculation_context spec_ctx;
|
||||||
|
|
||||||
|
// pass void * spec_ctx
|
||||||
static void split_done_cb(int split)
|
static void split_done_cb(int split)
|
||||||
{
|
{
|
||||||
if (split == 1 || split == 2)
|
if (split == 1 || split == 2)
|
||||||
|
@ -65,20 +66,12 @@ static void split_done_cb(int split)
|
||||||
}
|
}
|
||||||
|
|
||||||
// this ignores all the other sampling criteria
|
// this ignores all the other sampling criteria
|
||||||
static std::vector<llama_token> greedy_tokens(
|
static llama_tokens greedy_tokens(llama_model * model, llama_context * ctx, int32_t from, int32_t to)
|
||||||
llama_model * model,
|
|
||||||
llama_context * ctx,
|
|
||||||
int32_t from_idx,
|
|
||||||
int32_t to_idx)
|
|
||||||
{
|
{
|
||||||
auto n_vocab = llama_n_vocab(model);
|
auto n_vocab = llama_n_vocab(model);
|
||||||
std::vector<llama_token> res;
|
std::vector<llama_token> res;
|
||||||
if (n_vocab <= 0)
|
|
||||||
{
|
|
||||||
return res;
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int idx = from_idx; idx < to_idx; idx++)
|
for (int idx = from; idx < to; idx++)
|
||||||
{
|
{
|
||||||
auto * logits = llama_get_logits_ith(ctx, idx);
|
auto * logits = llama_get_logits_ith(ctx, idx);
|
||||||
llama_token new_token_id = 0;
|
llama_token new_token_id = 0;
|
||||||
|
@ -99,8 +92,9 @@ static int speculation(
|
||||||
llama_model * model,
|
llama_model * model,
|
||||||
speculation_context * spec_ctx,
|
speculation_context * spec_ctx,
|
||||||
llama_context * ctx,
|
llama_context * ctx,
|
||||||
llama_tokens input /* copy here */) {
|
const llama_tokens & input) {
|
||||||
|
|
||||||
|
// TODO: check that input is non-empty
|
||||||
llama_batch batch = llama_batch_init(512, 0, 1);
|
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||||
|
|
||||||
for (size_t i = 0; i < input.size(); i++)
|
for (size_t i = 0; i < input.size(); i++)
|
||||||
|
@ -122,7 +116,7 @@ static int speculation(
|
||||||
// TODO: here we need to not generate too many and wait
|
// TODO: here we need to not generate too many and wait
|
||||||
while (true)
|
while (true)
|
||||||
{
|
{
|
||||||
// silliest thing ever
|
// TODO: cond var instead
|
||||||
bool wait = false;
|
bool wait = false;
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> g(spec_ctx->mtx);
|
std::lock_guard<std::mutex> g(spec_ctx->mtx);
|
||||||
|
@ -141,7 +135,6 @@ static int speculation(
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
auto next_tokens = greedy_tokens(model, ctx, logit_idx, logit_idx + 1);
|
auto next_tokens = greedy_tokens(model, ctx, logit_idx, logit_idx + 1);
|
||||||
if (next_tokens.size() != 1) {
|
if (next_tokens.size() != 1) {
|
||||||
fprintf(stderr, "invalid next tokens\n");
|
fprintf(stderr, "invalid next tokens\n");
|
||||||
|
@ -149,7 +142,6 @@ static int speculation(
|
||||||
}
|
}
|
||||||
|
|
||||||
local.push_back(next_tokens[0]);
|
local.push_back(next_tokens[0]);
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
|
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
|
||||||
auto& shared = spec_ctx->candidate;
|
auto& shared = spec_ctx->candidate;
|
||||||
|
@ -202,7 +194,7 @@ static int target(
|
||||||
size_t n_predict)
|
size_t n_predict)
|
||||||
{
|
{
|
||||||
dbg_default(to_string(ctx, input.begin(), input.end()));
|
dbg_default(to_string(ctx, input.begin(), input.end()));
|
||||||
// TODO: batch size
|
// TODO: create int decode()
|
||||||
llama_batch batch = llama_batch_init(512, 0, 1);
|
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||||
for (size_t i = 0; i < input.size(); i++)
|
for (size_t i = 0; i < input.size(); i++)
|
||||||
{
|
{
|
||||||
|
@ -359,7 +351,7 @@ int main(int argc, char ** argv) {
|
||||||
llama_tokens input = llama_tokenize(ctx, params.prompt, true);
|
llama_tokens input = llama_tokenize(ctx, params.prompt, true);
|
||||||
spec_ctx.candidate = input;
|
spec_ctx.candidate = input;
|
||||||
|
|
||||||
// prepare draft model and contexts. No need for two model instances?
|
// prepare draft model and contexts.
|
||||||
llama_model * draft_model = nullptr;
|
llama_model * draft_model = nullptr;
|
||||||
llama_context * draft_ctx = nullptr;
|
llama_context * draft_ctx = nullptr;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue