diff --git a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt index efe3600a6..be95e2221 100644 --- a/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt +++ b/examples/llama.android/app/src/main/java/com/example/llama/MainViewModel.kt @@ -10,6 +10,11 @@ import kotlinx.coroutines.flow.catch import kotlinx.coroutines.launch class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() { + companion object { + @JvmStatic + private val NanosPerSecond = 1_000_000_000.0 + } + private val tag: String? = this::class.simpleName var messages by mutableStateOf(listOf("Initializing...")) @@ -39,9 +44,9 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() { messages += "" viewModelScope.launch { - llm.send(if (text.last() == '\n') text else text + "\n") + llm.send(text) .catch { - Log.e(tag, "send() flow failed", it) + Log.e(tag, "send() failed", it) messages += it.message!! } .collect { messages = messages.dropLast(1) + (messages.last() + it) } @@ -51,8 +56,23 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() { fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) { viewModelScope.launch { try { - llm.bench(pp, tg, pl, nr) + val start = System.nanoTime() + val warmupResult = llm.bench(pp, tg, pl, nr) + val end = System.nanoTime() + + messages += warmupResult + + val warmup = (end - start).toDouble() / NanosPerSecond + messages += "Warm up time: $warmup seconds, please wait..." + + if (warmup > 5.0) { + messages += "Warm up took too long, aborting benchmark" + return@launch + } + + messages += llm.bench(512, 128, 1, 3) } catch (exc: IllegalStateException) { + Log.e(tag, "bench() failed", exc) messages += exc.message!! } } @@ -64,6 +84,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() { llm.load(pathToModel) messages += "Loaded $pathToModel" } catch (exc: IllegalStateException) { + Log.e(tag, "load() failed", exc) messages += exc.message!! } }