llama : refactor sampling v2 (#9294)
- Add `struct llama_sampler` and `struct llama_sampler_i` - Add `llama_sampler_` API - Add `llama_sampler_chain_` API for chaining multiple samplers - Remove `LLAMA_API_INTERNAL` - Add `llama_perf_` API and remove old `llama_print_timings` and `llama_reset_timings`
This commit is contained in:
parent
947538acb8
commit
df270ef745
48 changed files with 3497 additions and 2914 deletions
|
@ -210,7 +210,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
llama_print_timings(ctx);
|
||||
LOG_TEE("\n");
|
||||
llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
|
|
|
@ -27,7 +27,6 @@ guard let model = llama_load_model_from_file(modelPath.cString(using: .utf8), mo
|
|||
print("Failed to load model")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
defer {
|
||||
llama_free_model(model)
|
||||
}
|
||||
|
@ -37,7 +36,6 @@ var tokens = tokenize(text: prompt, add_bos: true)
|
|||
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
|
||||
|
||||
var context_params = llama_context_default_params()
|
||||
context_params.seed = 1234
|
||||
context_params.n_ctx = n_kv_req
|
||||
context_params.n_batch = UInt32(max(n_len, n_parallel))
|
||||
context_params.n_threads = 8
|
||||
|
@ -48,11 +46,26 @@ guard context != nil else {
|
|||
print("Failed to initialize context")
|
||||
exit(1)
|
||||
}
|
||||
|
||||
defer {
|
||||
llama_free(context)
|
||||
}
|
||||
|
||||
var sparams = llama_sampler_chain_default_params()
|
||||
|
||||
let smpl = llama_sampler_chain_init(sparams)
|
||||
guard smpl != nil else {
|
||||
print("Failed to initialize sampling")
|
||||
exit(1)
|
||||
}
|
||||
defer {
|
||||
llama_sampler_free(smpl)
|
||||
}
|
||||
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(40));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(0.9, 1));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (0.4));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (1234));
|
||||
|
||||
let n_ctx = llama_n_ctx(context)
|
||||
|
||||
print("\nn_len = \(n_len), n_ctx = \(n_ctx), n_batch = \(context_params.n_batch), n_parallel = \(n_parallel), n_kv_req = \(n_kv_req)\n")
|
||||
|
@ -125,32 +138,9 @@ while n_cur <= n_len {
|
|||
continue
|
||||
}
|
||||
|
||||
var n_vocab = llama_n_vocab(model)
|
||||
var logits = llama_get_logits_ith(context, i_batch[i])
|
||||
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
|
||||
|
||||
var candidates: [llama_token_data] = .init(repeating: llama_token_data(), count: Int(n_vocab))
|
||||
|
||||
for token_id in 0 ..< n_vocab {
|
||||
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
|
||||
}
|
||||
|
||||
var candidates_p: llama_token_data_array = .init(
|
||||
data: &candidates,
|
||||
size: candidates.count,
|
||||
sorted: false
|
||||
)
|
||||
|
||||
let top_k: Int32 = 40
|
||||
let top_p: Float = 0.9
|
||||
let temp: Float = 0.4
|
||||
|
||||
llama_sample_top_k(context, &candidates_p, top_k, 1)
|
||||
llama_sample_top_p(context, &candidates_p, top_p, 1)
|
||||
llama_sample_temp(context, &candidates_p, temp)
|
||||
|
||||
let new_token_id = llama_sample_token(context, &candidates_p)
|
||||
|
||||
// const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
llama_sampler_accept(smpl, new_token_id)
|
||||
|
||||
// is it an end of stream? -> mark the stream as finished
|
||||
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
||||
|
@ -210,9 +200,10 @@ if n_parallel > 1 {
|
|||
|
||||
let t_main_end = ggml_time_us()
|
||||
|
||||
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n")
|
||||
print("decoded \(n_decode) tokens in \(String(format: "%.2f", Double(t_main_end - t_main_start) / 1_000_000.0)) s, speed: \(String(format: "%.2f", Double(n_decode) / (Double(t_main_end - t_main_start) / 1_000_000.0))) t/s\n\n")
|
||||
|
||||
llama_print_timings(context)
|
||||
llama_perf_print(UnsafeRawPointer(context), LLAMA_PERF_TYPE_CONTEXT)
|
||||
llama_perf_print(UnsafeRawPointer(smpl), LLAMA_PERF_TYPE_SAMPLER_CHAIN)
|
||||
|
||||
private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
|
||||
let utf8Count = text.utf8.count
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -65,6 +64,15 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_k(params.sparams.top_k));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_top_p(params.sparams.top_p, params.sparams.min_keep));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sparams.temp));
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sparams.seed));
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
return 1;
|
||||
|
@ -164,29 +172,9 @@ int main(int argc, char ** argv) {
|
|||
continue;
|
||||
}
|
||||
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
|
||||
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, i_batch[i]);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
const int top_k = 40;
|
||||
const float top_p = 0.9f;
|
||||
const float temp = 0.4f;
|
||||
|
||||
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
|
||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
||||
llama_sample_temp (ctx, &candidates_p, temp);
|
||||
|
||||
const llama_token new_token_id = llama_sample_token(ctx, &candidates_p);
|
||||
|
||||
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
llama_sampler_accept(smpl, new_token_id);
|
||||
|
||||
// is it an end of generation? -> mark the stream as finished
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
||||
|
@ -244,12 +232,15 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
llama_print_timings(ctx);
|
||||
LOG_TEE("\n");
|
||||
llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
|
||||
llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
|
|
|
@ -90,13 +90,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
print_build_info();
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
@ -313,8 +307,10 @@ int main(int argc, char ** argv) {
|
|||
if (notArray) fprintf(stdout, "\n}\n");
|
||||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
|
||||
|
||||
// clean up
|
||||
llama_print_timings(ctx);
|
||||
llama_batch_free(batch);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
|
|
@ -151,8 +151,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
print_build_info();
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
|
@ -183,7 +181,8 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
llama_print_timings(ctx);
|
||||
LOG_TEE("\n");
|
||||
llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
|
|
@ -1,9 +1,5 @@
|
|||
#define LLAMA_API_INTERNAL
|
||||
|
||||
#include "grammar-parser.h"
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "unicode.h"
|
||||
#include "llama-grammar.h"
|
||||
|
||||
#include <cstdio>
|
||||
#include <cstdlib>
|
||||
|
@ -12,29 +8,28 @@
|
|||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
static bool llama_sample_grammar_string(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
|
||||
auto decoded = decode_utf8(input_str, {});
|
||||
const auto & code_points = decoded.first;
|
||||
static bool llama_grammar_validate(struct llama_grammar * grammar, const std::string & input_str, size_t & error_pos, std::string & error_msg) {
|
||||
const auto cpts = unicode_cpts_from_utf8(input_str);
|
||||
|
||||
const llama_grammar_rules & rules = llama_grammar_get_rules (grammar);
|
||||
llama_grammar_stacks & cur_stacks = llama_grammar_get_stacks(grammar);
|
||||
llama_grammar_stacks & stacks_cur = llama_grammar_get_stacks(grammar);
|
||||
|
||||
size_t pos = 0;
|
||||
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
|
||||
const llama_grammar_stacks prev_stacks = llama_grammar_get_stacks(grammar); // copy
|
||||
for (const auto & cpt : cpts) {
|
||||
const llama_grammar_stacks stacks_prev = llama_grammar_get_stacks(grammar); // copy
|
||||
|
||||
llama_grammar_accept(rules, prev_stacks, *it, cur_stacks);
|
||||
llama_grammar_accept(rules, stacks_prev, cpt, stacks_cur);
|
||||
|
||||
if (cur_stacks.empty()) {
|
||||
if (stacks_cur.empty()) {
|
||||
error_pos = pos;
|
||||
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(*it) + "'";
|
||||
cur_stacks = prev_stacks;
|
||||
error_msg = "Unexpected character '" + unicode_cpt_to_utf8(cpt) + "'";
|
||||
stacks_cur = stacks_prev;
|
||||
return false;
|
||||
}
|
||||
++pos;
|
||||
}
|
||||
|
||||
for (const auto & stack : cur_stacks) {
|
||||
for (const auto & stack : stacks_cur) {
|
||||
if (stack.empty()) {
|
||||
return true;
|
||||
}
|
||||
|
@ -85,27 +80,7 @@ int main(int argc, char** argv) {
|
|||
grammar_str = buffer.str();
|
||||
}
|
||||
|
||||
// Parse the GBNF grammar
|
||||
auto parsed_grammar = grammar_parser::parse(grammar_str.c_str());
|
||||
|
||||
// will be empty (default) if there are parse errors
|
||||
if (parsed_grammar.rules.empty()) {
|
||||
fprintf(stdout, "%s: failed to parse grammar\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// Ensure that there is a "root" node.
|
||||
if (parsed_grammar.symbol_ids.find("root") == parsed_grammar.symbol_ids.end()) {
|
||||
fprintf(stdout, "%s: grammar does not contain a 'root' symbol\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
std::vector<const llama_grammar_element *> grammar_rules(parsed_grammar.c_rules());
|
||||
|
||||
// Create the LLAMA grammar
|
||||
auto grammar = llama_grammar_init(
|
||||
grammar_rules.data(),
|
||||
grammar_rules.size(), parsed_grammar.symbol_ids.at("root"));
|
||||
llama_grammar * grammar = llama_grammar_init_impl(nullptr, grammar_str.c_str(), "root");
|
||||
if (grammar == nullptr) {
|
||||
throw std::runtime_error("Failed to initialize llama_grammar");
|
||||
}
|
||||
|
@ -122,7 +97,7 @@ int main(int argc, char** argv) {
|
|||
// Validate the input string against the grammar
|
||||
size_t error_pos;
|
||||
std::string error_msg;
|
||||
bool is_valid = llama_sample_grammar_string(grammar, input_str, error_pos, error_msg);
|
||||
bool is_valid = llama_grammar_validate(grammar, input_str, error_pos, error_msg);
|
||||
|
||||
if (is_valid) {
|
||||
fprintf(stdout, "Input string is valid according to the grammar.\n");
|
||||
|
@ -131,7 +106,7 @@ int main(int argc, char** argv) {
|
|||
}
|
||||
|
||||
// Clean up
|
||||
llama_grammar_free(grammar);
|
||||
llama_grammar_free_impl(grammar);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@
|
|||
static std::vector<std::vector<float>> encode(llama_context * ctx, const std::vector<std::string> & sentences, const std::string & instruction) {
|
||||
std::vector<std::vector<float>> result;
|
||||
|
||||
const llama_model * mdl = llama_get_model(ctx);
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
|
||||
llama_batch batch = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||
|
||||
|
@ -18,16 +18,16 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||
|
||||
const std::string input_string = instruction + sentences[i];
|
||||
|
||||
std::vector<llama_token> inputs = llama_tokenize(mdl, input_string, true, false);
|
||||
std::vector<llama_token> inputs = llama_tokenize(model, input_string, true, false);
|
||||
|
||||
const int32_t n_toks = inputs.size();
|
||||
|
||||
// GritLM seems to have EOS = ""
|
||||
// https://github.com/ContextualAI/gritlm/blob/92025b16534712b31b3c4aaaf069350e222bd5f8/gritlm/gritlm.py#L18
|
||||
// inputs.push_back(llama_token_eos(mdl));
|
||||
// inputs.push_back(llama_token_eos(model));
|
||||
|
||||
// we want to ignore instruction tokens for mean pooling
|
||||
const int32_t n_inst = llama_tokenize(mdl, instruction, true, false).size();
|
||||
const int32_t n_inst = llama_tokenize(model, instruction, true, false).size();
|
||||
|
||||
#ifdef GRIT_DEBUG
|
||||
// debug tokens - should be matching as referenced in the GritLM sample
|
||||
|
@ -51,7 +51,7 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||
llama_decode(ctx, batch);
|
||||
|
||||
// get embedding dimensions
|
||||
uint64_t n_embd = llama_n_embd(mdl);
|
||||
uint64_t n_embd = llama_n_embd(model);
|
||||
|
||||
// allocate embedding output
|
||||
std::vector<float> emb_unorm(n_embd, 0.0f);
|
||||
|
@ -92,11 +92,11 @@ static std::vector<std::vector<float>> encode(llama_context * ctx, const std::ve
|
|||
return result;
|
||||
}
|
||||
|
||||
static std::string generate(llama_context * ctx, const std::string & prompt, bool stream) {
|
||||
static std::string generate(llama_context * ctx, llama_sampler * smpl, const std::string & prompt, bool stream) {
|
||||
std::string result;
|
||||
|
||||
const llama_model * mdl = llama_get_model(ctx);
|
||||
llama_token eos_token = llama_token_eos(mdl);
|
||||
const llama_model * model = llama_get_model(ctx);
|
||||
llama_token eos_token = llama_token_eos(model);
|
||||
|
||||
llama_kv_cache_clear(ctx);
|
||||
llama_set_embeddings(ctx, false);
|
||||
|
@ -104,28 +104,25 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo
|
|||
|
||||
llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1);
|
||||
|
||||
std::vector<llama_token> inputs = llama_tokenize(mdl, prompt, false, true);
|
||||
std::vector<llama_token> inputs = llama_tokenize(model, prompt, false, true);
|
||||
int32_t i_current_token = 0;
|
||||
|
||||
while (true) {
|
||||
llama_batch_clear(bat);
|
||||
auto n_inputs = (int32_t)inputs.size();
|
||||
for (int32_t i = 0; i < n_inputs; i++) {
|
||||
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
|
||||
{
|
||||
const int32_t n_inputs = inputs.size();
|
||||
|
||||
for (int32_t i = 0; i < n_inputs; i++) {
|
||||
llama_batch_add(bat, inputs[i], i_current_token++, { 0 }, i == n_inputs - 1);
|
||||
}
|
||||
}
|
||||
inputs.clear();
|
||||
|
||||
llama_decode(ctx, bat);
|
||||
auto logits = llama_get_logits_ith(ctx, bat.n_tokens - 1);
|
||||
|
||||
auto candidates = std::vector<llama_token_data>(llama_n_vocab(mdl));
|
||||
auto n_candidates = (int32_t)candidates.size();
|
||||
for (int32_t token = 0; token < n_candidates; token++) {
|
||||
candidates[token] = llama_token_data{ token, logits[token], 0.0f };
|
||||
}
|
||||
auto candidates_p = llama_token_data_array{ candidates.data(), candidates.size(), false };
|
||||
llama_token token = llama_sampler_sample(smpl, ctx, bat.n_tokens - 1);
|
||||
llama_sampler_accept(smpl, token);
|
||||
|
||||
llama_token token = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
if (token == eos_token) {
|
||||
break;
|
||||
}
|
||||
|
@ -167,10 +164,18 @@ int main(int argc, char * argv[]) {
|
|||
|
||||
llama_backend_init();
|
||||
|
||||
llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||
llama_model * model = llama_load_model_from_file(params.model.c_str(), mparams);
|
||||
|
||||
// create generation context
|
||||
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
|
||||
llama_context * ctx = llama_new_context_with_model(model, cparams);
|
||||
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
|
||||
sparams.no_perf = false;
|
||||
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
||||
|
||||
// ### Embedding/Representation ###
|
||||
// samples taken from: https://github.com/ContextualAI/gritlm#basic
|
||||
|
@ -191,7 +196,7 @@ int main(int argc, char * argv[]) {
|
|||
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
|
||||
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
|
||||
|
||||
const int n_embd = llama_n_embd(mdl);
|
||||
const int n_embd = llama_n_embd(model);
|
||||
|
||||
const float cosine_sim_q0_d0 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[0].data(), n_embd);
|
||||
const float cosine_sim_q0_d1 = llama_embd_similarity_cos(q_rep[0].data(), d_rep[1].data(), n_embd);
|
||||
|
@ -208,11 +213,12 @@ int main(int argc, char * argv[]) {
|
|||
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
|
||||
{
|
||||
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
|
||||
std::string response = generate(ctx, prompt, true);
|
||||
std::string response = generate(ctx, smpl, prompt, true);
|
||||
}
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
llama_free(ctx);
|
||||
llama_free_model(mdl);
|
||||
llama_free_model(model);
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
|
|
|
@ -638,7 +638,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
g_collector.save_imatrix();
|
||||
|
||||
llama_print_timings(ctx);
|
||||
LOG_TEE("\n");
|
||||
llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
|
||||
#include "console.h"
|
||||
#include "llama.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
#include <cassert>
|
||||
#include <cinttypes>
|
||||
|
@ -34,6 +33,7 @@
|
|||
|
||||
static llama_context ** g_ctx;
|
||||
static llama_model ** g_model;
|
||||
static gpt_sampler ** g_smpl;
|
||||
static gpt_params * g_params;
|
||||
static std::vector<llama_token> * g_input_tokens;
|
||||
static std::ostringstream * g_output_ss;
|
||||
|
@ -81,7 +81,7 @@ static void write_logfile(
|
|||
yaml_dump_string_multiline(logfile, "output", output.c_str());
|
||||
yaml_dump_vector_int(logfile, "output_tokens", output_tokens);
|
||||
|
||||
llama_dump_timing_info_yaml(logfile, ctx);
|
||||
llama_perf_dump_yaml(logfile, ctx);
|
||||
fclose(logfile);
|
||||
}
|
||||
|
||||
|
@ -93,7 +93,7 @@ static void sigint_handler(int signo) {
|
|||
} else {
|
||||
console::cleanup();
|
||||
printf("\n");
|
||||
llama_print_timings(*g_ctx);
|
||||
gpt_perf_print(*g_ctx, *g_smpl);
|
||||
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
||||
_exit(130);
|
||||
}
|
||||
|
@ -103,7 +103,6 @@ static void sigint_handler(int signo) {
|
|||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
llama_sampling_params & sparams = params.sparams;
|
||||
g_params = ¶ms;
|
||||
|
||||
if (!gpt_params_parse(argc, argv, params)) {
|
||||
|
@ -111,6 +110,8 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
auto & sparams = params.sparams;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("infill", "log"));
|
||||
LOG_TEE("Log start\n");
|
||||
|
@ -156,26 +157,21 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
|
||||
}
|
||||
|
||||
LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
|
||||
LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
|
||||
print_build_info();
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
|
||||
LOG_TEE("%s: seed = %u\n", __func__, params.seed);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
|
||||
|
||||
LOG("%s: llama backend init\n", __func__);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
gpt_sampler * smpl = nullptr;
|
||||
|
||||
g_model = &model;
|
||||
g_ctx = &ctx;
|
||||
g_smpl = &smpl;
|
||||
|
||||
// load the model and apply lora adapter, if any
|
||||
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||
|
@ -305,7 +301,7 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("Input suffix: '%s'\n", params.input_suffix.c_str());
|
||||
}
|
||||
}
|
||||
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
|
||||
LOG_TEE("sampling: \n%s\n", sparams.print().c_str());
|
||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||
LOG_TEE("\n\n");
|
||||
|
||||
|
@ -349,7 +345,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
std::vector<llama_token> embd;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||
smpl = gpt_sampler_init(model, sparams);
|
||||
|
||||
while (n_remain != 0 || params.interactive) {
|
||||
// predict
|
||||
|
@ -421,11 +417,11 @@ int main(int argc, char ** argv) {
|
|||
embd.clear();
|
||||
|
||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, nullptr);
|
||||
const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||
gpt_sampler_accept(smpl, id, true);
|
||||
|
||||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
||||
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
|
||||
|
||||
embd.push_back(id);
|
||||
|
||||
|
@ -444,7 +440,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||
// for the prompt, we don't apply grammar rules
|
||||
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], false);
|
||||
gpt_sampler_accept(smpl, embd_inp[n_consumed], false);
|
||||
|
||||
++n_consumed;
|
||||
if ((int) embd.size() >= params.n_batch) {
|
||||
|
@ -476,7 +472,7 @@ int main(int argc, char ** argv) {
|
|||
// if not currently processing queued inputs;
|
||||
if ((int) embd_inp.size() <= n_consumed) {
|
||||
// deal with eot token in infill mode
|
||||
if ((llama_sampling_last(ctx_sampling) == llama_token_eot(model) || is_interacting) && params.interactive){
|
||||
if ((gpt_sampler_last(smpl) == llama_token_eot(model) || is_interacting) && params.interactive){
|
||||
if (is_interacting && !params.interactive_first) {
|
||||
// print an eot token
|
||||
printf("%s", llama_token_to_piece(ctx, llama_token_eot(model)).c_str());
|
||||
|
@ -542,7 +538,7 @@ int main(int argc, char ** argv) {
|
|||
is_interacting = false;
|
||||
}
|
||||
// deal with end of generation tokens in interactive mode
|
||||
else if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
|
||||
else if (llama_token_is_eog(model, gpt_sampler_last(smpl))) {
|
||||
LOG("found EOS token\n");
|
||||
|
||||
if (params.interactive) {
|
||||
|
@ -615,7 +611,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (n_past > 0) {
|
||||
if (is_interacting) {
|
||||
llama_sampling_reset(ctx_sampling);
|
||||
gpt_sampler_reset(smpl);
|
||||
}
|
||||
is_interacting = false;
|
||||
}
|
||||
|
@ -638,13 +634,14 @@ int main(int argc, char ** argv) {
|
|||
fflush(stdout);
|
||||
}
|
||||
|
||||
llama_print_timings(ctx);
|
||||
LOG_TEE("\n");
|
||||
gpt_perf_print(ctx, smpl);
|
||||
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
llama_sampling_free(ctx_sampling);
|
||||
gpt_sampler_free(smpl);
|
||||
llama_backend_free();
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
|
|
|
@ -1630,7 +1630,7 @@ int main(int argc, char ** argv) {
|
|||
fflush(p_err->fout);
|
||||
}
|
||||
|
||||
llama_print_timings(ctx);
|
||||
llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
|
||||
|
||||
llama_free(ctx);
|
||||
|
||||
|
|
|
@ -120,8 +120,8 @@ Java_android_llama_cpp_LLamaAndroid_new_1context(JNIEnv *env, jobject, jlong jmo
|
|||
LOGi("Using %d threads", n_threads);
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
ctx_params.seed = 1234;
|
||||
ctx_params.n_ctx = 2048;
|
||||
|
||||
ctx_params.n_ctx = 2048;
|
||||
ctx_params.n_threads = n_threads;
|
||||
ctx_params.n_threads_batch = n_threads;
|
||||
|
||||
|
@ -380,11 +380,13 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||
JNIEnv * env,
|
||||
jobject,
|
||||
jlong context_pointer,
|
||||
jlong sampling_pointer,
|
||||
jlong batch_pointer,
|
||||
jint n_len,
|
||||
jobject intvar_ncur
|
||||
) {
|
||||
const auto context = reinterpret_cast<llama_context *>(context_pointer);
|
||||
const auto sampling = reinterpret_cast<llama_sampler *>(sampling_pointer);
|
||||
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
|
||||
const auto model = llama_get_model(context);
|
||||
|
||||
|
@ -392,20 +394,10 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
|
|||
if (!la_int_var_value) la_int_var_value = env->GetMethodID(la_int_var, "getValue", "()I");
|
||||
if (!la_int_var_inc) la_int_var_inc = env->GetMethodID(la_int_var, "inc", "()V");
|
||||
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
auto logits = llama_get_logits_ith(context, batch->n_tokens - 1);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
// sample the most likely token
|
||||
const auto new_token_id = llama_sample_token_greedy(context, &candidates_p);
|
||||
const auto new_token_id = llama_sampler_sample(sampling, context, batch->n_tokens - 1);
|
||||
|
||||
llama_sampler_accept(sampling, new_token_id);
|
||||
|
||||
const auto n_cur = env->CallIntMethod(intvar_ncur, la_int_var_value);
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
||||
|
|
|
@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
|
|||
actor LlamaContext {
|
||||
private var model: OpaquePointer
|
||||
private var context: OpaquePointer
|
||||
private var sampling: UnsafeMutablePointer<llama_sampler>
|
||||
private var batch: llama_batch
|
||||
private var tokens_list: [llama_token]
|
||||
var is_done: Bool = false
|
||||
|
@ -42,9 +43,15 @@ actor LlamaContext {
|
|||
self.tokens_list = []
|
||||
self.batch = llama_batch_init(512, 0, 1)
|
||||
self.temporary_invalid_cchars = []
|
||||
let sparams = llama_sampler_chain_default_params()
|
||||
self.sampling = llama_sampler_chain_init(sparams)
|
||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
|
||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
|
||||
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
|
||||
}
|
||||
|
||||
deinit {
|
||||
llama_sampler_free(sampling)
|
||||
llama_batch_free(batch)
|
||||
llama_free(context)
|
||||
llama_free_model(model)
|
||||
|
@ -69,7 +76,6 @@ actor LlamaContext {
|
|||
print("Using \(n_threads) threads")
|
||||
|
||||
var ctx_params = llama_context_default_params()
|
||||
ctx_params.seed = 1234
|
||||
ctx_params.n_ctx = 2048
|
||||
ctx_params.n_threads = Int32(n_threads)
|
||||
ctx_params.n_threads_batch = Int32(n_threads)
|
||||
|
@ -144,20 +150,9 @@ actor LlamaContext {
|
|||
func completion_loop() -> String {
|
||||
var new_token_id: llama_token = 0
|
||||
|
||||
let n_vocab = llama_n_vocab(model)
|
||||
let logits = llama_get_logits_ith(context, batch.n_tokens - 1)
|
||||
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
|
||||
|
||||
var candidates = Array<llama_token_data>()
|
||||
candidates.reserveCapacity(Int(n_vocab))
|
||||
|
||||
for token_id in 0..<n_vocab {
|
||||
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
|
||||
}
|
||||
candidates.withUnsafeMutableBufferPointer() { buffer in
|
||||
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
|
||||
|
||||
new_token_id = llama_sample_token_greedy(context, &candidates_p)
|
||||
}
|
||||
llama_sampler_accept(sampling, new_token_id)
|
||||
|
||||
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
|
||||
print("\n")
|
||||
|
|
|
@ -40,11 +40,11 @@ static bool eval_string(struct llama_context * ctx_llama, const char* str, int n
|
|||
return true;
|
||||
}
|
||||
|
||||
static const char * sample(struct llama_sampling_context * ctx_sampling,
|
||||
static const char * sample(struct gpt_sampler * smpl,
|
||||
struct llama_context * ctx_llama,
|
||||
int * n_past) {
|
||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
|
||||
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
|
||||
const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1);
|
||||
gpt_sampler_accept(smpl, id, true);
|
||||
static std::string ret;
|
||||
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
||||
ret = "</s>";
|
||||
|
@ -191,15 +191,15 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
|||
|
||||
LOG_TEE("\n");
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
|
||||
if (!ctx_sampling) {
|
||||
struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
|
||||
if (!smpl) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
std::string response = "";
|
||||
for (int i = 0; i < max_tgt_len; i++) {
|
||||
const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
|
||||
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
|
||||
response += tmp;
|
||||
if (strcmp(tmp, "</s>") == 0) break;
|
||||
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
||||
|
@ -211,7 +211,7 @@ static void process_prompt(struct llava_context * ctx_llava, struct llava_image_
|
|||
fflush(stdout);
|
||||
}
|
||||
|
||||
llama_sampling_free(ctx_sampling);
|
||||
gpt_sampler_free(smpl);
|
||||
printf("\n");
|
||||
}
|
||||
|
||||
|
@ -310,7 +310,7 @@ int main(int argc, char ** argv) {
|
|||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_print_timings(ctx_llava->ctx_llama);
|
||||
llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
|
@ -327,7 +327,7 @@ int main(int argc, char ** argv) {
|
|||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt);
|
||||
|
||||
llama_print_timings(ctx_llava->ctx_llama);
|
||||
llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
|
||||
llava_image_embed_free(image_embed);
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
|
|
|
@ -163,11 +163,11 @@ static void process_image(struct llava_context * ctx_llava, struct llava_image_e
|
|||
LOG_TEE("%s: image token past: %d\n", __func__, n_past);
|
||||
}
|
||||
|
||||
static const char * sample(struct llama_sampling_context * ctx_sampling,
|
||||
static const char * sample(struct gpt_sampler * smpl,
|
||||
struct llama_context * ctx_llama,
|
||||
int * n_past) {
|
||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx_llama, NULL);
|
||||
llama_sampling_accept(ctx_sampling, ctx_llama, id, true);
|
||||
const llama_token id = gpt_sampler_sample(smpl, ctx_llama, -1);
|
||||
gpt_sampler_accept(smpl, id, true);
|
||||
static std::string ret;
|
||||
if (llama_token_is_eog(llama_get_model(ctx_llama), id)) {
|
||||
ret = "</s>";
|
||||
|
@ -214,7 +214,7 @@ static struct llava_context * minicpmv_init(gpt_params * params, const std::stri
|
|||
return ctx_llava;
|
||||
}
|
||||
|
||||
static struct llama_sampling_context * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
|
||||
static struct gpt_sampler * llama_init(struct llava_context * ctx_llava, gpt_params * params, std::string prompt, int &n_past, bool is_first = false){
|
||||
std::string user_prompt = prompt;
|
||||
int has_minicpmv_projector = clip_is_minicpmv(ctx_llava->ctx_clip);
|
||||
if (!is_first) {
|
||||
|
@ -238,13 +238,13 @@ static struct llama_sampling_context * llama_init(struct llava_context * ctx_lla
|
|||
|
||||
LOG_TEE("\n");
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params->sparams);
|
||||
return ctx_sampling;
|
||||
struct gpt_sampler * smpl = gpt_sampler_init(ctx_llava->model, params->sparams);
|
||||
return smpl;
|
||||
}
|
||||
|
||||
static const char * llama_loop(struct llava_context * ctx_llava,struct llama_sampling_context * ctx_sampling, int &n_past){
|
||||
static const char * llama_loop(struct llava_context * ctx_llava,struct gpt_sampler * smpl, int &n_past){
|
||||
|
||||
const char * tmp = sample(ctx_sampling, ctx_llava->ctx_llama, &n_past);
|
||||
const char * tmp = sample(smpl, ctx_llava->ctx_llama, &n_past);
|
||||
return tmp;
|
||||
}
|
||||
|
||||
|
@ -278,12 +278,12 @@ int main(int argc, char ** argv) {
|
|||
if (!params.prompt.empty()) {
|
||||
LOG_TEE("<user>%s\n", params.prompt.c_str());
|
||||
LOG_TEE("<assistant>");
|
||||
auto ctx_sampling = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true);
|
||||
auto smpl = llama_init(ctx_llava, ¶ms, params.prompt.c_str(), n_past, true);
|
||||
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
||||
std::string response = "";
|
||||
bool have_tmp = false;
|
||||
for (int i = 0; i < max_tgt_len; i++) {
|
||||
auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
|
||||
auto tmp = llama_loop(ctx_llava, smpl, n_past);
|
||||
response += tmp;
|
||||
if (strcmp(tmp, "</s>") == 0){
|
||||
if(!have_tmp)continue;
|
||||
|
@ -296,18 +296,18 @@ int main(int argc, char ** argv) {
|
|||
|
||||
fflush(stdout);
|
||||
}
|
||||
llama_sampling_free(ctx_sampling);
|
||||
gpt_sampler_free(smpl);
|
||||
}else {
|
||||
while (true) {
|
||||
LOG_TEE("<user>");
|
||||
std::string prompt;
|
||||
std::getline(std::cin, prompt);
|
||||
LOG_TEE("<assistant>");
|
||||
auto ctx_sampling = llama_init(ctx_llava, ¶ms, prompt, n_past, true);
|
||||
auto smpl = llama_init(ctx_llava, ¶ms, prompt, n_past, true);
|
||||
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
||||
std::string response = "";
|
||||
for (int i = 0; i < max_tgt_len; i++) {
|
||||
auto tmp = llama_loop(ctx_llava, ctx_sampling, n_past);
|
||||
auto tmp = llama_loop(ctx_llava, smpl, n_past);
|
||||
response += tmp;
|
||||
if (strcmp(tmp, "</s>") == 0) break;
|
||||
if (strstr(tmp, "###")) break; // Yi-VL behavior
|
||||
|
@ -315,11 +315,11 @@ int main(int argc, char ** argv) {
|
|||
if (strstr(response.c_str(), "<user>")) break; // minicpm-v
|
||||
fflush(stdout);
|
||||
}
|
||||
llama_sampling_free(ctx_sampling);
|
||||
gpt_sampler_free(smpl);
|
||||
}
|
||||
}
|
||||
printf("\n");
|
||||
llama_print_timings(ctx_llava->ctx_llama);
|
||||
llama_perf_print(ctx_llava->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
|
||||
|
||||
ctx_llava->model = NULL;
|
||||
llava_free(ctx_llava);
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
@ -118,7 +117,7 @@ int main(int argc, char ** argv) {
|
|||
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
|
||||
|
||||
// target model sampling context
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams);
|
||||
|
||||
// verification n-grams
|
||||
std::vector<ngram_data> ngrams_cur(G);
|
||||
|
@ -159,9 +158,9 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// sample first token
|
||||
{
|
||||
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
|
||||
id = gpt_sampler_sample(smpl, ctx, 0);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||
gpt_sampler_accept(smpl, id, true);
|
||||
|
||||
{
|
||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||
|
@ -284,9 +283,9 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// sample the next token
|
||||
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
|
||||
id = gpt_sampler_sample(smpl, ctx, i_batch);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||
gpt_sampler_accept(smpl, id, true);
|
||||
|
||||
// print
|
||||
{
|
||||
|
@ -361,7 +360,7 @@ int main(int argc, char ** argv) {
|
|||
if (v == 0) {
|
||||
// sample from the last level
|
||||
for (int i = 0; i < W; i++) {
|
||||
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
||||
tokens_j[N - 2][i] = gpt_sampler_sample(smpl, ctx, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < W; i++) {
|
||||
|
@ -468,10 +467,12 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("n_predict = %d\n", n_predict);
|
||||
LOG_TEE("n_accept = %d\n", n_accept);
|
||||
|
||||
llama_print_timings(ctx);
|
||||
LOG_TEE("\n");
|
||||
gpt_perf_print(ctx, smpl);
|
||||
|
||||
gpt_sampler_free(smpl);
|
||||
|
||||
llama_kv_cache_view_free(&kvc_view);
|
||||
llama_sampling_free(ctx_sampling);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
|
|
|
@ -3,13 +3,11 @@
|
|||
#include "common.h"
|
||||
#include "ngram-cache.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdint>
|
||||
#include <cstdio>
|
||||
#include <fstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <unordered_map>
|
||||
|
||||
int main(int argc, char ** argv){
|
||||
gpt_params params;
|
||||
|
@ -106,7 +104,7 @@ int main(int argc, char ** argv){
|
|||
|
||||
bool has_eos = false;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
struct gpt_sampler * smpl = gpt_sampler_init(model, params.sparams);
|
||||
|
||||
std::vector<llama_token> draft;
|
||||
|
||||
|
@ -130,9 +128,9 @@ int main(int argc, char ** argv){
|
|||
int i_dft = 0;
|
||||
while (true) {
|
||||
// sample from the target model
|
||||
llama_token id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_dft);
|
||||
llama_token id = gpt_sampler_sample(smpl, ctx, i_dft);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, true);
|
||||
gpt_sampler_accept(smpl, id, true);
|
||||
|
||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||
|
||||
|
@ -240,10 +238,12 @@ int main(int argc, char ** argv){
|
|||
LOG_TEE("n_accept = %d\n", n_accept);
|
||||
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
||||
|
||||
LOG_TEE("\ntarget:\n");
|
||||
llama_print_timings(ctx);
|
||||
LOG_TEE("\ntarget:\n\n");
|
||||
llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
|
||||
llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
|
||||
|
||||
gpt_sampler_free(smpl);
|
||||
|
||||
llama_sampling_free(ctx_sampling);
|
||||
llama_batch_free(batch_tgt);
|
||||
|
||||
llama_free(ctx);
|
||||
|
|
|
@ -33,6 +33,7 @@
|
|||
|
||||
static llama_context ** g_ctx;
|
||||
static llama_model ** g_model;
|
||||
static gpt_sampler ** g_smpl;
|
||||
static gpt_params * g_params;
|
||||
static std::vector<llama_token> * g_input_tokens;
|
||||
static std::ostringstream * g_output_ss;
|
||||
|
@ -92,7 +93,7 @@ static void write_logfile(
|
|||
yaml_dump_string_multiline(logfile, "output", output.c_str());
|
||||
yaml_dump_vector_int(logfile, "output_tokens", output_tokens);
|
||||
|
||||
llama_dump_timing_info_yaml(logfile, ctx);
|
||||
llama_perf_dump_yaml(logfile, ctx);
|
||||
fclose(logfile);
|
||||
}
|
||||
|
||||
|
@ -105,7 +106,7 @@ static void sigint_handler(int signo) {
|
|||
} else {
|
||||
console::cleanup();
|
||||
printf("\n");
|
||||
llama_print_timings(*g_ctx);
|
||||
gpt_perf_print(*g_ctx, *g_smpl);
|
||||
write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens);
|
||||
_exit(130);
|
||||
}
|
||||
|
@ -121,8 +122,7 @@ static void llama_log_callback_logTee(ggml_log_level level, const char * text, v
|
|||
|
||||
static std::string chat_add_and_format(struct llama_model * model, std::vector<llama_chat_msg> & chat_msgs, std::string role, std::string content) {
|
||||
llama_chat_msg new_msg{role, content};
|
||||
auto formatted = llama_chat_format_single(
|
||||
model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
||||
auto formatted = llama_chat_format_single(model, g_params->chat_template, chat_msgs, new_msg, role == "user");
|
||||
chat_msgs.push_back({role, content});
|
||||
LOG("formatted: %s\n", formatted.c_str());
|
||||
return formatted;
|
||||
|
@ -137,7 +137,7 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
llama_sampling_params & sparams = params.sparams;
|
||||
auto & sparams = params.sparams;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("main", "log"));
|
||||
|
@ -183,27 +183,23 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("%s: warning: scaling RoPE frequency by %g.\n", __func__, params.rope_freq_scale);
|
||||
}
|
||||
|
||||
LOG_TEE("%s: build = %d (%s)\n", __func__, LLAMA_BUILD_NUMBER, LLAMA_COMMIT);
|
||||
LOG_TEE("%s: built with %s for %s\n", __func__, LLAMA_COMPILER, LLAMA_BUILD_TARGET);
|
||||
print_build_info();
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
|
||||
LOG_TEE("%s: seed = %u\n", __func__, params.seed);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
|
||||
|
||||
LOG("%s: llama backend init\n", __func__);
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
||||
llama_model * model;
|
||||
llama_context * ctx;
|
||||
llama_context * ctx_guidance = NULL;
|
||||
llama_model * model = nullptr;
|
||||
llama_context * ctx = nullptr;
|
||||
gpt_sampler * smpl = nullptr;
|
||||
|
||||
std::vector<llama_chat_msg> chat_msgs;
|
||||
|
||||
g_model = &model;
|
||||
g_ctx = &ctx;
|
||||
g_smpl = &smpl;
|
||||
|
||||
// load the model and apply lora adapter, if any
|
||||
LOG("%s: load the model and apply lora adapter, if any\n", __func__);
|
||||
|
@ -211,10 +207,6 @@ int main(int argc, char ** argv) {
|
|||
|
||||
model = llama_init.model;
|
||||
ctx = llama_init.context;
|
||||
if (sparams.cfg_scale > 1.f) {
|
||||
struct llama_context_params lparams = llama_context_params_from_gpt_params(params);
|
||||
ctx_guidance = llama_new_context_with_model(model, lparams);
|
||||
}
|
||||
|
||||
if (model == NULL) {
|
||||
LOG_TEE("%s: error: unable to load model\n", __func__);
|
||||
|
@ -251,9 +243,6 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
llama_attach_threadpool(ctx, threadpool, threadpool_batch);
|
||||
if (ctx_guidance) {
|
||||
llama_attach_threadpool(ctx_guidance, threadpool, threadpool_batch);
|
||||
}
|
||||
|
||||
const int n_ctx_train = llama_n_ctx_train(model);
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
|
@ -337,24 +326,6 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// Tokenize negative prompt
|
||||
std::vector<llama_token> guidance_inp;
|
||||
int guidance_offset = 0;
|
||||
int original_prompt_len = 0;
|
||||
if (ctx_guidance) {
|
||||
LOG("cfg_negative_prompt: \"%s\"\n", log_tostr(sparams.cfg_negative_prompt));
|
||||
|
||||
guidance_inp = ::llama_tokenize(ctx_guidance, sparams.cfg_negative_prompt, true, true);
|
||||
LOG("guidance_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_guidance, guidance_inp).c_str());
|
||||
|
||||
std::vector<llama_token> original_inp = ::llama_tokenize(ctx, params.prompt, true, true);
|
||||
LOG("original_inp tokenized: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, original_inp).c_str());
|
||||
|
||||
original_prompt_len = original_inp.size();
|
||||
guidance_offset = (int)guidance_inp.size() - original_prompt_len;
|
||||
LOG("original_prompt_len: %s", log_tostr(original_prompt_len));
|
||||
LOG("guidance_offset: %s", log_tostr(guidance_offset));
|
||||
}
|
||||
|
||||
if ((int) embd_inp.size() > n_ctx - 4) {
|
||||
LOG_TEE("%s: error: prompt is too long (%d tokens, max %d)\n", __func__, (int) embd_inp.size(), n_ctx - 4);
|
||||
return 1;
|
||||
|
@ -421,15 +392,6 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
|
||||
}
|
||||
|
||||
if (ctx_guidance) {
|
||||
LOG_TEE("\n");
|
||||
LOG_TEE("%s: negative prompt: '%s'\n", __func__, sparams.cfg_negative_prompt.c_str());
|
||||
LOG_TEE("%s: number of tokens in negative prompt = %zu\n", __func__, guidance_inp.size());
|
||||
for (int i = 0; i < (int) guidance_inp.size(); i++) {
|
||||
LOG_TEE("%6d -> '%s'\n", guidance_inp[i], llama_token_to_piece(ctx, guidance_inp[i]).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
if (params.n_keep > add_bos) {
|
||||
LOG_TEE("%s: static prompt based on n_keep: '", __func__);
|
||||
for (int i = 0; i < params.n_keep; i++) {
|
||||
|
@ -495,8 +457,15 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
}
|
||||
LOG_TEE("sampling: \n%s\n", llama_sampling_print(sparams).c_str());
|
||||
LOG_TEE("sampling order: \n%s\n", llama_sampling_order_print(sparams).c_str());
|
||||
|
||||
smpl = gpt_sampler_init(model, sparams);
|
||||
if (!smpl) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
LOG_TEE("sampling params: \n%s\n", sparams.print().c_str());
|
||||
LOG_TEE(" sampler constr: \n%s\n", gpt_sampler_print(smpl).c_str());
|
||||
LOG_TEE("generate: n_ctx = %d, n_batch = %d, n_predict = %d, n_keep = %d\n", n_ctx, params.n_batch, params.n_predict, params.n_keep);
|
||||
|
||||
// group-attention state
|
||||
|
@ -543,7 +512,6 @@ int main(int argc, char ** argv) {
|
|||
int n_remain = params.n_predict;
|
||||
int n_consumed = 0;
|
||||
int n_session_consumed = 0;
|
||||
int n_past_guidance = 0;
|
||||
|
||||
std::vector<int> input_tokens; g_input_tokens = &input_tokens;
|
||||
std::vector<int> output_tokens; g_output_tokens = &output_tokens;
|
||||
|
@ -555,7 +523,6 @@ int main(int argc, char ** argv) {
|
|||
display = params.display_prompt;
|
||||
|
||||
std::vector<llama_token> embd;
|
||||
std::vector<llama_token> embd_guidance;
|
||||
|
||||
// tokenized antiprompts
|
||||
std::vector<std::vector<llama_token>> antiprompt_ids;
|
||||
|
@ -565,12 +532,6 @@ int main(int argc, char ** argv) {
|
|||
antiprompt_ids.emplace_back(::llama_tokenize(ctx, antiprompt, false, true));
|
||||
}
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(sparams);
|
||||
if (!ctx_sampling) {
|
||||
fprintf(stderr, "%s: failed to initialize sampling subsystem\n", __func__);
|
||||
exit(1);
|
||||
}
|
||||
|
||||
if (llama_model_has_encoder(model)) {
|
||||
int enc_input_size = embd_inp.size();
|
||||
llama_token * enc_input_buf = embd_inp.data();
|
||||
|
@ -612,7 +573,7 @@ int main(int argc, char ** argv) {
|
|||
// if we run out of context:
|
||||
// - take the n_keep first tokens from the original prompt (via n_past)
|
||||
// - take half of the last (n_ctx - n_keep) tokens and recompute the logits in batches
|
||||
if (n_past + (int) embd.size() + std::max<int>(0, guidance_offset) >= n_ctx) {
|
||||
if (n_past + (int) embd.size() >= n_ctx) {
|
||||
if (params.n_predict == -2) {
|
||||
LOG_TEE("\n\n%s: context full and n_predict == -%d => stopping\n", __func__, params.n_predict);
|
||||
break;
|
||||
|
@ -629,11 +590,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
n_past -= n_discard;
|
||||
|
||||
if (ctx_guidance) {
|
||||
n_past_guidance -= n_discard;
|
||||
}
|
||||
|
||||
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
|
||||
LOG("after swap: n_past = %d\n", n_past);
|
||||
|
||||
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd).c_str());
|
||||
|
||||
|
@ -686,46 +643,6 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
// evaluate tokens in batches
|
||||
// embd is typically prepared beforehand to fit within a batch, but not always
|
||||
if (ctx_guidance) {
|
||||
int input_size = 0;
|
||||
llama_token * input_buf = NULL;
|
||||
|
||||
if (n_past_guidance < (int) guidance_inp.size()) {
|
||||
// Guidance context should have the same data with these modifications:
|
||||
//
|
||||
// * Replace the initial prompt
|
||||
// * Shift everything by guidance_offset
|
||||
embd_guidance = guidance_inp;
|
||||
if (embd.begin() + original_prompt_len < embd.end()) {
|
||||
embd_guidance.insert(
|
||||
embd_guidance.end(),
|
||||
embd.begin() + original_prompt_len,
|
||||
embd.end()
|
||||
);
|
||||
}
|
||||
|
||||
input_buf = embd_guidance.data();
|
||||
input_size = embd_guidance.size();
|
||||
|
||||
LOG("guidance context: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd_guidance).c_str());
|
||||
} else {
|
||||
input_buf = embd.data();
|
||||
input_size = embd.size();
|
||||
}
|
||||
|
||||
for (int i = 0; i < input_size; i += params.n_batch) {
|
||||
int n_eval = std::min(input_size - i, params.n_batch);
|
||||
if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0))) {
|
||||
LOG_TEE("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
n_past_guidance += n_eval;
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = 0; i < (int) embd.size(); i += params.n_batch) {
|
||||
int n_eval = (int) embd.size() - i;
|
||||
if (n_eval > params.n_batch) {
|
||||
|
@ -755,7 +672,6 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
embd.clear();
|
||||
embd_guidance.clear();
|
||||
|
||||
if ((int) embd_inp.size() <= n_consumed && !is_interacting) {
|
||||
// optionally save the session on first sample (for faster prompt loading next time)
|
||||
|
@ -766,11 +682,11 @@ int main(int argc, char ** argv) {
|
|||
LOG("saved session to %s\n", path_session.c_str());
|
||||
}
|
||||
|
||||
const llama_token id = llama_sampling_sample(ctx_sampling, ctx, ctx_guidance);
|
||||
const llama_token id = gpt_sampler_sample(smpl, ctx, -1);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx, id, /* apply_grammar= */ true);
|
||||
gpt_sampler_accept(smpl, id, /* apply_grammar= */ true);
|
||||
|
||||
LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, ctx_sampling->prev).c_str());
|
||||
// LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, smpl->prev.to_vector()).c_str());
|
||||
|
||||
embd.push_back(id);
|
||||
|
||||
|
@ -789,7 +705,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// push the prompt in the sampling context in order to apply repetition penalties later
|
||||
// for the prompt, we don't apply grammar rules
|
||||
llama_sampling_accept(ctx_sampling, ctx, embd_inp[n_consumed], /* apply_grammar= */ false);
|
||||
gpt_sampler_accept(smpl, embd_inp[n_consumed], /* apply_grammar= */ false);
|
||||
|
||||
++n_consumed;
|
||||
if ((int) embd.size() >= params.n_batch) {
|
||||
|
@ -832,7 +748,7 @@ int main(int argc, char ** argv) {
|
|||
// check for reverse prompt in the last n_prev tokens
|
||||
if (!params.antiprompt.empty()) {
|
||||
const int n_prev = 32;
|
||||
const std::string last_output = llama_sampling_prev_str(ctx_sampling, ctx, n_prev);
|
||||
const std::string last_output = gpt_sampler_prev_str(smpl, ctx, n_prev);
|
||||
|
||||
is_antiprompt = false;
|
||||
// Check if each of the reverse prompts appears at the end of the output.
|
||||
|
@ -854,7 +770,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// check for reverse prompt using special tokens
|
||||
llama_token last_token = llama_sampling_last(ctx_sampling);
|
||||
llama_token last_token = gpt_sampler_last(smpl);
|
||||
for (std::vector<llama_token> ids : antiprompt_ids) {
|
||||
if (ids.size() == 1 && last_token == ids[0]) {
|
||||
if (params.interactive) {
|
||||
|
@ -871,7 +787,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// deal with end of generation tokens in interactive mode
|
||||
if (llama_token_is_eog(model, llama_sampling_last(ctx_sampling))) {
|
||||
if (llama_token_is_eog(model, gpt_sampler_last(smpl))) {
|
||||
LOG("found an EOG token\n");
|
||||
|
||||
if (params.interactive) {
|
||||
|
@ -892,7 +808,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// if current token is not EOG, we add it to current assistant message
|
||||
if (params.conversation) {
|
||||
auto id = llama_sampling_last(ctx_sampling);
|
||||
const auto id = gpt_sampler_last(smpl);
|
||||
assistant_ss << llama_token_to_piece(ctx, id, false);
|
||||
}
|
||||
|
||||
|
@ -988,7 +904,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
if (n_past > 0) {
|
||||
if (is_interacting) {
|
||||
llama_sampling_reset(ctx_sampling);
|
||||
gpt_sampler_reset(smpl);
|
||||
}
|
||||
is_interacting = false;
|
||||
}
|
||||
|
@ -1013,14 +929,15 @@ int main(int argc, char ** argv) {
|
|||
llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size());
|
||||
}
|
||||
|
||||
llama_print_timings(ctx);
|
||||
LOG_TEE("\n");
|
||||
gpt_perf_print(ctx, smpl);
|
||||
write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens);
|
||||
|
||||
if (ctx_guidance) { llama_free(ctx_guidance); }
|
||||
gpt_sampler_free(smpl);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
llama_sampling_free(ctx_sampling);
|
||||
llama_backend_free();
|
||||
|
||||
ggml_threadpool_free(threadpool);
|
||||
|
|
|
@ -50,8 +50,8 @@ static std::vector<std::string> k_prompts = {
|
|||
|
||||
struct client {
|
||||
~client() {
|
||||
if (ctx_sampling) {
|
||||
llama_sampling_free(ctx_sampling);
|
||||
if (smpl) {
|
||||
gpt_sampler_free(smpl);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -72,7 +72,7 @@ struct client {
|
|||
std::string prompt;
|
||||
std::string response;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling = nullptr;
|
||||
struct gpt_sampler * smpl = nullptr;
|
||||
};
|
||||
|
||||
static void print_date_time() {
|
||||
|
@ -161,7 +161,7 @@ int main(int argc, char ** argv) {
|
|||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.ctx_sampling = llama_sampling_init(params.sparams);
|
||||
client.smpl = gpt_sampler_init(model, params.sparams);
|
||||
}
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
|
@ -253,7 +253,7 @@ int main(int argc, char ** argv) {
|
|||
client.prompt = client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
|
||||
llama_sampling_reset(client.ctx_sampling);
|
||||
gpt_sampler_reset(client.smpl);
|
||||
|
||||
// do not prepend BOS because we have a system prompt!
|
||||
std::vector<llama_token> tokens_prompt;
|
||||
|
@ -341,9 +341,9 @@ int main(int argc, char ** argv) {
|
|||
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
||||
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
||||
|
||||
const llama_token id = llama_sampling_sample(client.ctx_sampling, ctx, NULL, client.i_batch - i);
|
||||
const llama_token id = gpt_sampler_sample(client.smpl, ctx, client.i_batch - i);
|
||||
|
||||
llama_sampling_accept(client.ctx_sampling, ctx, id, true);
|
||||
gpt_sampler_accept(client.smpl, id, true);
|
||||
|
||||
if (client.n_decoded == 1) {
|
||||
// start measuring generation time after the first token to make sure all concurrent clients
|
||||
|
@ -371,7 +371,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
||||
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
|
||||
llama_kv_cache_seq_rm(ctx, client.id + 1, -1, -1);
|
||||
llama_kv_cache_seq_cp(ctx, 0, client.id + 1, -1, -1);
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
@ -413,7 +413,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG_TEE("\n");
|
||||
|
||||
llama_print_timings(ctx);
|
||||
// TODO: print sampling/grammar timings for all clients
|
||||
llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
|
|
|
@ -26,8 +26,6 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
srand(params.seed == LLAMA_DEFAULT_SEED ? time(NULL) : params.seed);
|
||||
|
||||
int n_junk = params.n_junk;
|
||||
int n_keep = params.n_keep;
|
||||
int n_grp = params.grp_attn_n;
|
||||
|
@ -80,12 +78,17 @@ int main(int argc, char ** argv) {
|
|||
GGML_ASSERT(ctx_params.n_batch % n_grp == 0 && "n_batch must be divisible by n_grp");
|
||||
|
||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
||||
|
||||
// tokenize the prompt
|
||||
std::vector<llama_token> tokens_list;
|
||||
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
|
||||
|
@ -217,20 +220,9 @@ int main(int argc, char ** argv) {
|
|||
while (n_cur <= n_len) {
|
||||
// sample the next token
|
||||
{
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
// sample the most likely token
|
||||
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
llama_sampler_accept(smpl, new_token_id);
|
||||
|
||||
// is it an end of generation?
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_len) {
|
||||
|
@ -267,10 +259,13 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
llama_print_timings(ctx);
|
||||
LOG_TEE("\n");
|
||||
llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_free(ctx);
|
||||
|
|
|
@ -76,7 +76,7 @@ static void write_logfile(
|
|||
fprintf(logfile, "ppl_value: %f\n", results.ppl_value);
|
||||
yaml_dump_vector_float(logfile, "probs", results.probs);
|
||||
|
||||
llama_dump_timing_info_yaml(logfile, ctx);
|
||||
llama_perf_dump_yaml(logfile, ctx);
|
||||
fclose(logfile);
|
||||
}
|
||||
|
||||
|
@ -2007,13 +2007,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
print_build_info();
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
|
||||
fprintf(stderr, "%s: seed = %u\n", __func__, params.seed);
|
||||
|
||||
std::mt19937 rng(params.seed);
|
||||
LOG_TEE("%s: seed = %u\n", __func__, params.sparams.seed);
|
||||
|
||||
llama_backend_init();
|
||||
llama_numa_init(params.numa);
|
||||
|
@ -2054,7 +2048,8 @@ int main(int argc, char ** argv) {
|
|||
results = perplexity(ctx, params, n_ctx);
|
||||
}
|
||||
|
||||
llama_print_timings(ctx);
|
||||
LOG_TEE("\n");
|
||||
llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
|
||||
write_logfile(ctx, params, model, results);
|
||||
|
||||
llama_free(ctx);
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
#define LLAMA_API_INTERNAL
|
||||
#include "common.h"
|
||||
#include "ggml.h"
|
||||
#include "llama.h"
|
||||
#include "llama-impl.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cassert>
|
||||
|
@ -319,8 +319,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
auto cparams = llama_context_default_params();
|
||||
cparams.n_ctx = 256;
|
||||
cparams.seed = 1;
|
||||
cparams.n_ctx = 256;
|
||||
|
||||
ctx = llama_new_context_with_model(model, cparams);
|
||||
|
||||
|
|
|
@ -293,9 +293,11 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
|
||||
|
||||
// clean up
|
||||
llama_batch_free(query_batch);
|
||||
llama_print_timings(ctx);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
llama_backend_free();
|
||||
|
|
|
@ -3,12 +3,12 @@
|
|||
|
||||
#include <vector>
|
||||
#include <cstdio>
|
||||
#include <chrono>
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
params.prompt = "The quick brown fox";
|
||||
params.sparams.seed = 1234;
|
||||
|
||||
if (!gpt_params_parse(argc, argv, params)) {
|
||||
gpt_params_print_usage(argc, argv, params);
|
||||
|
@ -38,6 +38,13 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_softmax());
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_dist(params.sparams.seed));
|
||||
|
||||
// tokenize prompt
|
||||
auto tokens = llama_tokenize(ctx, params.prompt, true);
|
||||
|
||||
|
@ -64,18 +71,11 @@ int main(int argc, char ** argv) {
|
|||
printf("\nfirst run: %s", params.prompt.c_str());
|
||||
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto * logits = llama_get_logits(ctx);
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
auto next_token = llama_sample_token(ctx, &candidates_p);
|
||||
auto next_token = llama_sampler_sample(smpl, ctx, -1);
|
||||
auto next_token_str = llama_token_to_piece(ctx, next_token);
|
||||
|
||||
llama_sampler_accept(smpl, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result0 += next_token_str;
|
||||
|
||||
|
@ -96,6 +96,11 @@ int main(int argc, char ** argv) {
|
|||
// make new context
|
||||
auto * ctx2 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||
|
||||
llama_sampler * smpl2 = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl2, llama_sampler_init_softmax());
|
||||
llama_sampler_chain_add(smpl2, llama_sampler_init_dist(params.sparams.seed));
|
||||
|
||||
printf("\nsecond run: %s", params.prompt.c_str());
|
||||
|
||||
// load state (rng, logits, embedding and kv_cache) from file
|
||||
|
@ -124,17 +129,11 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// second run
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto * logits = llama_get_logits(ctx2);
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
auto next_token = llama_sample_token(ctx2, &candidates_p);
|
||||
auto next_token = llama_sampler_sample(smpl2, ctx2, -1);
|
||||
auto next_token_str = llama_token_to_piece(ctx2, next_token);
|
||||
|
||||
llama_sampler_accept(smpl2, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result1 += next_token_str;
|
||||
|
||||
|
@ -157,7 +156,12 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// make new context
|
||||
auto* ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||
auto * ctx3 = llama_new_context_with_model(model, llama_context_params_from_gpt_params(params));
|
||||
|
||||
llama_sampler * smpl3 = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl3, llama_sampler_init_softmax());
|
||||
llama_sampler_chain_add(smpl3, llama_sampler_init_dist(params.sparams.seed));
|
||||
|
||||
printf("\nsingle seq run: %s", params.prompt.c_str());
|
||||
|
||||
|
@ -215,17 +219,11 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// third run with seq 1 instead of 0
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto * logits = llama_get_logits(ctx3);
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{token_id, logits[token_id], 0.0f});
|
||||
}
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
auto next_token = llama_sample_token(ctx3, &candidates_p);
|
||||
auto next_token = llama_sampler_sample(smpl3, ctx3, -1);
|
||||
auto next_token_str = llama_token_to_piece(ctx3, next_token);
|
||||
|
||||
llama_sampler_accept(smpl3, next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
result2 += next_token_str;
|
||||
|
||||
|
@ -240,6 +238,10 @@ int main(int argc, char ** argv) {
|
|||
|
||||
printf("\n");
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
llama_sampler_free(smpl2);
|
||||
llama_sampler_free(smpl3);
|
||||
|
||||
llama_free(ctx3);
|
||||
llama_free_model(model);
|
||||
|
||||
|
|
|
@ -470,8 +470,6 @@ node index.js
|
|||
|
||||
`frequency_penalty`: Repeat alpha frequency penalty. Default: `0.0`, which is disabled.
|
||||
|
||||
`penalty_prompt`: This will replace the `prompt` for the purpose of the penalty evaluation. Can be either `null`, a string or an array of numbers representing tokens. Default: `null`, which is to use the original `prompt`.
|
||||
|
||||
`mirostat`: Enable Mirostat sampling, controlling perplexity during text generation. Default: `0`, where `0` is disabled, `1` is Mirostat, and `2` is Mirostat 2.0.
|
||||
|
||||
`mirostat_tau`: Set the Mirostat target entropy, parameter tau. Default: `5.0`
|
||||
|
@ -724,7 +722,6 @@ Example:
|
|||
"stopping_word": ""
|
||||
},
|
||||
"penalize_nl": true,
|
||||
"penalty_prompt_tokens": [],
|
||||
"presence_penalty": 0.0,
|
||||
"prompt": "Say hello to llama.cpp",
|
||||
"repeat_last_n": 64,
|
||||
|
@ -748,8 +745,7 @@ Example:
|
|||
"tfs_z": 1.0,
|
||||
"top_k": 40,
|
||||
"top_p": 0.949999988079071,
|
||||
"typical_p": 1.0,
|
||||
"use_penalty_prompt_tokens": false
|
||||
"typical_p": 1.0
|
||||
}
|
||||
]
|
||||
```
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
#include "common.h"
|
||||
#include "json-schema-to-grammar.h"
|
||||
#include "llama.h"
|
||||
#include "grammar-parser.h"
|
||||
|
||||
// Change JSON_ASSERT from assert() to GGML_ASSERT:
|
||||
#define JSON_ASSERT GGML_ASSERT
|
||||
|
@ -169,11 +168,13 @@ struct server_slot {
|
|||
std::string stopping_word;
|
||||
|
||||
// sampling
|
||||
llama_token sampled;
|
||||
struct llama_sampling_params sparams;
|
||||
llama_sampling_context * ctx_sampling = nullptr;
|
||||
json json_schema;
|
||||
|
||||
struct gpt_sampler_params sparams;
|
||||
struct gpt_sampler * smpl = nullptr;
|
||||
|
||||
llama_token sampled;
|
||||
|
||||
int32_t ga_i = 0; // group-attention state
|
||||
int32_t ga_n = 1; // group-attention factor
|
||||
int32_t ga_w = 512; // group-attention width
|
||||
|
@ -651,8 +652,8 @@ struct server_context {
|
|||
|
||||
// Clear any sampling context
|
||||
for (server_slot & slot : slots) {
|
||||
if (slot.ctx_sampling != nullptr) {
|
||||
llama_sampling_free(slot.ctx_sampling);
|
||||
if (slot.smpl != nullptr) {
|
||||
gpt_sampler_free(slot.smpl);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -883,8 +884,8 @@ struct server_context {
|
|||
bool launch_slot_with_task(server_slot & slot, const server_task & task) {
|
||||
slot_params default_params;
|
||||
// Sampling parameter defaults are loaded from the global server context (but individual requests can still override them)
|
||||
llama_sampling_params default_sparams = params.sparams;
|
||||
auto & data = task.data;
|
||||
auto default_sparams = params.sparams;
|
||||
const auto & data = task.data;
|
||||
|
||||
if (data.count("__oaicompat") != 0) {
|
||||
slot.oaicompat = true;
|
||||
|
@ -901,7 +902,7 @@ struct server_context {
|
|||
slot.sparams.top_p = json_value(data, "top_p", default_sparams.top_p);
|
||||
slot.sparams.min_p = json_value(data, "min_p", default_sparams.min_p);
|
||||
slot.sparams.tfs_z = json_value(data, "tfs_z", default_sparams.tfs_z);
|
||||
slot.sparams.typical_p = json_value(data, "typical_p", default_sparams.typical_p);
|
||||
slot.sparams.typ_p = json_value(data, "typical_p", default_sparams.typ_p);
|
||||
slot.sparams.temp = json_value(data, "temperature", default_sparams.temp);
|
||||
slot.sparams.dynatemp_range = json_value(data, "dynatemp_range", default_sparams.dynatemp_range);
|
||||
slot.sparams.dynatemp_exponent = json_value(data, "dynatemp_exponent", default_sparams.dynatemp_exponent);
|
||||
|
@ -923,7 +924,8 @@ struct server_context {
|
|||
if (data.contains("json_schema") && !data.at("json_schema").is_null() && data.contains("grammar") && !data.at("grammar").is_null()) {
|
||||
send_error(task, "Either \"json_schema\" or \"grammar\" can be specified, but not both", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
} else if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
}
|
||||
if (data.contains("json_schema") && !data.contains("grammar")) {
|
||||
try {
|
||||
auto schema = json_value(data, "json_schema", json::object());
|
||||
slot.sparams.grammar = json_schema_to_grammar(schema);
|
||||
|
@ -973,56 +975,11 @@ struct server_context {
|
|||
}
|
||||
}
|
||||
|
||||
// penalize user-provided tokens
|
||||
{
|
||||
slot.sparams.penalty_prompt_tokens.clear();
|
||||
slot.sparams.use_penalty_prompt_tokens = false;
|
||||
|
||||
const auto & penalty_prompt = data.find("penalty_prompt");
|
||||
|
||||
if (penalty_prompt != data.end()) {
|
||||
if (penalty_prompt->is_string()) {
|
||||
const auto penalty_prompt_string = penalty_prompt->get<std::string>();
|
||||
slot.sparams.penalty_prompt_tokens = llama_tokenize(model, penalty_prompt_string, false);
|
||||
|
||||
if (slot.params.n_predict > 0) {
|
||||
slot.sparams.penalty_prompt_tokens.reserve(slot.sparams.penalty_prompt_tokens.size() + slot.params.n_predict);
|
||||
}
|
||||
slot.sparams.use_penalty_prompt_tokens = true;
|
||||
|
||||
LOG_VERBOSE("penalty_prompt_tokens", {
|
||||
{"id_slot", slot.id},
|
||||
{"tokens", slot.sparams.penalty_prompt_tokens},
|
||||
});
|
||||
}
|
||||
else if (penalty_prompt->is_array()) {
|
||||
const auto n_tokens = penalty_prompt->size();
|
||||
slot.sparams.penalty_prompt_tokens.reserve(n_tokens + std::max(0, slot.params.n_predict));
|
||||
|
||||
const int n_vocab = llama_n_vocab(model);
|
||||
for (const auto & penalty_token : *penalty_prompt) {
|
||||
if (penalty_token.is_number_integer()) {
|
||||
const auto tok = penalty_token.get<llama_token>();
|
||||
if (tok >= 0 && tok < n_vocab) {
|
||||
slot.sparams.penalty_prompt_tokens.push_back(tok);
|
||||
}
|
||||
}
|
||||
}
|
||||
slot.sparams.use_penalty_prompt_tokens = true;
|
||||
|
||||
LOG_VERBOSE("penalty_prompt_tokens", {
|
||||
{"id_slot", slot.id},
|
||||
{"tokens", slot.sparams.penalty_prompt_tokens},
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
slot.sparams.logit_bias.clear();
|
||||
|
||||
if (json_value(data, "ignore_eos", false) && has_eos_token) {
|
||||
slot.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
||||
slot.sparams.logit_bias.push_back({llama_token_eos(model), -INFINITY});
|
||||
}
|
||||
|
||||
const auto & logit_bias = data.find("logit_bias");
|
||||
|
@ -1043,12 +1000,12 @@ struct server_context {
|
|||
if (el[0].is_number_integer()) {
|
||||
llama_token tok = el[0].get<llama_token>();
|
||||
if (tok >= 0 && tok < n_vocab) {
|
||||
slot.sparams.logit_bias[tok] = bias;
|
||||
slot.sparams.logit_bias.push_back({tok, bias});
|
||||
}
|
||||
} else if (el[0].is_string()) {
|
||||
auto toks = llama_tokenize(model, el[0].get<std::string>(), false);
|
||||
for (auto tok : toks) {
|
||||
slot.sparams.logit_bias[tok] = bias;
|
||||
slot.sparams.logit_bias.push_back({tok, bias});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1070,26 +1027,27 @@ struct server_context {
|
|||
}
|
||||
|
||||
{
|
||||
const auto & samplers_sequence = data.find("samplers");
|
||||
if (samplers_sequence != data.end() && samplers_sequence->is_array()) {
|
||||
const auto & samplers = data.find("samplers");
|
||||
if (samplers != data.end() && samplers->is_array()) {
|
||||
std::vector<std::string> sampler_names;
|
||||
for (const auto & sampler_name : *samplers_sequence) {
|
||||
if (sampler_name.is_string()) {
|
||||
sampler_names.emplace_back(sampler_name);
|
||||
for (const auto & name : *samplers) {
|
||||
if (name.is_string()) {
|
||||
sampler_names.emplace_back(name);
|
||||
}
|
||||
}
|
||||
slot.sparams.samplers_sequence = llama_sampling_types_from_names(sampler_names, false);
|
||||
slot.sparams.samplers = gpt_sampler_types_from_names(sampler_names, false);
|
||||
} else {
|
||||
slot.sparams.samplers_sequence = default_sparams.samplers_sequence;
|
||||
slot.sparams.samplers = default_sparams.samplers;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
if (slot.ctx_sampling != nullptr) {
|
||||
llama_sampling_free(slot.ctx_sampling);
|
||||
if (slot.smpl != nullptr) {
|
||||
gpt_sampler_free(slot.smpl);
|
||||
}
|
||||
slot.ctx_sampling = llama_sampling_init(slot.sparams);
|
||||
if (slot.ctx_sampling == nullptr) {
|
||||
|
||||
slot.smpl = gpt_sampler_init(model, slot.sparams);
|
||||
if (slot.smpl == nullptr) {
|
||||
// for now, the only error that may happen here is invalid grammar
|
||||
send_error(task, "Failed to parse grammar", ERROR_TYPE_INVALID_REQUEST);
|
||||
return false;
|
||||
|
@ -1178,11 +1136,6 @@ struct server_context {
|
|||
slot.generated_text += token_str;
|
||||
slot.has_next_token = true;
|
||||
|
||||
if (slot.ctx_sampling->params.use_penalty_prompt_tokens && result.tok != -1) {
|
||||
// we can change penalty_prompt_tokens because it is always created from scratch each request
|
||||
slot.ctx_sampling->params.penalty_prompt_tokens.push_back(result.tok);
|
||||
}
|
||||
|
||||
// check if there is incomplete UTF-8 character at the end
|
||||
bool incomplete = false;
|
||||
for (unsigned i = 1; i < 5 && i <= slot.generated_text.size(); ++i) {
|
||||
|
@ -1300,13 +1253,10 @@ struct server_context {
|
|||
}
|
||||
|
||||
json get_formated_generation(const server_slot & slot) const {
|
||||
const auto eos_bias = slot.sparams.logit_bias.find(llama_token_eos(model));
|
||||
const bool ignore_eos = eos_bias != slot.sparams.logit_bias.end() && eos_bias->second < 0.0f && std::isinf(eos_bias->second);
|
||||
|
||||
std::vector<std::string> samplers_sequence;
|
||||
samplers_sequence.reserve(slot.sparams.samplers_sequence.size());
|
||||
for (const auto & sampler_type : slot.sparams.samplers_sequence) {
|
||||
samplers_sequence.emplace_back(llama_sampling_type_to_str(sampler_type));
|
||||
std::vector<std::string> samplers;
|
||||
samplers.reserve(slot.sparams.samplers.size());
|
||||
for (const auto & sampler : slot.sparams.samplers) {
|
||||
samplers.emplace_back(gpt_sampler_type_to_str(sampler));
|
||||
}
|
||||
|
||||
return json {
|
||||
|
@ -1321,13 +1271,11 @@ struct server_context {
|
|||
{"top_p", slot.sparams.top_p},
|
||||
{"min_p", slot.sparams.min_p},
|
||||
{"tfs_z", slot.sparams.tfs_z},
|
||||
{"typical_p", slot.sparams.typical_p},
|
||||
{"typical_p", slot.sparams.typ_p},
|
||||
{"repeat_last_n", slot.sparams.penalty_last_n},
|
||||
{"repeat_penalty", slot.sparams.penalty_repeat},
|
||||
{"presence_penalty", slot.sparams.penalty_present},
|
||||
{"frequency_penalty", slot.sparams.penalty_freq},
|
||||
{"penalty_prompt_tokens", slot.sparams.penalty_prompt_tokens},
|
||||
{"use_penalty_prompt_tokens", slot.sparams.use_penalty_prompt_tokens},
|
||||
{"mirostat", slot.sparams.mirostat},
|
||||
{"mirostat_tau", slot.sparams.mirostat_tau},
|
||||
{"mirostat_eta", slot.sparams.mirostat_eta},
|
||||
|
@ -1336,13 +1284,13 @@ struct server_context {
|
|||
{"max_tokens", slot.params.n_predict}, // User configured n_predict
|
||||
{"n_keep", slot.params.n_keep},
|
||||
{"n_discard", slot.params.n_discard},
|
||||
{"ignore_eos", ignore_eos},
|
||||
{"ignore_eos", slot.sparams.ignore_eos},
|
||||
{"stream", slot.params.stream},
|
||||
{"logit_bias", slot.sparams.logit_bias},
|
||||
//{"logit_bias", slot.sparams.logit_bias},
|
||||
{"n_probs", slot.sparams.n_probs},
|
||||
{"min_keep", slot.sparams.min_keep},
|
||||
{"grammar", slot.sparams.grammar},
|
||||
{"samplers", samplers_sequence}
|
||||
{"samplers", samplers},
|
||||
};
|
||||
}
|
||||
|
||||
|
@ -2136,7 +2084,7 @@ struct server_context {
|
|||
GGML_ASSERT(slot.n_prompt_tokens < slot.n_ctx);
|
||||
}
|
||||
|
||||
llama_sampling_reset(slot.ctx_sampling);
|
||||
gpt_sampler_reset(slot.smpl);
|
||||
|
||||
if (!slot.params.cache_prompt) {
|
||||
slot.n_past_se = 0;
|
||||
|
@ -2149,7 +2097,7 @@ struct server_context {
|
|||
|
||||
// push the prompt into the sampling context (do not apply grammar)
|
||||
for (int i = 0; i < slot.n_past; ++i) {
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, slot.cache_tokens[i], false);
|
||||
gpt_sampler_accept(slot.smpl, slot.cache_tokens[i], false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -2202,7 +2150,7 @@ struct server_context {
|
|||
slot.n_past_se = 0;
|
||||
slot.ga_i = 0;
|
||||
// TODO: is the system prompt ever in the sampling context?
|
||||
llama_sampling_reset(slot.ctx_sampling);
|
||||
gpt_sampler_reset(slot.smpl);
|
||||
}
|
||||
|
||||
// remove the non-common part from the cache
|
||||
|
@ -2375,18 +2323,18 @@ struct server_context {
|
|||
slot.release();
|
||||
slot.i_batch = -1;
|
||||
continue; // continue loop of slots
|
||||
} else {
|
||||
// prompt evaluated for next-token prediction
|
||||
slot.state = SLOT_STATE_GENERATING;
|
||||
}
|
||||
|
||||
// prompt evaluated for next-token prediction
|
||||
slot.state = SLOT_STATE_GENERATING;
|
||||
} else if (slot.state != SLOT_STATE_GENERATING) {
|
||||
continue; // continue loop of slots
|
||||
}
|
||||
|
||||
completion_token_output result;
|
||||
const llama_token id = llama_sampling_sample(slot.ctx_sampling, ctx, NULL, slot.i_batch - i);
|
||||
const llama_token id = gpt_sampler_sample(slot.smpl, ctx, slot.i_batch - i);
|
||||
|
||||
llama_sampling_accept(slot.ctx_sampling, ctx, id, true);
|
||||
gpt_sampler_accept(slot.smpl, id, true);
|
||||
|
||||
slot.n_decoded += 1;
|
||||
if (slot.n_decoded == 1) {
|
||||
|
@ -2395,34 +2343,15 @@ struct server_context {
|
|||
metrics.on_prompt_eval(slot);
|
||||
}
|
||||
|
||||
llama_token_data_array cur_p = { slot.ctx_sampling->cur.data(), slot.ctx_sampling->cur.size(), false };
|
||||
result.tok = id;
|
||||
|
||||
const size_t n_probs = std::min(cur_p.size, (size_t) slot.sparams.n_probs);
|
||||
if (n_probs > 0) {
|
||||
const size_t n_valid = slot.ctx_sampling->n_valid;
|
||||
const auto * cur_p = gpt_sampler_get_candidates(slot.smpl);
|
||||
|
||||
// Make sure at least n_probs top tokens are at the front of the vector:
|
||||
if (slot.sparams.temp == 0.0f && n_probs > n_valid) {
|
||||
llama_sample_top_k(ctx, &cur_p, n_probs, 0);
|
||||
}
|
||||
|
||||
if (slot.sparams.temp == 0.0f) {
|
||||
// With greedy sampling the probabilities have possibly not been calculated.
|
||||
for (size_t i = 0; i < n_probs; ++i) {
|
||||
result.probs.push_back({
|
||||
cur_p.data[i].id,
|
||||
i == 0 ? 1.0f : 0.0f
|
||||
});
|
||||
}
|
||||
} else {
|
||||
for (size_t i = 0; i < n_probs; ++i) {
|
||||
result.probs.push_back({
|
||||
cur_p.data[i].id,
|
||||
i >= n_valid ? 0.0f : cur_p.data[i].p // Tokens filtered out due to e.g. top_k have 0 probability.
|
||||
});
|
||||
}
|
||||
}
|
||||
for (size_t i = 0; i < (size_t) slot.sparams.n_probs; ++i) {
|
||||
result.probs.push_back({
|
||||
cur_p->data[i].id,
|
||||
i >= cur_p->size ? 0.0f : cur_p->data[i].p,
|
||||
});
|
||||
}
|
||||
|
||||
if (!process_token(result, slot)) {
|
||||
|
|
|
@ -55,6 +55,14 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
auto sparams = llama_sampler_chain_default_params();
|
||||
|
||||
sparams.no_perf = false;
|
||||
|
||||
llama_sampler * smpl = llama_sampler_chain_init(sparams);
|
||||
|
||||
llama_sampler_chain_add(smpl, llama_sampler_init_greedy());
|
||||
|
||||
// tokenize the prompt
|
||||
|
||||
std::vector<llama_token> tokens_list;
|
||||
|
@ -110,20 +118,9 @@ int main(int argc, char ** argv) {
|
|||
while (n_cur <= n_predict) {
|
||||
// sample the next token
|
||||
{
|
||||
auto n_vocab = llama_n_vocab(model);
|
||||
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
const llama_token new_token_id = llama_sampler_sample(smpl, ctx, batch.n_tokens - 1);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
// sample the most likely token
|
||||
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
llama_sampler_accept(smpl, new_token_id);
|
||||
|
||||
// is it an end of generation?
|
||||
if (llama_token_is_eog(model, new_token_id) || n_cur == n_predict) {
|
||||
|
@ -160,12 +157,14 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
llama_print_timings(ctx);
|
||||
LOG_TEE("\n");
|
||||
llama_perf_print(smpl, LLAMA_PERF_TYPE_SAMPLER_CHAIN);
|
||||
llama_perf_print(ctx, LLAMA_PERF_TYPE_CONTEXT);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_sampler_free(smpl);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ struct seq_draft {
|
|||
std::vector<llama_token> tokens;
|
||||
std::vector<std::vector<llama_token_data>> dists;
|
||||
|
||||
struct llama_sampling_context * ctx_sampling;
|
||||
struct gpt_sampler * smpl = nullptr;
|
||||
};
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
|
@ -43,10 +43,7 @@ int main(int argc, char ** argv) {
|
|||
// probability threshold for splitting a draft branch (only for n_seq_dft > 1)
|
||||
const float p_split = params.p_split;
|
||||
|
||||
if (params.seed == LLAMA_DEFAULT_SEED) {
|
||||
params.seed = time(NULL);
|
||||
}
|
||||
std::default_random_engine rng(params.seed);
|
||||
std::default_random_engine rng(params.sparams.seed);
|
||||
std::uniform_real_distribution<> u_dist;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
|
@ -179,19 +176,17 @@ int main(int argc, char ** argv) {
|
|||
// used to determine end of generation
|
||||
bool has_eos = false;
|
||||
|
||||
// target model sampling context
|
||||
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
|
||||
// target model sampling context (reuse the llama_context's sampling instance)
|
||||
struct gpt_sampler * smpl = gpt_sampler_init(model_tgt, params.sparams);
|
||||
|
||||
struct llama_sampler * softmax = llama_sampler_init_softmax();
|
||||
|
||||
// draft sequence data
|
||||
std::vector<seq_draft> drafts(n_seq_dft);
|
||||
|
||||
params.sparams.grammar.clear(); // the draft samplers will copy the target sampler's grammar
|
||||
if (params.sparams.temp == 0) {
|
||||
params.sparams.temp = -1.0f; // force greedy sampling with probs for the draft model
|
||||
}
|
||||
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
drafts[s].ctx_sampling = llama_sampling_init(params.sparams);
|
||||
// allocate gpt_sampler for each draft sequence
|
||||
drafts[s].smpl = gpt_sampler_init(model_dft, params.sparams);
|
||||
}
|
||||
|
||||
llama_batch batch_dft = llama_batch_init(params.n_ctx, 0, 1);
|
||||
|
@ -233,12 +228,12 @@ int main(int argc, char ** argv) {
|
|||
bool accept = false;
|
||||
if (params.sparams.temp > 0) {
|
||||
// stochastic verification
|
||||
gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
|
||||
|
||||
llama_token_data_array dist_tgt = llama_sampling_prepare(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft], true, NULL);
|
||||
llama_sample_softmax(ctx_tgt, &dist_tgt);
|
||||
float p_tgt = 0, p_dft = 0;
|
||||
auto & dist_tgt = *gpt_sampler_get_candidates(smpl);
|
||||
|
||||
// GGML_ASSERT(dist_tgt.size() == dist_dft.size());
|
||||
float p_tgt = 0.0f;
|
||||
float p_dft = 0.0f;
|
||||
|
||||
while (active_seqs.size() > 0) {
|
||||
// randomly select a sequence to verify from active sequences
|
||||
|
@ -257,9 +252,13 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
LOG("verifying sequence #%d at pos #%d from %d active sequence(s)\n", s, i_dft, (int) active_seqs.size());
|
||||
float r = u_dist(rng);
|
||||
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), true };
|
||||
llama_token_data_array dist_dft = { drafts[s].dists[i_dft].data() , drafts[s].dists[i_dft].size(), LLAMA_TOKEN_NULL, true };
|
||||
|
||||
//GGML_ASSERT(dist_tgt.size <= dist_dft.size);
|
||||
|
||||
// acquire the token probabilities assigned by the draft and target models
|
||||
for (size_t i = 0; i < dist_tgt.size; i++) {
|
||||
if (dist_tgt.data[i].id == drafts[s].tokens[i_dft]) {
|
||||
|
@ -278,7 +277,7 @@ int main(int argc, char ** argv) {
|
|||
accept = true;
|
||||
token_id = drafts[s].tokens[i_dft];
|
||||
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||
gpt_sampler_accept(smpl, token_id, true);
|
||||
|
||||
LOG("draft token %d of sequence %d (%d, '%s') accepted\n", i_dft, s, token_id, token_str.c_str());
|
||||
break;
|
||||
|
@ -289,7 +288,6 @@ int main(int argc, char ** argv) {
|
|||
// calculate residual probability
|
||||
GGML_ASSERT(dist_tgt.sorted);
|
||||
GGML_ASSERT(dist_dft.sorted);
|
||||
float sum_probs = 0.0f;
|
||||
|
||||
// sort dist by id
|
||||
std::sort(dist_tgt.data, dist_tgt.data + dist_tgt.size, [](const llama_token_data &a, const llama_token_data &b) {
|
||||
|
@ -299,10 +297,18 @@ int main(int argc, char ** argv) {
|
|||
return a.id < b.id;
|
||||
});
|
||||
|
||||
float sum_probs = 0.0f;
|
||||
|
||||
for (size_t i = 0; i < dist_tgt.size; i++) {
|
||||
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
|
||||
if (i < dist_dft.size) {
|
||||
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p - dist_dft.data[i].p);
|
||||
} else {
|
||||
dist_tgt.data[i].p = std::max(0.0f, dist_tgt.data[i].p);
|
||||
}
|
||||
|
||||
sum_probs += dist_tgt.data[i].p;
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < dist_tgt.size; i++) {
|
||||
dist_tgt.data[i].p /= sum_probs;
|
||||
}
|
||||
|
@ -332,21 +338,29 @@ int main(int argc, char ** argv) {
|
|||
// all drafted tokens were rejected
|
||||
// sample from the target model
|
||||
LOG("all drafted tokens were rejected, sampling from residual distribution\n");
|
||||
token_id = llama_sample_token(ctx_tgt, &dist_tgt);
|
||||
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||
std::vector<float> probs(dist_tgt.size);
|
||||
for (size_t i = 0; i < dist_tgt.size; ++i) {
|
||||
probs[i] = dist_tgt.data[i].p;
|
||||
}
|
||||
|
||||
std::discrete_distribution<> dist(probs.begin(), probs.end());
|
||||
|
||||
const int idx = dist(rng);
|
||||
|
||||
token_id = dist_tgt.data[idx].id;
|
||||
gpt_sampler_accept(smpl, token_id, true);
|
||||
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||
}
|
||||
|
||||
} else {
|
||||
// greedy verification
|
||||
|
||||
// sample from the target model
|
||||
LOG("sampling target: s_keep = %3d, i_dft = %3d, i_batch_tgt = %3d\n", s_keep, i_dft, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
token_id = llama_sampling_sample(ctx_sampling, ctx_tgt, NULL, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
token_id = gpt_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft]);
|
||||
|
||||
llama_sampling_accept(ctx_sampling, ctx_tgt, token_id, true);
|
||||
gpt_sampler_accept(smpl, token_id, true);
|
||||
|
||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, ctx_sampling->prev).c_str());
|
||||
//LOG("last: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx_tgt, smpl->prev).c_str());
|
||||
|
||||
token_str = llama_token_to_piece(ctx_tgt, token_id);
|
||||
|
||||
|
@ -434,7 +448,10 @@ int main(int argc, char ** argv) {
|
|||
break;
|
||||
}
|
||||
|
||||
llama_sampling_cp(ctx_sampling, drafts[0].ctx_sampling);
|
||||
if (drafts[0].smpl) {
|
||||
gpt_sampler_free(drafts[0].smpl);
|
||||
}
|
||||
drafts[0].smpl = gpt_sampler_clone(smpl);
|
||||
|
||||
int n_seq_cur = 1;
|
||||
int n_past_cur = n_past_dft;
|
||||
|
@ -463,20 +480,20 @@ int main(int argc, char ** argv) {
|
|||
continue;
|
||||
}
|
||||
|
||||
llama_sampling_sample(drafts[s].ctx_sampling, ctx_dft, NULL, drafts[s].i_batch_dft);
|
||||
gpt_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
|
||||
|
||||
const auto & cur_p = drafts[s].ctx_sampling->cur;
|
||||
const auto * cur_p = gpt_sampler_get_candidates(drafts[s].smpl);
|
||||
|
||||
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p.size()); ++k) {
|
||||
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
|
||||
LOG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",
|
||||
k, s, i, cur_p[k].id, cur_p[k].p, llama_token_to_piece(ctx_dft, cur_p[k].id).c_str());
|
||||
k, s, i, cur_p->data[k].id, cur_p->data[k].p, llama_token_to_piece(ctx_dft, cur_p->data[k].id).c_str());
|
||||
}
|
||||
|
||||
std::vector<int> sa(1, s);
|
||||
|
||||
// attempt to split the branch if the probability is high enough
|
||||
for (int f = 1; f < 8; ++f) {
|
||||
if (n_seq_cur < n_seq_dft && cur_p[f].p > p_split) {
|
||||
if (n_seq_cur < n_seq_dft && cur_p->data[f].p > p_split) {
|
||||
LOG("splitting seq %3d into %3d\n", s, n_seq_cur);
|
||||
|
||||
llama_kv_cache_seq_rm(ctx_dft, n_seq_cur, -1, -1);
|
||||
|
@ -503,7 +520,10 @@ int main(int argc, char ** argv) {
|
|||
drafts[n_seq_cur].i_batch_dft = drafts[s].i_batch_dft;
|
||||
drafts[n_seq_cur].i_batch_tgt = drafts[s].i_batch_tgt;
|
||||
|
||||
llama_sampling_cp(drafts[s].ctx_sampling, drafts[n_seq_cur].ctx_sampling);
|
||||
if (drafts[n_seq_cur].smpl) {
|
||||
gpt_sampler_free(drafts[n_seq_cur].smpl);
|
||||
}
|
||||
drafts[n_seq_cur].smpl = gpt_sampler_clone(drafts[s].smpl);
|
||||
|
||||
sa.push_back(n_seq_cur);
|
||||
|
||||
|
@ -515,15 +535,15 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// add drafted token for each sequence
|
||||
for (int is = 0; is < (int) sa.size(); ++is) {
|
||||
const llama_token id = cur_p[is].id;
|
||||
const llama_token id = cur_p->data[is].id;
|
||||
|
||||
const int s = sa[is];
|
||||
|
||||
llama_sampling_accept(drafts[s].ctx_sampling, ctx_dft, id, true);
|
||||
gpt_sampler_accept(drafts[s].smpl, id, true);
|
||||
|
||||
drafts[s].tokens.push_back(id);
|
||||
// save cur_p.data into drafts[s].dists
|
||||
drafts[s].dists.push_back(cur_p);
|
||||
drafts[s].dists.push_back({cur_p->data, cur_p->data + cur_p->size});
|
||||
|
||||
// add unique drafted tokens to the target batch
|
||||
drafts[s].i_batch_tgt.push_back(batch_tgt.n_tokens);
|
||||
|
@ -593,17 +613,19 @@ int main(int argc, char ** argv) {
|
|||
LOG_TEE("n_accept = %d\n", n_accept);
|
||||
LOG_TEE("accept = %.3f%%\n", 100.0f * n_accept / n_drafted);
|
||||
|
||||
LOG_TEE("\ndraft:\n");
|
||||
llama_print_timings(ctx_dft);
|
||||
LOG_TEE("\ndraft:\n\n");
|
||||
// TODO: print sampling/grammar timings for all drafts
|
||||
llama_perf_print(ctx_dft, LLAMA_PERF_TYPE_CONTEXT);
|
||||
|
||||
LOG_TEE("\ntarget:\n");
|
||||
llama_print_timings(ctx_tgt);
|
||||
LOG_TEE("\ntarget:\n\n");
|
||||
gpt_perf_print(ctx_tgt, smpl);
|
||||
|
||||
llama_sampling_free(ctx_sampling);
|
||||
gpt_sampler_free(smpl);
|
||||
for (int s = 0; s < n_seq_dft; ++s) {
|
||||
llama_sampling_free(drafts[s].ctx_sampling);
|
||||
gpt_sampler_free(drafts[s].smpl);
|
||||
}
|
||||
|
||||
llama_sampler_free(softmax);
|
||||
llama_batch_free(batch_dft);
|
||||
|
||||
llama_free(ctx_tgt);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue