some renaming

This commit is contained in:
Oleksandr Kuvshynov 2024-05-22 23:52:36 -04:00
parent 479c80a0db
commit 60fe62e6eb

View file

@ -10,9 +10,9 @@
static void dbg_color(const std::string & s, const std::string & fg)
{
static const std::string kReset = "\033[0m";
static const std::string bold[] = { "", "\033[1m" };
static const std::string kBold[] = { "", "\033[1m" };
static size_t index = 0;
std::cout << bold[index] << fg << s << kReset << std::flush;
std::cout << kBold[index] << fg << s << kReset << std::flush;
index = 1 - index;
}
@ -98,7 +98,7 @@ static int speculation(
std::vector<llama_model *> model,
speculation_context * spec_ctx,
std::vector<llama_context *> ctx,
std::vector<llama_token> input /* copy here */) {
llama_tokens input /* copy here */) {
int32_t active = 1;
@ -117,7 +117,7 @@ static int speculation(
}
int logit_idx = batch.n_tokens - 1;
std::vector<llama_token> local_spec = input;
llama_tokens local = input;
size_t match_len;
// TODO: here we need to not generate too many and wait
@ -128,7 +128,7 @@ static int speculation(
return 1;
}
local_spec.push_back(next_tokens[0]);
local.push_back(next_tokens[0]);
{
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
@ -136,12 +136,12 @@ static int speculation(
{
break;
}
auto& spec = spec_ctx->candidate;
auto& shared = spec_ctx->candidate;
bool match = true;
match_len = local_spec.size() - 1;
for (size_t i = 0; i < std::min(spec.size(), local_spec.size()); i++)
match_len = local.size() - 1;
for (size_t i = 0; i < std::min(shared.size(), local.size()); i++)
{
if (spec[i] != local_spec[i])
if (shared[i] != local[i])
{
match = false;
match_len = i;
@ -151,23 +151,28 @@ static int speculation(
break;
}
}
if (match) {
spec = local_spec;
} else {
local_spec = spec;
if (match && shared.size() < local.size())
{
shared = local;
}
else
{
local = shared;
}
active = spec_ctx->active_id;
}
llama_batch_clear(batch);
// TODO theoretically this can be empty?
for (size_t i = match_len; i < local_spec.size(); i++) {
llama_batch_add(batch, local_spec[i], i, { 0 }, true);
for (size_t i = match_len; i < local.size(); i++)
{
llama_batch_add(batch, local[i], i, { 0 }, true);
}
logit_idx = batch.n_tokens - 1;
if (llama_decode(ctx[active], batch)) {
if (llama_decode(ctx[active], batch) != 0)
{
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1;
}
@ -177,7 +182,11 @@ static int speculation(
return 0;
}
static int target(llama_model * model, llama_context * ctx, const llama_tokens& input, size_t n_predict)
static int target(
llama_model * model,
llama_context * ctx,
const llama_tokens& input,
size_t n_predict)
{
dbg_default(to_string(ctx, input.begin(), input.end()));
// TODO: batch size
@ -300,8 +309,8 @@ static int target(llama_model * model, llama_context * ctx, const llama_tokens&
const auto t_main_end = ggml_time_us();
LOG_TEE("%s: decoded %zu tokens in %.2f s, speed: %.2f t/s\n",
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
fprintf(stderr, "decoded %zu tokens in %.2f s, speed: %.2f t/s\n",
n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
llama_print_timings(ctx);
fprintf(stderr, "\n");