llama : custom attention mask + parallel decoding + no context swaps (#3228)
* tests : verify that RoPE is "additive" * llama : replace ggml_diag_mask_inf with ggml_add (custom -inf mask) * ggml : ggml_rope now takes a vector with positions instead of n_past * metal : add rope_f16 kernel + optimize cpy kernels * llama : unified KV cache + batch inference API * llama : add new llama_decode() API that works with llama_batch * llama : add cell_max heuristic for more efficient kv_cache * llama : extend llama_kv_cache API * llama : more robust cell_max heuristic + wip shift * metal : disable concurrency optimization * llama : add llama_kv_cache_shift_seq + no more context swaps * llama : apply K-cache roping for Falcon and Baichuan * speculative : fix KV cache management * parallel : example for serving multiple users in parallel * parallel : disable hot-plug to avoid cache fragmentation * fixes : speculative KV cache + llama worst-case graph * llama : extend batch API to select which logits to output * llama : fix worst case graph build * ggml-cuda : update rope implementation for parallel decoding (#3254) * ggml-cuda : update rope implementation for parallel decoding * better solution for p0 computation * fix rope * simpler rope implementation --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * make : add parallel to build + fix static functions in llama.cpp * simple : fix token counting * parallel : various improvements * llama : fix cell_max logic + rename functions * parallel : try smaller batches when the KV cache is fragmented * parallel : fix sequence termination criteria * llama : silence errors KV cache errors * parallel : remove new line from prompt * parallel : process system prompt once + configurable paramters + llama API * parallel : remove question with short answers * parallel : count cache misses * parallel : print misses on each request * parallel : minor * llama : fix n_kv to never become 0 * parallel : rename hot-plug to continuous-batching * llama : improve llama_batch API + simplify parallel example * simple : add parallel decoding support * simple : improve comments + free batch * ggml-cuda : add rope f16, restore performance with parallel decoding (#3272) * ggml-cuda : add rope f16, restore performance * offload KQ_mask with all models * fix rope shift --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * llama : disable MPI for now ggml-ci * train : make KQ_pos memory buffer permanent via dummy scale op * ggml : revert change to ggml_cpy, add ggml_cont_Nd instead (#3275) ggml-ci * parallel : fix bug (extra BOS) + smaller token_prev array * parallel : fix cases where the input prompts can overflow the batch * parallel : add disabled experimental batch chunking in powers of two * llama : llama.h formatting + comments * simple : add README.md * llama : fix kv cache heuristic when context is less than 32 * parallel : fix crash when `-n -1` * llama : simplify returns if/else branches * metal : use mm kernels for batch size > 2 * examples : utilize new llama_get_logits_ith() * examples : add example for batched decoding * examples : do not eval prompt 2 times (close #3348) * server : clear the KV cache beyond n_past before llama_decode * server : avoid context swaps by shifting the KV cache --------- Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
45855b3f1c
commit
ec893798b7
35 changed files with 2700 additions and 673 deletions
|
@ -23,7 +23,9 @@ else()
|
|||
add_subdirectory(train-text-from-scratch)
|
||||
add_subdirectory(convert-llama2c-to-ggml)
|
||||
add_subdirectory(simple)
|
||||
add_subdirectory(batched)
|
||||
add_subdirectory(speculative)
|
||||
add_subdirectory(parallel)
|
||||
add_subdirectory(embd-input)
|
||||
add_subdirectory(llama-bench)
|
||||
add_subdirectory(beam-search)
|
||||
|
|
|
@ -554,6 +554,14 @@ static struct ggml_tensor * forward(
|
|||
struct ggml_tensor * kc = kv_self.k;
|
||||
struct ggml_tensor * vc = kv_self.v;
|
||||
|
||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
{
|
||||
int * data = (int *) KQ_pos->data;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
data[i] = n_past + i;
|
||||
}
|
||||
}
|
||||
|
||||
// inpL shape [n_embd,N,1,1]
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
|
@ -581,8 +589,8 @@ static struct ggml_tensor * forward(
|
|||
// wk shape [n_embd, n_embd, 1, 1]
|
||||
// Qcur shape [n_embd/n_head, n_head, N, 1]
|
||||
// Kcur shape [n_embd/n_head, n_head, N, 1]
|
||||
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), n_past, n_rot, 0, 0);
|
||||
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_3d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N), KQ_pos, n_rot, 0, 0);
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
|
@ -808,9 +816,18 @@ static struct ggml_tensor * forward_batch(
|
|||
struct ggml_tensor * kc = kv_self.k;
|
||||
struct ggml_tensor * vc = kv_self.v;
|
||||
|
||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
{
|
||||
int * data = (int *) KQ_pos->data;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
data[i] = n_past + i;
|
||||
}
|
||||
}
|
||||
|
||||
// inpL shape [n_embd,N*n_batch,1]
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
|
||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
struct ggml_tensor * inpSA = inpL;
|
||||
|
||||
|
@ -838,8 +855,8 @@ static struct ggml_tensor * forward_batch(
|
|||
// wk shape [n_embd, n_embd, 1, 1]
|
||||
// Qcur shape [n_embd/n_head, n_head, N, n_batch]
|
||||
// Kcur shape [n_embd/n_head, n_head, N, n_batch]
|
||||
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), n_past, n_rot, 0, 0);
|
||||
struct ggml_tensor * Qcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wq, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope(ctx0, ggml_reshape_4d(ctx0, ggml_mul_mat(ctx0, model->layers[il].wk, cur), n_embd/n_head, n_head, N, n_batch), KQ_pos, n_rot, 0, 0);
|
||||
assert_shape_4d(Qcur, n_embd/n_head, n_head, N, n_batch);
|
||||
assert_shape_4d(Kcur, n_embd/n_head, n_head, N, n_batch);
|
||||
|
||||
|
@ -1097,6 +1114,14 @@ static struct ggml_tensor * forward_lora(
|
|||
struct ggml_tensor * kc = kv_self.k;
|
||||
struct ggml_tensor * vc = kv_self.v;
|
||||
|
||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N);
|
||||
{
|
||||
int * data = (int *) KQ_pos->data;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
data[i] = n_past + i;
|
||||
}
|
||||
}
|
||||
|
||||
// inpL shape [n_embd,N,1,1]
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
|
@ -1130,7 +1155,7 @@ static struct ggml_tensor * forward_lora(
|
|||
model->layers[il].wqb,
|
||||
cur)),
|
||||
n_embd/n_head, n_head, N),
|
||||
n_past, n_rot, 0, 0);
|
||||
KQ_pos, n_rot, 0, 0);
|
||||
struct ggml_tensor * Kcur = ggml_rope(ctx0,
|
||||
ggml_reshape_3d(ctx0,
|
||||
ggml_mul_mat(ctx0,
|
||||
|
@ -1139,7 +1164,7 @@ static struct ggml_tensor * forward_lora(
|
|||
model->layers[il].wkb,
|
||||
cur)),
|
||||
n_embd/n_head, n_head, N),
|
||||
n_past, n_rot, 0, 0);
|
||||
KQ_pos, n_rot, 0, 0);
|
||||
|
||||
// store key and value to memory
|
||||
{
|
||||
|
|
5
examples/batched/CMakeLists.txt
Normal file
5
examples/batched/CMakeLists.txt
Normal file
|
@ -0,0 +1,5 @@
|
|||
set(TARGET batched)
|
||||
add_executable(${TARGET} batched.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
44
examples/batched/README.md
Normal file
44
examples/batched/README.md
Normal file
|
@ -0,0 +1,44 @@
|
|||
# llama.cpp/example/batched
|
||||
|
||||
The example demonstrates batched generation from a given prompt
|
||||
|
||||
```bash
|
||||
./batched ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is" 4
|
||||
|
||||
...
|
||||
|
||||
main: n_len = 32, n_ctx = 2048, n_parallel = 4, n_kv_req = 113
|
||||
|
||||
Hello my name is
|
||||
|
||||
main: generating 4 sequences ...
|
||||
|
||||
main: stream 0 finished
|
||||
main: stream 1 finished
|
||||
main: stream 2 finished
|
||||
main: stream 3 finished
|
||||
|
||||
sequence 0:
|
||||
|
||||
Hello my name is Shirley. I am a 25-year-old female who has been working for over 5 years as a b
|
||||
|
||||
sequence 1:
|
||||
|
||||
Hello my name is Renee and I'm a 32 year old female from the United States. I'm looking for a man between
|
||||
|
||||
sequence 2:
|
||||
|
||||
Hello my name is Diana. I am looking for a housekeeping job. I have experience with children and have my own transportation. I am
|
||||
|
||||
sequence 3:
|
||||
|
||||
Hello my name is Cody. I am a 3 year old neutered male. I am a very friendly cat. I am very playful and
|
||||
|
||||
main: decoded 108 tokens in 3.57 s, speed: 30.26 t/s
|
||||
|
||||
llama_print_timings: load time = 587.00 ms
|
||||
llama_print_timings: sample time = 2.56 ms / 112 runs ( 0.02 ms per token, 43664.72 tokens per second)
|
||||
llama_print_timings: prompt eval time = 4089.11 ms / 118 tokens ( 34.65 ms per token, 28.86 tokens per second)
|
||||
llama_print_timings: eval time = 0.00 ms / 1 runs ( 0.00 ms per token, inf tokens per second)
|
||||
llama_print_timings: total time = 4156.04 ms
|
||||
```
|
246
examples/batched/batched.cpp
Normal file
246
examples/batched/batched.cpp
Normal file
|
@ -0,0 +1,246 @@
|
|||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
if (argc == 1 || argv[1][0] == '-') {
|
||||
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL]\n" , argv[0]);
|
||||
return 1 ;
|
||||
}
|
||||
|
||||
int n_parallel = 1;
|
||||
|
||||
if (argc >= 2) {
|
||||
params.model = argv[1];
|
||||
}
|
||||
|
||||
if (argc >= 3) {
|
||||
params.prompt = argv[2];
|
||||
}
|
||||
|
||||
if (argc >= 4) {
|
||||
n_parallel = std::atoi(argv[3]);
|
||||
}
|
||||
|
||||
if (params.prompt.empty()) {
|
||||
params.prompt = "Hello my name is";
|
||||
}
|
||||
|
||||
// total length of the sequences including the prompt
|
||||
const int n_len = 32;
|
||||
|
||||
// init LLM
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
|
||||
ctx_params.seed = 1234;
|
||||
ctx_params.n_ctx = n_len*n_parallel; // FIXME: use n_kv_req instead (tokenize with model after #3301)
|
||||
ctx_params.n_batch = std::max(n_len, n_parallel);
|
||||
// ctx_params.n_gpu_layers = 99; // offload all layers to the GPU
|
||||
|
||||
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
|
||||
|
||||
if (model == NULL) {
|
||||
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// tokenize the prompt
|
||||
|
||||
std::vector<llama_token> tokens_list;
|
||||
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
|
||||
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size())*n_parallel;
|
||||
|
||||
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_batch = %d, n_parallel = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, ctx_params.n_batch, n_parallel, n_kv_req);
|
||||
|
||||
// make sure the KV cache is big enough to hold all the prompt and generated tokens
|
||||
if (n_kv_req > n_ctx) {
|
||||
LOG_TEE("%s: error: n_kv_req (%d) > n_ctx, the required KV cache size is not big enough\n", __func__, n_kv_req);
|
||||
LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// print the prompt token-by-token
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
for (auto id : tokens_list) {
|
||||
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
|
||||
}
|
||||
|
||||
fflush(stderr);
|
||||
|
||||
// create a llama_batch with size 512
|
||||
// we use this object to submit token data for decoding
|
||||
|
||||
llama_batch batch = llama_batch_init(std::max(tokens_list.size(), (size_t)n_parallel), 0);
|
||||
|
||||
// evaluate the initial prompt
|
||||
batch.n_tokens = tokens_list.size();
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
batch.token[i] = tokens_list[i];
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
}
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
if (llama_decode(ctx, batch, params.n_threads) != 0) {
|
||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// assign the system KV cache to all parallel sequences
|
||||
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
|
||||
for (int32_t i = 1; i < n_parallel; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, batch.n_tokens);
|
||||
}
|
||||
|
||||
if (n_parallel > 1) {
|
||||
LOG_TEE("\n\n%s: generating %d sequences ...\n", __func__, n_parallel);
|
||||
}
|
||||
|
||||
// main loop
|
||||
|
||||
// we will store the parallel decoded sequences in this vector
|
||||
std::vector<std::string> streams(n_parallel);
|
||||
|
||||
// remember the batch index of the last token for each parallel sequence
|
||||
// we need this to determine which logits to sample from
|
||||
std::vector<int32_t> i_batch(n_parallel, batch.n_tokens - 1);
|
||||
|
||||
int n_cur = batch.n_tokens;
|
||||
int n_decode = 0;
|
||||
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
while (n_cur <= n_len) {
|
||||
// prepare the next batch
|
||||
batch.n_tokens = 0;
|
||||
|
||||
// sample the next token for each parallel sequence / stream
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
if (i_batch[i] < 0) {
|
||||
// the stream has already finished
|
||||
continue;
|
||||
}
|
||||
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
auto * logits = llama_get_logits_ith(ctx, i_batch[i]);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
const int top_k = 40;
|
||||
const float top_p = 0.9f;
|
||||
const float temp = 0.4f;
|
||||
|
||||
llama_sample_top_k(ctx, &candidates_p, top_k, 1);
|
||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
||||
llama_sample_temp (ctx, &candidates_p, temp);
|
||||
|
||||
const llama_token new_token_id = llama_sample_token(ctx, &candidates_p);
|
||||
|
||||
//const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
|
||||
// is it an end of stream? -> mark the stream as finished
|
||||
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
|
||||
i_batch[i] = -1;
|
||||
LOG_TEE("\n");
|
||||
if (n_parallel > 1) {
|
||||
LOG_TEE("%s: stream %d finished at n_cur = %d", __func__, i, n_cur);
|
||||
}
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
// if there is only one stream, we print immediately to stdout
|
||||
if (n_parallel == 1) {
|
||||
LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
|
||||
fflush(stdout);
|
||||
}
|
||||
|
||||
streams[i] += llama_token_to_piece(ctx, new_token_id);
|
||||
|
||||
// push this new token for next evaluation
|
||||
batch.token [batch.n_tokens] = new_token_id;
|
||||
batch.pos [batch.n_tokens] = n_cur;
|
||||
batch.seq_id[batch.n_tokens] = i;
|
||||
batch.logits[batch.n_tokens] = true;
|
||||
|
||||
i_batch[i] = batch.n_tokens;
|
||||
|
||||
batch.n_tokens += 1;
|
||||
|
||||
n_decode += 1;
|
||||
}
|
||||
|
||||
// all streams are finished
|
||||
if (batch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
n_cur += 1;
|
||||
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
|
||||
if (n_parallel > 1) {
|
||||
LOG_TEE("\n");
|
||||
|
||||
for (int32_t i = 0; i < n_parallel; ++i) {
|
||||
LOG_TEE("sequence %d:\n\n%s%s\n\n", i, params.prompt.c_str(), streams[i].c_str());
|
||||
}
|
||||
}
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
llama_print_timings(ctx);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -158,8 +158,9 @@ int main(int argc, char ** argv)
|
|||
}
|
||||
std::cout << std::flush;
|
||||
|
||||
int n_past = llama_get_kv_cache_token_count(ctx);
|
||||
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
|
||||
int n_past = 0;
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0), params.n_threads))
|
||||
{
|
||||
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
|
||||
return 1;
|
||||
|
|
|
@ -80,7 +80,8 @@ bool eval_float(void * model, float * input, int N){
|
|||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) {
|
||||
llama_batch batch = { int32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, nullptr, n_past, 1, 0, };
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
@ -101,7 +102,7 @@ bool eval_tokens(void * model, std::vector<llama_token> tokens) {
|
|||
if (n_eval > params.n_batch) {
|
||||
n_eval = params.n_batch;
|
||||
}
|
||||
if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0), params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
@ -183,11 +184,11 @@ llama_token sampling_id(struct MyModel* mymodel) {
|
|||
if (mirostat == 1) {
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
llama_sample_temp(ctx, &candidates_p, temp);
|
||||
id = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
||||
} else if (mirostat == 2) {
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
llama_sample_temp(ctx, &candidates_p, temp);
|
||||
id = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||
} else {
|
||||
// Temperature sampling
|
||||
|
@ -195,7 +196,7 @@ llama_token sampling_id(struct MyModel* mymodel) {
|
|||
llama_sample_tail_free(ctx, &candidates_p, tfs_z, 1);
|
||||
llama_sample_typical(ctx, &candidates_p, typical_p, 1);
|
||||
llama_sample_top_p(ctx, &candidates_p, top_p, 1);
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
llama_sample_temp(ctx, &candidates_p, temp);
|
||||
id = llama_sample_token(ctx, &candidates_p);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -78,7 +78,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
while (!embd_inp.empty()) {
|
||||
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
|
||||
if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0), params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
|
|
@ -891,7 +891,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
|
|||
int n_processed = 0;
|
||||
while (n_processed < n_prompt) {
|
||||
int n_tokens = std::min(n_prompt - n_processed, n_batch);
|
||||
llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads);
|
||||
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0), n_threads);
|
||||
n_processed += n_tokens;
|
||||
}
|
||||
}
|
||||
|
@ -899,7 +899,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
|
|||
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
|
||||
llama_token token = llama_token_bos(ctx);
|
||||
for (int i = 0; i < n_gen; i++) {
|
||||
llama_eval(ctx, &token, 1, n_past + i, n_threads);
|
||||
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0), n_threads);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -977,6 +977,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
test t(inst, lmodel, ctx);
|
||||
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
// warmup run
|
||||
if (t.n_prompt > 0) {
|
||||
test_prompt(ctx, std::min(2, t.n_batch), 0, t.n_batch, t.n_threads);
|
||||
|
@ -986,6 +988,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
for (int i = 0; i < params.reps; i++) {
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
uint64_t t_start = get_time_ns();
|
||||
if (t.n_prompt > 0) {
|
||||
test_prompt(ctx, t.n_prompt, 0, t.n_batch, t.n_threads);
|
||||
|
|
|
@ -124,7 +124,7 @@ int main(int argc, char ** argv) {
|
|||
console::init(params.simple_io, params.use_color);
|
||||
atexit([]() { console::cleanup(); });
|
||||
|
||||
if (params.perplexity) {
|
||||
if (params.logits_all) {
|
||||
printf("\n************\n");
|
||||
printf("%s: please use the 'perplexity' tool for perplexity calculations\n", __func__);
|
||||
printf("************\n\n");
|
||||
|
@ -200,15 +200,6 @@ int main(int argc, char ** argv) {
|
|||
params.n_threads, std::thread::hardware_concurrency(), llama_print_system_info());
|
||||
}
|
||||
|
||||
// export the cgraph and exit
|
||||
if (params.export_cgraph) {
|
||||
llama_eval_export(ctx, "llama.ggml");
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::string path_session = params.path_prompt_cache;
|
||||
std::vector<llama_token> session_tokens;
|
||||
|
||||
|
@ -508,18 +499,23 @@ int main(int argc, char ** argv) {
|
|||
break;
|
||||
}
|
||||
|
||||
const int n_left = n_past - params.n_keep;
|
||||
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d\n", n_past, n_left, n_ctx, params.n_keep);
|
||||
const int n_left = n_past - params.n_keep - 1;
|
||||
const int n_discard = n_left/2;
|
||||
|
||||
// always keep the first token - BOS
|
||||
n_past = std::max(1, params.n_keep);
|
||||
n_past_guidance = std::max(1, params.n_keep + guidance_offset);
|
||||
LOG("context full, swapping: n_past = %d, n_left = %d, n_ctx = %d, n_keep = %d, n_discard = %d\n",
|
||||
n_past, n_left, n_ctx, params.n_keep, n_discard);
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||
|
||||
n_past -= n_discard;
|
||||
|
||||
if (ctx_guidance) {
|
||||
n_past_guidance -= n_discard;
|
||||
}
|
||||
|
||||
LOG("after swap: n_past = %d, n_past_guidance = %d\n", n_past, n_past_guidance);
|
||||
|
||||
// insert n_left/2 tokens at the start of embd from last_tokens
|
||||
embd.insert(embd.begin(), last_tokens.begin() + n_ctx - n_left/2 - embd.size(), last_tokens.end() - embd.size());
|
||||
|
||||
LOG("embd: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
|
||||
|
||||
LOG("clear session path\n");
|
||||
|
@ -580,7 +576,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
for (int i = 0; i < input_size; i += params.n_batch) {
|
||||
int n_eval = std::min(input_size - i, params.n_batch);
|
||||
if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) {
|
||||
if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0), params.n_threads)) {
|
||||
LOG_TEE("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
@ -597,7 +593,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
|
||||
|
||||
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0), params.n_threads)) {
|
||||
LOG_TEE("%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
|
8
examples/parallel/CMakeLists.txt
Normal file
8
examples/parallel/CMakeLists.txt
Normal file
|
@ -0,0 +1,8 @@
|
|||
set(TARGET parallel)
|
||||
add_executable(${TARGET} parallel.cpp)
|
||||
install(TARGETS ${TARGET} RUNTIME)
|
||||
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
|
||||
target_compile_features(${TARGET} PRIVATE cxx_std_11)
|
||||
if(TARGET BUILD_INFO)
|
||||
add_dependencies(${TARGET} BUILD_INFO)
|
||||
endif()
|
3
examples/parallel/README.md
Normal file
3
examples/parallel/README.md
Normal file
|
@ -0,0 +1,3 @@
|
|||
# llama.cpp/example/parallel
|
||||
|
||||
Simplified simluation for serving incoming requests in parallel
|
380
examples/parallel/parallel.cpp
Normal file
380
examples/parallel/parallel.cpp
Normal file
|
@ -0,0 +1,380 @@
|
|||
// A basic application simulating a server with multiple clients.
|
||||
// The clients submite requests to the server and they are processed in parallel.
|
||||
|
||||
#include "build-info.h"
|
||||
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
|
||||
#include <cmath>
|
||||
#include <cstdio>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
// trim whitespace from the beginning and end of a string
|
||||
static std::string trim(const std::string & str) {
|
||||
size_t start = 0;
|
||||
size_t end = str.size();
|
||||
|
||||
while (start < end && isspace(str[start])) {
|
||||
start += 1;
|
||||
}
|
||||
|
||||
while (end > start && isspace(str[end - 1])) {
|
||||
end -= 1;
|
||||
}
|
||||
|
||||
return str.substr(start, end - start);
|
||||
}
|
||||
|
||||
static std::string k_system =
|
||||
R"(Transcript of a never ending dialog, where the User interacts with an Assistant.
|
||||
The Assistant is helpful, kind, honest, good at writing, and never fails to answer the User's requests immediately and with precision.
|
||||
|
||||
User: Recommend a nice restaurant in the area.
|
||||
Assistant: I recommend the restaurant "The Golden Duck". It is a 5 star restaurant with a great view of the city. The food is delicious and the service is excellent. The prices are reasonable and the portions are generous. The restaurant is located at 123 Main Street, New York, NY 10001. The phone number is (212) 555-1234. The hours are Monday through Friday from 11:00 am to 10:00 pm. The restaurant is closed on Saturdays and Sundays.
|
||||
User: Who is Richard Feynman?
|
||||
Assistant: Richard Feynman was an American physicist who is best known for his work in quantum mechanics and particle physics. He was awarded the Nobel Prize in Physics in 1965 for his contributions to the development of quantum electrodynamics. He was a popular lecturer and author, and he wrote several books, including "Surely You're Joking, Mr. Feynman!" and "What Do You Care What Other People Think?".
|
||||
User:)";
|
||||
|
||||
static std::vector<std::string> k_prompts = {
|
||||
"What is the meaning of life?",
|
||||
"Tell me an interesting fact about llamas.",
|
||||
"What is the best way to cook a steak?",
|
||||
"Are you familiar with the Special Theory of Relativity and can you explain it to me?",
|
||||
"Recommend some interesting books to read.",
|
||||
"What is the best way to learn a new language?",
|
||||
"How to get a job at Google?",
|
||||
"If you could have any superpower, what would it be?",
|
||||
"I want to learn how to play the piano.",
|
||||
};
|
||||
|
||||
struct client {
|
||||
int32_t id = 0;
|
||||
|
||||
llama_seq_id seq_id = -1;
|
||||
|
||||
llama_token sampled;
|
||||
|
||||
int64_t t_start_prompt;
|
||||
int64_t t_start_gen;
|
||||
|
||||
int32_t n_prompt = 0;
|
||||
int32_t n_decoded = 0;
|
||||
int32_t i_batch = -1;
|
||||
|
||||
std::string input;
|
||||
std::string prompt;
|
||||
std::string response;
|
||||
|
||||
std::vector<llama_token> tokens_prev;
|
||||
};
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
srand(1234);
|
||||
|
||||
gpt_params params;
|
||||
|
||||
if (gpt_params_parse(argc, argv, params) == false) {
|
||||
return 1;
|
||||
}
|
||||
|
||||
// number of simultaneous "clients" to simulate
|
||||
const int32_t n_clients = params.n_parallel;
|
||||
|
||||
// requests to simulate
|
||||
const int32_t n_seq = params.n_sequences;
|
||||
|
||||
// insert new requests as soon as the previous one is done
|
||||
const bool cont_batching = params.cont_batching;
|
||||
|
||||
#ifndef LOG_DISABLE_LOGS
|
||||
log_set_target(log_filename_generator("parallel", "log"));
|
||||
LOG_TEE("Log start\n");
|
||||
log_dump_cmdline(argc, argv);
|
||||
#endif // LOG_DISABLE_LOGS
|
||||
|
||||
// init llama.cpp
|
||||
llama_backend_init(params.numa);
|
||||
|
||||
llama_model * model = NULL;
|
||||
llama_context * ctx = NULL;
|
||||
|
||||
// load the target model
|
||||
params.logits_all = true;
|
||||
std::tie(model, ctx) = llama_init_from_gpt_params(params);
|
||||
|
||||
fprintf(stderr, "\n\n");
|
||||
fflush(stderr);
|
||||
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_vocab = llama_n_vocab(ctx);
|
||||
|
||||
std::vector<client> clients(n_clients);
|
||||
for (size_t i = 0; i < clients.size(); ++i) {
|
||||
auto & client = clients[i];
|
||||
client.id = i;
|
||||
client.tokens_prev.resize(std::max(256, params.n_predict));
|
||||
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
|
||||
}
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
std::vector<llama_token> tokens_system;
|
||||
tokens_system = ::llama_tokenize(ctx, k_system, true);
|
||||
const int32_t n_tokens_system = tokens_system.size();
|
||||
|
||||
llama_seq_id g_seq_id = 0;
|
||||
|
||||
// the max batch size is as large as the context to handle cases where we get very long input prompt from multiple
|
||||
// users. regardless of the size, the main loop will chunk the batch into a maximum of params.n_batch tokens at a time
|
||||
llama_batch batch = llama_batch_init(params.n_ctx, 0);
|
||||
|
||||
int32_t n_total_prompt = 0;
|
||||
int32_t n_total_gen = 0;
|
||||
int32_t n_cache_miss = 0;
|
||||
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
LOG_TEE("%s: Simulating parallel requests from clients:\n", __func__);
|
||||
LOG_TEE("%s: n_parallel = %d, n_sequences = %d, cont_batching = %d, system tokens = %d\n", __func__, n_clients, n_seq, cont_batching, n_tokens_system);
|
||||
LOG_TEE("\n");
|
||||
|
||||
{
|
||||
LOG_TEE("%s: Evaluating the system prompt ...\n", __func__);
|
||||
|
||||
batch.n_tokens = n_tokens_system;
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; ++i) {
|
||||
batch.token[i] = tokens_system[i];
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
}
|
||||
|
||||
if (llama_decode(ctx, batch, params.n_threads) != 0) {
|
||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// assign the system KV cache to all parallel sequences
|
||||
for (int32_t i = 1; i < n_clients; ++i) {
|
||||
llama_kv_cache_seq_cp(ctx, 0, i, 0, n_tokens_system);
|
||||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
}
|
||||
|
||||
LOG_TEE("Processing requests ...\n\n");
|
||||
|
||||
while (true) {
|
||||
batch.n_tokens = 0;
|
||||
|
||||
// decode any currently ongoing sequences
|
||||
for (auto & client : clients) {
|
||||
if (client.seq_id == -1) {
|
||||
continue;
|
||||
}
|
||||
|
||||
batch.token [batch.n_tokens] = client.sampled;
|
||||
batch.pos [batch.n_tokens] = n_tokens_system + client.n_prompt + client.n_decoded;
|
||||
batch.seq_id[batch.n_tokens] = client.id;
|
||||
batch.logits[batch.n_tokens] = true;
|
||||
|
||||
client.n_decoded += 1;
|
||||
client.i_batch = batch.n_tokens;
|
||||
|
||||
batch.n_tokens += 1;
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
// all sequences have ended - clear the entire KV cache
|
||||
for (int i = 0; i < n_clients; ++i) {
|
||||
llama_kv_cache_seq_rm(ctx, i, n_tokens_system, -1);
|
||||
}
|
||||
|
||||
LOG_TEE("%s: clearing the KV cache\n", __func__);
|
||||
}
|
||||
|
||||
// insert new sequences for decoding
|
||||
if (cont_batching || batch.n_tokens == 0) {
|
||||
for (auto & client : clients) {
|
||||
if (client.seq_id == -1 && g_seq_id < n_seq) {
|
||||
client.seq_id = g_seq_id;
|
||||
|
||||
client.t_start_prompt = ggml_time_us();
|
||||
client.t_start_gen = 0;
|
||||
|
||||
client.input = k_prompts[rand() % k_prompts.size()];
|
||||
client.prompt = client.input + "\nAssistant:";
|
||||
client.response = "";
|
||||
|
||||
std::fill(client.tokens_prev.begin(), client.tokens_prev.end(), 0);
|
||||
|
||||
// do not prepend BOS because we have a system prompt!
|
||||
std::vector<llama_token> tokens_prompt;
|
||||
tokens_prompt = ::llama_tokenize(ctx, client.prompt, false);
|
||||
|
||||
for (size_t i = 0; i < tokens_prompt.size(); ++i) {
|
||||
batch.token [batch.n_tokens] = tokens_prompt[i];
|
||||
batch.pos [batch.n_tokens] = i + n_tokens_system;
|
||||
batch.seq_id[batch.n_tokens] = client.id;
|
||||
batch.logits[batch.n_tokens] = false;
|
||||
batch.n_tokens += 1;
|
||||
}
|
||||
|
||||
// extract the logits only for the last token
|
||||
if (batch.n_tokens > 0) {
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
}
|
||||
|
||||
client.n_prompt = tokens_prompt.size();
|
||||
client.n_decoded = 0;
|
||||
client.i_batch = batch.n_tokens - 1;
|
||||
|
||||
LOG_TEE("\033[1mClient %3d, seq %4d, started decoding ...\033[0m\n", client.id, client.seq_id);
|
||||
|
||||
g_seq_id += 1;
|
||||
|
||||
// insert new requests one-by-one
|
||||
//if (cont_batching) {
|
||||
// break;
|
||||
//}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (batch.n_tokens == 0) {
|
||||
break;
|
||||
}
|
||||
|
||||
// process in chunks of params.n_batch
|
||||
int32_t n_batch = params.n_batch;
|
||||
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||
// experiment: process in powers of 2
|
||||
//if (i + n_batch > (int32_t) batch.n_tokens && n_batch > 32) {
|
||||
// n_batch /= 2;
|
||||
// i -= n_batch;
|
||||
// continue;
|
||||
//}
|
||||
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
|
||||
const int ret = llama_decode(ctx, batch_view, params.n_threads);
|
||||
if (ret != 0) {
|
||||
if (n_batch == 1 || ret < 0) {
|
||||
// if you get here, it means the KV cache is full - try increasing it via the context size
|
||||
LOG_TEE("%s : failed to decode the batch, n_batch = %d, ret = %d\n", __func__, n_batch, ret);
|
||||
return 1;
|
||||
}
|
||||
|
||||
LOG("%s : failed to decode the batch, retrying with n_batch = %d\n", __func__, n_batch / 2);
|
||||
|
||||
n_cache_miss += 1;
|
||||
|
||||
// retry with half the batch size to try to find a free slot in the KV cache
|
||||
n_batch /= 2;
|
||||
i -= n_batch;
|
||||
|
||||
continue;
|
||||
}
|
||||
|
||||
LOG("%s : decoded batch of %d tokens\n", __func__, n_tokens);
|
||||
|
||||
for (auto & client : clients) {
|
||||
if (client.i_batch < (int) i || client.i_batch >= (int) (i + n_tokens)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
//printf("client %d, seq %d, token %d, pos %d, batch %d\n",
|
||||
// client.id, client.seq_id, client.sampled, client.n_decoded, client.i_batch);
|
||||
|
||||
const llama_token id = llama_sample_token(ctx, NULL, NULL, params, client.tokens_prev, candidates, client.i_batch - i);
|
||||
|
||||
if (client.n_decoded == 1) {
|
||||
// start measuring generation time after the first token to make sure all concurrent clients
|
||||
// have their prompt already processed
|
||||
client.t_start_gen = ggml_time_us();
|
||||
}
|
||||
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
client.tokens_prev.erase(client.tokens_prev.begin());
|
||||
client.tokens_prev.push_back(id);
|
||||
|
||||
const std::string token_str = llama_token_to_piece(ctx, id);
|
||||
client.response += token_str;
|
||||
client.sampled = id;
|
||||
|
||||
//printf("client %d, seq %d, token %d, pos %d, batch %d: %s\n",
|
||||
// client.id, client.seq_id, id, client.n_decoded, client.i_batch, token_str.c_str());
|
||||
|
||||
if (client.n_decoded > 2 &&
|
||||
(id == llama_token_eos(ctx) ||
|
||||
(params.n_predict > 0 && client.n_decoded + client.n_prompt >= params.n_predict) ||
|
||||
client.response.find("User:") != std::string::npos ||
|
||||
client.response.find('\n') != std::string::npos)) {
|
||||
// basic reverse prompt
|
||||
const size_t pos = client.response.find("User:");
|
||||
if (pos != std::string::npos) {
|
||||
client.response = client.response.substr(0, pos);
|
||||
}
|
||||
|
||||
// delete only the generated part of the sequence, i.e. keep the system prompt in the cache
|
||||
llama_kv_cache_seq_rm(ctx, client.id, n_tokens_system, n_ctx);
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
LOG_TEE("\033[1mClient %3d, seq %4d, prompt %4d t, response %4d t, time %5.2f s, speed %5.2f t/s, cache miss %d \033[0m \n\nInput: %s\nResponse: %s\n\n",
|
||||
client.id, client.seq_id, client.n_prompt, client.n_decoded,
|
||||
(t_main_end - client.t_start_prompt) / 1e6,
|
||||
(double) (client.n_prompt + client.n_decoded) / (t_main_end - client.t_start_prompt) * 1e6,
|
||||
n_cache_miss,
|
||||
::trim(client.input).c_str(),
|
||||
::trim(client.response).c_str());
|
||||
|
||||
n_total_prompt += client.n_prompt;
|
||||
n_total_gen += client.n_decoded;
|
||||
|
||||
client.seq_id = -1;
|
||||
}
|
||||
|
||||
client.i_batch = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
LOG_TEE("\n\n");
|
||||
LOG_TEE("Total prompt tokens: %6d, speed: %5.2f t/s\n", n_total_prompt, (double) (n_total_prompt ) / (t_main_end - t_main_start) * 1e6);
|
||||
LOG_TEE("Total gen tokens: %6d, speed: %5.2f t/s\n", n_total_gen, (double) (n_total_gen ) / (t_main_end - t_main_start) * 1e6);
|
||||
LOG_TEE("Total speed (AVG): %6s speed: %5.2f t/s\n", "", (double) (n_total_prompt + n_total_gen) / (t_main_end - t_main_start) * 1e6);
|
||||
LOG_TEE("Cache misses: %6d\n", n_cache_miss);
|
||||
|
||||
LOG_TEE("\n\n");
|
||||
|
||||
llama_print_timings(ctx);
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
fprintf(stderr, "\n\n");
|
||||
|
||||
return 0;
|
||||
}
|
|
@ -80,7 +80,9 @@ static void write_logfile(
|
|||
static std::vector<float> softmax(const std::vector<float>& logits) {
|
||||
std::vector<float> probs(logits.size());
|
||||
float max_logit = logits[0];
|
||||
for (float v : logits) max_logit = std::max(max_logit, v);
|
||||
for (float v : logits) {
|
||||
max_logit = std::max(max_logit, v);
|
||||
}
|
||||
double sum_exp = 0.0;
|
||||
for (size_t i = 0; i < logits.size(); i++) {
|
||||
// Subtract the maximum logit value from the current logit value for numerical stability
|
||||
|
@ -89,15 +91,21 @@ static std::vector<float> softmax(const std::vector<float>& logits) {
|
|||
sum_exp += exp_logit;
|
||||
probs[i] = exp_logit;
|
||||
}
|
||||
for (size_t i = 0; i < probs.size(); i++) probs[i] /= sum_exp;
|
||||
for (size_t i = 0; i < probs.size(); i++) {
|
||||
probs[i] /= sum_exp;
|
||||
}
|
||||
return probs;
|
||||
}
|
||||
|
||||
static results_log_softmax log_softmax(int n_vocab, const float * logits, int tok) {
|
||||
float max_logit = logits[0];
|
||||
for (int i = 1; i < n_vocab; ++i) max_logit = std::max(max_logit, logits[i]);
|
||||
for (int i = 1; i < n_vocab; ++i) {
|
||||
max_logit = std::max(max_logit, logits[i]);
|
||||
}
|
||||
double sum_exp = 0.0;
|
||||
for (int i = 0; i < n_vocab; ++i) sum_exp += expf(logits[i] - max_logit);
|
||||
for (int i = 0; i < n_vocab; ++i) {
|
||||
sum_exp += expf(logits[i] - max_logit);
|
||||
}
|
||||
return {logits[tok] - max_logit - log(sum_exp), logits[tok], expf(logits[tok] - max_logit) / (float) sum_exp};
|
||||
}
|
||||
|
||||
|
@ -108,7 +116,8 @@ static void process_logits(
|
|||
std::mutex mutex;
|
||||
int counter = 0;
|
||||
auto compute = [&mutex, &counter, &nll, &nll2, logit_history, prob_history, n_vocab, logits, tokens, n_token] () {
|
||||
double local_nll = 0, local_nll2 = 0;
|
||||
double local_nll = 0;
|
||||
double local_nll2 = 0;
|
||||
while (true) {
|
||||
std::unique_lock<std::mutex> lock(mutex);
|
||||
int i = counter++;
|
||||
|
@ -126,10 +135,13 @@ static void process_logits(
|
|||
prob_history[i] = results.prob;
|
||||
}
|
||||
};
|
||||
for (auto & w : workers) w = std::thread(compute);
|
||||
for (auto & w : workers) {
|
||||
w = std::thread(compute);
|
||||
}
|
||||
compute();
|
||||
for (auto & w : workers) w.join();
|
||||
|
||||
for (auto & w : workers) {
|
||||
w.join();
|
||||
}
|
||||
}
|
||||
|
||||
static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & params) {
|
||||
|
@ -152,8 +164,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|||
return {std::move(tokens), 0., {}, {}};
|
||||
}
|
||||
|
||||
std::vector<float> logit_history;
|
||||
std::vector<float> prob_history;
|
||||
std::vector<float> logit_history;
|
||||
std::vector<float> prob_history;
|
||||
|
||||
logit_history.resize(tokens.size());
|
||||
prob_history.resize(tokens.size());
|
||||
|
@ -195,12 +207,15 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
|
|||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
const int batch_size = std::min(end - batch_start, n_batch);
|
||||
|
||||
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
|
||||
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
|
||||
//fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return {tokens, -1, logit_history, prob_history};
|
||||
}
|
||||
|
@ -320,6 +335,9 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
|
||||
const auto t_start = std::chrono::high_resolution_clock::now();
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
for (int j = 0; j < num_batches; ++j) {
|
||||
const int batch_start = start + j * n_batch;
|
||||
const int batch_size = std::min(end - batch_start, n_batch);
|
||||
|
@ -332,7 +350,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
tokens[batch_start] = llama_token_bos(ctx);
|
||||
}
|
||||
|
||||
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return {tokens, -1, logit_history, prob_history};
|
||||
}
|
||||
|
@ -402,7 +420,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
|
|||
}
|
||||
|
||||
static std::vector<float> hellaswag_evaluate_tokens(
|
||||
llama_context * ctx, const std::vector<int>& tokens, int n_past, int n_batch, int n_vocab, int n_thread
|
||||
llama_context * ctx, std::vector<int> & tokens, int n_past, int n_batch, int n_vocab, int n_thread
|
||||
) {
|
||||
std::vector<float> result;
|
||||
result.reserve(tokens.size() * n_vocab);
|
||||
|
@ -410,7 +428,7 @@ static std::vector<float> hellaswag_evaluate_tokens(
|
|||
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
|
||||
size_t n_tokens = tokens.size() - i_chunk * n_batch;
|
||||
n_tokens = std::min(n_tokens, size_t(n_batch));
|
||||
if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return {};
|
||||
}
|
||||
|
@ -550,6 +568,9 @@ static void hellaswag_score(llama_context * ctx, const gpt_params & params) {
|
|||
query_embd.resize(32);
|
||||
}
|
||||
|
||||
// clear the KV cache
|
||||
llama_kv_cache_tokens_rm(ctx, -1, -1);
|
||||
|
||||
auto logits = hellaswag_evaluate_tokens(ctx, query_embd, 0, params.n_batch, n_vocab, params.n_threads);
|
||||
if (logits.empty()) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
|
@ -661,7 +682,7 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
params.perplexity = true;
|
||||
params.logits_all = true;
|
||||
params.n_batch = std::min(params.n_batch, params.n_ctx);
|
||||
|
||||
if (params.ppl_stride > 0) {
|
||||
|
|
|
@ -35,11 +35,11 @@ int main(int argc, char ** argv) {
|
|||
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
|
||||
|
||||
// init
|
||||
auto model = llama_load_model_from_file(params.model.c_str(), lparams);
|
||||
auto * model = llama_load_model_from_file(params.model.c_str(), lparams);
|
||||
if (model == nullptr) {
|
||||
return 1;
|
||||
}
|
||||
auto ctx = llama_new_context_with_model(model, lparams);
|
||||
auto * ctx = llama_new_context_with_model(model, lparams);
|
||||
if (ctx == nullptr) {
|
||||
llama_free_model(model);
|
||||
return 1;
|
||||
|
@ -54,7 +54,7 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// evaluate prompt
|
||||
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads);
|
||||
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads);
|
||||
|
||||
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
|
||||
n_past += n_prompt_tokens;
|
||||
|
@ -78,7 +78,7 @@ int main(int argc, char ** argv) {
|
|||
printf("\n%s", params.prompt.c_str());
|
||||
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto logits = llama_get_logits(ctx);
|
||||
auto * logits = llama_get_logits(ctx);
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
@ -91,7 +91,7 @@ int main(int argc, char ** argv) {
|
|||
last_n_tokens_data.push_back(next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
@ -106,7 +106,7 @@ int main(int argc, char ** argv) {
|
|||
llama_free(ctx);
|
||||
|
||||
// make new context
|
||||
auto ctx2 = llama_new_context_with_model(model, lparams);
|
||||
auto * ctx2 = llama_new_context_with_model(model, lparams);
|
||||
|
||||
// Load state (rng, logits, embedding and kv_cache) from file
|
||||
{
|
||||
|
@ -138,7 +138,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
// second run
|
||||
for (auto i = 0; i < params.n_predict; i++) {
|
||||
auto logits = llama_get_logits(ctx2);
|
||||
auto * logits = llama_get_logits(ctx2);
|
||||
auto n_vocab = llama_n_vocab(ctx2);
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
@ -151,7 +151,7 @@ int main(int argc, char ** argv) {
|
|||
last_n_tokens_data.push_back(next_token);
|
||||
|
||||
printf("%s", next_token_str.c_str());
|
||||
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
|
||||
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
|
||||
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
|
||||
llama_free(ctx2);
|
||||
llama_free_model(model);
|
||||
|
|
|
@ -381,6 +381,10 @@ struct llama_server_context
|
|||
|
||||
// compare the evaluated prompt with the new prompt
|
||||
n_past = common_part(embd, prompt_tokens);
|
||||
|
||||
// since #3228 we now have to manually manage the KV cache
|
||||
llama_kv_cache_seq_rm(ctx, 0, n_past, params.n_ctx);
|
||||
|
||||
embd = prompt_tokens;
|
||||
if (n_past == num_prompt_tokens)
|
||||
{
|
||||
|
@ -411,19 +415,27 @@ struct llama_server_context
|
|||
|
||||
if (embd.size() >= (size_t)params.n_ctx)
|
||||
{
|
||||
// Reset context
|
||||
const int n_left = (params.n_ctx - params.n_keep) / 2;
|
||||
// Shift context
|
||||
|
||||
const int n_left = n_past - params.n_keep - 1;
|
||||
const int n_discard = n_left/2;
|
||||
|
||||
llama_kv_cache_seq_rm (ctx, 0, params.n_keep + 1 , params.n_keep + n_discard + 1);
|
||||
llama_kv_cache_seq_shift(ctx, 0, params.n_keep + 1 + n_discard, n_past, -n_discard);
|
||||
|
||||
for (size_t i = params.n_keep + 1 + n_discard; i < embd.size(); i++)
|
||||
{
|
||||
embd[i - n_discard] = embd[i];
|
||||
}
|
||||
embd.resize(embd.size() - n_discard);
|
||||
|
||||
n_past -= n_discard;
|
||||
|
||||
std::vector<llama_token> new_tokens(embd.begin(), embd.begin() + params.n_keep);
|
||||
new_tokens.insert(new_tokens.end(), embd.end() - n_left, embd.end());
|
||||
embd = new_tokens;
|
||||
n_past = params.n_keep;
|
||||
truncated = true;
|
||||
LOG_VERBOSE("input truncated", {
|
||||
{"n_ctx", params.n_ctx},
|
||||
{"n_keep", params.n_keep},
|
||||
{"n_left", n_left},
|
||||
{"new_tokens", tokens_to_str(ctx, new_tokens.cbegin(), new_tokens.cend())},
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -434,7 +446,8 @@ struct llama_server_context
|
|||
{
|
||||
n_eval = params.n_batch;
|
||||
}
|
||||
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads))
|
||||
|
||||
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads))
|
||||
{
|
||||
LOG_ERROR("failed to eval", {
|
||||
{"n_eval", n_eval},
|
||||
|
@ -523,13 +536,13 @@ struct llama_server_context
|
|||
{
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
const int mirostat_m = 100;
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
llama_sample_temp(ctx, &candidates_p, temp);
|
||||
result.tok = llama_sample_token_mirostat(ctx, &candidates_p, mirostat_tau, mirostat_eta, mirostat_m, &mirostat_mu);
|
||||
}
|
||||
else if (mirostat == 2)
|
||||
{
|
||||
static float mirostat_mu = 2.0f * mirostat_tau;
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
llama_sample_temp(ctx, &candidates_p, temp);
|
||||
result.tok = llama_sample_token_mirostat_v2(ctx, &candidates_p, mirostat_tau, mirostat_eta, &mirostat_mu);
|
||||
}
|
||||
else
|
||||
|
@ -540,7 +553,7 @@ struct llama_server_context
|
|||
llama_sample_tail_free(ctx, &candidates_p, tfs_z, min_keep);
|
||||
llama_sample_typical(ctx, &candidates_p, typical_p, min_keep);
|
||||
llama_sample_top_p(ctx, &candidates_p, top_p, min_keep);
|
||||
llama_sample_temperature(ctx, &candidates_p, temp);
|
||||
llama_sample_temp(ctx, &candidates_p, temp);
|
||||
result.tok = llama_sample_token(ctx, &candidates_p);
|
||||
}
|
||||
}
|
||||
|
|
21
examples/simple/README.md
Normal file
21
examples/simple/README.md
Normal file
|
@ -0,0 +1,21 @@
|
|||
# llama.cpp/example/simple
|
||||
|
||||
The purpose of this example is to demonstrate a minimal usage of llama.cpp for generating text with a given prompt.
|
||||
|
||||
```bash
|
||||
./simple ./models/llama-7b-v2/ggml-model-f16.gguf "Hello my name is"
|
||||
|
||||
...
|
||||
|
||||
main: n_len = 32, n_ctx = 2048, n_parallel = 1, n_kv_req = 32
|
||||
|
||||
Hello my name is Shawn and I'm a 20 year old male from the United States. I'm a 20 year old
|
||||
|
||||
main: decoded 27 tokens in 2.31 s, speed: 11.68 t/s
|
||||
|
||||
llama_print_timings: load time = 579.15 ms
|
||||
llama_print_timings: sample time = 0.72 ms / 28 runs ( 0.03 ms per token, 38888.89 tokens per second)
|
||||
llama_print_timings: prompt eval time = 655.63 ms / 10 tokens ( 65.56 ms per token, 15.25 tokens per second)
|
||||
llama_print_timings: eval time = 2180.97 ms / 27 runs ( 80.78 ms per token, 12.38 tokens per second)
|
||||
llama_print_timings: total time = 2891.13 ms
|
||||
```
|
|
@ -26,12 +26,18 @@ int main(int argc, char ** argv) {
|
|||
params.prompt = "Hello my name is";
|
||||
}
|
||||
|
||||
// total length of the sequence including the prompt
|
||||
const int n_len = 32;
|
||||
|
||||
// init LLM
|
||||
|
||||
llama_backend_init(params.numa);
|
||||
|
||||
llama_context_params ctx_params = llama_context_default_params();
|
||||
|
||||
ctx_params.seed = 1234;
|
||||
ctx_params.n_ctx = 2048;
|
||||
|
||||
llama_model * model = llama_load_model_from_file(params.model.c_str(), ctx_params);
|
||||
|
||||
if (model == NULL) {
|
||||
|
@ -41,20 +47,31 @@ int main(int argc, char ** argv) {
|
|||
|
||||
llama_context * ctx = llama_new_context_with_model(model, ctx_params);
|
||||
|
||||
if (ctx == NULL) {
|
||||
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// tokenize the prompt
|
||||
|
||||
std::vector<llama_token> tokens_list;
|
||||
tokens_list = ::llama_tokenize(ctx, params.prompt, true);
|
||||
|
||||
const int max_context_size = llama_n_ctx(ctx);
|
||||
const int max_tokens_list_size = max_context_size - 4;
|
||||
const int n_ctx = llama_n_ctx(ctx);
|
||||
const int n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
|
||||
|
||||
if ((int) tokens_list.size() > max_tokens_list_size) {
|
||||
fprintf(stderr, "%s: error: prompt too long (%d tokens, max %d)\n", __func__, (int) tokens_list.size(), max_tokens_list_size);
|
||||
LOG_TEE("\n%s: n_len = %d, n_ctx = %d, n_kv_req = %d\n", __func__, n_len, n_ctx, n_kv_req);
|
||||
|
||||
// make sure the KV cache is big enough to hold all the prompt and generated tokens
|
||||
if (n_kv_req > n_ctx) {
|
||||
LOG_TEE("%s: error: n_kv_req > n_ctx, the required KV cache size is not big enough\n", __func__);
|
||||
LOG_TEE("%s: either reduce n_parallel or increase n_ctx\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
fprintf(stderr, "\n\n");
|
||||
// print the prompt token-by-token
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
for (auto id : tokens_list) {
|
||||
fprintf(stderr, "%s", llama_token_to_piece(ctx, id).c_str());
|
||||
|
@ -62,63 +79,104 @@ int main(int argc, char ** argv) {
|
|||
|
||||
fflush(stderr);
|
||||
|
||||
// create a llama_batch with size 512
|
||||
// we use this object to submit token data for decoding
|
||||
|
||||
llama_batch batch = llama_batch_init(512, 0);
|
||||
|
||||
// evaluate the initial prompt
|
||||
batch.n_tokens = tokens_list.size();
|
||||
|
||||
for (int32_t i = 0; i < batch.n_tokens; i++) {
|
||||
batch.token[i] = tokens_list[i];
|
||||
batch.pos[i] = i;
|
||||
batch.seq_id[i] = 0;
|
||||
batch.logits[i] = false;
|
||||
}
|
||||
|
||||
// llama_decode will output logits only for the last token of the prompt
|
||||
batch.logits[batch.n_tokens - 1] = true;
|
||||
|
||||
if (llama_decode(ctx, batch, params.n_threads) != 0) {
|
||||
LOG_TEE("%s: llama_decode() failed\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// main loop
|
||||
|
||||
// The LLM keeps a contextual cache memory of previous token evaluation.
|
||||
// Usually, once this cache is full, it is required to recompute a compressed context based on previous
|
||||
// tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist
|
||||
// example, we will just stop the loop once this cache is full or once an end of stream is detected.
|
||||
int n_cur = batch.n_tokens;
|
||||
int n_decode = 0;
|
||||
|
||||
const int n_gen = std::min(32, max_context_size);
|
||||
const auto t_main_start = ggml_time_us();
|
||||
|
||||
while (llama_get_kv_cache_token_count(ctx) < n_gen) {
|
||||
// evaluate the transformer
|
||||
while (n_cur <= n_len) {
|
||||
// sample the next token
|
||||
{
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
auto * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
|
||||
|
||||
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), llama_get_kv_cache_token_count(ctx), params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
// sample the most likely token
|
||||
const llama_token new_token_id = llama_sample_token_greedy(ctx, &candidates_p);
|
||||
|
||||
// is it an end of stream?
|
||||
if (new_token_id == llama_token_eos(ctx) || n_cur == n_len) {
|
||||
LOG_TEE("\n");
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
LOG_TEE("%s", llama_token_to_piece(ctx, new_token_id).c_str());
|
||||
fflush(stdout);
|
||||
|
||||
// prepare the next batch
|
||||
batch.n_tokens = 0;
|
||||
|
||||
// push this new token for next evaluation
|
||||
batch.token [batch.n_tokens] = new_token_id;
|
||||
batch.pos [batch.n_tokens] = n_cur;
|
||||
batch.seq_id[batch.n_tokens] = 0;
|
||||
batch.logits[batch.n_tokens] = true;
|
||||
|
||||
batch.n_tokens += 1;
|
||||
|
||||
n_decode += 1;
|
||||
}
|
||||
|
||||
n_cur += 1;
|
||||
|
||||
// evaluate the current batch with the transformer model
|
||||
if (llama_decode(ctx, batch, params.n_threads)) {
|
||||
fprintf(stderr, "%s : failed to eval, return code %d\n", __func__, 1);
|
||||
return 1;
|
||||
}
|
||||
|
||||
tokens_list.clear();
|
||||
|
||||
// sample the next token
|
||||
|
||||
llama_token new_token_id = 0;
|
||||
|
||||
auto logits = llama_get_logits(ctx);
|
||||
auto n_vocab = llama_n_vocab(ctx);
|
||||
|
||||
std::vector<llama_token_data> candidates;
|
||||
candidates.reserve(n_vocab);
|
||||
|
||||
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
|
||||
candidates.emplace_back(llama_token_data{ token_id, logits[token_id], 0.0f });
|
||||
}
|
||||
|
||||
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
|
||||
|
||||
new_token_id = llama_sample_token_greedy(ctx , &candidates_p);
|
||||
|
||||
// is it an end of stream ?
|
||||
if (new_token_id == llama_token_eos(ctx)) {
|
||||
fprintf(stderr, " [end of text]\n");
|
||||
break;
|
||||
}
|
||||
|
||||
// print the new token :
|
||||
printf("%s", llama_token_to_piece(ctx, new_token_id).c_str());
|
||||
fflush(stdout);
|
||||
|
||||
// push this new token for next evaluation
|
||||
tokens_list.push_back(new_token_id);
|
||||
}
|
||||
|
||||
LOG_TEE("\n");
|
||||
|
||||
const auto t_main_end = ggml_time_us();
|
||||
|
||||
LOG_TEE("%s: decoded %d tokens in %.2f s, speed: %.2f t/s\n",
|
||||
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
|
||||
|
||||
llama_print_timings(ctx);
|
||||
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
llama_batch_free(batch);
|
||||
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
fprintf(stderr, "\n\n");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -37,7 +37,7 @@ int main(int argc, char ** argv) {
|
|||
llama_context * ctx_dft = NULL;
|
||||
|
||||
// load the target model
|
||||
params.perplexity = true; // HACK: enable logits_all = true
|
||||
params.logits_all = true;
|
||||
std::tie(model_tgt, ctx_tgt) = llama_init_from_gpt_params(params);
|
||||
|
||||
// load the draft model
|
||||
|
@ -70,9 +70,9 @@ int main(int argc, char ** argv) {
|
|||
const auto t_enc_start = ggml_time_us();
|
||||
|
||||
// eval the prompt with both models
|
||||
llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0, params.n_threads);
|
||||
llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1, params.n_threads);
|
||||
llama_eval(ctx_dft, inp.data(), int(inp.size()), 0, params.n_threads);
|
||||
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0), params.n_threads);
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0), params.n_threads);
|
||||
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0), params.n_threads);
|
||||
|
||||
const auto t_enc_end = ggml_time_us();
|
||||
|
||||
|
@ -134,7 +134,7 @@ int main(int argc, char ** argv) {
|
|||
|
||||
while (true) {
|
||||
// sample from the target model
|
||||
const llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
|
||||
llama_token id = llama_sample_token(ctx_tgt, NULL, grammar_tgt, params, last_tokens, candidates, i_dft);
|
||||
|
||||
// remember which tokens were sampled - used for repetition penalties during sampling
|
||||
last_tokens.erase(last_tokens.begin());
|
||||
|
@ -172,7 +172,8 @@ int main(int argc, char ** argv) {
|
|||
LOG("out of drafted tokens\n");
|
||||
}
|
||||
|
||||
llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_dft, n_ctx);
|
||||
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
|
||||
++n_past_dft;
|
||||
|
||||
// heuristic for n_draft
|
||||
|
@ -256,7 +257,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// evaluate the drafted token on the draft model
|
||||
llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads);
|
||||
llama_kv_cache_seq_rm(ctx_dft, 0, n_past_cur, n_ctx);
|
||||
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
|
||||
++n_past_cur;
|
||||
|
||||
if (grammar_dft != NULL) {
|
||||
|
@ -265,7 +267,8 @@ int main(int argc, char ** argv) {
|
|||
}
|
||||
|
||||
// evaluate the target model on the drafted tokens
|
||||
llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads);
|
||||
llama_kv_cache_seq_rm(ctx_tgt, 0, n_past_tgt, n_ctx);
|
||||
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
|
||||
++n_past_tgt;
|
||||
|
||||
// the first token is always proposed by the traget model before the speculation loop
|
||||
|
|
|
@ -679,15 +679,23 @@ struct ggml_tensor * llama_build_train_graphs(
|
|||
}
|
||||
};
|
||||
|
||||
// KQ_pos - contains the positions
|
||||
struct ggml_tensor * KQ_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, N);
|
||||
{
|
||||
int * data = (int *) KQ_pos->data;
|
||||
for (int i = 0; i < N; ++i) {
|
||||
data[i] = n_past + i;
|
||||
}
|
||||
}
|
||||
|
||||
// rope has so much parameters that we make a custom function for it
|
||||
auto rope = [ctx, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
|
||||
auto rope = [ctx, KQ_pos, n_rot, n_ctx, rope_freq_base, rope_freq_scale]
|
||||
(struct ggml_tensor * t) -> struct ggml_tensor * {
|
||||
// not capturing these, to silcence warnings
|
||||
const int n_past = 0;
|
||||
const int rope_mode = 0;
|
||||
|
||||
return ggml_rope_custom(ctx,
|
||||
t, n_past, n_rot, rope_mode, n_ctx,
|
||||
t, KQ_pos, n_rot, rope_mode, n_ctx,
|
||||
rope_freq_base, rope_freq_scale);
|
||||
};
|
||||
|
||||
|
@ -787,6 +795,8 @@ struct ggml_tensor * llama_build_train_graphs(
|
|||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36, one));
|
||||
// input gradient
|
||||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, t36->grad, one));
|
||||
// KQ_pos
|
||||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, KQ_pos, one));
|
||||
GGML_ASSERT(t36->grad->data == NULL && !ggml_is_view(t36->grad));
|
||||
ggml_allocr_alloc(alloc, t36->grad);
|
||||
// gradient tensors (will be set to zero by ggml_graph_reset)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue