fixes
This commit is contained in:
parent
02e2c91d01
commit
66982abcb1
2 changed files with 31 additions and 4 deletions
|
@ -4,4 +4,4 @@ Minimal example. What's not implemented, but can be implemented separately in pi
|
|||
* tree-based speculation
|
||||
* correct sampling
|
||||
* support more than 2 instances
|
||||
*
|
||||
* just one instance speculates
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <chrono>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <memory>
|
||||
|
@ -59,6 +60,7 @@ static void split_done_cb(int split)
|
|||
if (split == 1 || split == 2)
|
||||
{
|
||||
std::lock_guard<std::mutex> guard(spec_ctx.mtx);
|
||||
fprintf(stderr, "split_done = %d\n", split);
|
||||
spec_ctx.active_id = split - 1;
|
||||
}
|
||||
}
|
||||
|
@ -121,7 +123,24 @@ static int speculation(
|
|||
size_t match_len;
|
||||
|
||||
// TODO: here we need to not generate too many and wait
|
||||
while (true) {
|
||||
while (true)
|
||||
{
|
||||
// silliest thing ever
|
||||
bool wait = false;
|
||||
{
|
||||
std::lock_guard<std::mutex> g(spec_ctx->mtx);
|
||||
if (spec_ctx->active_id != 0)
|
||||
{
|
||||
wait = true;
|
||||
}
|
||||
}
|
||||
if (wait)
|
||||
{
|
||||
std::this_thread::sleep_for(std::chrono::milliseconds{10});
|
||||
continue;
|
||||
}
|
||||
|
||||
|
||||
auto next_tokens = greedy_tokens(model[active], ctx[active], logit_idx, logit_idx + 1);
|
||||
if (next_tokens.size() != 1) {
|
||||
fprintf(stderr, "invalid next tokens\n");
|
||||
|
@ -139,6 +158,7 @@ static int speculation(
|
|||
auto& shared = spec_ctx->candidate;
|
||||
bool match = true;
|
||||
match_len = local.size() - 1;
|
||||
fprintf(stderr, "spec #%d: %zu | %zu\n", active, shared.size(), local.size());
|
||||
for (size_t i = 0; i < std::min(shared.size(), local.size()); i++)
|
||||
{
|
||||
if (shared[i] != local[i])
|
||||
|
@ -159,7 +179,11 @@ static int speculation(
|
|||
{
|
||||
local = shared;
|
||||
}
|
||||
active = spec_ctx->active_id;
|
||||
if (active != spec_ctx->active_id)
|
||||
{
|
||||
active = spec_ctx->active_id;
|
||||
fprintf(stderr, "updating active_id = %d\n", active);
|
||||
}
|
||||
}
|
||||
|
||||
llama_batch_clear(batch);
|
||||
|
@ -294,6 +318,8 @@ static int target(
|
|||
break;
|
||||
}
|
||||
|
||||
fprintf(stderr, "tgt: input_seq.size() = %zu\n", input_seq.size());
|
||||
|
||||
llama_batch_clear(batch);
|
||||
for (size_t i = 0; i < input_seq.size(); i++)
|
||||
{
|
||||
|
@ -365,7 +391,8 @@ int main(int argc, char ** argv) {
|
|||
params.n_threads = params.n_threads_draft;
|
||||
}
|
||||
params.n_threads_batch = params.n_threads_batch_draft;
|
||||
|
||||
|
||||
params.cb_split_done = nullptr;
|
||||
params.rpc_servers = draft_rpcs.substr(0, i);
|
||||
std::tie(draft_models[0], draft_ctx[0]) = llama_init_from_gpt_params(params);
|
||||
params.rpc_servers = draft_rpcs.substr(i + 1);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue