Sync bench code
This commit is contained in:
parent
943bba2e5d
commit
abfc5188e9
1 changed files with 24 additions and 3 deletions
|
@ -10,6 +10,11 @@ import kotlinx.coroutines.flow.catch
|
||||||
import kotlinx.coroutines.launch
|
import kotlinx.coroutines.launch
|
||||||
|
|
||||||
class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
|
class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
|
||||||
|
companion object {
|
||||||
|
@JvmStatic
|
||||||
|
private val NanosPerSecond = 1_000_000_000.0
|
||||||
|
}
|
||||||
|
|
||||||
private val tag: String? = this::class.simpleName
|
private val tag: String? = this::class.simpleName
|
||||||
|
|
||||||
var messages by mutableStateOf(listOf("Initializing..."))
|
var messages by mutableStateOf(listOf("Initializing..."))
|
||||||
|
@ -39,9 +44,9 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
|
||||||
messages += ""
|
messages += ""
|
||||||
|
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
llm.send(if (text.last() == '\n') text else text + "\n")
|
llm.send(text)
|
||||||
.catch {
|
.catch {
|
||||||
Log.e(tag, "send() flow failed", it)
|
Log.e(tag, "send() failed", it)
|
||||||
messages += it.message!!
|
messages += it.message!!
|
||||||
}
|
}
|
||||||
.collect { messages = messages.dropLast(1) + (messages.last() + it) }
|
.collect { messages = messages.dropLast(1) + (messages.last() + it) }
|
||||||
|
@ -51,8 +56,23 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
|
||||||
fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) {
|
fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1) {
|
||||||
viewModelScope.launch {
|
viewModelScope.launch {
|
||||||
try {
|
try {
|
||||||
llm.bench(pp, tg, pl, nr)
|
val start = System.nanoTime()
|
||||||
|
val warmupResult = llm.bench(pp, tg, pl, nr)
|
||||||
|
val end = System.nanoTime()
|
||||||
|
|
||||||
|
messages += warmupResult
|
||||||
|
|
||||||
|
val warmup = (end - start).toDouble() / NanosPerSecond
|
||||||
|
messages += "Warm up time: $warmup seconds, please wait..."
|
||||||
|
|
||||||
|
if (warmup > 5.0) {
|
||||||
|
messages += "Warm up took too long, aborting benchmark"
|
||||||
|
return@launch
|
||||||
|
}
|
||||||
|
|
||||||
|
messages += llm.bench(512, 128, 1, 3)
|
||||||
} catch (exc: IllegalStateException) {
|
} catch (exc: IllegalStateException) {
|
||||||
|
Log.e(tag, "bench() failed", exc)
|
||||||
messages += exc.message!!
|
messages += exc.message!!
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -64,6 +84,7 @@ class MainViewModel(private val llm: Llm = Llm.instance()): ViewModel() {
|
||||||
llm.load(pathToModel)
|
llm.load(pathToModel)
|
||||||
messages += "Loaded $pathToModel"
|
messages += "Loaded $pathToModel"
|
||||||
} catch (exc: IllegalStateException) {
|
} catch (exc: IllegalStateException) {
|
||||||
|
Log.e(tag, "load() failed", exc)
|
||||||
messages += exc.message!!
|
messages += exc.message!!
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue