llama : refactor sampling v2 (#9294)
- Add `struct llama_sampler` and `struct llama_sampler_i` - Add `llama_sampler_` API - Add `llama_sampler_chain_` API for chaining multiple samplers - Remove `LLAMA_API_INTERNAL` - Add `llama_perf_` API and remove old `llama_print_timings` and `llama_reset_timings`
This commit is contained in:
parent
947538acb8
commit
df270ef745
48 changed files with 3497 additions and 2914 deletions
|
@ -50,8 +50,8 @@ static std::vector<std::string> k_prompts = {
|
|||
|
||||
struct client {
|
||||
~client() {
|
||||
if (ctx_sampling) {
|
||||
llama_sampling_free(ctx_sampling);
|
||||
if (smpl) {
|
||||
gpt_sampler_free(smpl);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,7 +72,7 @@ struct client {
|
|||
std::string prompt;
|
||||
std::string response;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = nullptr;
|
||||
struct gpt_sampler * smpl = nullptr;
|
||||
};
|
||||
|
||||
static void print_date_time() {
|
||||
|
@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
|
|||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.ctx_sampling = llama_sampling_init(params.sparams);
|
||||
client.smpl = gpt_sampler_init(model, params.sparams);
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
|
@ -253,7 +253,7 @@ int main(int argc, char ** argv) {
|
|||
client.prompt = client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
|
||||
llama_sampling_reset(client.ctx_sampling);
|
||||
gpt_sampler_reset(client.smpl);
|
||||
|
||||
// do not prepend BOS because we have a system prompt!
|
||||
std::vector<llama_token> tokens_prompt;
|
||||
|
@ -341,9 +341,9 @@ int main(int argc, char ** argv) {
|
|||
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
||||
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
||||
|
||||
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
|
||||
const llama_token id = gpt_sampler_sample(client.smpl, ctx, client.i_batch - i);
|
||||
|
||||
llama_sampling_accept(client.ctx_sampling, ctx, id, true);
|
||||
gpt_sampler_accept(client.smpl, id, true);
|
||||
|
||||
if (client.n_decoded == 1) {
|
||||
// start measuring generation time after the first token to make sure all concurrent clients
|
||||
|
@ -371,7 +371,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
||||
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
|
||||
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
@ -413,7 +413,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG_TEE("\n");
|
||||
|
||||
llama_print_timings(ctx);
|
||||
// TODO: print sampling/grammar timings for all clients
|
||||
llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue