diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 03d7c6897..4a2da4982 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -9,13 +9,18 @@ enum LlamaError: Error { actor LlamaContext { private var model: OpaquePointer private var context: OpaquePointer - + private var batch: llama_batch private var tokens_list: [llama_token] + var n_len: Int32 = 32 + var n_cur: Int32 = 0 + var n_decode: Int32 = 0 + init(model: OpaquePointer, context: OpaquePointer) { self.model = model self.context = context self.tokens_list = [] + self.batch = llama_batch_init(512, 0) } deinit { @@ -26,14 +31,15 @@ actor LlamaContext { static func createContext(path: String) throws -> LlamaContext { llama_backend_init(false) - let params = llama_context_default_params() - let model = llama_load_model_from_file(path, params) + let model_params = llama_model_default_params() + + let model = llama_load_model_from_file(path, model_params) guard let model else { print("Could not load model at \(path)") throw LlamaError.couldNotInitializeContext } - - let context = llama_new_context_with_model(model, params) + let ctx_params = llama_context_default_params() + let context = llama_new_context_with_model(model, ctx_params) guard let context else { print("Could not load context!") throw LlamaError.couldNotInitializeContext @@ -42,50 +48,52 @@ actor LlamaContext { return LlamaContext(model: model, context: context) } - func get_kv_cache() -> Int32 { - return llama_get_kv_cache_token_count(context) + func get_n_tokens() -> Int32 { + return batch.n_tokens; } - func completion_init(text: String) -> Int32 { + func completion_init(text: String) { print("attempting to complete \(text)...") tokens_list = tokenize(text: text, add_bos: true) - let max_context_size = llama_n_ctx(context) - let max_tokens_list_size = max_context_size - 4 + let n_ctx = llama_n_ctx(context) + let n_kv_req = tokens_list.count + (Int(n_len) - tokens_list.count) - if tokens_list.count > max_tokens_list_size { - print("error: prompt too long (\(tokens_list.count) tokens, max \(max_tokens_list_size)") + print("\n n_len = \(n_len), n_ctx = \(n_ctx), n_kv_req = \(n_kv_req)") + + if n_kv_req > n_ctx { + print("error: n_kv_req > n_ctx, the required KV cache size is not big enough") } for id in tokens_list { print(token_to_piece(token: id)) } - let n_gen = min(32, max_context_size) - return n_gen + // batch = llama_batch_init(512, 0) // done in init() + batch.n_tokens = Int32(tokens_list.count) + + for i1 in 0...batch.n_tokens { + let i = Int(i1) + batch.token[i] = tokens_list[i] + batch.pos[i] = i1 + batch.seq_id[i] = 0 + batch.logits[i] = 0 // false + } + batch.logits[Int(batch.n_tokens) - 1] = 1 // true + + if llama_decode(context, batch) != 0 { + print("llama_decode() failed") + } + + n_cur = batch.n_tokens } func completion_loop() -> String { - var done = false - tokens_list.withUnsafeBufferPointer() { cArray in - let res = llama_eval(context, cArray.baseAddress, Int32(tokens_list.count), llama_get_kv_cache_token_count(context), 8) - if res != 0 { - print("error evaluating llama!") - done = true - return - } - } - if done { - return "" - } - - tokens_list.removeAll() - var new_token_id: llama_token = 0 - let logits = llama_get_logits(context) let n_vocab = llama_n_vocab(context) + let logits = llama_get_logits(context) var candidates = Array() candidates.reserveCapacity(Int(n_vocab)) @@ -98,9 +106,30 @@ actor LlamaContext { new_token_id = llama_sample_token_greedy(context, &candidates_p) } + + if new_token_id == llama_token_eos(context) || n_cur == n_len { + print("\n") + return "" + } + let new_token_str = token_to_piece(token: new_token_id) print(new_token_str) tokens_list.append(new_token_id) + + batch.token[Int(batch.n_tokens)] = new_token_id + batch.pos[Int(batch.n_tokens)] = n_cur + batch.seq_id[Int(batch.n_tokens)] = 0 + batch.logits[Int(batch.n_tokens)] = 1 // true + + batch.n_tokens += 1 + n_decode += 1 + + n_cur += 1 + + if llama_decode(context, batch) != 0 { + print("failed to evaluate llama!") + } + return new_token_str } diff --git a/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift b/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift index 1c02bdd26..ed1573f4e 100644 --- a/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift +++ b/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift @@ -33,17 +33,14 @@ class LlamaState: ObservableObject { return } messageLog += "Attempting to complete text...\n" - let n_ctx = await llamaContext.completion_init(text: text) - messageLog += "context size: \(n_ctx)\n" + await llamaContext.completion_init(text: text) messageLog += "\(text)" - if n_ctx > 0 { - while await llamaContext.get_kv_cache() < n_ctx { - let result = await llamaContext.completion_loop() - messageLog += "\(result)" - } - await llamaContext.clear() - messageLog += "\n\ndone\n" + while await llamaContext.n_cur <= llamaContext.n_len { + let result = await llamaContext.completion_loop() + messageLog += "\(result)" } + await llamaContext.clear() + messageLog += "\n\ndone\n" } }