Merge branch 'master' into concedo_experimental

# Conflicts:
#	Makefile
#	README.md
This commit is contained in:
Concedo 2023-11-27 14:06:14 +08:00
commit 8acd7be734
20 changed files with 1274 additions and 136 deletions

1
.gitignore vendored
View file

@ -36,6 +36,7 @@ models-mnt
/libllama.so /libllama.so
/llama-bench /llama-bench
/llava-cli /llava-cli
/lookahead
/main /main
/metal /metal
/perplexity /perplexity

View file

@ -13,6 +13,7 @@
#include <regex> #include <regex>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include <cinttypes> #include <cinttypes>
@ -496,6 +497,8 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
params.chatml = true; params.chatml = true;
} else if (arg == "--infill") { } else if (arg == "--infill") {
params.infill = true; params.infill = true;
} else if (arg == "-dkvc" || arg == "--dump-kv-cache") {
params.dump_kv_cache = true;
} else if (arg == "--multiline-input") { } else if (arg == "--multiline-input") {
params.multiline_input = true; params.multiline_input = true;
} else if (arg == "--simple-io") { } else if (arg == "--simple-io") {
@ -836,6 +839,8 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
#endif // GGML_USE_CUBLAS #endif // GGML_USE_CUBLAS
#endif #endif
printf(" --verbose-prompt print prompt before generation\n"); printf(" --verbose-prompt print prompt before generation\n");
printf(" -dkvc, --dump-kv-cache\n");
printf(" verbose print of the KV cache\n");
printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n"); printf(" --simple-io use basic IO for better compatibility in subprocesses and limited consoles\n");
printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n"); printf(" --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n"); printf(" --lora-scaled FNAME S apply LoRA adapter with user defined scaling S (implies --no-mmap)\n");
@ -1387,3 +1392,77 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p); fprintf(stream, "typical_p: %f # default: 1.0\n", sparams.typical_p);
fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false");
} }
//
// KV cache utils
//
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size) {
static const char slot_chars[] = ".123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz+";
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
llama_kv_cache_view_cell * c_curr = view.cells;
llama_seq_id * cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
if (i % row_size == 0) {
printf("\n%5d: ", i);
}
int seq_count = 0;
for (int j = 0; j < view.n_max_seq; j++) {
if (cs_curr[j] >= 0) { seq_count++; }
}
putchar(slot_chars[std::min(sizeof(slot_chars) - 2, size_t(seq_count))]);
}
printf("\n=== Done dumping\n");
}
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size) {
static const char slot_chars[] = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz";
printf("=== Dumping KV cache. total cells %d, max sequences per cell %d, populated cells %d, total tokens in cache %d, largest empty slot=%d @ %d\n",
view.n_cells, view.n_max_seq, view.used_cells, view.token_count, view.max_contiguous, view.max_contiguous_idx);
std::unordered_map<llama_seq_id, size_t> seqs;
llama_kv_cache_view_cell * c_curr = view.cells;
llama_seq_id * cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
for (int j = 0; j < view.n_max_seq; j++) {
if (cs_curr[j] < 0) { continue; }
if (seqs.find(cs_curr[j]) == seqs.end()) {
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
seqs[cs_curr[j]] = seqs.size();
}
}
if (seqs.size() + 1 >= sizeof(slot_chars)) { break; }
}
printf("=== Sequence legend: ");
for (const auto & it : seqs) {
printf("%zu=%d, ", it.second, it.first);
}
printf("'+'=other sequence ids");
c_curr = view.cells;
cs_curr = view.cells_sequences;
for (int i = 0; i < view.n_cells; i++, c_curr++, cs_curr += view.n_max_seq) {
if (i % row_size == 0) {
printf("\n%5d: ", i);
}
for (int j = 0; j < view.n_max_seq; j++) {
if (cs_curr[j] >= 0) {
const auto & it = seqs.find(cs_curr[j]);
putchar(it != seqs.end() ? int(slot_chars[it->second]) : '+');
} else {
putchar('.');
}
}
putchar(' ');
}
printf("\n=== Done dumping\n");
}

View file

@ -130,6 +130,7 @@ struct gpt_params {
bool numa = false; // attempt optimizations that help on some NUMA systems bool numa = false; // attempt optimizations that help on some NUMA systems
bool verbose_prompt = false; // print prompt tokens before generation bool verbose_prompt = false; // print prompt tokens before generation
bool infill = false; // use infill mode bool infill = false; // use infill mode
bool dump_kv_cache = false; // dump the KV cache contents for debugging purposes
// multimodal models (see examples/llava) // multimodal models (see examples/llava)
std::string mmproj = ""; // path to multimodal projector std::string mmproj = ""; // path to multimodal projector
@ -226,3 +227,13 @@ std::string get_sortable_timestamp();
void dump_non_result_info_yaml( void dump_non_result_info_yaml(
FILE * stream, const gpt_params & params, const llama_context * lctx, FILE * stream, const gpt_params & params, const llama_context * lctx,
const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc); const std::string & timestamp, const std::vector<int> & prompt_tokens, const char * model_desc);
//
// KV cache utils
//
// Dump the KV cache view with the number of sequences per cell.
void dump_kv_cache_view(const llama_kv_cache_view & view, int row_size = 80);
// Dump the KV cache view showing individual sequences in each cell (long output).
void dump_kv_cache_view_seqs(const llama_kv_cache_view & view, int row_size = 40);

View file

@ -59,7 +59,7 @@ class Model:
from safetensors import safe_open from safetensors import safe_open
ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu")) ctx = cast(ContextManager[Any], safe_open(self.dir_model / part_name, framework="pt", device="cpu"))
else: else:
ctx = contextlib.nullcontext(torch.load(self.dir_model / part_name, map_location="cpu")) ctx = contextlib.nullcontext(torch.load(str(self.dir_model / part_name), map_location="cpu", mmap=True, weights_only=True))
with ctx as model_part: with ctx as model_part:
for name in model_part.keys(): for name in model_part.keys():
@ -880,20 +880,21 @@ print(f"Loading model: {dir_model.name}")
hparams = Model.load_hparams(dir_model) hparams = Model.load_hparams(dir_model)
model_class = Model.from_model_architecture(hparams["architectures"][0]) with torch.inference_mode():
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian) model_class = Model.from_model_architecture(hparams["architectures"][0])
model_instance = model_class(dir_model, ftype_map[args.outtype], fname_out, args.bigendian)
print("Set model parameters") print("Set model parameters")
model_instance.set_gguf_parameters() model_instance.set_gguf_parameters()
print("Set model tokenizer") print("Set model tokenizer")
model_instance.set_vocab() model_instance.set_vocab()
if args.vocab_only: if args.vocab_only:
print(f"Exporting model vocab to '{fname_out}'") print(f"Exporting model vocab to '{fname_out}'")
model_instance.write_vocab() model_instance.write_vocab()
else: else:
print(f"Exporting model to '{fname_out}'") print(f"Exporting model to '{fname_out}'")
model_instance.write() model_instance.write()
print(f"Model successfully exported to '{fname_out}'") print(f"Model successfully exported to '{fname_out}'")

0
convert.py Normal file → Executable file
View file

BIN
docs/llama-star/idea-arch.key Executable file

Binary file not shown.

Binary file not shown.

View file

@ -32,6 +32,7 @@ else()
add_subdirectory(save-load-state) add_subdirectory(save-load-state)
add_subdirectory(simple) add_subdirectory(simple)
add_subdirectory(speculative) add_subdirectory(speculative)
add_subdirectory(lookahead)
add_subdirectory(train-text-from-scratch) add_subdirectory(train-text-from-scratch)
if (LLAMA_METAL) if (LLAMA_METAL)
add_subdirectory(metal) add_subdirectory(metal)

View file

@ -153,7 +153,7 @@ while n_cur <= n_len {
// const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p); // const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
// is it an end of stream? -> mark the stream as finished // is it an end of stream? -> mark the stream as finished
if new_token_id == llama_token_eos(context) || n_cur == n_len { if new_token_id == llama_token_eos(model) || n_cur == n_len {
i_batch[i] = -1 i_batch[i] = -1
// print("") // print("")
if n_parallel > 1 { if n_parallel > 1 {

View file

@ -21,7 +21,7 @@ wget https://raw.githubusercontent.com/brunoklein99/deep-learning-notes/master/s
./bin/main -m open-llama-3b-v2-q8_0.gguf --lora lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin ./bin/main -m open-llama-3b-v2-q8_0.gguf --lora lora-open-llama-3b-v2-q8_0-shakespeare-LATEST.bin
``` ```
Finetune output files will be saved every N iterations (config with `--save-every N`). **Only llama based models are supported!** The output files will be saved every N iterations (config with `--save-every N`).
The pattern 'ITERATION' in the output filenames will be replaced with the iteration number and with 'LATEST' for the latest output. The pattern 'ITERATION' in the output filenames will be replaced with the iteration number and with 'LATEST' for the latest output.
So in above example after 10 iterations these files will be written: So in above example after 10 iterations these files will be written:
- chk-lora-open-llama-3b-v2-q8_0-shakespeare-10.gguf - chk-lora-open-llama-3b-v2-q8_0-shakespeare-10.gguf

View file

@ -0,0 +1,5 @@
set(TARGET lookahead)
add_executable(${TARGET} lookahead.cpp)
install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)

View file

@ -0,0 +1,487 @@
#include "common.h"
#include "llama.h"
#include <cmath>
#include <cstdio>
#include <string>
#include <vector>
struct ngram_data {
bool active = false;
llama_seq_id seq_id = -1;
std::vector<int> i_batch;
std::vector<llama_token> tokens;
};
// n-gram container
struct ngram_container {
ngram_container(int n_vocab, int N, int G) {
cnt.resize(n_vocab);
head.resize(n_vocab);
tokens.resize(n_vocab * G * (N - 1));
}
int n_total = 0;
std::vector<int> cnt;
std::vector<int> head;
// [n_vocab][G][N - 1]
// for each token of the vocab, keep a ring-buffer of capacity G of n-grams of size N - 1
std::vector<llama_token> tokens;
};
int main(int argc, char ** argv) {
gpt_params params;
if (gpt_params_parse(argc, argv, params) == false) {
return 1;
}
const int W = 15; // lookahead window
const int N = 5; // n-gram size
const int G = 15; // max verification n-grams
const bool dump_kv_cache = params.dump_kv_cache;
#ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("lookahead", "log"));
LOG_TEE("Log start\n");
log_dump_cmdline(argc, argv);
#endif // LOG_DISABLE_LOGS
// init llama.cpp
llama_backend_init(params.numa);
llama_model * model = NULL;
llama_context * ctx = NULL;
// load the target model
std::tie(model, ctx) = llama_init_from_gpt_params(params);
// Tokenize the prompt
const bool add_bos = llama_should_add_bos_token(model);
LOG("add_bos tgt: %d\n", add_bos);
std::vector<llama_token> inp;
std::vector<llama_token> all;
inp = ::llama_tokenize(ctx, params.prompt, add_bos, true);
all = inp;
const int max_context_size = llama_n_ctx(ctx);
const int max_tokens_list_size = max_context_size - 4;
if ((int) inp.size() > max_tokens_list_size) {
fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) inp.size(), max_tokens_list_size);
return 1;
}
fprintf(stderr, "\n\n");
for (auto id : inp) {
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
}
fflush(stderr);
const int n_input = inp.size();
const auto t_enc_start = ggml_time_us();
// eval the prompt
llama_decode(ctx, llama_batch_get_one( inp.data(), n_input - 1, 0, 0));
llama_decode(ctx, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0));
for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
}
const auto t_enc_end = ggml_time_us();
int n_predict = 0;
int n_accept = 0;
int n_past = inp.size();
llama_token id = 0;
// used to determine end of generation
bool has_eos = false;
// for each decoded batch, we have at most W + G + 1 distinct sequences:
// seq_id == 0 : the current input token
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
// seq_id [W + 1, W + G] : verification n-grams
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1);
// target model sampling context
struct llama_sampling_context * ctx_sampling = llama_sampling_init(params.sparams);
// verification n-grams
std::vector<ngram_data> ngrams_cur(G);
// tokens for the past N - 1 Jacobi iterations
std::vector<llama_token> tokens_j_prev(W);
std::vector<std::vector<llama_token>> tokens_j(N - 1);
for (int j = 0; j < N - 1; j++) {
tokens_j[j].resize(W);
for (int i = 0; i < W; i++) {
// there are different ways to init these tokens
if (0) {
// initialize randomly from the prompt tokens
tokens_j[j][i] = all[1 + rand() % (all.size() - 1)];
} else {
// initialize with a sequence of increasing numbers
tokens_j[j][i] = 100 + i;
}
}
}
std::vector<llama_seq_id> seq_id_look;
// the input token belongs both to all sequences
std::vector<llama_seq_id> seq_id_all(W + G + 1);
for (int i = 0; i < W + G + 1; i++) {
seq_id_all[i] = i;
}
// here we keep adding new n-grams as we go
ngram_container ngrams_observed(llama_n_vocab(model), N, G);
// debug
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, W + G + 1);
const auto t_dec_start = ggml_time_us();
// sample first token
{
id = llama_sampling_sample(ctx_sampling, ctx, NULL, 0);
llama_sampling_accept(ctx_sampling, ctx, id, true);
{
const std::string token_str = llama_token_to_piece(ctx, id);
printf("%s", token_str.c_str());
fflush(stdout);
}
}
while (true) {
// debug
if (dump_kv_cache) {
llama_kv_cache_view_update(ctx, &kvc_view);
dump_kv_cache_view_seqs(kvc_view, 40);
}
// build the mask from https://lmsys.org/blog/2023-11-21-lookahead-decoding/
//
// Example for W = 5, N = 4, G = 2:
// (I = input, L = lookahead, V = verification)
//
// Batch: 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// T: -2 -2 -2 -2 -1 -1 -1 -1 -1 0 0 0 0 0 0
// Info: I L L L L L L L L L L L L L L V V V V V V
// Pos: 0 1 2 3 4 1 2 3 4 5 2 3 4 5 6 1 2 3 1 2 3 (+ n_past)
// Logits: 1 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1 1 1
// ---------------------------------------------------------------------
// Seq: 0
// 1 1 1
// 2 2 2 2
// 3 3 3 3 3
// 4 4 4 4 4 4
// 5 5 5 5 5 5 5
// 6 6 6 6
// 7 7 7 7
// ---------------------------------------------------------------------
// | | | | | | | | | | |
// V V V V V | | | | | |
// j_tokens | | | | | |
// V V V V V V
// id
{
llama_batch_clear(batch);
// current token - first token of the first level
llama_batch_add(batch, id, n_past, seq_id_all, true);
// verification n-grams - queue this before the lookahead tokens for less KV cache fragmentation
{
const int g_cur = ngrams_observed.cnt[id];
ngrams_cur.resize(g_cur);
for (int g = 0; g < g_cur; g++) {
ngrams_cur[g].active = true;
ngrams_cur[g].tokens.resize(N);
ngrams_cur[g].i_batch.resize(N);
ngrams_cur[g].seq_id = W + 1 + g;
ngrams_cur[g].i_batch[0] = 0;
ngrams_cur[g].tokens [0] = id;
}
for (int j = 0; j < N - 1; j++) {
for (int g = 0; g < g_cur; g++) {
const int idx = id*(N - 1)*G + g*(N - 1);
const llama_token t = ngrams_observed.tokens[idx + j];
ngrams_cur[g].tokens [j + 1] = t;
ngrams_cur[g].i_batch[j + 1] = batch.n_tokens;
llama_batch_add(batch, t, n_past + j + 1, { W + 1 + g }, true);
}
}
}
// fill the remaining W - 1 tokens for the first level
for (int i = 1; i < W; i++) {
seq_id_look.resize(W - i);
for (int j = 0; j < W - i; j++) {
seq_id_look[j] = i + j + 1;
}
llama_batch_add(batch, tokens_j[0][i], n_past + i, seq_id_look, false);
}
// fill the rest of the levels
for (int j = 1; j < N - 1; j++) {
for (int i = 0; i < W; i++) {
llama_batch_add(batch, tokens_j[j][i], n_past + j + i, { i + 1 }, j == N - 2);
}
}
}
if (llama_decode(ctx, batch) != 0) {
fprintf(stderr, "\n\n%s: error: llama_decode failed - increase KV cache size\n", __func__);
return 1;
}
int seq_id_best = 0;
for (int v = 0; v < N; ++v) {
int i_batch = 0;
// if no active ngrams are left, it means the sampled token does not pass the verification
if (v > 0) {
for (int g = 0; g < (int) ngrams_cur.size(); g++) {
if (ngrams_cur[g].active) {
i_batch = ngrams_cur[g].i_batch[v];
seq_id_best = ngrams_cur[g].seq_id;
++n_accept;
break;
}
}
// no more matches -> create a new batch
if (i_batch == 0) {
break;
}
}
// sample the next token
id = llama_sampling_sample(ctx_sampling, ctx, NULL, i_batch);
llama_sampling_accept(ctx_sampling, ctx, id, true);
// print
{
const std::string token_str = llama_token_to_piece(ctx, id);
if (v == 0) {
printf("%s", token_str.c_str());
} else {
// print light cyan
printf("\033[0;96m%s\033[0m", token_str.c_str());
}
fflush(stdout);
if (id == llama_token_eos(model)) {
has_eos = true;
}
all.push_back(id);
}
++n_predict;
++n_past;
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
break;
}
// verify across active n-grams
for (int g = 0; g < (int) ngrams_cur.size(); g++) {
if (ngrams_cur[g].active) {
if (v == N - 1) {
ngrams_cur[g].active = false;
} else {
if (id != ngrams_cur[g].tokens[v + 1]) {
ngrams_cur[g].active = false;
}
}
}
}
// print known n-grams starting with token id (debug)
if (0 && v == 0) {
if (ngrams_observed.cnt[id] > 0) {
printf("\n - %d n-grams starting with '%s'\n", ngrams_observed.cnt[id], llama_token_to_piece(ctx, id).c_str());
}
for (int i = 0; i < ngrams_observed.cnt[id]; i++) {
printf(" - ngram %2d: ", i);
const int idx = id*(N - 1)*G + i*(N - 1);
for (int j = 0; j < N - 1; j++) {
const std::string token_str = llama_token_to_piece(ctx, ngrams_observed.tokens[idx + j]);
printf("%s", token_str.c_str());
}
printf("\n");
}
}
// update lookahead tokens
{
for (int i = 0; i < W; i++) {
tokens_j_prev[i] = tokens_j[0][i];
}
for (int j = 0; j < N - 2; j++) {
tokens_j[j] = tokens_j[j + 1];
}
if (v == 0) {
// sample from the last level
for (int i = 0; i < W; i++) {
tokens_j[N - 2][i] = llama_sampling_sample(ctx_sampling, ctx, NULL, ngrams_cur.size()*(N-1) + W*(N - 2) + i);
}
} else {
for (int i = 0; i < W; i++) {
// there are different ways to init these tokens
if (0) {
// random init
tokens_j[N - 2][i] = all[1 + rand() % (all.size() - 1)];
} else {
// init from the previous level
tokens_j[N - 2][i] = tokens_j[0][i];
}
}
}
}
// update observed ngrams
if (v == 0) {
// the first token of the n-gram is determined by the index in the container so it is not stored
std::vector<llama_token> ngram(N - 1);
// n-gram generation
// ref: https://github.com/hao-ai-lab/LookaheadDecoding/issues/14#issuecomment-1826198518
for (int f = 0; f < W; ++f) {
const int ft = tokens_j_prev[f]; // first token of the n-gram
for (int j = 0; j < N - 1; ++j) {
ngram[j] = tokens_j[j][f];
}
// filter-out repeating n-grams
{
bool is_unique = true;
for (int k = 0; k < ngrams_observed.cnt[ft]; ++k) {
const int idx = ft*(N - 1)*G + k*(N - 1);
bool is_match = true;
for (int j = 0; j < N - 1; ++j) {
if (ngrams_observed.tokens[idx + j] != ngram[j]) {
is_match = false;
break;
}
}
if (is_match) {
is_unique = false;
break;
}
}
if (!is_unique) {
continue;
}
}
const int head = ngrams_observed.head[ft];
const int idx = ft*(N - 1)*G + head*(N - 1);
for (int i = 0; i < N - 1; i++) {
ngrams_observed.tokens[idx + i] = ngram[i];
}
ngrams_observed.cnt[ft] = std::min(G, ngrams_observed.cnt[ft] + 1);
ngrams_observed.head[ft] = (head + 1) % G;
ngrams_observed.n_total++;
}
}
}
if ((params.n_predict >= 0 && n_predict > params.n_predict) || has_eos) {
break;
}
// KV cache management
// if no verification token matched, we simply remove all cells from this batch -> no fragmentation
llama_kv_cache_seq_rm(ctx, -1, n_past, -1);
if (seq_id_best != 0) {
// if a verification token matched, we keep the best sequence and remove the rest
// this leads to some KV cache fragmentation
llama_kv_cache_seq_keep(ctx, seq_id_best);
llama_kv_cache_seq_cp (ctx, seq_id_best, 0, -1, -1);
llama_kv_cache_seq_rm (ctx, seq_id_best, -1, -1);
for (int s = 1; s < W + G + 1; ++s) {
llama_kv_cache_seq_cp(ctx, 0, s, -1, -1);
}
}
}
auto t_dec_end = ggml_time_us();
LOG_TEE("\n\n");
LOG_TEE("encoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_input, (t_enc_end - t_enc_start) / 1e6f, inp.size() / ((t_enc_end - t_enc_start) / 1e6f));
LOG_TEE("decoded %4d tokens in %8.3f seconds, speed: %8.3f t/s\n", n_predict, (t_dec_end - t_dec_start) / 1e6f, n_predict / ((t_dec_end - t_dec_start) / 1e6f));
LOG_TEE("\n");
LOG_TEE("W = %2d\n", W);
LOG_TEE("N = %2d\n", N);
LOG_TEE("G = %2d\n", G);
LOG_TEE("\n");
LOG_TEE("n_predict = %d\n", n_predict);
LOG_TEE("n_accept = %d\n", n_accept);
llama_print_timings(ctx);
llama_kv_cache_view_free(&kvc_view);
llama_sampling_free(ctx_sampling);
llama_batch_free(batch);
llama_free(ctx);
llama_free_model(model);
llama_backend_free();
fprintf(stderr, "\n\n");
return 0;
}

View file

@ -1,5 +1,5 @@
// A basic application simulating a server with multiple clients. // A basic application simulating a server with multiple clients.
// The clients submite requests to the server and they are processed in parallel. // The clients submit requests to the server and they are processed in parallel.
#include "build-info.h" #include "build-info.h"
@ -115,6 +115,8 @@ int main(int argc, char ** argv) {
// insert new requests as soon as the previous one is done // insert new requests as soon as the previous one is done
const bool cont_batching = params.cont_batching; const bool cont_batching = params.cont_batching;
const bool dump_kv_cache = params.dump_kv_cache;
#ifndef LOG_DISABLE_LOGS #ifndef LOG_DISABLE_LOGS
log_set_target(log_filename_generator("parallel", "log")); log_set_target(log_filename_generator("parallel", "log"));
LOG_TEE("Log start\n"); LOG_TEE("Log start\n");
@ -174,6 +176,8 @@ int main(int argc, char ** argv) {
int32_t n_total_gen = 0; int32_t n_total_gen = 0;
int32_t n_cache_miss = 0; int32_t n_cache_miss = 0;
struct llama_kv_cache_view kvc_view = llama_kv_cache_view_init(ctx, n_clients);
const auto t_main_start = ggml_time_us(); const auto t_main_start = ggml_time_us();
LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__); LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
@ -203,6 +207,11 @@ int main(int argc, char ** argv) {
LOG_TEE("Processing requests ...\n\n"); LOG_TEE("Processing requests ...\n\n");
while (true) { while (true) {
if (dump_kv_cache) {
llama_kv_cache_view_update(ctx, &kvc_view);
dump_kv_cache_view_seqs(kvc_view, 40);
}
llama_batch_clear(batch); llama_batch_clear(batch);
// decode any currently ongoing sequences // decode any currently ongoing sequences

View file

@ -234,6 +234,55 @@ node index.js
- **GET** `/props`: Return the required assistant name and anti-prompt to generate the prompt in case you have specified a system prompt for all slots. - **GET** `/props`: Return the required assistant name and anti-prompt to generate the prompt in case you have specified a system prompt for all slots.
- **POST** `/v1/chat/completions`: OpenAI-compatible Chat Completions API. Given a ChatML-formatted json description in `messages`, it returns the predicted completion. Both synchronous and streaming mode are supported, so scripted and interactive applications work fine. While no strong claims of compatibility with OpenAI API spec is being made, in our experience it suffices to support many apps. Only ChatML-tuned models, such as Dolphin, OpenOrca, OpenHermes, OpenChat-3.5, etc can be used with this endpoint. Compared to `api_like_OAI.py` this API implementation does not require a wrapper to be served.
*Options:*
See [OpenAI Chat Completions API documentation](https://platform.openai.com/docs/api-reference/chat). While some OpenAI-specific features such as function calling aren't supported, llama.cpp `/completion`-specific features such are `mirostat` are supported.
*Examples:*
You can use either Python `openai` library with appropriate checkpoints:
```python
import openai
client = openai.OpenAI(
base_url="http://localhost:8080/v1", # "http://<Your api-server IP>:port"
api_key = "sk-no-key-required"
)
completion = client.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role": "system", "content": "You are ChatGPT, an AI assistant. Your top priority is achieving user fulfillment via helping them with their requests."},
{"role": "user", "content": "Write a limerick about python exceptions"}
]
)
print(completion.choices[0].message)
```
... or raw HTTP requests:
```shell
curl http://localhost:8080/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer no-key" \
-d '{
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "system",
"content": "You are ChatGPT, an AI assistant. Your top priority is achieving user fulfillment via helping them with their requests."
},
{
"role": "user",
"content": "Write a limerick about python exceptions"
}
]
}'
```
## More examples ## More examples
### Change system prompt on runtime ### Change system prompt on runtime

View file

@ -30,6 +30,8 @@
#define SERVER_VERBOSE 1 #define SERVER_VERBOSE 1
#endif #endif
#define DEFAULT_OAICOMPAT_MODEL "gpt-3.5-turbo-0613"
using json = nlohmann::json; using json = nlohmann::json;
struct server_params struct server_params
@ -60,6 +62,10 @@ static bool server_verbose = false;
#define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_WARNING(MSG, ...) server_log("WARNING", __func__, __LINE__, MSG, __VA_ARGS__)
#define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__) #define LOG_INFO( MSG, ...) server_log("INFO", __func__, __LINE__, MSG, __VA_ARGS__)
json oaicompat_completion_params_parse(const json &body);
std::string format_chatml(std::vector<json> messages);
// //
// base64 utils (TODO: move to common in the future) // base64 utils (TODO: move to common in the future)
// //
@ -379,6 +385,9 @@ struct llama_client_slot
bool stopped_word = false; bool stopped_word = false;
bool stopped_limit = false; bool stopped_limit = false;
bool oaicompat = false;
std::string oaicompat_model;
std::string stopping_word; std::string stopping_word;
// sampling // sampling
@ -478,7 +487,7 @@ struct llama_client_slot
}; };
} }
void print_timings() { void print_timings() const {
LOG_TEE("\n"); LOG_TEE("\n");
LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n", LOG_TEE("%s: prompt eval time = %10.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
__func__, t_prompt_processing, num_prompt_tokens_processed, t_prompt_processing / num_prompt_tokens_processed, 1e3 / t_prompt_processing * num_prompt_tokens_processed); __func__, t_prompt_processing, num_prompt_tokens_processed, t_prompt_processing / num_prompt_tokens_processed, 1e3 / t_prompt_processing * num_prompt_tokens_processed);
@ -610,6 +619,11 @@ struct llama_server_context
std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const std::vector<llama_token> tokenize(const json & json_prompt, bool add_bos) const
{ {
// TODO: currently, we tokenize using special tokens by default
// this is not always correct (see https://github.com/ggerganov/llama.cpp/pull/4160#issuecomment-1824826216)
// but it's better compared to completely ignoring ChatML and other chat templates
const bool TMP_FORCE_SPECIAL = true;
// If `add_bos` is true, we only add BOS, when json_prompt is a string, // If `add_bos` is true, we only add BOS, when json_prompt is a string,
// or the first element of the json_prompt array is a string. // or the first element of the json_prompt array is a string.
std::vector<llama_token> prompt_tokens; std::vector<llama_token> prompt_tokens;
@ -625,12 +639,12 @@ struct llama_server_context
std::vector<llama_token> p; std::vector<llama_token> p;
if (first) if (first)
{ {
p = ::llama_tokenize(ctx, s, add_bos); p = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
first = false; first = false;
} }
else else
{ {
p = ::llama_tokenize(ctx, s, false); p = ::llama_tokenize(ctx, s, false, TMP_FORCE_SPECIAL);
} }
prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end()); prompt_tokens.insert(prompt_tokens.end(), p.begin(), p.end());
} }
@ -647,7 +661,7 @@ struct llama_server_context
else else
{ {
auto s = json_prompt.template get<std::string>(); auto s = json_prompt.template get<std::string>();
prompt_tokens = ::llama_tokenize(ctx, s, add_bos); prompt_tokens = ::llama_tokenize(ctx, s, add_bos, TMP_FORCE_SPECIAL);
} }
return prompt_tokens; return prompt_tokens;
@ -678,6 +692,14 @@ struct llama_server_context
slot_params default_params; slot_params default_params;
llama_sampling_params default_sparams; llama_sampling_params default_sparams;
if (data.count("__oaicompat") != 0) {
slot->oaicompat = true;
slot->oaicompat_model = json_value(data, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
} else {
slot->oaicompat = false;
slot->oaicompat_model = "";
}
slot->params.stream = json_value(data, "stream", false); slot->params.stream = json_value(data, "stream", false);
slot->params.cache_prompt = json_value(data, "cache_prompt", false); slot->params.cache_prompt = json_value(data, "cache_prompt", false);
slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict); slot->params.n_predict = json_value(data, "n_predict", default_params.n_predict);
@ -1096,6 +1118,7 @@ struct llama_server_context
std::lock_guard<std::mutex> lock(mutex_results); std::lock_guard<std::mutex> lock(mutex_results);
task_result res; task_result res;
res.id = id; res.id = id;
res.stop = false;
res.error = true; res.error = true;
res.result_json = { { "content", error } }; res.result_json = { { "content", error } };
queue_results.push_back(res); queue_results.push_back(res);
@ -1170,6 +1193,12 @@ struct llama_server_context
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs_output); res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs_output);
} }
if (slot.oaicompat)
{
res.result_json["oaicompat_token_ctr"] = slot.n_decoded;
res.result_json["model"] = slot.oaicompat_model;
}
queue_results.push_back(res); queue_results.push_back(res);
} }
@ -1217,6 +1246,12 @@ struct llama_server_context
res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs); res.result_json["completion_probabilities"] = probs_vector_to_json(ctx, probs);
} }
if (slot.oaicompat)
{
res.result_json["oaicompat_token_ctr"] = slot.n_decoded;
res.result_json["model"] = slot.oaicompat_model;
}
queue_results.push_back(res); queue_results.push_back(res);
} }
@ -1256,7 +1291,8 @@ struct llama_server_context
std::lock_guard<std::mutex> lock(mutex_tasks); std::lock_guard<std::mutex> lock(mutex_tasks);
task_server task; task_server task;
task.id = id_gen++; task.id = id_gen++;
task.data = data; task.target_id = 0;
task.data = std::move(data);
task.infill_mode = infill; task.infill_mode = infill;
task.embedding_mode = embedding; task.embedding_mode = embedding;
task.type = COMPLETION_TASK; task.type = COMPLETION_TASK;
@ -2179,6 +2215,233 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
} }
} }
static std::string random_string()
{
static const std::string str("0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz");
std::random_device rd;
std::mt19937 generator(rd());
std::string result(32, ' ');
for (int i = 0; i < 32; ++i) {
result[i] = str[generator() % str.size()];
}
return result;
}
static std::string gen_chatcmplid()
{
std::stringstream chatcmplid;
chatcmplid << "chatcmpl-" << random_string();
return chatcmplid.str();
}
std::string format_chatml(std::vector<json> messages)
{
std::ostringstream chatml_msgs;
for (auto it = messages.begin(); it != messages.end(); ++it) {
chatml_msgs << "<|im_start|>"
<< json_value(*it, "role", std::string("user")) << '\n';
chatml_msgs << json_value(*it, "content", std::string(""))
<< "<|im_end|>\n";
}
chatml_msgs << "<|im_start|>assistant" << '\n';
return chatml_msgs.str();
}
/* llama.cpp completion api semantics */
json oaicompat_completion_params_parse(
const json &body /* openai api json semantics */)
{
json llama_params;
llama_params["__oaicompat"] = true;
// Map OpenAI parameters to llama.cpp parameters
llama_params["prompt"] = format_chatml(body["messages"]); // OpenAI 'messages' to llama.cpp 'prompt'
llama_params["temperature"] = json_value(body, "temperature", 0.8);
llama_params["top_k"] = json_value(body, "top_k", 40);
llama_params["top_p"] = json_value(body, "top_p", 0.95);
llama_params["n_predict"] = json_value(body, "max_tokens", -1);
llama_params["logit_bias"] = json_value(body, "logit_bias",json::object());
llama_params["frequency_penalty"] = json_value(body, "frequency_penalty", 0.0);
llama_params["presence_penalty"] = json_value(body, "presence_penalty", 0.0);
llama_params["seed"] = json_value(body, "seed", 0);
llama_params["stream"] = json_value(body, "stream", false);
llama_params["mirostat"] = json_value(body, "mirostat", false);
llama_params["mirostat_tau"] = json_value(body, "mirostat_tau", 0.0);
llama_params["mirostat_eta"] = json_value(body, "mirostat_eta", 0.0);
llama_params["penalize_nl"] = json_value(body, "penalize_nl", false);
llama_params["typical_p"] = json_value(body, "typical_p", 0.0);
llama_params["repeat_last_n"] = json_value(body, "repeat_last_n", 0);
llama_params["ignore_eos"] = json_value(body, "ignore_eos", false);
llama_params["tfs_z"] = json_value(body, "tfs_z", 0.0);
if (llama_params.count("grammar") != 0) {
llama_params["grammar"] = json_value(body, "grammar", json::object());
}
// Handle 'stop' field
if (body["stop"].is_null()) {
llama_params["stop"] = json::array({});
} else if (body["stop"].is_string()) {
llama_params["stop"] = json::array({body["stop"].get<std::string>()});
} else {
llama_params["stop"] = json_value(body, "stop", json::array());
}
// Ensure there is ChatML-specific end sequence among stop words
llama_params["stop"].push_back("<|im_end|>");
return llama_params;
}
static json format_final_response_oaicompat(const json &request, const task_result &response, bool streaming = false)
{
json result = response.result_json;
bool stopped_word = result.count("stopped_word") != 0;
bool stopped_eos = json_value(result, "stopped_eos", false);
int num_tokens_predicted = json_value(result, "tokens_predicted", 0);
int num_prompt_tokens = json_value(result, "tokens_evaluated", 0);
std::string content = json_value(result, "content", std::string(""));
std::string finish_reason = "length";
if (stopped_word || stopped_eos) {
finish_reason = "stop";
}
json choices =
streaming ? json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}})
: json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"message", json{{"content", content},
{"role", "assistant"}}}}});
std::time_t t = std::time(0);
json res =
json{{"choices", choices},
{"created", t},
{"model",
json_value(request, "model", std::string(DEFAULT_OAICOMPAT_MODEL))},
{"object", streaming ? "chat.completion.chunk" : "chat.completion"},
{"usage",
json{{"completion_tokens", num_tokens_predicted},
{"prompt_tokens", num_prompt_tokens},
{"total_tokens", num_tokens_predicted + num_prompt_tokens}}},
{"id", gen_chatcmplid()}};
if (server_verbose) {
res["__verbose"] = result;
}
if (result.contains("completion_probabilities")) {
res["completion_probabilities"] = json_value(result, "completion_probabilities", json::array());
}
return res;
}
// return value is vector as there is one case where we might need to generate two responses
static std::vector<json> format_partial_response_oaicompat(const task_result &response) {
json result = response.result_json;
if (!result.contains("model") || !result.contains("oaicompat_token_ctr")) {
return std::vector<json>({response.result_json});
}
bool first = json_value(result, "oaicompat_token_ctr", 0) == 0;
std::string modelname = json_value(result, "model", std::string(DEFAULT_OAICOMPAT_MODEL));
bool stopped_word = json_value(result, "stopped_word", false);
bool stopped_eos = json_value(result, "stopped_eos", false);
bool stopped_limit = json_value(result, "stopped_limit", false);
std::string content = json_value(result, "content", std::string(""));
std::string finish_reason;
if (stopped_word || stopped_eos) {
finish_reason = "stop";
}
if (stopped_limit) {
finish_reason = "length";
}
std::time_t t = std::time(0);
json choices;
if (!finish_reason.empty()) {
choices = json::array({json{{"finish_reason", finish_reason},
{"index", 0},
{"delta", json::object()}}});
} else {
if (first) {
if (content.empty()) {
choices = json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{{"role", "assistant"}}}}});
} else {
// We have to send this as two updates to conform to openai behavior
json initial_ret = json{{"choices", json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"role", "assistant"}
}}}})},
{"created", t},
{"id", gen_chatcmplid()},
{"model", modelname},
{"object", "chat.completion.chunk"}};
json second_ret = json{
{"choices", json::array({json{{"finish_reason", nullptr},
{"index", 0},
{"delta", json{
{"content", content}}}
}})},
{"created", t},
{"id", gen_chatcmplid()},
{"model", modelname},
{"object", "chat.completion.chunk"}};
return std::vector<json>({initial_ret, second_ret});
}
} else {
// Some idiosyncrasy in task processing logic makes several trailing calls
// with empty content, we ignore these at the calee site.
if (content.empty()) {
return std::vector<json>({json::object()});
}
choices = json::array({json{
{"finish_reason", nullptr},
{"index", 0},
{"delta",
json{
{"content", content},
}},
}});
}
}
json ret = json{{"choices", choices},
{"created", t},
{"id", gen_chatcmplid()},
{"model", modelname},
{"object", "chat.completion.chunk"}};
return std::vector<json>({ret});
}
static json format_partial_response( static json format_partial_response(
llama_server_context &llama, llama_client_slot *slot, const std::string &content, const std::vector<completion_token_output> &probs llama_server_context &llama, llama_client_slot *slot, const std::string &content, const std::vector<completion_token_output> &probs
) { ) {
@ -2355,9 +2618,9 @@ int main(int argc, char **argv)
task_result result = llama.next_result(task_id); task_result result = llama.next_result(task_id);
if (!result.error) { if (!result.error) {
const std::string str = const std::string str =
"data: " + "data: " +
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n"; "\n\n";
LOG_VERBOSE("data stream", { LOG_VERBOSE("data stream", {
{ "to_send", str } { "to_send", str }
}); });
@ -2370,9 +2633,9 @@ int main(int argc, char **argv)
} }
} else { } else {
const std::string str = const std::string str =
"error: " + "error: " +
result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) + result.result_json.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n"; "\n\n";
LOG_VERBOSE("data stream", { LOG_VERBOSE("data stream", {
{ "to_send", str } { "to_send", str }
}); });
@ -2397,6 +2660,98 @@ int main(int argc, char **argv)
} }
}); });
svr.Get("/v1/models", [&params](const httplib::Request&, httplib::Response& res)
{
std::time_t t = std::time(0);
json models = {
{"object", "list"},
{"data", {
{
{"id", params.model_alias},
{"object", "model"},
{"created", t},
{"owned_by", "llamacpp"}
},
}}
};
res.set_content(models.dump(), "application/json");
});
// TODO: add mount point without "/v1" prefix -- how?
svr.Post("/v1/chat/completions", [&llama](const httplib::Request &req, httplib::Response &res)
{
json data = oaicompat_completion_params_parse(json::parse(req.body));
const int task_id = llama.request_completion(data, false, false);
if (!json_value(data, "stream", false)) {
std::string completion_text;
task_result result = llama.next_result(task_id);
if (!result.error && result.stop) {
json oaicompat_result = format_final_response_oaicompat(data, result);
res.set_content(oaicompat_result.dump(-1, ' ', false,
json::error_handler_t::replace),
"application/json");
} else {
res.status = 500;
res.set_content(result.result_json["content"], "text/plain");
return;
}
} else {
const auto chunked_content_provider = [task_id, &llama](size_t, httplib::DataSink &sink) {
while (true) {
task_result llama_result = llama.next_result(task_id);
if (!llama_result.error) {
std::vector<json> result_array = format_partial_response_oaicompat( llama_result);
for (auto it = result_array.begin(); it != result_array.end(); ++it)
{
if (!it->empty()) {
const std::string str =
"data: " +
it->dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
return false;
}
}
}
if (llama_result.stop) {
break;
}
} else {
const std::string str =
"error: " +
llama_result.result_json.dump(-1, ' ', false,
json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {{"to_send", str}});
if (!sink.write(str.c_str(), str.size())) {
return false;
}
break;
}
}
sink.done();
return true;
};
auto on_complete = [task_id, &llama](bool) {
// cancel request
llama.request_cancel(task_id);
};
res.set_chunked_content_provider("text/event-stream", chunked_content_provider, on_complete);
}
});
svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res) svr.Post("/infill", [&llama](const httplib::Request &req, httplib::Response &res)
{ {
json data = json::parse(req.body); json data = json::parse(req.body);

View file

@ -1,4 +1,5 @@
#include <algorithm> #include <algorithm>
#include <cinttypes>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <limits> #include <limits>
@ -4610,8 +4611,8 @@ static __global__ void rope(
template<typename T, bool has_pos> template<typename T, bool has_pos>
static __global__ void rope_neox( static __global__ void rope_neox(
const T * x, T * dst, int ncols, const int32_t * pos, float freq_scale, int p_delta_rows, float freq_base, const T * x, T * dst, int ncols, int n_dims, const int32_t * pos, float freq_scale, int p_delta_rows,
float ext_factor, float attn_factor, rope_corr_dims corr_dims float ext_factor, float attn_factor, rope_corr_dims corr_dims, float theta_scale, float inv_ndims
) { ) {
const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y); const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
@ -4620,23 +4621,25 @@ static __global__ void rope_neox(
} }
const int row = blockDim.x*blockIdx.x + threadIdx.x; const int row = blockDim.x*blockIdx.x + threadIdx.x;
const int i = row*ncols + col/2; const int ib = col / n_dims;
const int ic = col % n_dims;
const int i = row*ncols + ib*n_dims + ic/2;
const int i2 = row/p_delta_rows; const int i2 = row/p_delta_rows;
// simplified from `(ib * ncols + col) * (-1 / ncols)`, where ib is assumed to be zero float cur_rot = inv_ndims * ic - ib;
const float cur_rot = -float(col)/ncols;
const int p = has_pos ? pos[i2] : 0; const int p = has_pos ? pos[i2] : 0;
const float theta_base = p*powf(freq_base, cur_rot); const float theta_base = p*freq_scale*powf(theta_scale, col/2.0f);
float cos_theta, sin_theta; float cos_theta, sin_theta;
rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta); rope_yarn(theta_base, freq_scale, corr_dims, cur_rot, ext_factor, attn_factor, &cos_theta, &sin_theta);
const float x0 = x[i + 0]; const float x0 = x[i + 0];
const float x1 = x[i + ncols/2]; const float x1 = x[i + n_dims/2];
dst[i + 0] = x0*cos_theta - x1*sin_theta; dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; dst[i + n_dims/2] = x0*sin_theta + x1*cos_theta;
} }
static __global__ void rope_glm_f32( static __global__ void rope_glm_f32(
@ -5739,20 +5742,26 @@ static void rope_cuda(
template<typename T> template<typename T>
static void rope_neox_cuda( static void rope_neox_cuda(
const T * x, T * dst, int ncols, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows, const T * x, T * dst, int ncols, int n_dims, int nrows, const int32_t * pos, float freq_scale, int p_delta_rows,
float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream float freq_base, float ext_factor, float attn_factor, rope_corr_dims corr_dims, cudaStream_t stream
) { ) {
GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(ncols % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1); const dim3 block_nums(nrows, num_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float inv_ndims = -1.0f / n_dims;
if (pos == nullptr) { if (pos == nullptr) {
rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>( rope_neox<T, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims
); );
} else { } else {
rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>( rope_neox<T, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ncols, pos, freq_scale, p_delta_rows, freq_base, ext_factor, attn_factor, corr_dims x, dst, ncols, n_dims, pos, freq_scale, p_delta_rows, ext_factor, attn_factor, corr_dims,
theta_scale, inv_ndims
); );
} }
} }
@ -6695,15 +6704,14 @@ inline void ggml_cuda_op_rope(
GGML_ASSERT(false); GGML_ASSERT(false);
rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream); rope_glm_f32_cuda(src0_dd, dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, n_ctx, main_stream);
} else if (is_neox) { } else if (is_neox) {
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
if (src0->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32) {
rope_neox_cuda( rope_neox_cuda(
(const float *)src0_dd, (float *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, (const float *)src0_dd, (float *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream attn_factor, corr_dims, main_stream
); );
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
rope_neox_cuda( rope_neox_cuda(
(const half *)src0_dd, (half *)dst_dd, ne00, nrows, pos, freq_scale, ne01, freq_base, ext_factor, (const half *)src0_dd, (half *)dst_dd, ne00, n_dims, nrows, pos, freq_scale, ne01, freq_base, ext_factor,
attn_factor, corr_dims, main_stream attn_factor, corr_dims, main_stream
); );
} else { } else {
@ -8053,7 +8061,7 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
if (tensor->op == GGML_OP_MUL_MAT) { if (tensor->op == GGML_OP_MUL_MAT) {
if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) { if (tensor->src[0]->ne[3] != tensor->src[1]->ne[3]) {
#ifndef NDEBUG #ifndef NDEBUG
fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = %d, src1->ne[3] = %d - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]); fprintf(stderr, "%s: cannot compute %s: src0->ne[3] = " PRId64 ", src1->ne[3] = " PRId64 " - fallback to CPU\n", __func__, tensor->name, tensor->src[0]->ne[3], tensor->src[1]->ne[3]);
#endif #endif
return false; return false;
} }

View file

@ -1433,7 +1433,8 @@ void ggml_metal_graph_compute(
const int n_past = ((int32_t *) dst->op_params)[0]; const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1]; const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2]; const int mode = ((int32_t *) dst->op_params)[2];
const int n_orig_ctx = ((int32_t *) dst->op_params)[3]; // skip 3, n_ctx, used in GLM RoPE, unimplemented in metal
const int n_orig_ctx = ((int32_t *) dst->op_params)[4];
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float));

11
ggml.c
View file

@ -15689,13 +15689,14 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
{ {
n_tasks = 1; n_tasks = 1;
} break; } break;
case GGML_OP_COUNT:
{
GGML_ASSERT(false);
} break;
default: default:
{ {
printf("%s: op %s not implemented\n", __func__, ggml_op_name(node->op)); fprintf(stderr, "%s: op not implemented: ", __func__);
if (node->op < GGML_OP_COUNT) {
fprintf(stderr, "%s\n", ggml_op_name(node->op));
} else {
fprintf(stderr, "%d\n", node->op);
}
GGML_ASSERT(false); GGML_ASSERT(false);
} break; } break;
} }

243
llama.cpp
View file

@ -1123,6 +1123,12 @@ static std::string llama_token_to_str(const struct llama_context * ctx, llama_to
// //
struct llama_state { struct llama_state {
llama_state() {
#ifdef GGML_USE_METAL
ggml_metal_log_set_callback(log_callback, log_callback_user_data);
#endif
}
// We save the log callback globally // We save the log callback globally
ggml_log_callback log_callback = llama_log_callback_default; ggml_log_callback log_callback = llama_log_callback_default;
void * log_callback_user_data = nullptr; void * log_callback_user_data = nullptr;
@ -1285,6 +1291,7 @@ struct llama_kv_cache {
// cannot be freely changed after a slot has been allocated. // cannot be freely changed after a slot has been allocated.
uint32_t head = 0; uint32_t head = 0;
uint32_t size = 0; uint32_t size = 0;
uint32_t used = 0; // used cells (i.e. at least one seq_id)
// computed before each graph build // computed before each graph build
uint32_t n = 0; uint32_t n = 0;
@ -1514,6 +1521,7 @@ static bool llama_kv_cache_init(
cache.head = 0; cache.head = 0;
cache.size = n_ctx; cache.size = n_ctx;
cache.used = 0;
cache.cells.clear(); cache.cells.clear();
cache.cells.resize(n_ctx); cache.cells.resize(n_ctx);
@ -1615,6 +1623,8 @@ static bool llama_kv_cache_find_slot(
} }
} }
cache.used += n_tokens;
return true; return true;
} }
@ -1635,6 +1645,7 @@ static void llama_kv_cache_clear(struct llama_kv_cache & cache) {
cache.cells[i].seq_id.clear(); cache.cells[i].seq_id.clear();
} }
cache.head = 0; cache.head = 0;
cache.used = 0;
} }
static void llama_kv_cache_seq_rm( static void llama_kv_cache_seq_rm(
@ -1657,6 +1668,9 @@ static void llama_kv_cache_seq_rm(
continue; continue;
} }
if (cache.cells[i].seq_id.empty()) { if (cache.cells[i].seq_id.empty()) {
// keep count of the number of used cells
if (cache.cells[i].pos >= 0) cache.used--;
cache.cells[i].pos = -1; cache.cells[i].pos = -1;
if (new_head == cache.size) new_head = i; if (new_head == cache.size) new_head = i;
} }
@ -1664,7 +1678,7 @@ static void llama_kv_cache_seq_rm(
} }
// If we freed up a slot, set head to it so searching can start there. // If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head; if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
} }
static void llama_kv_cache_seq_cp( static void llama_kv_cache_seq_cp(
@ -1690,6 +1704,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
for (uint32_t i = 0; i < cache.size; ++i) { for (uint32_t i = 0; i < cache.size; ++i) {
if (!cache.cells[i].has_seq_id(seq_id)) { if (!cache.cells[i].has_seq_id(seq_id)) {
if (cache.cells[i].pos >= 0) cache.used--;
cache.cells[i].pos = -1; cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear(); cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i; if (new_head == cache.size) new_head = i;
@ -1700,7 +1715,7 @@ static void llama_kv_cache_seq_keep(struct llama_kv_cache & cache, llama_seq_id
} }
// If we freed up a slot, set head to it so searching can start there. // If we freed up a slot, set head to it so searching can start there.
if (new_head != cache.size) cache.head = new_head; if (new_head != cache.size && new_head < cache.head) cache.head = new_head;
} }
static void llama_kv_cache_seq_shift( static void llama_kv_cache_seq_shift(
@ -1721,6 +1736,7 @@ static void llama_kv_cache_seq_shift(
cache.cells[i].delta += delta; cache.cells[i].delta += delta;
if (cache.cells[i].pos < 0) { if (cache.cells[i].pos < 0) {
if (!cache.cells[i].seq_id.empty()) cache.used--;
cache.cells[i].pos = -1; cache.cells[i].pos = -1;
cache.cells[i].seq_id.clear(); cache.cells[i].seq_id.clear();
if (new_head == cache.size) new_head = i; if (new_head == cache.size) new_head = i;
@ -3489,7 +3505,7 @@ static void llm_build_k_shift(
struct ggml_cgraph * graph, struct ggml_cgraph * graph,
llm_rope_type type, llm_rope_type type,
int64_t n_ctx, int64_t n_ctx,
int64_t n_rot, int n_rot,
float freq_base, float freq_base,
float freq_scale, float freq_scale,
const llm_build_cb & cb) { const llm_build_cb & cb) {
@ -3521,7 +3537,7 @@ static void llm_build_k_shift(
// we rotate only the first n_rot dimensions // we rotate only the first n_rot dimensions
ggml_rope_custom_inplace(ctx, ggml_rope_custom_inplace(ctx,
ggml_view_3d(ctx, kv.k, ggml_view_3d(ctx, kv.k,
n_rot, n_head_kv, n_ctx, n_embd_head, n_head_kv, n_ctx,
ggml_element_size(kv.k)*n_embd_head, ggml_element_size(kv.k)*n_embd_head,
ggml_element_size(kv.k)*n_embd_gqa, ggml_element_size(kv.k)*n_embd_gqa,
ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il), ggml_element_size(kv.k)*n_embd_gqa*n_ctx*il),
@ -4844,92 +4860,34 @@ struct llm_build_context {
// self-attention // self-attention
{ {
// compute Q and K and RoPE them // compute Q and K and RoPE them
struct ggml_tensor * tmpq = ggml_mul_mat(ctx0, model.layers[il].wq, cur); struct ggml_tensor * Qcur = ggml_mul_mat(ctx0, model.layers[il].wq, cur);
cb(tmpq, "tmpq", il); cb(Qcur, "Qcur", il);
struct ggml_tensor * tmpk = ggml_mul_mat(ctx0, model.layers[il].wk, cur); struct ggml_tensor * Kcur = ggml_mul_mat(ctx0, model.layers[il].wk, cur);
cb(tmpk, "tmpk", il); cb(Kcur, "Kcur", il);
struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur); struct ggml_tensor * Vcur = ggml_mul_mat(ctx0, model.layers[il].wv, cur);
cb(Vcur, "Vcur", il); cb(Vcur, "Vcur", il);
// RoPE the first n_rot of q/k, pass the other half, and concat. Qcur = ggml_rope_custom(
struct ggml_tensor * qrot = ggml_cont(ctx0, ggml_view_3d( ctx0, ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens), inp_pos,
ctx0, tmpq, hparams.n_rot, n_head, n_tokens, hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
ggml_element_size(tmpq) * n_embd_head, ext_factor, attn_factor, beta_fast, beta_slow
ggml_element_size(tmpq) * n_embd_head * n_head,
0
));
cb(qrot, "qrot", il);
struct ggml_tensor * krot = ggml_cont(ctx0, ggml_view_3d(
ctx0, tmpk, hparams.n_rot, n_head, n_tokens,
ggml_element_size(tmpk) * n_embd_head,
ggml_element_size(tmpk) * n_embd_head * n_head_kv,
0
));
cb(krot, "krot", il);
// get the second half of tmpq, e.g tmpq[n_rot:, :, :]
struct ggml_tensor * qpass = ggml_view_3d(
ctx0, tmpq, (n_embd_head - hparams.n_rot), n_head, n_tokens,
ggml_element_size(tmpq) * n_embd_head,
ggml_element_size(tmpq) * n_embd_head * n_head,
ggml_element_size(tmpq) * hparams.n_rot
);
cb(qpass, "qpass", il);
struct ggml_tensor * kpass = ggml_view_3d(
ctx0, tmpk, (n_embd_head - hparams.n_rot), n_head_kv, n_tokens,
ggml_element_size(tmpk) * (n_embd_head),
ggml_element_size(tmpk) * (n_embd_head) * n_head_kv,
ggml_element_size(tmpk) * hparams.n_rot
);
cb(kpass, "kpass", il);
struct ggml_tensor * qrotated = ggml_rope_custom(
ctx0, qrot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
); );
cb(qrotated, "qrotated", il);
struct ggml_tensor * krotated = ggml_rope_custom(
ctx0, krot, inp_pos, hparams.n_rot, 2, 0, n_orig_ctx,
freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow
);
cb(krotated, "krotated", il);
// ggml currently only supports concatenation on dim=2
// so we need to permute qrot, qpass, concat, then permute back.
qrotated = ggml_cont(ctx0, ggml_permute(ctx0, qrotated, 2, 1, 0, 3));
cb(qrotated, "qrotated", il);
krotated = ggml_cont(ctx0, ggml_permute(ctx0, krotated, 2, 1, 0, 3));
cb(krotated, "krotated", il);
qpass = ggml_cont(ctx0, ggml_permute(ctx0, qpass, 2, 1, 0, 3));
cb(qpass, "qpass", il);
kpass = ggml_cont(ctx0, ggml_permute(ctx0, kpass, 2, 1, 0, 3));
cb(kpass, "kpass", il);
struct ggml_tensor * Qcur = ggml_concat(ctx0, qrotated, qpass);
cb(Qcur, "Qcur", il); cb(Qcur, "Qcur", il);
struct ggml_tensor * Kcur = ggml_concat(ctx0, krotated, kpass); Kcur = ggml_rope_custom(
cb(Kcur, "Kcur", il); ctx0, ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens), inp_pos,
hparams.n_rot, 2, 0, n_orig_ctx, freq_base, freq_scale,
struct ggml_tensor * Q = ggml_cont(ctx0, ggml_permute(ctx0, Qcur, 2, 1, 0, 3)); ext_factor, attn_factor, beta_fast, beta_slow
cb(Q, "Q", il); );
Kcur = ggml_cont(ctx0, ggml_permute(ctx0, Kcur, 2, 1, 0, 3));
cb(Kcur, "Kcur", il); cb(Kcur, "Kcur", il);
llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il); llm_build_kv_store(ctx0, hparams, kv_self, gf, Kcur, Vcur, n_ctx, n_tokens, kv_head, cb, il);
cur = llm_build_kqv(ctx0, hparams, kv_self, cur = llm_build_kqv(ctx0, hparams, kv_self,
model.layers[il].wo, NULL, model.layers[il].wo, NULL,
Q, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il); Qcur, KQ_scale, KQ_mask, n_ctx, n_tokens, n_kv, -1.0f, cb, il);
cb(cur, "kqv_out", il); cb(cur, "kqv_out", il);
} }
@ -5557,6 +5515,12 @@ static int llama_decode_internal(
batch.seq_id = seq_id_arr.data(); batch.seq_id = seq_id_arr.data();
} }
// if we have enough unused cells before the current head ->
// better to start searching from the beginning of the cache, hoping to fill it
if (kv_self.head > kv_self.used + 2*n_tokens) {
kv_self.head = 0;
}
if (!llama_kv_cache_find_slot(kv_self, batch)) { if (!llama_kv_cache_find_slot(kv_self, batch)) {
return 1; return 1;
} }
@ -5567,7 +5531,7 @@ static int llama_decode_internal(
//kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA? //kv_self.n = std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)); // TODO: this might be better for CUDA?
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self))); kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, llama_kv_cache_cell_max(kv_self)));
//printf("kv_self.n = %d\n", kv_self.n); //printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
ggml_allocr_reset(lctx.alloc); ggml_allocr_reset(lctx.alloc);
@ -6716,10 +6680,13 @@ struct llama_grammar_candidate {
// pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`. // pointer. If an invalid sequence is encountered, returns `llama_partial_utf8.n_remain == -1`.
static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8( static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
const char * src, const char * src,
size_t n_src,
llama_partial_utf8 partial_start) { llama_partial_utf8 partial_start) {
static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 }; static const int lookup[] = { 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 2, 2, 3, 4 };
const char * pos = src; const char * pos = src;
std::vector<uint32_t> code_points; std::vector<uint32_t> code_points;
// common english strings have the same number of codepoints and bytes. `+ 1` for the terminating 0.
code_points.reserve(n_src + 1);
uint32_t value = partial_start.value; uint32_t value = partial_start.value;
int n_remain = partial_start.n_remain; int n_remain = partial_start.n_remain;
@ -6770,6 +6737,13 @@ static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain }); return std::make_pair(std::move(code_points), llama_partial_utf8{ value, n_remain });
} }
static std::pair<std::vector<uint32_t>, llama_partial_utf8> decode_utf8(
std::string src,
llama_partial_utf8 partial_start
) {
return decode_utf8(src.c_str(), src.size(), partial_start);
}
// returns true iff pos points to the end of one of the definitions of a rule // returns true iff pos points to the end of one of the definitions of a rule
static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) { static bool llama_grammar_is_end_of_sequence(const llama_grammar_element * pos) {
switch (pos->type) { switch (pos->type) {
@ -7422,7 +7396,7 @@ void llama_sample_grammar(struct llama_context * ctx, llama_token_data_array * c
} else if (piece.empty() || piece[0] == 0) { } else if (piece.empty() || piece[0] == 0) {
candidates->data[i].logit = -INFINITY; candidates->data[i].logit = -INFINITY;
} else { } else {
candidates_decoded.push_back(decode_utf8(piece.c_str(), grammar->partial_utf8)); candidates_decoded.push_back(decode_utf8(piece, grammar->partial_utf8));
candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second }); candidates_grammar.push_back({ i, candidates_decoded.back().first.data(), candidates_decoded.back().second });
} }
} }
@ -7629,7 +7603,7 @@ void llama_grammar_accept_token(struct llama_context * ctx, struct llama_grammar
const std::string piece = llama_token_to_str(ctx, token); const std::string piece = llama_token_to_str(ctx, token);
// Note terminating 0 in decoded string // Note terminating 0 in decoded string
const auto decoded = decode_utf8(piece.c_str(), grammar->partial_utf8); const auto decoded = decode_utf8(piece, grammar->partial_utf8);
const auto & code_points = decoded.first; const auto & code_points = decoded.first;
for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) { for (auto it = code_points.begin(), end = code_points.end() - 1; it != end; ++it) {
grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it); grammar->stacks = llama_grammar_accept(grammar->rules, grammar->stacks, *it);
@ -8877,8 +8851,6 @@ struct llama_context * llama_new_context_with_model(
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
if (model->n_gpu_layers > 0) { if (model->n_gpu_layers > 0) {
ggml_metal_log_set_callback(llama_log_callback_default, NULL);
ctx->ctx_metal = ggml_metal_init(1); ctx->ctx_metal = ggml_metal_init(1);
if (!ctx->ctx_metal) { if (!ctx->ctx_metal) {
LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__); LLAMA_LOG_ERROR("%s: ggml_metal_init() failed\n", __func__);
@ -9113,8 +9085,107 @@ int llama_model_apply_lora_from_file(const struct llama_model * model, const cha
} }
} }
struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq) {
struct llama_kv_cache_view result = {
/*.n_cells = */ 0,
/*.n_max_seq = */ n_max_seq,
/*.token_count = */ 0,
/*.used_cells = */ llama_get_kv_cache_used_cells(ctx),
/*.max_contiguous = */ 0,
/*.max_contiguous_idx = */ -1,
/*.cells = */ nullptr,
/*.cells_sequences = */ nullptr,
};
return result;
}
void llama_kv_cache_view_free(struct llama_kv_cache_view * view) {
if (view->cells != nullptr) {
free(view->cells);
view->cells = nullptr;
}
if (view->cells_sequences != nullptr) {
free(view->cells_sequences);
view->cells_sequences = nullptr;
}
}
void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view) {
if (uint32_t(view->n_cells) < ctx->kv_self.size || view->cells == nullptr) {
view->n_cells = int32_t(ctx->kv_self.size);
void * p = realloc(view->cells, sizeof(struct llama_kv_cache_view_cell) * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells");
view->cells = (struct llama_kv_cache_view_cell *)p;
p = realloc(view->cells_sequences, sizeof(llama_seq_id) * view->n_max_seq * view->n_cells);
GGML_ASSERT(p != nullptr && "Failed to alloc kv_cache_view cells sequences");
view->cells_sequences = (llama_seq_id *)p;
}
const std::vector<llama_kv_cell> & kv_cells = ctx->kv_self.cells;
llama_kv_cache_view_cell * c_curr = view->cells;
llama_seq_id * cs_curr = view->cells_sequences;
int32_t used_cells = 0;
int32_t token_count = 0;
int32_t curr_contig_idx = -1;
uint32_t max_contig = 0;
int32_t max_contig_idx = -1;
for (int32_t i = 0; i < int32_t(ctx->kv_self.size); i++, c_curr++, cs_curr += view->n_max_seq) {
const size_t curr_size = kv_cells[i].seq_id.size();
token_count += curr_size;
c_curr->pos = kv_cells[i].pos + kv_cells[i].delta;
if (curr_size > 0) {
if (curr_contig_idx >= 0 && uint32_t(i - curr_contig_idx) > max_contig) {
max_contig = i - curr_contig_idx;
max_contig_idx = curr_contig_idx;
}
curr_contig_idx = -1;
} else if (curr_contig_idx < 0) {
curr_contig_idx = i;
}
int seq_idx = 0;
for (const llama_seq_id it : kv_cells[i].seq_id) {
if (seq_idx >= view->n_max_seq) {
break;
}
cs_curr[seq_idx] = it;
seq_idx++;
}
if (seq_idx != 0) {
used_cells++;
}
for (; seq_idx < view->n_max_seq; seq_idx++) {
cs_curr[seq_idx] = -1;
}
}
if (curr_contig_idx >= 0 && kv_cells.size() - curr_contig_idx > max_contig) {
max_contig_idx = curr_contig_idx;
max_contig = kv_cells.size() - curr_contig_idx;
}
view->max_contiguous = max_contig;
view->max_contiguous_idx = max_contig_idx;
view->token_count = token_count;
view->used_cells = used_cells;
if (uint32_t(used_cells) != ctx->kv_self.used) {
LLAMA_LOG_ERROR("%s: used cells mismatch. kv_cache says %d but we calculated %d\n",
__func__, ctx->kv_self.used, used_cells);
}
}
int llama_get_kv_cache_token_count(const struct llama_context * ctx) { int llama_get_kv_cache_token_count(const struct llama_context * ctx) {
return ctx->kv_self.head; int result = 0;
for (uint32_t i = 0; i < ctx->kv_self.size; i++) {
result += ctx->kv_self.cells[i].seq_id.size();
}
return result;
}
int llama_get_kv_cache_used_cells(const struct llama_context * ctx) {
return ctx->kv_self.used;
} }
void llama_kv_cache_clear(struct llama_context * ctx) { void llama_kv_cache_clear(struct llama_context * ctx) {
@ -9284,10 +9355,12 @@ static void llama_copy_state_data_internal(struct llama_context * ctx, llama_dat
const size_t kv_buf_size = kv_self.buf.size; const size_t kv_buf_size = kv_self.buf.size;
const uint32_t kv_head = kv_self.head; const uint32_t kv_head = kv_self.head;
const uint32_t kv_size = kv_self.size; const uint32_t kv_size = kv_self.size;
const uint32_t kv_used = kv_self.used;
data_ctx->write(&kv_buf_size, sizeof(kv_buf_size)); data_ctx->write(&kv_buf_size, sizeof(kv_buf_size));
data_ctx->write(&kv_head, sizeof(kv_head)); data_ctx->write(&kv_head, sizeof(kv_head));
data_ctx->write(&kv_size, sizeof(kv_size)); data_ctx->write(&kv_size, sizeof(kv_size));
data_ctx->write(&kv_used, sizeof(kv_used));
if (kv_buf_size) { if (kv_buf_size) {
const size_t elt_size = ggml_element_size(kv_self.k); const size_t elt_size = ggml_element_size(kv_self.k);
@ -9410,10 +9483,12 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
size_t kv_buf_size; size_t kv_buf_size;
uint32_t kv_head; uint32_t kv_head;
uint32_t kv_size; uint32_t kv_size;
uint32_t kv_used;
memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size); memcpy(&kv_buf_size, inp, sizeof(kv_buf_size)); inp += sizeof(kv_buf_size);
memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head); memcpy(&kv_head, inp, sizeof(kv_head)); inp += sizeof(kv_head);
memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size); memcpy(&kv_size, inp, sizeof(kv_size)); inp += sizeof(kv_size);
memcpy(&kv_used, inp, sizeof(kv_used)); inp += sizeof(kv_used);
if (kv_buf_size) { if (kv_buf_size) {
GGML_ASSERT(kv_self.buf.size == kv_buf_size); GGML_ASSERT(kv_self.buf.size == kv_buf_size);
@ -9448,6 +9523,7 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) {
ctx->kv_self.head = kv_head; ctx->kv_self.head = kv_head;
ctx->kv_self.size = kv_size; ctx->kv_self.size = kv_size;
ctx->kv_self.used = kv_used;
ctx->kv_self.cells.resize(kv_size); ctx->kv_self.cells.resize(kv_size);
@ -9923,6 +9999,9 @@ const std::vector<std::pair<std::string, struct ggml_tensor *>> & llama_internal
void llama_log_set(ggml_log_callback log_callback, void * user_data) { void llama_log_set(ggml_log_callback log_callback, void * user_data) {
g_state.log_callback = log_callback ? log_callback : llama_log_callback_default; g_state.log_callback = log_callback ? log_callback : llama_log_callback_default;
g_state.log_callback_user_data = user_data; g_state.log_callback_user_data = user_data;
#ifdef GGML_USE_METAL
ggml_metal_log_set_callback(g_state.log_callback, g_state.log_callback_user_data);
#endif
} }
static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) { static void llama_log_internal_v(ggml_log_level level, const char * format, va_list args) {

59
llama.h
View file

@ -185,7 +185,7 @@ extern "C" {
// ref: https://github.com/ggerganov/llama.cpp/pull/2054 // ref: https://github.com/ggerganov/llama.cpp/pull/2054
float rope_freq_base; // RoPE base frequency, 0 = from model float rope_freq_base; // RoPE base frequency, 0 = from model
float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model float rope_freq_scale; // RoPE frequency scaling factor, 0 = from model
float yarn_ext_factor; // YaRN extrapolation mix factor, NaN = from model float yarn_ext_factor; // YaRN extrapolation mix factor, negative = from model
float yarn_attn_factor; // YaRN magnitude scaling factor float yarn_attn_factor; // YaRN magnitude scaling factor
float yarn_beta_fast; // YaRN low correction dim float yarn_beta_fast; // YaRN low correction dim
float yarn_beta_slow; // YaRN high correction dim float yarn_beta_slow; // YaRN high correction dim
@ -361,9 +361,60 @@ extern "C" {
// KV cache // KV cache
// //
// Returns the number of tokens in the KV cache // Information associated with an individual cell in the KV cache view.
LLAMA_API DEPRECATED(int llama_get_kv_cache_token_count(const struct llama_context * ctx), struct llama_kv_cache_view_cell {
"avoid using this, it will be removed in the future, instead - count the tokens in user code"); // The position for this cell. Takes KV cache shifts into account.
// May be negative if the cell is not populated.
llama_pos pos;
};
// An updateable view of the KV cache.
struct llama_kv_cache_view {
// Number of KV cache cells. This will be the same as the context size.
int32_t n_cells;
// Maximum number of sequences that can exist in a cell. It's not an error
// if there are more sequences in a cell than this value, however they will
// not be visible in the view cells_sequences.
int32_t n_max_seq;
// Number of tokens in the cache. For example, if there are two populated
// cells, the first with 1 sequence id in it and the second with 2 sequence
// ids then you'll have 3 tokens.
int32_t token_count;
// Number of populated cache cells.
int32_t used_cells;
// Maximum contiguous empty slots in the cache.
int32_t max_contiguous;
// Index to the start of the max_contiguous slot range. Can be negative
// when cache is full.
int32_t max_contiguous_idx;
// Information for an individual cell.
struct llama_kv_cache_view_cell * cells;
// The sequences for each cell. There will be n_max_seq items per cell.
llama_seq_id * cells_sequences;
};
// Create an empty KV cache view. (use only for debugging purposes)
LLAMA_API struct llama_kv_cache_view llama_kv_cache_view_init(const struct llama_context * ctx, int32_t n_max_seq);
// Free a KV cache view. (use only for debugging purposes)
LLAMA_API void llama_kv_cache_view_free(struct llama_kv_cache_view * view);
// Update the KV cache view structure with the current state of the KV cache. (use only for debugging purposes)
LLAMA_API void llama_kv_cache_view_update(const struct llama_context * ctx, struct llama_kv_cache_view * view);
// Returns the number of tokens in the KV cache (slow, use only for debug)
// If a KV cell has multiple sequences assigned to it, it will be counted multiple times
LLAMA_API int llama_get_kv_cache_token_count(const struct llama_context * ctx);
// Returns the number of used KV cells (i.e. have at least one sequence assigned to them)
LLAMA_API int llama_get_kv_cache_used_cells(const struct llama_context * ctx);
// Clear the KV cache // Clear the KV cache
LLAMA_API void llama_kv_cache_clear( LLAMA_API void llama_kv_cache_clear(