Sync bench code

This commit is contained in:
Neuman Vong 2024-01-16 11:10:35 +11:00
parent 943bba2e5d
commit abfc5188e9

View file

@ -10,6 +10,11 @@ import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.launch import kotlinx.coroutines.launch
class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() { class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
companion object {
@JvmStatic
private val NanosPerSecond = 1_000_000_000.0
}
private val tag: String? = this::class.simpleName private val tag: String? = this::class.simpleName
var messages by mutableStateOf(listOf("Initializing...")) var messages by mutableStateOf(listOf("Initializing..."))
@ -39,9 +44,9 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
messages += "" messages += ""
viewModelScope.launch { viewModelScope.launch {
llm.send(if (text.last() == '\n') text else text + "\n") llm.send(text)
.catch { .catch {
Log.e(tag, "send() flow failed", it) Log.e(tag, "send() failed", it)
messages += it.message!! messages += it.message!!
} }
.collect { messages = messages.dropLast(1) + (messages.last() + it) } .collect { messages = messages.dropLast(1) + (messages.last() + it) }
@ -51,8 +56,23 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) { fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) {
viewModelScope.launch { viewModelScope.launch {
try { try {
llm.bench(pp, tg, pl, nr) val start = System.nanoTime()
val warmupResult = llm.bench(pp, tg, pl, nr)
val end = System.nanoTime()
messages += warmupResult
val warmup = (end - start).toDouble() / NanosPerSecond
messages += "Warm up time: $warmup seconds, please wait..."
if (warmup > 5.0) {
messages += "Warm up took too long, aborting benchmark"
return@launch
}
messages += llm.bench(512, 128, 1, 3)
} catch (exc: IllegalStateException) { } catch (exc: IllegalStateException) {
Log.e(tag, "bench() failed", exc)
messages += exc.message!! messages += exc.message!!
} }
} }
@ -64,6 +84,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
llm.load(pathToModel) llm.load(pathToModel)
messages += "Loaded $pathToModel" messages += "Loaded $pathToModel"
} catch (exc: IllegalStateException) { } catch (exc: IllegalStateException) {
Log.e(tag, "load() failed", exc)
messages += exc.message!! messages += exc.message!!
} }
} }