update android example

This commit is contained in:
slaren 2024-10-09 16:47:24 +02:00
parent 3c0b8628cd
commit 6ea0304b20

View file

@ -186,11 +186,11 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
for (nri = 0; nri < nr; nri++) { for (nri = 0; nri < nr; nri++) {
LOGi("Benchmark prompt processing (pp)"); LOGi("Benchmark prompt processing (pp)");
llama_batch_clear(*batch); common_batch_clear(*batch);
const int n_tokens = pp; const int n_tokens = pp;
for (i = 0; i < n_tokens; i++) { for (i = 0; i < n_tokens; i++) {
llama_batch_add(*batch, 0, i, { 0 }, false); common_batch_add(*batch, 0, i, { 0 }, false);
} }
batch->logits[batch->n_tokens - 1] = true; batch->logits[batch->n_tokens - 1] = true;
@ -210,9 +210,9 @@ Java_android_llama_cpp_LLamaAndroid_bench_1model(
const auto t_tg_start = ggml_time_us(); const auto t_tg_start = ggml_time_us();
for (i = 0; i < tg; i++) { for (i = 0; i < tg; i++) {
llama_batch_clear(*batch); common_batch_clear(*batch);
for (j = 0; j < pl; j++) { for (j = 0; j < pl; j++) {
llama_batch_add(*batch, 0, i, { j }, true); common_batch_add(*batch, 0, i, { j }, true);
} }
LOGi("llama_decode() text generation: %d", i); LOGi("llama_decode() text generation: %d", i);
@ -357,7 +357,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
const auto context = reinterpret_cast<llama_context *>(context_pointer); const auto context = reinterpret_cast<llama_context *>(context_pointer);
const auto batch = reinterpret_cast<llama_batch *>(batch_pointer); const auto batch = reinterpret_cast<llama_batch *>(batch_pointer);
const auto tokens_list = llama_tokenize(context, text, 1); const auto tokens_list = common_tokenize(context, text, 1);
auto n_ctx = llama_n_ctx(context); auto n_ctx = llama_n_ctx(context);
auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size()); auto n_kv_req = tokens_list.size() + (n_len - tokens_list.size());
@ -369,14 +369,14 @@ Java_android_llama_cpp_LLamaAndroid_completion_1init(
} }
for (auto id : tokens_list) { for (auto id : tokens_list) {
LOGi("%s", llama_token_to_piece(context, id).c_str()); LOGi("%s", common_token_to_piece(context, id).c_str());
} }
llama_batch_clear(*batch); common_batch_clear(*batch);
// evaluate the initial prompt // evaluate the initial prompt
for (auto i = 0; i < tokens_list.size(); i++) { for (auto i = 0; i < tokens_list.size(); i++) {
llama_batch_add(*batch, tokens_list[i], i, { 0 }, false); common_batch_add(*batch, tokens_list[i], i, { 0 }, false);
} }
// llama_decode will output logits only for the last token of the prompt // llama_decode will output logits only for the last token of the prompt
@ -419,7 +419,7 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
return nullptr; return nullptr;
} }
auto new_token_chars = llama_token_to_piece(context, new_token_id); auto new_token_chars = common_token_to_piece(context, new_token_id);
cached_token_chars += new_token_chars; cached_token_chars += new_token_chars;
jstring new_token = nullptr; jstring new_token = nullptr;
@ -431,8 +431,8 @@ Java_android_llama_cpp_LLamaAndroid_completion_1loop(
new_token = env->NewStringUTF(""); new_token = env->NewStringUTF("");
} }
llama_batch_clear(*batch); common_batch_clear(*batch);
llama_batch_add(*batch, new_token_id, n_cur, { 0 }, true); common_batch_add(*batch, new_token_id, n_cur, { 0 }, true);
env->CallVoidMethod(intvar_ncur, la_int_var_inc); env->CallVoidMethod(intvar_ncur, la_int_var_inc);