duo: cleanup v2

This commit is contained in:
Oleksandr Kuvshynov 2024-05-22 23:31:23 -04:00
parent eecdd3b0ce
commit 479c80a0db

View file

@ -18,19 +18,28 @@ static void dbg_color(const std::string & s, const std::string & fg)
static void dbg_accepted(const std::string & accepted)
{
static const std::string kGreen = "\033[32m";
dbg_color(accepted, kGreen);
dbg_color(accepted, /* green */ "\033[32m");
}
static void dbg_not_matched(const std::string & accepted)
static void dbg_default(const std::string & accepted)
{
dbg_color(accepted, "");
}
static void dbg_rejected(const std::string & rejected)
{
static const std::string kRed = "\033[31m";
dbg_color(rejected, kRed);
dbg_color(rejected, /* red */ "\033[31m");
}
template<typename Iterator>
static std::string to_string(llama_context * ctx, Iterator from, Iterator to)
{
std::string res = "";
for (auto it = from; it != to; ++it)
{
res += llama_token_to_piece(ctx, *it);
}
return res;
}
using llama_tokens = std::vector<llama_token>;
@ -111,6 +120,7 @@ static int speculation(
std::vector<llama_token> local_spec = input;
size_t match_len;
// TODO: here we need to not generate too many and wait
while (true) {
auto next_tokens = greedy_tokens(model[active], ctx[active], logit_idx, logit_idx + 1);
if (next_tokens.size() != 1) {
@ -169,18 +179,17 @@ static int speculation(
static int target(llama_model * model, llama_context * ctx, const llama_tokens& input, size_t n_predict)
{
dbg_default(to_string(ctx, input.begin(), input.end()));
// TODO: batch size
llama_batch batch = llama_batch_init(512, 0, 1);
// evaluate the initial prompt
for (size_t i = 0; i < input.size(); i++) {
for (size_t i = 0; i < input.size(); i++)
{
llama_batch_add(batch, input[i], i, { 0 }, false);
}
batch.logits[batch.n_tokens - 1] = true;
if (llama_decode(ctx, batch) != 0) {
LOG_TEE("%s: llama_decode() failed\n", __func__);
fprintf(stderr, "llama_decode() failed\n");
return 1;
}
@ -195,7 +204,7 @@ static int target(llama_model * model, llama_context * ctx, const llama_tokens&
int logits_from = batch.n_tokens - 1;
int logits_to = batch.n_tokens;
llama_tokens input_seq, next_tokens, output;
llama_tokens input_seq, next_tokens;
input_seq.push_back(input.back());
while (n_decode <= n_predict)
@ -241,7 +250,6 @@ static int target(llama_model * model, llama_context * ctx, const llama_tokens&
break;
}
}
output.insert(output.end(), next_tokens.begin(), next_tokens.end());
{
std::lock_guard<std::mutex> _lock(spec_ctx.mtx);
@ -259,31 +267,11 @@ static int target(llama_model * model, llama_context * ctx, const llama_tokens&
}
}
std::string accepted = "";
for (size_t i = next_tokens_pos; i < next_tokens_pos + n_match; i++)
{
accepted += llama_token_to_piece(ctx, spec[i]);
}
dbg_accepted(accepted);
if (n_match != next_tokens.size())
{
std::string rejected = "";
for (size_t i = next_tokens_pos + n_match; i < spec.size(); i++)
{
rejected += llama_token_to_piece(ctx, spec[i]);
}
dbg_rejected(rejected);
std::string not_matched = "";
for (size_t i = n_match; i < next_tokens.size(); i++)
{
not_matched += llama_token_to_piece(ctx, next_tokens[i]);
}
dbg_not_matched(not_matched);
}
// remove non-matched tokens
dbg_accepted(to_string(ctx, spec.begin() + next_tokens_pos, spec.begin() + next_tokens_pos + n_match));
if (n_match != next_tokens.size())
{
dbg_rejected(to_string(ctx, spec.begin() + next_tokens_pos + n_match, spec.end()));
dbg_default(to_string(ctx, next_tokens.begin() + n_match, next_tokens.end()));
spec.erase(spec.begin() + next_tokens_pos, spec.end());
for (const auto tok: next_tokens)
{
@ -337,7 +325,6 @@ int main(int argc, char ** argv) {
params.seed = time(NULL);
}
// parse 2 speculation rpc instances
std::string draft_rpcs = params.rpc_servers_draft;
size_t i = draft_rpcs.find(',');
if (i == std::string::npos || draft_rpcs.find(',', i + 1) != std::string::npos)