swift : fix llama-vocab api usage (#11645)

* swiftui : fix vocab api usage

* batched.swift : fix vocab api usage
This commit is contained in:
Jhen-Jie Hong 2025-02-04 19:15:24 +08:00 committed by GitHub
parent 534c46b53c
commit f117d84b48
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 17 additions and 10 deletions

View file

@ -24,6 +24,7 @@ func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama
actor LlamaContext {
private var model: OpaquePointer
private var context: OpaquePointer
private var vocab: OpaquePointer
private var sampling: UnsafeMutablePointer<llama_sampler>
private var batch: llama_batch
private var tokens_list: [llama_token]
@ -47,6 +48,7 @@ actor LlamaContext {
self.sampling = llama_sampler_chain_init(sparams)
llama_sampler_chain_add(self.sampling, llama_sampler_init_temp(0.4))
llama_sampler_chain_add(self.sampling, llama_sampler_init_dist(1234))
vocab = llama_model_get_vocab(model)
}
deinit {
@ -79,7 +81,7 @@ actor LlamaContext {
ctx_params.n_threads = Int32(n_threads)
ctx_params.n_threads_batch = Int32(n_threads)
let context = llama_new_context_with_model(model, ctx_params)
let context = llama_init_from_model(model, ctx_params)
guard let context else {
print("Could not load context!")
throw LlamaError.couldNotInitializeContext
@ -151,7 +153,7 @@ actor LlamaContext {
new_token_id = llama_sampler_sample(sampling, context, batch.n_tokens - 1)
if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
print("\n")
is_done = true
let new_token_str = String(cString: temporary_invalid_cchars + [0])
@ -297,7 +299,7 @@ actor LlamaContext {
let utf8Count = text.utf8.count
let n_tokens = utf8Count + (add_bos ? 1 : 0) + 1
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)
let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)
var swiftTokens: [llama_token] = []
for i in 0..<tokenCount {
@ -316,7 +318,7 @@ actor LlamaContext {
defer {
result.deallocate()
}
let nTokens = llama_token_to_piece(model, token, result, 8, 0, false)
let nTokens = llama_token_to_piece(vocab, token, result, 8, 0, false)
if nTokens < 0 {
let newResult = UnsafeMutablePointer<Int8>.allocate(capacity: Int(-nTokens))
@ -324,7 +326,7 @@ actor LlamaContext {
defer {
newResult.deallocate()
}
let nNewTokens = llama_token_to_piece(model, token, newResult, -nTokens, 0, false)
let nNewTokens = llama_token_to_piece(vocab, token, newResult, -nTokens, 0, false)
let bufferPointer = UnsafeBufferPointer(start: newResult, count: Int(nNewTokens))
return Array(bufferPointer)
} else {