From 9629448716a649e519c267cd66dd5bfc5f225c02 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 17 Dec 2023 11:46:13 +0200 Subject: [PATCH] llama.swiftui : UX improvements --- .../llama.cpp.swift/LibLlama.swift | 7 +++- .../llama.swiftui/Models/LlamaState.swift | 10 ++++- .../llama.swiftui/UI/ContentView.swift | 39 +++++++++++++++---- .../llama.swiftui/UI/DownloadButton.swift | 3 +- 4 files changed, 48 insertions(+), 11 deletions(-) diff --git a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift index 3bd144c0f..7114fd7ec 100644 --- a/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift +++ b/examples/llama.swiftui/llama.cpp.swift/LibLlama.swift @@ -30,7 +30,7 @@ actor LlamaContext { /// This variable is used to store temporarily invalid cchars private var temporary_invalid_cchars: [CChar] - var n_len: Int32 = 512 + var n_len: Int32 = 64 var n_cur: Int32 = 0 var n_decode: Int32 = 0 @@ -241,6 +241,8 @@ actor LlamaContext { let t_tg_end = ggml_time_us() + llama_kv_cache_clear(context) + let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0 let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0 @@ -254,11 +256,12 @@ actor LlamaContext { func clear() { tokens_list.removeAll() temporary_invalid_cchars.removeAll() + llama_kv_cache_clear(context) } private func tokenize(text: String, add_bos: Bool) -> [llama_token] { let utf8Count = text.utf8.count - let n_tokens = utf8Count + (add_bos ? 1 : 0) + let n_tokens = utf8Count + (add_bos ? 1 : 0) + 1 let tokens = UnsafeMutablePointer.allocate(capacity: n_tokens) let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false) diff --git a/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift b/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift index 5d7a81e0d..2e80eca7d 100644 --- a/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift +++ b/examples/llama.swiftui/llama.swiftui/Models/LlamaState.swift @@ -34,7 +34,6 @@ class LlamaState: ObservableObject { return } - messageLog += "Attempting to complete text...\n" await llamaContext.completion_init(text: text) messageLog += "\(text)" @@ -58,4 +57,13 @@ class LlamaState: ObservableObject { let result = await llamaContext.bench() messageLog += "\(result)" } + + func clear() async { + guard let llamaContext else { + return + } + + await llamaContext.clear() + messageLog = "" + } } diff --git a/examples/llama.swiftui/llama.swiftui/UI/ContentView.swift b/examples/llama.swiftui/llama.swiftui/UI/ContentView.swift index 66fa4c561..d4985a983 100644 --- a/examples/llama.swiftui/llama.swiftui/UI/ContentView.swift +++ b/examples/llama.swiftui/llama.swiftui/UI/ContentView.swift @@ -23,22 +23,26 @@ struct ContentView: View { var body: some View { VStack { - // automatically scroll to bottom of text view ScrollView(.vertical, showsIndicators: true) { Text(llamaState.messageLog) + .font(.system(size: 12)) + .frame(maxWidth: .infinity, alignment: .leading) + .padding() + .onTapGesture { + UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to: nil, from: nil, for: nil) + } } TextEditor(text: $multiLineText) - .frame(height: 200) + .frame(height: 80) .padding() .border(Color.gray, width: 0.5) - // add two buttons "Send" and "Bench" next to each other HStack { Button("Send") { sendText() } - .padding() + .padding(8) .background(Color.blue) .foregroundColor(.white) .cornerRadius(8) @@ -46,7 +50,15 @@ struct ContentView: View { Button("Bench") { bench() } - .padding() + .padding(8) + .background(Color.blue) + .foregroundColor(.white) + .cornerRadius(8) + + Button("Clear") { + clear() + } + .padding(8) .background(Color.blue) .foregroundColor(.white) .cornerRadius(8) @@ -55,20 +67,27 @@ struct ContentView: View { VStack { DownloadButton( llamaState: llamaState, - modelName: "TheBloke / TinyLlama-1.1B-1T-OpenOrca-GGUF (Q4_0)", + modelName: "TinyLlama-1.1B (Q4_0)", modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true", filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf" ) + .font(.system(size: 12)) + .padding(.top, 4) + DownloadButton( llamaState: llamaState, - modelName: "TheBloke / TinyLlama-1.1B-1T-OpenOrca-GGUF (Q8_0)", + modelName: "TinyLlama-1.1B (Q8_0)", modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q8_0.gguf?download=true", filename: "tinyllama-1.1b-1t-openorca.Q8_0.gguf" ) + .font(.system(size: 12)) + Button("Clear downloaded models") { ContentView.cleanupModelCaches() llamaState.cacheCleared = true } + .padding(8) + .font(.system(size: 12)) } } .padding() @@ -86,6 +105,12 @@ struct ContentView: View { await llamaState.bench() } } + + func clear() { + Task { + await llamaState.clear() + } + } } //#Preview { diff --git a/examples/llama.swiftui/llama.swiftui/UI/DownloadButton.swift b/examples/llama.swiftui/llama.swiftui/UI/DownloadButton.swift index ee30a11df..26877e4a7 100644 --- a/examples/llama.swiftui/llama.swiftui/UI/DownloadButton.swift +++ b/examples/llama.swiftui/llama.swiftui/UI/DownloadButton.swift @@ -31,6 +31,7 @@ struct DownloadButton: View { private func download() { status = "downloading" + print("Downloading model \(modelName) from \(modelUrl)") downloadTask = URLSession.shared.dataTask(with: URL(string: modelUrl)!) { data, response, error in if let error = error { print("Error: \(error.localizedDescription)") @@ -44,12 +45,12 @@ struct DownloadButton: View { if let data = data { do { + print("Writing to \(filename)") let fileURL = DownloadButton.getFileURL(filename: filename) try data.write(to: fileURL) llamaState.cacheCleared = false - print("Downloaded model \(modelName) to \(fileURL)") status = "downloaded" try llamaState.loadModel(modelUrl: fileURL) } catch let err {