speculative : add tree-based sampling example (#3624)

* sampling : one sequence per sampling context

ggml-ci

* speculative : add tree-based sampling support

ggml-ci

* speculative : reuse the n_parallel CLI param

* speculative : refactor sampling

* examples : fix build after sampling refactoring

ggml-ci

* batched : fix n_seq_id

* sampling : fix malloc

ggml-ci

* swift : fix build

ggml-ci

* swift : try to fix build

ggml-ci

* prompts : add assistant.txt

* common : add llama_batch_add() and llama_batch_clear() helpers

* speculative : minor refactor

ggml-ci

* minor : comments + rename

ggml-ci

* speculative : fix off-by-one for n_drafted

* speculative : fix the n_drafted fix + p constants
This commit is contained in:
Georgi Gerganov 2023-10-18 16:21:57 +03:00 committed by GitHub
parent c67fe68e41
commit 0e89203b51
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 737 additions and 578 deletions

View file

@ -17,7 +17,7 @@ inline bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, nullptr, *n_past, 1, 0, };
if (llama_decode(ctx_llama, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return false;

View file

@ -127,7 +127,7 @@ int main(int argc, char ** argv) {
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params.n_batch, &n_past, true);
eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:", params.n_batch, &n_past, true);
eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past);
eval_string(ctx_llama, (params.prompt + "\nASSISTANT:").c_str(), params.n_batch, &n_past, false);