llama : refactor sampling v2 (#9294)

- Add `struct llama_sampler` and `struct llama_sampler_i`
- Add `llama_sampler_` API
- Add `llama_sampler_chain_` API for chaining multiple samplers
- Remove `LLAMA_API_INTERNAL`
- Add `llama_perf_` API and remove old `llama_print_timings` and `llama_reset_timings`
This commit is contained in:
Georgi Gerganov 2024-09-07 15:16:19 +03:00 committed by GitHub
parent 947538acb8
commit df270ef745
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
48 changed files with 3497 additions and 2914 deletions

View file

@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
actor LlamaContext {
private var model: OpaquePointer
private var context: OpaquePointer
private var sampling: UnsafeMutablePointer<llama_sampler>
private var batch: llama_batch
private var tokens_list: [llama_token]
var is_done: Bool = false
@ -42,9 +43,15 @@ actor LlamaContext {
self.tokens_list = []
self.batch = llama_batch_init(512, 0, 1)
self.temporary_invalid_cchars = []
let sparams = llama_sampler_chain_default_params()
self.sampling = llama_sampler_chain_init(sparams)
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
llama_sampler_chain_add(self.sampling, llama_sampler_init_softmax())
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
}
deinit {
llama_sampler_free(sampling)
llama_batch_free(batch)
llama_free(context)
llama_free_model(model)
@ -69,7 +76,6 @@ actor LlamaContext {
print("Using \(n_threads) threads")
var ctx_params = llama_context_default_params()
ctx_params.seed = 1234
ctx_params.n_ctx = 2048
ctx_params.n_threads = Int32(n_threads)
ctx_params.n_threads_batch = Int32(n_threads)
@ -144,20 +150,9 @@ actor LlamaContext {
func completion_loop() -> String {
var new_token_id: llama_token = 0
let n_vocab = llama_n_vocab(model)
let logits = llama_get_logits_ith(context, batch.n_tokens - 1)
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
var candidates = Array<llama_token_data>()
candidates.reserveCapacity(Int(n_vocab))
for token_id in 0..<n_vocab {
candidates.append(llama_token_data(id: token_id, logit: logits![Int(token_id)], p: 0.0))
}
candidates.withUnsafeMutableBufferPointer() { buffer in
var candidates_p = llama_token_data_array(data: buffer.baseAddress, size: buffer.count, sorted: false)
new_token_id = llama_sample_token_greedy(context, &candidates_p)
}
llama_sampler_accept(sampling, new_token_id)
if llama_token_is_eog(model, new_token_id) || n_cur == n_len {
print("\n")