speculative : add tree-based sampling example (#3624)

* sampling : one sequence per sampling context

ggml-ci

* speculative : add tree-based sampling support

ggml-ci

* speculative : reuse the n_parallel CLI param

* speculative : refactor sampling

* examples : fix build after sampling refactoring

ggml-ci

* batched : fix n_seq_id

* sampling : fix malloc

ggml-ci

* swift : fix build

ggml-ci

* swift : try to fix build

ggml-ci

* prompts : add assistant.txt

* common : add llama_batch_add() and llama_batch_clear() helpers

* speculative : minor refactor

ggml-ci

* minor : comments + rename

ggml-ci

* speculative : fix off-by-one for n_drafted

* speculative : fix the n_drafted fix + p constants
This commit is contained in:
Georgi Gerganov 2023-10-18 16:21:57 +03:00 committed by GitHub
parent c67fe68e41
commit 0e89203b51
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
21 changed files with 737 additions and 578 deletions

View file

@ -69,7 +69,7 @@ for id: llama_token in tokens {
print("\n")
var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0)
var batch = llama_batch_init(max(Int32(tokens.count), Int32(n_parallel)), 0, 1)
defer {
llama_batch_free(batch)
}
@ -80,7 +80,12 @@ batch.n_tokens = Int32(tokens.count)
for (i, token) in tokens.enumerated() {
batch.token[i] = token
batch.pos[i] = Int32(i)
batch.seq_id[i] = 0
batch.n_seq_id[i] = 1
// batch.seq_id[i][0] = 0
// TODO: is this the proper way to do this?
if let seq_id = batch.seq_id[i] {
seq_id[0] = 0
}
batch.logits[i] = 0
}
@ -169,7 +174,10 @@ while n_cur <= n_len {
// push this new token for next evaluation
batch.token[Int(batch.n_tokens)] = new_token_id
batch.pos[Int(batch.n_tokens)] = n_cur
batch.seq_id[Int(batch.n_tokens)] = Int32(i)
batch.n_seq_id[Int(batch.n_tokens)] = 1
if let seq_id = batch.seq_id[Int(batch.n_tokens)] {
seq_id[0] = Int32(i)
}
batch.logits[Int(batch.n_tokens)] = 1
i_batch[i] = batch.n_tokens