llama.swiftui : UX improvements

This commit is contained in:
Georgi Gerganov 2023-12-17 11:46:13 +02:00
parent d36ca171b6
commit 9629448716
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 48 additions and 11 deletions

View file

@ -30,7 +30,7 @@ actor LlamaContext {
/// This variable is used to store temporarily invalid cchars /// This variable is used to store temporarily invalid cchars
private var temporary_invalid_cchars: [CChar] private var temporary_invalid_cchars: [CChar]
var n_len: Int32 = 512 var n_len: Int32 = 64
var n_cur: Int32 = 0 var n_cur: Int32 = 0
var n_decode: Int32 = 0 var n_decode: Int32 = 0
@ -241,6 +241,8 @@ actor LlamaContext {
let t_tg_end = ggml_time_us() let t_tg_end = ggml_time_us()
llama_kv_cache_clear(context)
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0 let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0 let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
@ -254,11 +256,12 @@ actor LlamaContext {
func clear() { func clear() {
tokens_list.removeAll() tokens_list.removeAll()
temporary_invalid_cchars.removeAll() temporary_invalid_cchars.removeAll()
llama_kv_cache_clear(context)
} }
private func tokenize(text: String, add_bos: Bool) -> [llama_token] { private func tokenize(text: String, add_bos: Bool) -> [llama_token] {
let utf8Count = text.utf8.count let utf8Count = text.utf8.count
let n_tokens = utf8Count + (add_bos ? 1 : 0) let n_tokens = utf8Count + (add_bos ? 1 : 0) + 1
let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens) let tokens = UnsafeMutablePointer<llama_token>.allocate(capacity: n_tokens)
let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false) let tokenCount = llama_tokenize(model, text, Int32(utf8Count), tokens, Int32(n_tokens), add_bos, false)

View file

@ -34,7 +34,6 @@ class LlamaState: ObservableObject {
return return
} }
messageLog += "Attempting to complete text...\n"
await llamaContext.completion_init(text: text) await llamaContext.completion_init(text: text)
messageLog += "\(text)" messageLog += "\(text)"
@ -58,4 +57,13 @@ class LlamaState: ObservableObject {
let result = await llamaContext.bench() let result = await llamaContext.bench()
messageLog += "\(result)" messageLog += "\(result)"
} }
func clear() async {
guard let llamaContext else {
return
}
await llamaContext.clear()
messageLog = ""
}
} }

View file

@ -23,22 +23,26 @@ struct ContentView: View {
var body: some View { var body: some View {
VStack { VStack {
// automatically scroll to bottom of text view
ScrollView(.vertical, showsIndicators: true) { ScrollView(.vertical, showsIndicators: true) {
Text(llamaState.messageLog) Text(llamaState.messageLog)
.font(.system(size: 12))
.frame(maxWidth: .infinity, alignment: .leading)
.padding()
.onTapGesture {
UIApplication.shared.sendAction(#selector(UIResponder.resignFirstResponder), to: nil, from: nil, for: nil)
}
} }
TextEditor(text: $multiLineText) TextEditor(text: $multiLineText)
.frame(height: 200) .frame(height: 80)
.padding() .padding()
.border(Color.gray, width: 0.5) .border(Color.gray, width: 0.5)
// add two buttons "Send" and "Bench" next to each other
HStack { HStack {
Button("Send") { Button("Send") {
sendText() sendText()
} }
.padding() .padding(8)
.background(Color.blue) .background(Color.blue)
.foregroundColor(.white) .foregroundColor(.white)
.cornerRadius(8) .cornerRadius(8)
@ -46,7 +50,15 @@ struct ContentView: View {
Button("Bench") { Button("Bench") {
bench() bench()
} }
.padding() .padding(8)
.background(Color.blue)
.foregroundColor(.white)
.cornerRadius(8)
Button("Clear") {
clear()
}
.padding(8)
.background(Color.blue) .background(Color.blue)
.foregroundColor(.white) .foregroundColor(.white)
.cornerRadius(8) .cornerRadius(8)
@ -55,20 +67,27 @@ struct ContentView: View {
VStack { VStack {
DownloadButton( DownloadButton(
llamaState: llamaState, llamaState: llamaState,
modelName: "TheBloke / TinyLlama-1.1B-1T-OpenOrca-GGUF (Q4_0)", modelName: "TinyLlama-1.1B (Q4_0)",
modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true", modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q4_0.gguf?download=true",
filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf" filename: "tinyllama-1.1b-1t-openorca.Q4_0.gguf"
) )
.font(.system(size: 12))
.padding(.top, 4)
DownloadButton( DownloadButton(
llamaState: llamaState, llamaState: llamaState,
modelName: "TheBloke / TinyLlama-1.1B-1T-OpenOrca-GGUF (Q8_0)", modelName: "TinyLlama-1.1B (Q8_0)",
modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q8_0.gguf?download=true", modelUrl: "https://huggingface.co/TheBloke/TinyLlama-1.1B-1T-OpenOrca-GGUF/resolve/main/tinyllama-1.1b-1t-openorca.Q8_0.gguf?download=true",
filename: "tinyllama-1.1b-1t-openorca.Q8_0.gguf" filename: "tinyllama-1.1b-1t-openorca.Q8_0.gguf"
) )
.font(.system(size: 12))
Button("Clear downloaded models") { Button("Clear downloaded models") {
ContentView.cleanupModelCaches() ContentView.cleanupModelCaches()
llamaState.cacheCleared = true llamaState.cacheCleared = true
} }
.padding(8)
.font(.system(size: 12))
} }
} }
.padding() .padding()
@ -86,6 +105,12 @@ struct ContentView: View {
await llamaState.bench() await llamaState.bench()
} }
} }
func clear() {
Task {
await llamaState.clear()
}
}
} }
//#Preview { //#Preview {

View file

@ -31,6 +31,7 @@ struct DownloadButton: View {
private func download() { private func download() {
status = "downloading" status = "downloading"
print("Downloading model \(modelName) from \(modelUrl)")
downloadTask = URLSession.shared.dataTask(with: URL(string: modelUrl)!) { data, response, error in downloadTask = URLSession.shared.dataTask(with: URL(string: modelUrl)!) { data, response, error in
if let error = error { if let error = error {
print("Error: \(error.localizedDescription)") print("Error: \(error.localizedDescription)")
@ -44,12 +45,12 @@ struct DownloadButton: View {
if let data = data { if let data = data {
do { do {
print("Writing to \(filename)")
let fileURL = DownloadButton.getFileURL(filename: filename) let fileURL = DownloadButton.getFileURL(filename: filename)
try data.write(to: fileURL) try data.write(to: fileURL)
llamaState.cacheCleared = false llamaState.cacheCleared = false
print("Downloaded model \(modelName) to \(fileURL)")
status = "downloaded" status = "downloaded"
try llamaState.loadModel(modelUrl: fileURL) try llamaState.loadModel(modelUrl: fileURL)
} catch let err { } catch let err {