This commit is contained in:
Dhruv Anand 2025-01-29 01:50:00 +08:00 committed by GitHub
commit 6db379dc81
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 149 additions and 0 deletions

View file

@ -4,6 +4,7 @@
#include <math.h>
#include <string>
#include <unistd.h>
#include <json.hpp>
#include "llama.h"
#include "common.h"
@ -72,6 +73,121 @@ bool is_valid_utf8(const char * string) {
return true;
}
using json = nlohmann::ordered_json;
template <typename T>
static T json_value(const json & body, const std::string & key, const T & default_value) {
// Fallback null to default value
if (body.contains(key) && !body.at(key).is_null()) {
try {
return body.at(key);
} catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) {
return default_value;
}
} else {
return default_value;
}
}
std::string mapListToJSONString(JNIEnv *env, jobjectArray allMessages) {
json jsonArray = json::array();
jsize arrayLength = env->GetArrayLength(allMessages);
for (jsize i = 0; i < arrayLength; ++i) {
// Get the individual message from the array
jobject messageObj = env->GetObjectArrayElement(allMessages, i);
if (!messageObj) {
LOGe("Error: Received null jobject at index %d", i);
continue;
}
// Check if the object is a Map
jclass mapClass = env->FindClass("java/util/Map");
if (!env->IsInstanceOf(messageObj, mapClass)) {
LOGe("Error: Object is not a Map at index %d", i);
env->DeleteLocalRef(messageObj);
continue;
}
// Get Map methods
jmethodID getMethod = env->GetMethodID(mapClass, "get", "(Ljava/lang/Object;)Ljava/lang/Object;");
jmethodID keySetMethod = env->GetMethodID(mapClass, "keySet", "()Ljava/util/Set;");
if (!getMethod || !keySetMethod) {
LOGe("Error: Could not find Map methods");
env->DeleteLocalRef(messageObj);
continue;
}
// Create a JSON object for this map
json jsonMsg;
// Get role
jstring roleKey = env->NewStringUTF("role");
jobject roleObj = env->CallObjectMethod(messageObj, getMethod, roleKey);
if (roleObj) {
const char* roleStr = env->GetStringUTFChars((jstring)roleObj, nullptr);
jsonMsg["role"] = roleStr;
env->ReleaseStringUTFChars((jstring)roleObj, roleStr);
}
// Get content
jstring contentKey = env->NewStringUTF("content");
jobject contentObj = env->CallObjectMethod(messageObj, getMethod, contentKey);
if (contentObj) {
const char* contentStr = env->GetStringUTFChars((jstring)contentObj, nullptr);
jsonMsg["content"] = contentStr;
env->ReleaseStringUTFChars((jstring)contentObj, contentStr);
}
// Add to array if both role and content were successfully extracted
if (!jsonMsg.empty()) {
jsonArray.push_back(jsonMsg);
}
// Clean up local references
env->DeleteLocalRef(messageObj);
}
return jsonArray.dump();
}
// Format given chat. If tmpl is empty, we take the template from model metadata
inline std::string format_chat(const llama_model *model, const std::string &tmpl, const std::vector<json> &messages) {
std::vector<common_chat_msg> chat;
for (size_t i = 0; i < messages.size(); ++i) {
const auto &curr_msg = messages[i];
std::string role = json_value(curr_msg, "role", std::string(""));
std::string content;
if (curr_msg.contains("content")) {
if (curr_msg["content"].is_string()) {
content = curr_msg["content"].get<std::string>();
} else if (curr_msg["content"].is_array()) {
for (const auto &part : curr_msg["content"]) {
if (part.contains("text")) {
content += "\n" + part["text"].get<std::string>();
}
}
} else {
throw std::runtime_error("Invalid 'content' type.");
}
} else {
throw std::runtime_error("Missing 'content'.");
}
chat.push_back({role, content});
}
const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true);
LOGi("formatted_chat: '%s'\n", formatted_chat.c_str());
return formatted_chat;
}
static void log_callback(ggml_log_level level, const char * fmt, void * data) {
if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
@ -450,3 +566,20 @@ JNIEXPORT void JNICALL
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
}
extern "C"
JNIEXPORT jstring JNICALL
Java_android_llama_cpp_LLamaAndroid_apply_1chat_1template(JNIEnv *env, jobject, jobjectArray allMessages, jlong model){
try {
// Convert the messages to JSON
std::string parsedData = mapListToJSONString(env, allMessages);
// Parse and format
std::vector<json> jsonMessages = json::parse(parsedData);
const auto formattedPrompts = format_chat(reinterpret_cast<const llama_model *>(model), "", jsonMessages);
return env->NewStringUTF(formattedPrompts.c_str());
} catch (const std::exception &e) {
LOGe("Error processing data: %s", e.what());
return env->NewStringUTF("");
}
}

View file

@ -79,6 +79,8 @@ class LLamaAndroid {
private external fun kv_cache_clear(context: Long)
private external fun apply_chat_template(allmessages: Array<Map<String, String>>, model: Long): String
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
return withContext(runLoop) {
when (val state = threadLocalState.get()) {
@ -153,6 +155,20 @@ class LLamaAndroid {
}
}
}
// call this function before sending the message using send function
suspend fun applyChatTemplate(messages: List<Map<String, String>>): String {
var data = ""
withContext(runLoop){
when(val state = threadLocalState.get()){
is State.Loaded -> {
val arrayMessages = messages.toTypedArray() //Convert list to array for JNI compatibility
data = apply_chat_template(allmessages = arrayMessages, model = state.model)
}
else -> {}
}
}
return data
}
companion object {
private class IntVar(value: Int) {