llama.swiftui : initial bench functionality

This commit is contained in:
Georgi Gerganov 2023-12-15 16:39:16 +02:00
parent afd336f7a6
commit 6a8680204c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 114 additions and 29 deletions

View file

@ -6,6 +6,22 @@ enum LlamaError: Error {
case couldNotInitializeContext
}
func llama_batch_clear(_ batch: inout llama_batch) {
batch.n_tokens = 0
}
func llama_batch_add(_ batch: inout llama_batch, _ id: llama_token, _ pos: llama_pos, _ seq_ids: [llama_seq_id], _ logits: Bool) {
batch.token [Int(batch.n_tokens)] = id
batch.pos [Int(batch.n_tokens)] = pos
batch.n_seq_id[Int(batch.n_tokens)] = Int32(seq_ids.count)
for i in 0..<seq_ids.count {
batch.seq_id[Int(batch.n_tokens)]![Int(i)] = seq_ids[i]
}
batch.logits [Int(batch.n_tokens)] = logits ? 1 : 0
batch.n_tokens += 1
}
actor LlamaContext {
private var model: OpaquePointer
private var context: OpaquePointer
@ -16,6 +32,7 @@ actor LlamaContext {
var n_len: Int32 = 512
var n_cur: Int32 = 0
var n_decode: Int32 = 0
init(model: OpaquePointer, context: OpaquePointer) {
@ -27,12 +44,13 @@ actor LlamaContext {
}
deinit {
llama_batch_free(batch)
llama_free(context)
llama_free_model(model)
llama_backend_free()
}
static func createContext(path: String) throws -> LlamaContext {
static func create_context(path: String) throws -> LlamaContext {
llama_backend_init(false)
let model_params = llama_model_default_params()
@ -41,11 +59,15 @@ actor LlamaContext {
print("Could not load model at \(path)")
throw LlamaError.couldNotInitializeContext
}
let n_threads = max(1, min(8, ProcessInfo.processInfo.processorCount - 2))
print("Using \(n_threads) threads")
var ctx_params = llama_context_default_params()
ctx_params.seed = 1234
ctx_params.seed = 1234
ctx_params.n_ctx = 2048
ctx_params.n_threads = 8
ctx_params.n_threads_batch = 8
ctx_params.n_threads = UInt32(n_threads)
ctx_params.n_threads_batch = UInt32(n_threads)
let context = llama_new_context_with_model(model, ctx_params)
guard let context else {
@ -56,6 +78,26 @@ actor LlamaContext {
return LlamaContext(model: model, context: context)
}
func model_info() -> String {
let result = UnsafeMutablePointer<Int8>.allocate(capacity: 256)
result.initialize(repeating: Int8(0), count: 256)
defer {
result.deallocate()
}
// TODO: this is probably very stupid way to get the string from C
let nChars = llama_model_desc(model, result, 256)
let bufferPointer = UnsafeBufferPointer(start: result, count: Int(nChars))
var SwiftString = ""
for char in bufferPointer {
SwiftString.append(Character(UnicodeScalar(UInt8(char))))
}
return SwiftString
}
func get_n_tokens() -> Int32 {
return batch.n_tokens;
}
@ -79,16 +121,11 @@ actor LlamaContext {
print(String(cString: token_to_piece(token: id) + [0]))
}
// batch = llama_batch_init(512, 0) // done in init()
batch.n_tokens = Int32(tokens_list.count)
llama_batch_clear(&batch)
for i1 in 0..<batch.n_tokens {
for i1 in 0..<tokens_list.count {
let i = Int(i1)
batch.token[i] = tokens_list[i]
batch.pos[i] = i1
batch.n_seq_id[Int(i)] = 1
batch.seq_id[Int(i)]![0] = 0
batch.logits[i] = 0
llama_batch_add(&batch, tokens_list[i], Int32(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
@ -141,18 +178,11 @@ actor LlamaContext {
print(new_token_str)
// tokens_list.append(new_token_id)
batch.n_tokens = 0
batch.token[Int(batch.n_tokens)] = new_token_id
batch.pos[Int(batch.n_tokens)] = n_cur
batch.n_seq_id[Int(batch.n_tokens)] = 1
batch.seq_id[Int(batch.n_tokens)]![0] = 0
batch.logits[Int(batch.n_tokens)] = 1 // true
batch.n_tokens += 1
llama_batch_clear(&batch)
llama_batch_add(&batch, new_token_id, n_cur, [0], true)
n_decode += 1
n_cur += 1
n_cur += 1
if llama_decode(context, batch) != 0 {
print("failed to evaluate llama!")
@ -161,8 +191,60 @@ actor LlamaContext {
return new_token_str
}
func bench() -> String{
return "bench not implemented"
func bench() -> String {
let pp = 512
let tg = 128
let pl = 1
// bench prompt processing
llama_batch_clear(&batch)
let n_tokens = pp
for i in 0..<n_tokens {
llama_batch_add(&batch, 0, Int32(i), [0], false)
}
batch.logits[Int(batch.n_tokens) - 1] = 1 // true
llama_kv_cache_clear(context)
let t_pp_start = ggml_time_us()
if llama_decode(context, batch) != 0 {
print("llama_decode() failed during prompt")
}
let t_pp_end = ggml_time_us()
// bench text generation
llama_kv_cache_clear(context)
let t_tg_start = ggml_time_us()
for i in 0..<tg {
llama_batch_clear(&batch)
for j in 0..<pl {
llama_batch_add(&batch, 0, Int32(i), [Int32(j)], true)
}
if llama_decode(context, batch) != 0 {
print("llama_decode() failed during text generation")
}
}
let t_tg_end = ggml_time_us()
let t_pp = Double(t_pp_end - t_pp_start) / 1000000.0
let t_tg = Double(t_tg_end - t_tg_start) / 1000000.0
let speed_pp = Double(pp) / t_pp
let speed_tg = Double(pl*tg) / t_tg
return String(format: "PP 512 speed: %7.2f t/s\n", speed_pp) +
String(format: "TG 128 speed: %7.2f t/s\n", speed_tg)
}
func clear() {

View file

@ -6,7 +6,7 @@ class LlamaState: ObservableObject {
private var llamaContext: LlamaContext?
private var modelUrl: URL? {
Bundle.main.url(forResource: "ggml-model-q8_0", withExtension: "gguf", subdirectory: "models")
Bundle.main.url(forResource: "ggml-model", withExtension: "gguf", subdirectory: "models")
// Bundle.main.url(forResource: "llama-2-7b-chat", withExtension: "Q2_K.gguf", subdirectory: "models")
}
init() {
@ -20,7 +20,7 @@ class LlamaState: ObservableObject {
private func loadModel() throws {
messageLog += "Loading model...\n"
if let modelUrl {
llamaContext = try LlamaContext.createContext(path: modelUrl.path())
llamaContext = try LlamaContext.create_context(path: modelUrl.path())
messageLog += "Loaded model \(modelUrl.lastPathComponent)\n"
} else {
messageLog += "Could not locate model\n"
@ -49,7 +49,10 @@ class LlamaState: ObservableObject {
return
}
messageLog += "Model info: "
messageLog += await llamaContext.model_info() + "\n"
messageLog += "Running benchmark...\n"
await llamaContext.bench() // heat up
let result = await llamaContext.bench()
messageLog += "\(result)"
}

View file

@ -53,6 +53,6 @@ struct ContentView: View {
}
}
#Preview {
ContentView()
}
//#Preview {
// ContentView()
//}