Merge 996dc4cdd2
into cae9fb4361
This commit is contained in:
commit
6db379dc81
2 changed files with 149 additions and 0 deletions
|
@ -4,6 +4,7 @@
|
|||
#include <math.h>
|
||||
#include <string>
|
||||
#include <unistd.h>
|
||||
#include <json.hpp>
|
||||
#include "llama.h"
|
||||
#include "common.h"
|
||||
|
||||
|
@ -72,6 +73,121 @@ bool is_valid_utf8(const char * string) {
|
|||
return true;
|
||||
}
|
||||
|
||||
|
||||
using json = nlohmann::ordered_json;
|
||||
|
||||
template <typename T>
|
||||
static T json_value(const json & body, const std::string & key, const T & default_value) {
|
||||
// Fallback null to default value
|
||||
if (body.contains(key) && !body.at(key).is_null()) {
|
||||
try {
|
||||
return body.at(key);
|
||||
} catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) {
|
||||
|
||||
return default_value;
|
||||
}
|
||||
} else {
|
||||
return default_value;
|
||||
}
|
||||
}
|
||||
|
||||
std::string mapListToJSONString(JNIEnv *env, jobjectArray allMessages) {
|
||||
json jsonArray = json::array();
|
||||
|
||||
jsize arrayLength = env->GetArrayLength(allMessages);
|
||||
for (jsize i = 0; i < arrayLength; ++i) {
|
||||
// Get the individual message from the array
|
||||
jobject messageObj = env->GetObjectArrayElement(allMessages, i);
|
||||
if (!messageObj) {
|
||||
LOGe("Error: Received null jobject at index %d", i);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Check if the object is a Map
|
||||
jclass mapClass = env->FindClass("java/util/Map");
|
||||
if (!env->IsInstanceOf(messageObj, mapClass)) {
|
||||
LOGe("Error: Object is not a Map at index %d", i);
|
||||
env->DeleteLocalRef(messageObj);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Get Map methods
|
||||
jmethodID getMethod = env->GetMethodID(mapClass, "get", "(Ljava/lang/Object;)Ljava/lang/Object;");
|
||||
jmethodID keySetMethod = env->GetMethodID(mapClass, "keySet", "()Ljava/util/Set;");
|
||||
if (!getMethod || !keySetMethod) {
|
||||
LOGe("Error: Could not find Map methods");
|
||||
env->DeleteLocalRef(messageObj);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Create a JSON object for this map
|
||||
json jsonMsg;
|
||||
|
||||
// Get role
|
||||
jstring roleKey = env->NewStringUTF("role");
|
||||
jobject roleObj = env->CallObjectMethod(messageObj, getMethod, roleKey);
|
||||
if (roleObj) {
|
||||
const char* roleStr = env->GetStringUTFChars((jstring)roleObj, nullptr);
|
||||
jsonMsg["role"] = roleStr;
|
||||
env->ReleaseStringUTFChars((jstring)roleObj, roleStr);
|
||||
}
|
||||
|
||||
// Get content
|
||||
jstring contentKey = env->NewStringUTF("content");
|
||||
jobject contentObj = env->CallObjectMethod(messageObj, getMethod, contentKey);
|
||||
if (contentObj) {
|
||||
const char* contentStr = env->GetStringUTFChars((jstring)contentObj, nullptr);
|
||||
jsonMsg["content"] = contentStr;
|
||||
env->ReleaseStringUTFChars((jstring)contentObj, contentStr);
|
||||
}
|
||||
|
||||
// Add to array if both role and content were successfully extracted
|
||||
if (!jsonMsg.empty()) {
|
||||
jsonArray.push_back(jsonMsg);
|
||||
}
|
||||
|
||||
// Clean up local references
|
||||
env->DeleteLocalRef(messageObj);
|
||||
}
|
||||
|
||||
return jsonArray.dump();
|
||||
}
|
||||
|
||||
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||
inline std::string format_chat(const llama_model *model, const std::string &tmpl, const std::vector<json> &messages) {
|
||||
std::vector<common_chat_msg> chat;
|
||||
|
||||
for (size_t i = 0; i < messages.size(); ++i) {
|
||||
const auto &curr_msg = messages[i];
|
||||
|
||||
std::string role = json_value(curr_msg, "role", std::string(""));
|
||||
std::string content;
|
||||
|
||||
if (curr_msg.contains("content")) {
|
||||
if (curr_msg["content"].is_string()) {
|
||||
content = curr_msg["content"].get<std::string>();
|
||||
} else if (curr_msg["content"].is_array()) {
|
||||
for (const auto &part : curr_msg["content"]) {
|
||||
if (part.contains("text")) {
|
||||
content += "\n" + part["text"].get<std::string>();
|
||||
}
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Invalid 'content' type.");
|
||||
}
|
||||
} else {
|
||||
throw std::runtime_error("Missing 'content'.");
|
||||
}
|
||||
|
||||
chat.push_back({role, content});
|
||||
}
|
||||
|
||||
const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true);
|
||||
LOGi("formatted_chat: '%s'\n", formatted_chat.c_str());
|
||||
|
||||
return formatted_chat;
|
||||
}
|
||||
|
||||
static void log_callback(ggml_log_level level, const char * fmt, void * data) {
|
||||
if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
|
||||
else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
|
||||
|
@ -450,3 +566,20 @@ JNIEXPORT void JNICALL
|
|||
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
|
||||
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
|
||||
}
|
||||
|
||||
extern "C"
|
||||
JNIEXPORT jstring JNICALL
|
||||
Java_android_llama_cpp_LLamaAndroid_apply_1chat_1template(JNIEnv *env, jobject, jobjectArray allMessages, jlong model){
|
||||
try {
|
||||
// Convert the messages to JSON
|
||||
std::string parsedData = mapListToJSONString(env, allMessages);
|
||||
// Parse and format
|
||||
std::vector<json> jsonMessages = json::parse(parsedData);
|
||||
const auto formattedPrompts = format_chat(reinterpret_cast<const llama_model *>(model), "", jsonMessages);
|
||||
|
||||
return env->NewStringUTF(formattedPrompts.c_str());
|
||||
} catch (const std::exception &e) {
|
||||
LOGe("Error processing data: %s", e.what());
|
||||
return env->NewStringUTF("");
|
||||
}
|
||||
}
|
||||
|
|
|
@ -79,6 +79,8 @@ class LLamaAndroid {
|
|||
|
||||
private external fun kv_cache_clear(context: Long)
|
||||
|
||||
private external fun apply_chat_template(allmessages: Array<Map<String, String>>, model: Long): String
|
||||
|
||||
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
|
||||
return withContext(runLoop) {
|
||||
when (val state = threadLocalState.get()) {
|
||||
|
@ -153,6 +155,20 @@ class LLamaAndroid {
|
|||
}
|
||||
}
|
||||
}
|
||||
// call this function before sending the message using send function
|
||||
suspend fun applyChatTemplate(messages: List<Map<String, String>>): String {
|
||||
var data = ""
|
||||
withContext(runLoop){
|
||||
when(val state = threadLocalState.get()){
|
||||
is State.Loaded -> {
|
||||
val arrayMessages = messages.toTypedArray() //Convert list to array for JNI compatibility
|
||||
data = apply_chat_template(allmessages = arrayMessages, model = state.model)
|
||||
}
|
||||
else -> {}
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
companion object {
|
||||
private class IntVar(value: Int) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue