From 5f546b860eb5cfcc236a12f7ba1cfef62ceae5e0 Mon Sep 17 00:00:00 2001 From: dhruvanand24 Date: Fri, 3 Jan 2025 16:12:55 +0530 Subject: [PATCH 1/2] Created JNI Binding for applying chat template. Created 2 more utility function- mapListToJSONString and format_chat. Added Function to apply chat template in LLamaAndroid.kt --- .../llama/src/main/cpp/llama-android.cpp | 133 ++++++++++++++++++ .../java/android/llama/cpp/LLamaAndroid.kt | 16 +++ 2 files changed, 149 insertions(+) diff --git a/examples/llama.android/llama/src/main/cpp/llama-android.cpp b/examples/llama.android/llama/src/main/cpp/llama-android.cpp index 66ec2aeeb..889caccb1 100644 --- a/examples/llama.android/llama/src/main/cpp/llama-android.cpp +++ b/examples/llama.android/llama/src/main/cpp/llama-android.cpp @@ -4,6 +4,7 @@ #include #include #include +#include #include "llama.h" #include "common.h" @@ -72,6 +73,121 @@ bool is_valid_utf8(const char * string) { return true; } + +using json = nlohmann::ordered_json; + +template +static T json_value(const json & body, const std::string & key, const T & default_value) { + // Fallback null to default value + if (body.contains(key) && !body.at(key).is_null()) { + try { + return body.at(key); + } catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) { + + return default_value; + } + } else { + return default_value; + } +} + +std::string mapListToJSONString(JNIEnv *env, jobjectArray allMessages) { + json jsonArray = json::array(); + + jsize arrayLength = env->GetArrayLength(allMessages); + for (jsize i = 0; i < arrayLength; ++i) { + // Get the individual message from the array + jobject messageObj = env->GetObjectArrayElement(allMessages, i); + if (!messageObj) { + LOGe("Error: Received null jobject at index %d", i); + continue; + } + + // Check if the object is a Map + jclass mapClass = env->FindClass("java/util/Map"); + if (!env->IsInstanceOf(messageObj, mapClass)) { + LOGe("Error: Object is not a Map at index %d", i); + env->DeleteLocalRef(messageObj); + continue; + } + + // Get Map methods + jmethodID getMethod = env->GetMethodID(mapClass, "get", "(Ljava/lang/Object;)Ljava/lang/Object;"); + jmethodID keySetMethod = env->GetMethodID(mapClass, "keySet", "()Ljava/util/Set;"); + if (!getMethod || !keySetMethod) { + LOGe("Error: Could not find Map methods"); + env->DeleteLocalRef(messageObj); + continue; + } + + // Create a JSON object for this map + json jsonMsg; + + // Get role + jstring roleKey = env->NewStringUTF("role"); + jobject roleObj = env->CallObjectMethod(messageObj, getMethod, roleKey); + if (roleObj) { + const char* roleStr = env->GetStringUTFChars((jstring)roleObj, nullptr); + jsonMsg["role"] = roleStr; + env->ReleaseStringUTFChars((jstring)roleObj, roleStr); + } + + // Get content + jstring contentKey = env->NewStringUTF("content"); + jobject contentObj = env->CallObjectMethod(messageObj, getMethod, contentKey); + if (contentObj) { + const char* contentStr = env->GetStringUTFChars((jstring)contentObj, nullptr); + jsonMsg["content"] = contentStr; + env->ReleaseStringUTFChars((jstring)contentObj, contentStr); + } + + // Add to array if both role and content were successfully extracted + if (!jsonMsg.empty()) { + jsonArray.push_back(jsonMsg); + } + + // Clean up local references + env->DeleteLocalRef(messageObj); + } + + return jsonArray.dump(); +} + +// Format given chat. If tmpl is empty, we take the template from model metadata +inline std::string format_chat(const llama_model *model, const std::string &tmpl, const std::vector &messages) { + std::vector chat; + + for (size_t i = 0; i < messages.size(); ++i) { + const auto &curr_msg = messages[i]; + + std::string role = json_value(curr_msg, "role", std::string("")); + std::string content; + + if (curr_msg.contains("content")) { + if (curr_msg["content"].is_string()) { + content = curr_msg["content"].get(); + } else if (curr_msg["content"].is_array()) { + for (const auto &part : curr_msg["content"]) { + if (part.contains("text")) { + content += "\n" + part["text"].get(); + } + } + } else { + throw std::runtime_error("Invalid 'content' type."); + } + } else { + throw std::runtime_error("Missing 'content'."); + } + + chat.push_back({role, content}); + } + + const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true); + LOGi("formatted_chat: '%s'\n", formatted_chat.c_str()); + + return formatted_chat; +} + static void log_callback(ggml_log_level level, const char * fmt, void * data) { if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data); else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data); @@ -447,3 +563,20 @@ JNIEXPORT void JNICALL Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) { llama_kv_cache_clear(reinterpret_cast(context)); } + +extern "C" +JNIEXPORT jstring JNICALL +Java_android_llama_cpp_LLamaAndroid_apply_1chat_1template(JNIEnv *env, jobject, jobjectArray allMessages, jlong model){ + try { + // Convert the messages to JSON + std::string parsedData = mapListToJSONString(env, allMessages); + // Parse and format + std::vector jsonMessages = json::parse(parsedData); + const auto formattedPrompts = format_chat(reinterpret_cast(model), "", jsonMessages); + + return env->NewStringUTF(formattedPrompts.c_str()); + } catch (const std::exception &e) { + LOGe("Error processing data: %s", e.what()); + return env->NewStringUTF(""); + } +} diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index cf520e459..4693a8e29 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -78,6 +78,8 @@ class LLamaAndroid { private external fun kv_cache_clear(context: Long) + private external fun apply_chat_template(allmessages: Array>, model: Long): String + suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String { return withContext(runLoop) { when (val state = threadLocalState.get()) { @@ -153,6 +155,20 @@ class LLamaAndroid { } } + suspend fun applyChatTemplate(messages: List>): String { + var data = "" + withContext(runLoop){ + when(val state = threadLocalState.get()){ + is State.Loaded -> { + val arrayMessages = messages.toTypedArray() //Convert list to array for JNI compatibility + data = apply_chat_template(allmessages = arrayMessages, model = state.model) + } + else -> {} + } + } + return data + } + companion object { private class IntVar(value: Int) { @Volatile From 996dc4cdd2d68be69d774f5dca3794b55fdf5335 Mon Sep 17 00:00:00 2001 From: dhruvanand24 Date: Fri, 3 Jan 2025 16:15:47 +0530 Subject: [PATCH 2/2] Added Relevant comments --- .../llama/src/main/java/android/llama/cpp/LLamaAndroid.kt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt index 4693a8e29..1e1964b53 100644 --- a/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt +++ b/examples/llama.android/llama/src/main/java/android/llama/cpp/LLamaAndroid.kt @@ -154,7 +154,7 @@ class LLamaAndroid { } } } - + // call this function before sending the message using send function suspend fun applyChatTemplate(messages: List>): String { var data = "" withContext(runLoop){