some renaming
This commit is contained in:
parent
479c80a0db
commit
60fe62e6eb
1 changed files with 30 additions and 21 deletions
|
@ -10,9 +10,9 @@
|
|||
static void dbg_color(const std::string & s, const std::string & fg)
|
||||
{
|
||||
static const std::string kReset = "\033[0m";
|
||||
static const std::string bold[] = { "", "\033[1m" };
|
||||
static const std::string kBold[] = { "", "\033[1m" };
|
||||
static size_t index = 0;
|
||||
std::cout << bold[index] << fg << s << kReset << std::flush;
|
||||
std::cout << kBold[index] << fg << s << kReset << std::flush;
|
||||
index = 1 - index;
|
||||
}
|
||||
|
||||
|
@ -98,7 +98,7 @@ static int speculation(
|
|||
std::vector<llama_model *> model,
|
||||
speculation_context * spec_ctx,
|
||||
std::vector<llama_context *> ctx,
|
||||
std::vector<llama_token> input /* copy here */) {
|
||||
llama_tokens input /* copy here */) {
|
||||
|
||||
int32_t active = 1;
|
||||
|
||||
|
@ -117,18 +117,18 @@ static int speculation(
|
|||
}
|
||||
|
||||
int logit_idx = batch.n_tokens - 1;
|
||||
std::vector<llama_token> local_spec = input;
|
||||
llama_tokens local = input;
|
||||
size_t match_len;
|
||||
|
||||
// TODO: here we need to not generate too many and wait
|
||||
while (true) {
|
||||
auto next_tokens = greedy_tokens(model[active], ctx[active], logit_idx, logit_idx + 1);
|
||||
if (next_tokens.size() != 1) {
|
||||
fprintf(stderr, "invalid next tokens\n");
|
||||
return 1;
|
||||
fprintf(stderr, "invalid next tokens\n");
|
||||
return 1;
|
||||
}
|
||||
|
||||
local_spec.push_back(next_tokens[0]);
|
||||
local.push_back(next_tokens[0]);
|
||||
|
||||
{
|
||||
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
|
||||
|
@ -136,12 +136,12 @@ static int speculation(
|
|||
{
|
||||
break;
|
||||
}
|
||||
auto& spec = spec_ctx->candidate;
|
||||
auto& shared = spec_ctx->candidate;
|
||||
bool match = true;
|
||||
match_len = local_spec.size() - 1;
|
||||
for (size_t i = 0; i < std::min(spec.size(), local_spec.size()); i++)
|
||||
match_len = local.size() - 1;
|
||||
for (size_t i = 0; i < std::min(shared.size(), local.size()); i++)
|
||||
{
|
||||
if (spec[i] != local_spec[i])
|
||||
if (shared[i] != local[i])
|
||||
{
|
||||
match = false;
|
||||
match_len = i;
|
||||
|
@ -151,23 +151,28 @@ static int speculation(
|
|||
break;
|
||||
}
|
||||
}
|
||||
if (match) {
|
||||
spec = local_spec;
|
||||
} else {
|
||||
local_spec = spec;
|
||||
if (match && shared.size() < local.size())
|
||||
{
|
||||
shared = local;
|
||||
}
|
||||
else
|
||||
{
|
||||
local = shared;
|
||||
}
|
||||
active = spec_ctx->active_id;
|
||||
}
|
||||
|
||||
llama_batch_clear(batch);
|
||||
// TODO theoretically this can be empty?
|
||||
for (size_t i = match_len; i < local_spec.size(); i++) {
|
||||
llama_batch_add(batch, local_spec[i], i, { 0 }, true);
|
||||
for (size_t i = match_len; i < local.size(); i++)
|
||||
{
|
||||
llama_batch_add(batch, local[i], i, { 0 }, true);
|
||||
}
|
||||
|
||||
logit_idx = batch.n_tokens - 1;
|
||||
|
||||
if (llama_decode(ctx[active], batch)) {
|
||||
if (llama_decode(ctx[active], batch) != 0)
|
||||
{
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
|
@ -177,7 +182,11 @@ static int speculation(
|
|||
return 0;
|
||||
}
|
||||
|
||||
static int target(llama_model * model, llama_context * ctx, const llama_tokens& input, size_t n_predict)
|
||||
static int target(
|
||||
llama_model * model,
|
||||
llama_context * ctx,
|
||||
const llama_tokens& input,
|
||||
size_t n_predict)
|
||||
{
|
||||
dbg_default(to_string(ctx, input.begin(), input.end()));
|
||||
// TODO: batch size
|
||||
|
@ -300,8 +309,8 @@ static int target(llama_model * model, llama_context * ctx, const llama_tokens&
|
|||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
LOG_TEE("%s: decoded %zu tokens in %.2f s, speed: %.2f t/s\n",
|
||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
fprintf(stderr, "decoded %zu tokens in %.2f s, speed: %.2f t/s\n",
|
||||
n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
llama_print_timings(ctx);
|
||||
fprintf(stderr, "\n");
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue