duo: first ~working option
This commit is contained in:
parent
2849247c4f
commit
eecdd3b0ce
4 changed files with 333 additions and 72 deletions
|
@ -1068,6 +1068,14 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
|
||||||
params.rpc_servers = argv[i];
|
params.rpc_servers = argv[i];
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
if (arg == "--rpcd") {
|
||||||
|
if (++i >= argc) {
|
||||||
|
invalid_param = true;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
params.rpc_servers_draft = argv[i];
|
||||||
|
return true;
|
||||||
|
}
|
||||||
if (arg == "--no-mmap") {
|
if (arg == "--no-mmap") {
|
||||||
params.use_mmap = false;
|
params.use_mmap = false;
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -83,6 +83,7 @@ struct gpt_params {
|
||||||
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
int32_t yarn_orig_ctx = 0; // YaRN original context length
|
||||||
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
float defrag_thold = -1.0f; // KV cache defragmentation threshold
|
||||||
std::string rpc_servers = ""; // comma separated list of RPC servers
|
std::string rpc_servers = ""; // comma separated list of RPC servers
|
||||||
|
std::string rpc_servers_draft = ""; // comma separated list of RPC servers used for draft model
|
||||||
|
|
||||||
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
ggml_backend_sched_eval_callback cb_eval = nullptr;
|
||||||
void * cb_eval_user_data = nullptr;
|
void * cb_eval_user_data = nullptr;
|
||||||
|
|
|
@ -1 +1,7 @@
|
||||||
## duo
|
## duo
|
||||||
|
|
||||||
|
Minimal example. What's not implemented, but can be implemented separately in pieces:
|
||||||
|
* tree-based speculation
|
||||||
|
* correct sampling
|
||||||
|
* support more than 2 instances
|
||||||
|
*
|
|
@ -7,54 +7,169 @@
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
|
static void dbg_color(const std::string & s, const std::string & fg)
|
||||||
|
{
|
||||||
|
static const std::string kReset = "\033[0m";
|
||||||
|
static const std::string bold[] = { "", "\033[1m" };
|
||||||
|
static size_t index = 0;
|
||||||
|
std::cout << bold[index] << fg << s << kReset << std::flush;
|
||||||
|
index = 1 - index;
|
||||||
|
}
|
||||||
|
|
||||||
|
static void dbg_accepted(const std::string & accepted)
|
||||||
|
{
|
||||||
|
static const std::string kGreen = "\033[32m";
|
||||||
|
dbg_color(accepted, kGreen);
|
||||||
|
}
|
||||||
|
|
||||||
|
static void dbg_not_matched(const std::string & accepted)
|
||||||
|
{
|
||||||
|
dbg_color(accepted, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
static void dbg_rejected(const std::string & rejected)
|
||||||
|
{
|
||||||
|
static const std::string kRed = "\033[31m";
|
||||||
|
dbg_color(rejected, kRed);
|
||||||
|
}
|
||||||
|
|
||||||
using llama_tokens = std::vector<llama_token>;
|
using llama_tokens = std::vector<llama_token>;
|
||||||
|
|
||||||
struct speculation_context
|
struct speculation_context
|
||||||
{
|
{
|
||||||
llama_tokens speculation;
|
llama_tokens candidate;
|
||||||
int32_t instance_id;
|
int32_t active_id;
|
||||||
std::mutex mtx;
|
std::mutex mtx;
|
||||||
|
bool done;
|
||||||
};
|
};
|
||||||
|
|
||||||
speculation_context spec_ctx;
|
speculation_context spec_ctx;
|
||||||
|
|
||||||
static void split_done_cb(int split)
|
static void split_done_cb(int split)
|
||||||
{
|
{
|
||||||
//fprintf(stderr, "split done: %d\n", split);
|
|
||||||
if (split == 1 || split == 2)
|
if (split == 1 || split == 2)
|
||||||
{
|
{
|
||||||
std::lock_guard<std::mutex> guard(spec_ctx.mtx);
|
std::lock_guard<std::mutex> guard(spec_ctx.mtx);
|
||||||
spec_ctx.instance_id = 3 - split;
|
spec_ctx.active_id = 2 - split;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
int main(int argc, char ** argv) {
|
// this ignores all the other sampling criteria
|
||||||
gpt_params params;
|
static std::vector<llama_token> greedy_tokens(
|
||||||
|
llama_model * model,
|
||||||
|
llama_context * ctx,
|
||||||
|
int32_t from_idx,
|
||||||
|
int32_t to_idx)
|
||||||
|
{
|
||||||
|
auto n_vocab = llama_n_vocab(model);
|
||||||
|
std::vector<llama_token> res;
|
||||||
|
if (n_vocab <= 0)
|
||||||
|
{
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
if (gpt_params_parse(argc, argv, params) == false) {
|
for (int idx = from_idx; idx < to_idx; idx++)
|
||||||
|
{
|
||||||
|
auto * logits = llama_get_logits_ith(ctx, idx);
|
||||||
|
llama_token new_token_id = 0;
|
||||||
|
for (llama_token token_id = 1; token_id < n_vocab; token_id++)
|
||||||
|
{
|
||||||
|
if (logits[token_id] > logits[new_token_id])
|
||||||
|
{
|
||||||
|
new_token_id = token_id;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
res.push_back(new_token_id);
|
||||||
|
}
|
||||||
|
return res;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int speculation(
|
||||||
|
std::vector<llama_model *> model,
|
||||||
|
speculation_context * spec_ctx,
|
||||||
|
std::vector<llama_context *> ctx,
|
||||||
|
std::vector<llama_token> input /* copy here */) {
|
||||||
|
|
||||||
|
int32_t active = 1;
|
||||||
|
|
||||||
|
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||||
|
|
||||||
|
for (size_t i = 0; i < input.size(); i++)
|
||||||
|
{
|
||||||
|
llama_batch_add(batch, input[i], i, { 0 }, false);
|
||||||
|
}
|
||||||
|
|
||||||
|
batch.logits[batch.n_tokens - 1] = true;
|
||||||
|
|
||||||
|
if (llama_decode(ctx[active], batch) != 0) {
|
||||||
|
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
int logit_idx = batch.n_tokens - 1;
|
||||||
params.seed = time(NULL);
|
std::vector<llama_token> local_spec = input;
|
||||||
|
size_t match_len;
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
auto next_tokens = greedy_tokens(model[active], ctx[active], logit_idx, logit_idx + 1);
|
||||||
|
if (next_tokens.size() != 1) {
|
||||||
|
fprintf(stderr, "invalid next tokens\n");
|
||||||
|
return 1;
|
||||||
}
|
}
|
||||||
llama_backend_init();
|
|
||||||
llama_numa_init(params.numa);
|
|
||||||
|
|
||||||
llama_model * model = nullptr;
|
local_spec.push_back(next_tokens[0]);
|
||||||
llama_context * ctx = nullptr;
|
|
||||||
params.cb_split_done = split_done_cb;
|
|
||||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
|
||||||
|
|
||||||
llama_tokens input = llama_tokenize(ctx, params.prompt, true);
|
{
|
||||||
const size_t n_input = input.size();
|
std::lock_guard<std::mutex> _lock(spec_ctx->mtx);
|
||||||
|
if (spec_ctx->done)
|
||||||
// print the prompt token-by-token
|
{
|
||||||
for (auto id : input) {
|
break;
|
||||||
fprintf(stdout, "%s", llama_token_to_piece(ctx, id).c_str());
|
}
|
||||||
|
auto& spec = spec_ctx->candidate;
|
||||||
|
bool match = true;
|
||||||
|
match_len = local_spec.size() - 1;
|
||||||
|
for (size_t i = 0; i < std::min(spec.size(), local_spec.size()); i++)
|
||||||
|
{
|
||||||
|
if (spec[i] != local_spec[i])
|
||||||
|
{
|
||||||
|
match = false;
|
||||||
|
match_len = i;
|
||||||
|
// here we need to clear both contexts
|
||||||
|
llama_kv_cache_seq_rm(ctx[0], 0, i, -1);
|
||||||
|
llama_kv_cache_seq_rm(ctx[1], 0, i, -1);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (match) {
|
||||||
|
spec = local_spec;
|
||||||
|
} else {
|
||||||
|
local_spec = spec;
|
||||||
|
}
|
||||||
|
active = spec_ctx->active_id;
|
||||||
}
|
}
|
||||||
fflush(stdout);
|
|
||||||
|
|
||||||
|
llama_batch_clear(batch);
|
||||||
|
// TODO theoretically this can be empty?
|
||||||
|
for (size_t i = match_len; i < local_spec.size(); i++) {
|
||||||
|
llama_batch_add(batch, local_spec[i], i, { 0 }, true);
|
||||||
|
}
|
||||||
|
|
||||||
|
logit_idx = batch.n_tokens - 1;
|
||||||
|
|
||||||
|
if (llama_decode(ctx[active], batch)) {
|
||||||
|
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_batch_free(batch);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
static int target(llama_model * model, llama_context * ctx, const llama_tokens& input, size_t n_predict)
|
||||||
|
{
|
||||||
|
// TODO: batch size
|
||||||
llama_batch batch = llama_batch_init(512, 0, 1);
|
llama_batch batch = llama_batch_init(512, 0, 1);
|
||||||
|
|
||||||
// evaluate the initial prompt
|
// evaluate the initial prompt
|
||||||
|
@ -69,78 +184,209 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
int n_cur = batch.n_tokens;
|
// how many tokens are currently accepted
|
||||||
int n_decode = 0;
|
// TODO: rename to n_accepted
|
||||||
|
size_t n_cur = input.size();
|
||||||
|
size_t n_decode = 0;
|
||||||
|
|
||||||
const auto t_main_start = ggml_time_us();
|
const auto t_main_start = ggml_time_us();
|
||||||
|
|
||||||
// we'll use logits from this position to determine next token
|
// we'll use logits from this position to determine next token
|
||||||
int logit_idx = batch.n_tokens - 1;
|
int logits_from = batch.n_tokens - 1;
|
||||||
|
int logits_to = batch.n_tokens;
|
||||||
|
|
||||||
while (n_decode <= params.n_predict) {
|
llama_tokens input_seq, next_tokens, output;
|
||||||
// sample the next token
|
input_seq.push_back(input.back());
|
||||||
|
|
||||||
|
while (n_decode <= n_predict)
|
||||||
{
|
{
|
||||||
auto n_vocab = llama_n_vocab(model);
|
next_tokens = greedy_tokens(model, ctx, logits_from, logits_to);
|
||||||
auto * logits = llama_get_logits_ith(ctx, logit_idx);
|
if (next_tokens.size() != input_seq.size())
|
||||||
|
{
|
||||||
std::vector<llama_token_data> candidates;
|
fprintf(stderr, "invalid next tokens\n");
|
||||||
candidates.reserve(n_vocab);
|
return 1;
|
||||||
|
|
||||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
|
||||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
size_t next_tokens_pos = n_cur;
|
||||||
|
// we always accept at least one new token
|
||||||
|
n_cur += 1;
|
||||||
|
n_decode += 1;
|
||||||
|
for (size_t i = 0; i + 1 < input_seq.size(); i++)
|
||||||
|
{
|
||||||
|
if (next_tokens[i] == input_seq[i + 1])
|
||||||
|
{
|
||||||
|
n_cur += 1;
|
||||||
|
n_decode += 1;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
// reject. next_tokens[i] is the last correct one.
|
||||||
|
next_tokens.erase(next_tokens.begin() + i + 1, next_tokens.end());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// sample the most likely token
|
// empty the non-matching portion of kv cache.
|
||||||
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
// n_cur is incremented at least once and will be > 0
|
||||||
|
llama_kv_cache_seq_rm(ctx, 0, n_cur - 1, -1);
|
||||||
|
|
||||||
// is it an end of generation?
|
bool done = false;
|
||||||
if (llama_token_is_eog(model, new_token_id) || n_decode >= params.n_predict) {
|
for (size_t i = 0; i < next_tokens.size(); i++)
|
||||||
|
{
|
||||||
|
// TODO: what should we do here, is this correct
|
||||||
|
if (next_tokens[i] == llama_token_eos(model) || llama_token_is_eog(model, next_tokens[i]))
|
||||||
|
{
|
||||||
|
done = true;
|
||||||
|
next_tokens.erase(next_tokens.begin() + i, next_tokens.end());
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
output.insert(output.end(), next_tokens.begin(), next_tokens.end());
|
||||||
|
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> _lock(spec_ctx.mtx);
|
||||||
|
auto & spec = spec_ctx.candidate;
|
||||||
|
size_t n_match = 0;
|
||||||
|
for (size_t i = 0; i < next_tokens.size() && i + next_tokens_pos < spec.size(); i++)
|
||||||
|
{
|
||||||
|
if (next_tokens[i] == spec[i + next_tokens_pos])
|
||||||
|
{
|
||||||
|
n_match++;
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string accepted = "";
|
||||||
|
for (size_t i = next_tokens_pos; i < next_tokens_pos + n_match; i++)
|
||||||
|
{
|
||||||
|
accepted += llama_token_to_piece(ctx, spec[i]);
|
||||||
|
}
|
||||||
|
dbg_accepted(accepted);
|
||||||
|
if (n_match != next_tokens.size())
|
||||||
|
{
|
||||||
|
std::string rejected = "";
|
||||||
|
for (size_t i = next_tokens_pos + n_match; i < spec.size(); i++)
|
||||||
|
{
|
||||||
|
rejected += llama_token_to_piece(ctx, spec[i]);
|
||||||
|
}
|
||||||
|
dbg_rejected(rejected);
|
||||||
|
std::string not_matched = "";
|
||||||
|
for (size_t i = n_match; i < next_tokens.size(); i++)
|
||||||
|
{
|
||||||
|
not_matched += llama_token_to_piece(ctx, next_tokens[i]);
|
||||||
|
}
|
||||||
|
dbg_not_matched(not_matched);
|
||||||
|
}
|
||||||
|
|
||||||
|
// remove non-matched tokens
|
||||||
|
if (n_match != next_tokens.size())
|
||||||
|
{
|
||||||
|
spec.erase(spec.begin() + next_tokens_pos, spec.end());
|
||||||
|
for (const auto tok: next_tokens)
|
||||||
|
{
|
||||||
|
spec.push_back(tok);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
input_seq.assign(spec.begin() + n_cur - 1, spec.end());
|
||||||
|
}
|
||||||
|
if (n_decode >= n_predict || done)
|
||||||
|
{
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
fprintf(stdout, "%s", llama_token_to_piece(ctx, new_token_id).c_str());
|
|
||||||
fflush(stdout);
|
|
||||||
|
|
||||||
// prepare the next batch
|
|
||||||
llama_batch_clear(batch);
|
llama_batch_clear(batch);
|
||||||
|
for (size_t i = 0; i < input_seq.size(); i++)
|
||||||
// push this new token for next evaluation
|
{
|
||||||
llama_batch_add(batch, new_token_id, n_cur, { 0 }, true);
|
llama_batch_add(batch, input_seq[i], n_cur - 1 + i, { 0 }, true);
|
||||||
|
|
||||||
// we still use the 'original' token to sample on next iteration
|
|
||||||
logit_idx = batch.n_tokens - 1;
|
|
||||||
|
|
||||||
n_decode += 1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
n_cur += 1;
|
|
||||||
|
|
||||||
// evaluate the current batch with the transformer model
|
|
||||||
if (llama_decode(ctx, batch)) {
|
if (llama_decode(ctx, batch)) {
|
||||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
// remove the cached entries from mock tokens
|
logits_from = 0;
|
||||||
llama_kv_cache_seq_rm(ctx, 0, n_cur, -1);
|
logits_to = input_seq.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
LOG_TEE("\n");
|
|
||||||
|
|
||||||
const auto t_main_end = ggml_time_us();
|
const auto t_main_end = ggml_time_us();
|
||||||
|
|
||||||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
LOG_TEE("%s: decoded %zu tokens in %.2f s, speed: %.2f t/s\n",
|
||||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||||
|
|
||||||
//llama_print_timings(ctx);
|
llama_print_timings(ctx);
|
||||||
|
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
|
{
|
||||||
|
std::lock_guard<std::mutex> _lock(spec_ctx.mtx);
|
||||||
|
spec_ctx.done = true;
|
||||||
|
}
|
||||||
|
|
||||||
llama_batch_free(batch);
|
llama_batch_free(batch);
|
||||||
|
return 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char ** argv) {
|
||||||
|
gpt_params params;
|
||||||
|
|
||||||
|
if (gpt_params_parse(argc, argv, params) == false) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||||
|
params.seed = time(NULL);
|
||||||
|
}
|
||||||
|
|
||||||
|
// parse 2 speculation rpc instances
|
||||||
|
std::string draft_rpcs = params.rpc_servers_draft;
|
||||||
|
size_t i = draft_rpcs.find(',');
|
||||||
|
if (i == std::string::npos || draft_rpcs.find(',', i + 1) != std::string::npos)
|
||||||
|
{
|
||||||
|
fprintf(stderr, "drpc must contain exactly two servers\n");
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
llama_backend_init();
|
||||||
|
llama_numa_init(params.numa);
|
||||||
|
|
||||||
|
// main model and context
|
||||||
|
llama_model * model = nullptr;
|
||||||
|
llama_context * ctx = nullptr;
|
||||||
|
params.cb_split_done = split_done_cb;
|
||||||
|
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||||
|
|
||||||
|
llama_tokens input = llama_tokenize(ctx, params.prompt, true);
|
||||||
|
spec_ctx.candidate = input;
|
||||||
|
|
||||||
|
// prepare draft model and contexts. No need for two model instances?
|
||||||
|
std::vector<llama_model *> draft_models = {nullptr, nullptr};
|
||||||
|
std::vector<llama_context *> draft_ctx = {nullptr, nullptr};
|
||||||
|
|
||||||
|
params.model = params.model_draft;
|
||||||
|
params.n_gpu_layers = params.n_gpu_layers_draft;
|
||||||
|
if (params.n_threads_draft > 0)
|
||||||
|
{
|
||||||
|
params.n_threads = params.n_threads_draft;
|
||||||
|
}
|
||||||
|
params.n_threads_batch = params.n_threads_batch_draft;
|
||||||
|
|
||||||
|
params.rpc_servers = draft_rpcs.substr(0, i);
|
||||||
|
std::tie(draft_models[0], draft_ctx[0]) = llama_init_from_gpt_params(params);
|
||||||
|
params.rpc_servers = draft_rpcs.substr(i + 1);
|
||||||
|
std::tie(draft_models[1], draft_ctx[1]) = llama_init_from_gpt_params(params);
|
||||||
|
std::thread spec_thread = std::thread(speculation, draft_models, &spec_ctx, draft_ctx, input);
|
||||||
|
|
||||||
|
target(model, ctx, input, params.n_predict);
|
||||||
|
|
||||||
|
spec_thread.join();
|
||||||
|
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
|
llama_free(draft_ctx[0]);
|
||||||
|
llama_free(draft_ctx[1]);
|
||||||
|
|
||||||
llama_free_model(model);
|
llama_free_model(model);
|
||||||
|
llama_free_model(draft_models[0]);
|
||||||
|
llama_free_model(draft_models[1]);
|
||||||
|
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue