swift : fix llama-vocab api usage (#11645)

* swiftui : fix vocab api usage

* batched.swift : fix vocab api usage
This commit is contained in:
Jhen-Jie Hong 2025-02-04 19:15:24 +08:00 committed by GitHub
parent 534c46b53c
commit f117d84b48
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 17 additions and 10 deletions

View file

@ -31,6 +31,11 @@ defer {
llama_model_free(model)
}
guard let vocab = llama_model_get_vocab(model) else {
print("Failed to get vocab")
exit(1)
}
var tokens = tokenize(text: prompt, add_bos: true)
let n_kv_req = UInt32(tokens.count) + UInt32((n_len - Int(tokens.count)) * n_parallel)
@ -41,7 +46,7 @@ context_params.n_batch = UInt32(max(n_len, n_parallel))
context_params.n_threads = 8
context_params.n_threads_batch = 8
let context = llama_new_context_with_model(model, context_params)
let context = llama_init_from_model(model, context_params)
guard context != nil else {
print("Failed to initialize context")
exit(1)
@ -141,7 +146,7 @@ while n_cur <= n_len {
let new_token_id = llama_sampler_sample(smpl, context, i_batch[i])
// is it an end of stream? -> mark the stream as finished
if llama_vocab_is_eog(model, new_token_id) || n_cur == n_len {
if llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_len {
i_batch[i] = -1
// print("")
if n_parallel > 1 {
@ -207,7 +212,7 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
let utf8Count = text.utf8.count
let n_tokens = utf8Count + (add_bos ? 1 : 0)
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
let tokenCount = llama_tokenize(vocab, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, /*special tokens*/ false)
var swiftTokens: [llama_token] = []
for i in 0 ..< tokenCount {
swiftTokens.append(tokens[Int(i)])
@ -218,12 +223,12 @@ private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
private func token_to_piece(token: llama_token, buffer: inout [CChar]) -> String? {
var result = [CChar](repeating: 0, count: 8)
let nTokens = llama_token_to_piece(model, token, &result, Int32(result.count), 0, false)
let nTokens = llama_token_to_piece(vocab, token, &result, Int32(result.count), 0, false)
if nTokens < 0 {
let actualTokensCount = -Int(nTokens)
result = .init(repeating: 0, count: actualTokensCount)
let check = llama_token_to_piece(
model,
vocab,
token,
&result,
Int32(result.count),