This commit is contained in:
Oleksandr Kuvshynov 2024-05-24 22:21:41 -04:00
parent 66982abcb1
commit 10d5aefed5

View file

@ -136,7 +136,7 @@ static int speculation(
} }
if (wait) if (wait)
{ {
std::this_thread::sleep_for(std::chrono::milliseconds{10}); std::this_thread::sleep_for(std::chrono::milliseconds{5});
continue; continue;
} }
@ -158,7 +158,6 @@ static int speculation(
auto& shared = spec_ctx->candidate; auto& shared = spec_ctx->candidate;
bool match = true; bool match = true;
match_len = local.size() - 1; match_len = local.size() - 1;
fprintf(stderr, "spec #%d: %zu | %zu\n", active, shared.size(), local.size());
for (size_t i = 0; i < std::min(shared.size(), local.size()); i++) for (size_t i = 0; i < std::min(shared.size(), local.size()); i++)
{ {
if (shared[i] != local[i]) if (shared[i] != local[i])
@ -167,7 +166,7 @@ static int speculation(
match_len = i; match_len = i;
// here we need to clear both contexts // here we need to clear both contexts
llama_kv_cache_seq_rm(ctx[0], 0, i, -1); llama_kv_cache_seq_rm(ctx[0], 0, i, -1);
llama_kv_cache_seq_rm(ctx[1], 0, i, -1); //llama_kv_cache_seq_rm(ctx[1], 0, i, -1);
break; break;
} }
} }
@ -318,17 +317,20 @@ static int target(
break; break;
} }
fprintf(stderr, "tgt: input_seq.size() = %zu\n", input_seq.size()); fprintf(stderr, "\ntgt: input_seq.size() = %zu\n", input_seq.size());
llama_batch_clear(batch); llama_batch_clear(batch);
for (size_t i = 0; i < input_seq.size(); i++) for (size_t i = 0; i < input_seq.size(); i++)
{ {
llama_batch_add(batch, input_seq[i], n_cur - 1 + i, { 0 }, true); llama_batch_add(batch, input_seq[i], n_cur - 1 + i, { 0 }, true);
} }
auto s_us = ggml_time_us();
if (llama_decode(ctx, batch)) { if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1); fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
return 1; return 1;
} }
auto eval_us = ggml_time_us() - s_us;
fprintf(stderr, "eval_time: %lld", eval_us);
logits_from = 0; logits_from = 0;
logits_to = input_seq.size(); logits_to = input_seq.size();
} }