some renaming

This commit is contained in:
Oleksandr Kuvshynov 2024-05-22 23:52:36 -04:00
parent 479c80a0db
commit 60fe62e6eb

View file

@ -10,9 +10,9 @@
static void dbg_color(const std::string & s, const std::string & fg) static void dbg_color(const std::string & s, const std::string & fg)
{ {
static const std::string kReset = "\033[0m"; static const std::string kReset = "\033[0m";
static const std::string bold[] = { "", "\033[1m" }; static const std::string kBold[] = { "", "\033[1m" };
static size_t index = 0; static size_t index = 0;
std::cout << bold[index] << fg << s << kReset << std::flush; std::cout << kBold[index] << fg << s << kReset << std::flush;
index = 1 - index; index = 1 - index;
} }
@ -98,7 +98,7 @@ static int speculation(
std::vector<llama_model *> model, std::vector<llama_model *> model,
speculation_context * spec_ctx, speculation_context * spec_ctx,
std::vector<llama_context *> ctx, std::vector<llama_context *> ctx,
std::vector<llama_token> input /* copy here */) { llama_tokens input /* copy here */) {
int32_t active = 1; int32_t active = 1;
@ -117,7 +117,7 @@ static int speculation(
} }
int logit_idx = batch.n_tokens - 1; int logit_idx = batch.n_tokens - 1;
std::vector<llama_token> local_spec = input; llama_tokens local = input;
size_t match_len; size_t match_len;
// TODO: here we need to not generate too many and wait // TODO: here we need to not generate too many and wait
@ -128,7 +128,7 @@ static int speculation(
return 1; return 1;
} }
local_spec.push_back(next_tokens[0]); local.push_back(next_tokens[0]);
{ {
std::lock_guard<std::mutex> _lock(spec_ctx->mtx); std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
@ -136,12 +136,12 @@ static int speculation(
{ {
break; break;
} }
auto& spec = spec_ctx->candidate; auto& shared = spec_ctx->candidate;
bool match = true; bool match = true;
match_len = local_spec.size() - 1; match_len = local.size() - 1;
for (size_t i = 0; i < std::min(spec.size(), local_spec.size()); i++) for (size_t i = 0; i < std::min(shared.size(), local.size()); i++)
{ {
if (spec[i] != local_spec[i]) if (shared[i] != local[i])
{ {
match = false; match = false;
match_len = i; match_len = i;
@ -151,23 +151,28 @@ static int speculation(
break; break;
} }
} }
if (match) { if (match && shared.size() < local.size())
spec = local_spec; {
} else { shared = local;
local_spec = spec; }
else
{
local = shared;
} }
active = spec_ctx->active_id; active = spec_ctx->active_id;
} }
llama_batch_clear(batch); llama_batch_clear(batch);
// TODO theoretically this can be empty? // TODO theoretically this can be empty?
for (size_t i = match_len; i < local_spec.size(); i++) { for (size_t i = match_len; i < local.size(); i++)
llama_batch_add(batch, local_spec[i], i, { 0 }, true); {
llama_batch_add(batch, local[i], i, { 0 }, true);
} }
logit_idx = batch.n_tokens - 1; logit_idx = batch.n_tokens - 1;
if (llama_decode(ctx[active], batch)) { if (llama_decode(ctx[active], batch) != 0)
{
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1; return 1;
} }
@ -177,7 +182,11 @@ static int speculation(
return 0; return 0;
} }
static int target(llama_model * model, llama_context * ctx, const llama_tokens& input, size_t n_predict) static int target(
llama_model * model,
llama_context * ctx,
const llama_tokens& input,
size_t n_predict)
{ {
dbg_default(to_string(ctx, input.begin(), input.end())); dbg_default(to_string(ctx, input.begin(), input.end()));
// TODO: batch size // TODO: batch size
@ -300,8 +309,8 @@ static int target(llama_model * model, llama_context * ctx, const llama_tokens&
const auto t_main_end = ggml_time_us(); const auto t_main_end = ggml_time_us();
LOG_TEE("%s: decoded %zu tokens in %.2f s, speed: %.2f t/s\n", fprintf(stderr, "decoded %zu tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
llama_print_timings(ctx); llama_print_timings(ctx);
fprintf(stderr, "\n"); fprintf(stderr, "\n");