duo: cleanup v2
This commit is contained in:
parent
eecdd3b0ce
commit
479c80a0db
1 changed files with 99 additions and 112 deletions
|
@ -18,21 +18,30 @@ static void dbg_color(const std::string & s, const std::string & fg)
|
||||||
|
|
||||||
static void dbg_accepted(const std::string & accepted)
|
static void dbg_accepted(const std::string & accepted)
|
||||||
{
|
{
|
||||||
static const std::string kGreen = "\033[32m";
|
dbg_color(accepted, /* green */ "\033[32m");
|
||||||
dbg_color(accepted, kGreen);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dbg_not_matched(const std::string & accepted)
|
static void dbg_default(const std::string & accepted)
|
||||||
{
|
{
|
||||||
dbg_color(accepted, "");
|
dbg_color(accepted, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dbg_rejected(const std::string & rejected)
|
static void dbg_rejected(const std::string & rejected)
|
||||||
{
|
{
|
||||||
static const std::string kRed = "\033[31m";
|
dbg_color(rejected, /* red */ "\033[31m");
|
||||||
dbg_color(rejected, kRed);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<typename Iterator>
|
||||||
|
static std::string to_string(llama_context * ctx, Iterator from, Iterator to)
|
||||||
|
{
|
||||||
|
std::string res = "";
|
||||||
|
for (auto it = from; it != to; ++it)
|
||||||
|
{
|
||||||
|
res += llama_token_to_piece(ctx, *it);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
using llama_tokens = std::vector<llama_token>;
|
using llama_tokens = std::vector<llama_token>;
|
||||||
|
|
||||||
struct speculation_context
|
struct speculation_context
|
||||||
|
@ -93,97 +102,97 @@ static int speculation(
|
||||||
|
|
||||||
int32_t active = 1;
|
int32_t active = 1;
|
||||||
|
|
||||||
llama_batch batch = llama_batch_init(512, 0, 1);
|
|
||||||
|
|
||||||
for (size_t i = 0; i < input.size(); i++)
|
|
||||||
{
|
|
||||||
llama_batch_add(batch, input[i], i, { 0 }, false);
|
|
||||||
}
|
|
||||||
|
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
|
||||||
|
|
||||||
if (llama_decode(ctx[active], batch) != 0) {
|
|
||||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
int logit_idx = batch.n_tokens - 1;
|
|
||||||
std::vector<llama_token> local_spec = input;
|
|
||||||
size_t match_len;
|
|
||||||
|
|
||||||
while (true) {
|
|
||||||
auto next_tokens = greedy_tokens(model[active], ctx[active], logit_idx, logit_idx + 1);
|
|
||||||
if (next_tokens.size() != 1) {
|
|
||||||
fprintf(stderr, "invalid next tokens\n");
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
local_spec.push_back(next_tokens[0]);
|
|
||||||
|
|
||||||
{
|
|
||||||
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
|
|
||||||
if (spec_ctx->done)
|
|
||||||
{
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
auto& spec = spec_ctx->candidate;
|
|
||||||
bool match = true;
|
|
||||||
match_len = local_spec.size() - 1;
|
|
||||||
for (size_t i = 0; i < std::min(spec.size(), local_spec.size()); i++)
|
|
||||||
{
|
|
||||||
if (spec[i] != local_spec[i])
|
|
||||||
{
|
|
||||||
match = false;
|
|
||||||
match_len = i;
|
|
||||||
// here we need to clear both contexts
|
|
||||||
llama_kv_cache_seq_rm(ctx[0], 0, i, -1);
|
|
||||||
llama_kv_cache_seq_rm(ctx[1], 0, i, -1);
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (match) {
|
|
||||||
spec = local_spec;
|
|
||||||
} else {
|
|
||||||
local_spec = spec;
|
|
||||||
}
|
|
||||||
active = spec_ctx->active_id;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_batch_clear(batch);
|
|
||||||
// TODO theoretically this can be empty?
|
|
||||||
for (size_t i = match_len; i < local_spec.size(); i++) {
|
|
||||||
llama_batch_add(batch, local_spec[i], i, { 0 }, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
logit_idx = batch.n_tokens - 1;
|
|
||||||
|
|
||||||
if (llama_decode(ctx[active], batch)) {
|
|
||||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
|
||||||
return 1;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_batch_free(batch);
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
static int target(llama_model * model, llama_context * ctx, const llama_tokens& input, size_t n_predict)
|
|
||||||
{
|
|
||||||
// TODO: batch size
|
|
||||||
llama_batch batch = llama_batch_init(512, 0, 1);
|
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||||
|
|
||||||
// evaluate the initial prompt
|
for (size_t i = 0; i < input.size(); i++)
|
||||||
for (size_t i = 0; i < input.size(); i++) {
|
{
|
||||||
llama_batch_add(batch, input[i], i, { 0 }, false);
|
llama_batch_add(batch, input[i], i, { 0 }, false);
|
||||||
}
|
}
|
||||||
|
|
||||||
batch.logits[batch.n_tokens - 1] = true;
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
if (llama_decode(ctx, batch) != 0) {
|
if (llama_decode(ctx[active], batch) != 0) {
|
||||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int logit_idx = batch.n_tokens - 1;
|
||||||
|
std::vector<llama_token> local_spec = input;
|
||||||
|
size_t match_len;
|
||||||
|
|
||||||
|
// TODO: here we need to not generate too many and wait
|
||||||
|
while (true) {
|
||||||
|
auto next_tokens = greedy_tokens(model[active], ctx[active], logit_idx, logit_idx + 1);
|
||||||
|
if (next_tokens.size() != 1) {
|
||||||
|
fprintf(stderr, "invalid next tokens\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
local_spec.push_back(next_tokens[0]);
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
|
||||||
|
if (spec_ctx->done)
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
auto& spec = spec_ctx->candidate;
|
||||||
|
bool match = true;
|
||||||
|
match_len = local_spec.size() - 1;
|
||||||
|
for (size_t i = 0; i < std::min(spec.size(), local_spec.size()); i++)
|
||||||
|
{
|
||||||
|
if (spec[i] != local_spec[i])
|
||||||
|
{
|
||||||
|
match = false;
|
||||||
|
match_len = i;
|
||||||
|
// here we need to clear both contexts
|
||||||
|
llama_kv_cache_seq_rm(ctx[0], 0, i, -1);
|
||||||
|
llama_kv_cache_seq_rm(ctx[1], 0, i, -1);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (match) {
|
||||||
|
spec = local_spec;
|
||||||
|
} else {
|
||||||
|
local_spec = spec;
|
||||||
|
}
|
||||||
|
active = spec_ctx->active_id;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_batch_clear(batch);
|
||||||
|
// TODO theoretically this can be empty?
|
||||||
|
for (size_t i = match_len; i < local_spec.size(); i++) {
|
||||||
|
llama_batch_add(batch, local_spec[i], i, { 0 }, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
logit_idx = batch.n_tokens - 1;
|
||||||
|
|
||||||
|
if (llama_decode(ctx[active], batch)) {
|
||||||
|
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int target(llama_model * model, llama_context * ctx, const llama_tokens& input, size_t n_predict)
|
||||||
|
{
|
||||||
|
dbg_default(to_string(ctx, input.begin(), input.end()));
|
||||||
|
// TODO: batch size
|
||||||
|
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||||
|
for (size_t i = 0; i < input.size(); i++)
|
||||||
|
{
|
||||||
|
llama_batch_add(batch, input[i], i, { 0 }, false);
|
||||||
|
}
|
||||||
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
|
if (llama_decode(ctx, batch) != 0) {
|
||||||
|
fprintf(stderr, "llama_decode() failed\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
// how many tokens are currently accepted
|
// how many tokens are currently accepted
|
||||||
// TODO: rename to n_accepted
|
// TODO: rename to n_accepted
|
||||||
size_t n_cur = input.size();
|
size_t n_cur = input.size();
|
||||||
|
@ -195,7 +204,7 @@ static int target(llama_model * model, llama_context * ctx, const llama_tokens&
|
||||||
int logits_from = batch.n_tokens - 1;
|
int logits_from = batch.n_tokens - 1;
|
||||||
int logits_to = batch.n_tokens;
|
int logits_to = batch.n_tokens;
|
||||||
|
|
||||||
llama_tokens input_seq, next_tokens, output;
|
llama_tokens input_seq, next_tokens;
|
||||||
input_seq.push_back(input.back());
|
input_seq.push_back(input.back());
|
||||||
|
|
||||||
while (n_decode <= n_predict)
|
while (n_decode <= n_predict)
|
||||||
|
@ -241,7 +250,6 @@ static int target(llama_model * model, llama_context * ctx, const llama_tokens&
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
output.insert(output.end(), next_tokens.begin(), next_tokens.end());
|
|
||||||
|
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> _lock(spec_ctx.mtx);
|
std::lock_guard<std::mutex> _lock(spec_ctx.mtx);
|
||||||
|
@ -259,31 +267,11 @@ static int target(llama_model * model, llama_context * ctx, const llama_tokens&
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
std::string accepted = "";
|
dbg_accepted(to_string(ctx, spec.begin() + next_tokens_pos, spec.begin() + next_tokens_pos + n_match));
|
||||||
for (size_t i = next_tokens_pos; i < next_tokens_pos + n_match; i++)
|
|
||||||
{
|
|
||||||
accepted += llama_token_to_piece(ctx, spec[i]);
|
|
||||||
}
|
|
||||||
dbg_accepted(accepted);
|
|
||||||
if (n_match != next_tokens.size())
|
|
||||||
{
|
|
||||||
std::string rejected = "";
|
|
||||||
for (size_t i = next_tokens_pos + n_match; i < spec.size(); i++)
|
|
||||||
{
|
|
||||||
rejected += llama_token_to_piece(ctx, spec[i]);
|
|
||||||
}
|
|
||||||
dbg_rejected(rejected);
|
|
||||||
std::string not_matched = "";
|
|
||||||
for (size_t i = n_match; i < next_tokens.size(); i++)
|
|
||||||
{
|
|
||||||
not_matched += llama_token_to_piece(ctx, next_tokens[i]);
|
|
||||||
}
|
|
||||||
dbg_not_matched(not_matched);
|
|
||||||
}
|
|
||||||
|
|
||||||
// remove non-matched tokens
|
|
||||||
if (n_match != next_tokens.size())
|
if (n_match != next_tokens.size())
|
||||||
{
|
{
|
||||||
|
dbg_rejected(to_string(ctx, spec.begin() + next_tokens_pos + n_match, spec.end()));
|
||||||
|
dbg_default(to_string(ctx, next_tokens.begin() + n_match, next_tokens.end()));
|
||||||
spec.erase(spec.begin() + next_tokens_pos, spec.end());
|
spec.erase(spec.begin() + next_tokens_pos, spec.end());
|
||||||
for (const auto tok: next_tokens)
|
for (const auto tok: next_tokens)
|
||||||
{
|
{
|
||||||
|
@ -337,7 +325,6 @@ int main(int argc, char ** argv) {
|
||||||
params.seed = time(NULL);
|
params.seed = time(NULL);
|
||||||
}
|
}
|
||||||
|
|
||||||
// parse 2 speculation rpc instances
|
|
||||||
std::string draft_rpcs = params.rpc_servers_draft;
|
std::string draft_rpcs = params.rpc_servers_draft;
|
||||||
size_t i = draft_rpcs.find(',');
|
size_t i = draft_rpcs.find(',');
|
||||||
if (i == std::string::npos || draft_rpcs.find(',', i + 1) != std::string::npos)
|
if (i == std::string::npos || draft_rpcs.find(',', i + 1) != std::string::npos)
|
||||||
|
@ -360,7 +347,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// prepare draft model and contexts. No need for two model instances?
|
// prepare draft model and contexts. No need for two model instances?
|
||||||
std::vector<llama_model *> draft_models = {nullptr, nullptr};
|
std::vector<llama_model *> draft_models = {nullptr, nullptr};
|
||||||
std::vector<llama_context *> draft_ctx = {nullptr, nullptr};
|
std::vector<llama_context *> draft_ctx = {nullptr, nullptr};
|
||||||
|
|
||||||
params.model = params.model_draft;
|
params.model = params.model_draft;
|
||||||
params.n_gpu_layers = params.n_gpu_layers_draft;
|
params.n_gpu_layers = params.n_gpu_layers_draft;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue