Merge 996dc4cdd2
into cae9fb4361
This commit is contained in:
commit
6db379dc81
2 changed files with 149 additions and 0 deletions
|
@ -4,6 +4,7 @@
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
#include <string>
|
#include <string>
|
||||||
#include <unistd.h>
|
#include <unistd.h>
|
||||||
|
#include <json.hpp>
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
|
|
||||||
|
@ -72,6 +73,121 @@ bool is_valid_utf8(const char * string) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
using json = nlohmann::ordered_json;
|
||||||
|
|
||||||
|
template <typename T>
|
||||||
|
static T json_value(const json & body, const std::string & key, const T & default_value) {
|
||||||
|
// Fallback null to default value
|
||||||
|
if (body.contains(key) && !body.at(key).is_null()) {
|
||||||
|
try {
|
||||||
|
return body.at(key);
|
||||||
|
} catch (NLOHMANN_JSON_NAMESPACE::detail::type_error const &) {
|
||||||
|
|
||||||
|
return default_value;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return default_value;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string mapListToJSONString(JNIEnv *env, jobjectArray allMessages) {
|
||||||
|
json jsonArray = json::array();
|
||||||
|
|
||||||
|
jsize arrayLength = env->GetArrayLength(allMessages);
|
||||||
|
for (jsize i = 0; i < arrayLength; ++i) {
|
||||||
|
// Get the individual message from the array
|
||||||
|
jobject messageObj = env->GetObjectArrayElement(allMessages, i);
|
||||||
|
if (!messageObj) {
|
||||||
|
LOGe("Error: Received null jobject at index %d", i);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the object is a Map
|
||||||
|
jclass mapClass = env->FindClass("java/util/Map");
|
||||||
|
if (!env->IsInstanceOf(messageObj, mapClass)) {
|
||||||
|
LOGe("Error: Object is not a Map at index %d", i);
|
||||||
|
env->DeleteLocalRef(messageObj);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get Map methods
|
||||||
|
jmethodID getMethod = env->GetMethodID(mapClass, "get", "(Ljava/lang/Object;)Ljava/lang/Object;");
|
||||||
|
jmethodID keySetMethod = env->GetMethodID(mapClass, "keySet", "()Ljava/util/Set;");
|
||||||
|
if (!getMethod || !keySetMethod) {
|
||||||
|
LOGe("Error: Could not find Map methods");
|
||||||
|
env->DeleteLocalRef(messageObj);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a JSON object for this map
|
||||||
|
json jsonMsg;
|
||||||
|
|
||||||
|
// Get role
|
||||||
|
jstring roleKey = env->NewStringUTF("role");
|
||||||
|
jobject roleObj = env->CallObjectMethod(messageObj, getMethod, roleKey);
|
||||||
|
if (roleObj) {
|
||||||
|
const char* roleStr = env->GetStringUTFChars((jstring)roleObj, nullptr);
|
||||||
|
jsonMsg["role"] = roleStr;
|
||||||
|
env->ReleaseStringUTFChars((jstring)roleObj, roleStr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get content
|
||||||
|
jstring contentKey = env->NewStringUTF("content");
|
||||||
|
jobject contentObj = env->CallObjectMethod(messageObj, getMethod, contentKey);
|
||||||
|
if (contentObj) {
|
||||||
|
const char* contentStr = env->GetStringUTFChars((jstring)contentObj, nullptr);
|
||||||
|
jsonMsg["content"] = contentStr;
|
||||||
|
env->ReleaseStringUTFChars((jstring)contentObj, contentStr);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to array if both role and content were successfully extracted
|
||||||
|
if (!jsonMsg.empty()) {
|
||||||
|
jsonArray.push_back(jsonMsg);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clean up local references
|
||||||
|
env->DeleteLocalRef(messageObj);
|
||||||
|
}
|
||||||
|
|
||||||
|
return jsonArray.dump();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format given chat. If tmpl is empty, we take the template from model metadata
|
||||||
|
inline std::string format_chat(const llama_model *model, const std::string &tmpl, const std::vector<json> &messages) {
|
||||||
|
std::vector<common_chat_msg> chat;
|
||||||
|
|
||||||
|
for (size_t i = 0; i < messages.size(); ++i) {
|
||||||
|
const auto &curr_msg = messages[i];
|
||||||
|
|
||||||
|
std::string role = json_value(curr_msg, "role", std::string(""));
|
||||||
|
std::string content;
|
||||||
|
|
||||||
|
if (curr_msg.contains("content")) {
|
||||||
|
if (curr_msg["content"].is_string()) {
|
||||||
|
content = curr_msg["content"].get<std::string>();
|
||||||
|
} else if (curr_msg["content"].is_array()) {
|
||||||
|
for (const auto &part : curr_msg["content"]) {
|
||||||
|
if (part.contains("text")) {
|
||||||
|
content += "\n" + part["text"].get<std::string>();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Invalid 'content' type.");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw std::runtime_error("Missing 'content'.");
|
||||||
|
}
|
||||||
|
|
||||||
|
chat.push_back({role, content});
|
||||||
|
}
|
||||||
|
|
||||||
|
const auto formatted_chat = common_chat_apply_template(model, tmpl, chat, true);
|
||||||
|
LOGi("formatted_chat: '%s'\n", formatted_chat.c_str());
|
||||||
|
|
||||||
|
return formatted_chat;
|
||||||
|
}
|
||||||
|
|
||||||
static void log_callback(ggml_log_level level, const char * fmt, void * data) {
|
static void log_callback(ggml_log_level level, const char * fmt, void * data) {
|
||||||
if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
|
if (level == GGML_LOG_LEVEL_ERROR) __android_log_print(ANDROID_LOG_ERROR, TAG, fmt, data);
|
||||||
else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
|
else if (level == GGML_LOG_LEVEL_INFO) __android_log_print(ANDROID_LOG_INFO, TAG, fmt, data);
|
||||||
|
@ -450,3 +566,20 @@ JNIEXPORT void JNICALL
|
||||||
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
|
Java_android_llama_cpp_LLamaAndroid_kv_1cache_1clear(JNIEnv *, jobject, jlong context) {
|
||||||
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
|
llama_kv_cache_clear(reinterpret_cast<llama_context *>(context));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
extern "C"
|
||||||
|
JNIEXPORT jstring JNICALL
|
||||||
|
Java_android_llama_cpp_LLamaAndroid_apply_1chat_1template(JNIEnv *env, jobject, jobjectArray allMessages, jlong model){
|
||||||
|
try {
|
||||||
|
// Convert the messages to JSON
|
||||||
|
std::string parsedData = mapListToJSONString(env, allMessages);
|
||||||
|
// Parse and format
|
||||||
|
std::vector<json> jsonMessages = json::parse(parsedData);
|
||||||
|
const auto formattedPrompts = format_chat(reinterpret_cast<const llama_model *>(model), "", jsonMessages);
|
||||||
|
|
||||||
|
return env->NewStringUTF(formattedPrompts.c_str());
|
||||||
|
} catch (const std::exception &e) {
|
||||||
|
LOGe("Error processing data: %s", e.what());
|
||||||
|
return env->NewStringUTF("");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -79,6 +79,8 @@ class LLamaAndroid {
|
||||||
|
|
||||||
private external fun kv_cache_clear(context: Long)
|
private external fun kv_cache_clear(context: Long)
|
||||||
|
|
||||||
|
private external fun apply_chat_template(allmessages: Array<Map<String, String>>, model: Long): String
|
||||||
|
|
||||||
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
|
suspend fun bench(pp: Int, tg: Int, pl: Int, nr: Int = 1): String {
|
||||||
return withContext(runLoop) {
|
return withContext(runLoop) {
|
||||||
when (val state = threadLocalState.get()) {
|
when (val state = threadLocalState.get()) {
|
||||||
|
@ -153,6 +155,20 @@ class LLamaAndroid {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// call this function before sending the message using send function
|
||||||
|
suspend fun applyChatTemplate(messages: List<Map<String, String>>): String {
|
||||||
|
var data = ""
|
||||||
|
withContext(runLoop){
|
||||||
|
when(val state = threadLocalState.get()){
|
||||||
|
is State.Loaded -> {
|
||||||
|
val arrayMessages = messages.toTypedArray() //Convert list to array for JNI compatibility
|
||||||
|
data = apply_chat_template(allmessages = arrayMessages, model = state.model)
|
||||||
|
}
|
||||||
|
else -> {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
companion object {
|
companion object {
|
||||||
private class IntVar(value: Int) {
|
private class IntVar(value: Int) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue