llama : add new llama_decode() API that works with llama_batch

This commit is contained in:
Georgi Gerganov 2023-09-18 14:23:52 +03:00
parent 58bb5110ca
commit 9f42e75489
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
13 changed files with 146 additions and 75 deletions

View file

@ -160,7 +160,7 @@ int main(int argc, char ** argv)
int n_past = 0;
if (llama_eval(ctx, tokens_list.data(), tokens_list.size(), n_past, params.n_threads))
if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), tokens_list.size(), n_past, 0), params.n_threads))
{
fprintf(stderr, "%s : failed to eval prompt.\n" , __func__ );
return 1;

View file

@ -79,7 +79,8 @@ bool eval_float(void * model, float * input, int N){
if (n_eval > n_batch) {
n_eval = n_batch;
}
if (llama_eval_embd(ctx, (input+i*n_emb), n_eval, n_past, params.n_threads)) {
llama_batch batch = { uint32_t(n_eval), nullptr, (input+i*n_emb), nullptr, nullptr, n_past, 1, 0, false };
if (llama_decode(ctx, batch, params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}
@ -100,7 +101,7 @@ bool eval_tokens(void * model, std::vector<llama_token> tokens) {
if (n_eval > params.n_batch) {
n_eval = params.n_batch;
}
if (llama_eval(ctx, &tokens[i], n_eval, n_past, params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(&tokens[i], n_eval, n_past, 0), params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;
}

View file

@ -77,7 +77,7 @@ int main(int argc, char ** argv) {
while (!embd_inp.empty()) {
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
if (llama_eval(ctx, embd_inp.data(), n_tokens, n_past, params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0), params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}

View file

@ -891,7 +891,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
int n_processed = 0;
while (n_processed < n_prompt) {
int n_tokens = std::min(n_prompt - n_processed, n_batch);
llama_eval(ctx, tokens.data(), n_tokens, n_past + n_processed, n_threads);
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_tokens, n_past + n_processed, 0), n_threads);
n_processed += n_tokens;
}
}
@ -899,7 +899,7 @@ static void test_prompt(llama_context * ctx, int n_prompt, int n_past, int n_bat
static void test_gen(llama_context * ctx, int n_gen, int n_past, int n_threads) {
llama_token token = llama_token_bos(ctx);
for (int i = 0; i < n_gen; i++) {
llama_eval(ctx, &token, 1, n_past + i, n_threads);
llama_decode(ctx, llama_batch_get_one(&token, 1, n_past + i, 0), n_threads);
}
}

View file

@ -571,7 +571,7 @@ int main(int argc, char ** argv) {
for (int i = 0; i < input_size; i += params.n_batch) {
int n_eval = std::min(input_size - i, params.n_batch);
if (llama_eval(ctx_guidance, input_buf + i, n_eval, n_past_guidance, params.n_threads)) {
if (llama_decode(ctx_guidance, llama_batch_get_one(input_buf + i, n_eval, n_past_guidance, 0), params.n_threads)) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}
@ -588,7 +588,7 @@ int main(int argc, char ** argv) {
LOG("eval: %s\n", LOG_TOKENS_TOSTR_PRETTY(ctx, embd));
if (llama_eval(ctx, &embd[i], n_eval, n_past, params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(&embd[i], n_eval, n_past, 0), params.n_threads)) {
LOG_TEE("%s : failed to eval\n", __func__);
return 1;
}

View file

@ -199,7 +199,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params &
const int batch_size = std::min(end - batch_start, n_batch);
//fprintf(stderr, " Batch %d: starts at %d, size is %d, n_past is %d\n",j,batch_start,batch_size,j * n_batch);
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
//fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
@ -331,7 +331,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par
tokens[batch_start] = llama_token_bos(ctx);
}
if (llama_eval(ctx, tokens.data() + batch_start, batch_size, j * n_batch, params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0), params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {tokens, -1, logit_history, prob_history};
}
@ -409,7 +409,7 @@ static std::vector<float> hellaswag_evaluate_tokens(
for (size_t i_chunk = 0; i_chunk < n_chunk; ++i_chunk) {
size_t n_tokens = tokens.size() - i_chunk * n_batch;
n_tokens = std::min(n_tokens, size_t(n_batch));
if (llama_eval(ctx, tokens.data() + i_chunk * n_batch, n_tokens, n_past, n_thread)) {
if (llama_decode(ctx, llama_batch_get_one(tokens.data() + i_chunk * n_batch, n_tokens, n_past, 0), n_thread)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return {};
}

View file

@ -34,11 +34,11 @@ int main(int argc, char ** argv) {
auto last_n_tokens_data = std::vector<llama_token>(params.repeat_last_n, 0);
// init
auto model = llama_load_model_from_file(params.model.c_str(), lparams);
auto * model = llama_load_model_from_file(params.model.c_str(), lparams);
if (model == nullptr) {
return 1;
}
auto ctx = llama_new_context_with_model(model, lparams);
auto * ctx = llama_new_context_with_model(model, lparams);
if (ctx == nullptr) {
llama_free_model(model);
return 1;
@ -53,7 +53,7 @@ int main(int argc, char ** argv) {
}
// evaluate prompt
llama_eval(ctx, tokens.data(), n_prompt_tokens, n_past, params.n_threads);
llama_decode(ctx, llama_batch_get_one(tokens.data(), n_prompt_tokens, n_past, 0), params.n_threads);
last_n_tokens_data.insert(last_n_tokens_data.end(), tokens.data(), tokens.data() + n_prompt_tokens);
n_past += n_prompt_tokens;
@ -77,7 +77,7 @@ int main(int argc, char ** argv) {
printf("\n%s", params.prompt.c_str());
for (auto i = 0; i < params.n_predict; i++) {
auto logits = llama_get_logits(ctx);
auto * logits = llama_get_logits(ctx);
auto n_vocab = llama_n_vocab(ctx);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@ -90,7 +90,7 @@ int main(int argc, char ** argv) {
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str.c_str());
if (llama_eval(ctx, &next_token, 1, n_past, params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx);
llama_free_model(model);
@ -105,7 +105,7 @@ int main(int argc, char ** argv) {
llama_free(ctx);
// make new context
auto ctx2 = llama_new_context_with_model(model, lparams);
auto * ctx2 = llama_new_context_with_model(model, lparams);
// Load state (rng, logits, embedding and kv_cache) from file
{
@ -137,7 +137,7 @@ int main(int argc, char ** argv) {
// second run
for (auto i = 0; i < params.n_predict; i++) {
auto logits = llama_get_logits(ctx2);
auto * logits = llama_get_logits(ctx2);
auto n_vocab = llama_n_vocab(ctx2);
std::vector<llama_token_data> candidates;
candidates.reserve(n_vocab);
@ -150,7 +150,7 @@ int main(int argc, char ** argv) {
last_n_tokens_data.push_back(next_token);
printf("%s", next_token_str.c_str());
if (llama_eval(ctx2, &next_token, 1, n_past, params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(&next_token, 1, n_past, 0), params.n_threads)) {
fprintf(stderr, "\n%s : failed to evaluate\n", __func__);
llama_free(ctx2);
llama_free_model(model);

View file

@ -434,7 +434,7 @@ struct llama_server_context
{
n_eval = params.n_batch;
}
if (llama_eval(ctx, &embd[n_past], n_eval, n_past, params.n_threads))
if (llama_decode(ctx, llama_batch_get_one(&embd[n_past], n_eval, n_past, 0), params.n_threads))
{
LOG_ERROR("failed to eval", {
{"n_eval", n_eval},

View file

@ -76,7 +76,7 @@ int main(int argc, char ** argv) {
while (n_cur < n_gen) {
// evaluate the transformer
if (llama_eval(ctx, tokens_list.data(), int(tokens_list.size()), n_cur, params.n_threads)) {
if (llama_decode(ctx, llama_batch_get_one(tokens_list.data(), int(tokens_list.size()), n_cur, 0), params.n_threads)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}

View file

@ -70,9 +70,9 @@ int main(int argc, char ** argv) {
const auto t_enc_start = ggml_time_us();
// eval the prompt with both models
llama_eval(ctx_tgt, inp.data(), int(inp.size() - 1), 0, params.n_threads);
llama_eval(ctx_tgt, &inp.back(), 1, inp.size() - 1, params.n_threads);
llama_eval(ctx_dft, inp.data(), int(inp.size()), 0, params.n_threads);
llama_decode(ctx_tgt, llama_batch_get_one( inp.data(), n_input - 1, 0, 0), params.n_threads);
llama_decode(ctx_tgt, llama_batch_get_one(&inp.back(), 1, n_input - 1, 0), params.n_threads);
llama_decode(ctx_dft, llama_batch_get_one( inp.data(), n_input, 0, 0), params.n_threads);
const auto t_enc_end = ggml_time_us();
@ -172,7 +172,7 @@ int main(int argc, char ** argv) {
LOG("out of drafted tokens\n");
}
llama_eval(ctx_dft, &id, 1, n_past_dft, params.n_threads);
llama_decode(ctx_dft, llama_batch_get_one(&id, 1, n_past_dft, 0), params.n_threads);
++n_past_dft;
// heuristic for n_draft
@ -256,7 +256,7 @@ int main(int argc, char ** argv) {
}
// evaluate the drafted token on the draft model
llama_eval(ctx_dft, &drafted.back(), 1, n_past_cur, params.n_threads);
llama_decode(ctx_dft, llama_batch_get_one(&drafted.back(), 1, n_past_cur, 0), params.n_threads);
++n_past_cur;
if (grammar_dft != NULL) {
@ -265,7 +265,7 @@ int main(int argc, char ** argv) {
}
// evaluate the target model on the drafted tokens
llama_eval(ctx_tgt, drafted.data(), drafted.size(), n_past_tgt, params.n_threads);
llama_decode(ctx_tgt, llama_batch_get_one(drafted.data(), drafted.size(), n_past_tgt, 0), params.n_threads);
++n_past_tgt;
// the first token is always proposed by the traget model before the speculation loop