From 22da7bc379b491680a7db25600c14f8addfbc93d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E4=B8=BA?= Date: Wed, 6 Nov 2024 11:20:36 +0800 Subject: [PATCH] add returned string (pure c const char* type) for omni-vlm inference api --- examples/omni-vlm/omni-vlm-wrapper.cpp | 22 +++++++++++++++++----- examples/omni-vlm/omni-vlm-wrapper.h | 2 +- examples/omni-vlm/omni_vlm_cpp.py | 2 +- examples/omni-vlm/omni_vlm_demo.py | 6 ++++-- 4 files changed, 23 insertions(+), 9 deletions(-) diff --git a/examples/omni-vlm/omni-vlm-wrapper.cpp b/examples/omni-vlm/omni-vlm-wrapper.cpp index 81178205e..8c8cac786 100644 --- a/examples/omni-vlm/omni-vlm-wrapper.cpp +++ b/examples/omni-vlm/omni-vlm-wrapper.cpp @@ -24,6 +24,8 @@ struct omnivlm_context { struct llama_model * model = NULL; }; +void* internal_chars = nullptr; + static struct gpt_params params; static struct llama_model* model; static struct omnivlm_context* ctx_omnivlm; @@ -128,7 +130,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling, return ret.c_str(); } -static void process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_image_embed * image_embed, gpt_params * params, const std::string & prompt) { +static const char* process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_image_embed * image_embed, gpt_params * params, const std::string & prompt) { int n_past = 0; const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; @@ -172,11 +174,11 @@ static void process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_ima std::string response = ""; for (int i = 0; i < max_tgt_len; i++) { const char * tmp = sample(ctx_sampling, ctx_omnivlm->ctx_llama, &n_past); - response += tmp; if (strcmp(tmp, "<|im_end|>") == 0) break; if (strcmp(tmp, "") == 0) break; // if (strstr(tmp, "###")) break; // Yi-VL behavior - printf("%s", tmp); + // printf("%s", tmp); + response += tmp; // if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works) // if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6 // if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6 @@ -186,6 +188,13 @@ static void process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_ima llama_sampling_free(ctx_sampling); printf("\n"); + + // const char* ret_char_ptr = (const char*)(malloc(sizeof(char)*response.size())); + if(internal_chars != nullptr) { free(internal_chars); } + internal_chars = malloc(sizeof(char)*(response.size()+1)); + strncpy((char*)(internal_chars), response.c_str(), response.size()); + ((char*)(internal_chars))[response.size()] = '\0'; + return (const char*)(internal_chars); } static void omnivlm_free(struct omnivlm_context * ctx_omnivlm) { @@ -225,7 +234,7 @@ void omnivlm_init(const char* llm_model_path, const char* projector_model_path) ctx_omnivlm = omnivlm_init_context(¶ms, model); } -void omnivlm_inference(const char *prompt, const char *imag_path) { +const char* omnivlm_inference(const char *prompt, const char *imag_path) { std::string image = imag_path; params.prompt = prompt; auto * image_embed = load_image(ctx_omnivlm, ¶ms, image); @@ -234,13 +243,16 @@ void omnivlm_inference(const char *prompt, const char *imag_path) { throw std::runtime_error("failed to load image " + image); } // process the prompt - process_prompt(ctx_omnivlm, image_embed, ¶ms, params.prompt); + const char* ret_chars = process_prompt(ctx_omnivlm, image_embed, ¶ms, params.prompt); // llama_perf_print(ctx_omnivlm->ctx_llama, LLAMA_PERF_TYPE_CONTEXT); omnivlm_image_embed_free(image_embed); + + return ret_chars; } void omnivlm_free() { + if(internal_chars != nullptr) { free(internal_chars); } ctx_omnivlm->model = NULL; omnivlm_free(ctx_omnivlm); llama_free_model(model); diff --git a/examples/omni-vlm/omni-vlm-wrapper.h b/examples/omni-vlm/omni-vlm-wrapper.h index 4ab2c234c..ff37a2550 100644 --- a/examples/omni-vlm/omni-vlm-wrapper.h +++ b/examples/omni-vlm/omni-vlm-wrapper.h @@ -22,7 +22,7 @@ extern "C" { OMNIVLM_API void omnivlm_init(const char* llm_model_path, const char* projector_model_path); -OMNIVLM_API void omnivlm_inference(const char* prompt, const char* imag_path); +OMNIVLM_API const char* omnivlm_inference(const char* prompt, const char* imag_path); OMNIVLM_API void omnivlm_free(); diff --git a/examples/omni-vlm/omni_vlm_cpp.py b/examples/omni-vlm/omni_vlm_cpp.py index 6f23f7c4c..e623ff53b 100644 --- a/examples/omni-vlm/omni_vlm_cpp.py +++ b/examples/omni-vlm/omni_vlm_cpp.py @@ -73,7 +73,7 @@ def omnivlm_inference(prompt: omni_char_p, image_path: omni_char_p): _lib.omnivlm_inference.argtypes = [omni_char_p, omni_char_p] -_lib.omnivlm_inference.restype = None +_lib.omnivlm_inference.restype = omni_char_p def omnivlm_free(): diff --git a/examples/omni-vlm/omni_vlm_demo.py b/examples/omni-vlm/omni_vlm_demo.py index 4f8c5998f..0384631e0 100644 --- a/examples/omni-vlm/omni_vlm_demo.py +++ b/examples/omni-vlm/omni_vlm_demo.py @@ -20,7 +20,7 @@ class NexaOmniVlmInference: def inference(self, prompt: str, image_path: str): prompt = ctypes.c_char_p(prompt.encode("utf-8")) image_path = ctypes.c_char_p(image_path.encode("utf-8")) - omni_vlm_cpp.omnivlm_inference(prompt, image_path) + return omni_vlm_cpp.omnivlm_inference(prompt, image_path) def __del__(self): omni_vlm_cpp.omnivlm_free() @@ -52,4 +52,6 @@ if __name__ == "__main__": while not os.path.exists(image_path): print("ERROR: can not find image in your input path, please check and input agian.") image_path = input() - omni_vlm_obj.inference(prompt, image_path) + response = omni_vlm_obj.inference(prompt, image_path) + print("\tresponse:") + print(response.decode('utf-8'))