From 1487d32b46a210c5619886af8fe24c93091f7ca0 Mon Sep 17 00:00:00 2001 From: T <3923106166@qq.com> Date: Mon, 6 Jan 2025 18:10:50 +0800 Subject: [PATCH] Add streaming for omnivlm (#39) * omni vlm add streaming * omni vlm add streaming --- examples/omni-vlm/omni-vlm-wrapper.cpp | 127 ++++++++++++++++++++++++- examples/omni-vlm/omni-vlm-wrapper.h | 12 ++- 2 files changed, 136 insertions(+), 3 deletions(-) diff --git a/examples/omni-vlm/omni-vlm-wrapper.cpp b/examples/omni-vlm/omni-vlm-wrapper.cpp index d03aa9622..ea3326294 100644 --- a/examples/omni-vlm/omni-vlm-wrapper.cpp +++ b/examples/omni-vlm/omni-vlm-wrapper.cpp @@ -15,10 +15,10 @@ #include #include #include +#include #include "omni-vlm-wrapper.h" - struct omnivlm_context { struct clip_ctx * ctx_clip = NULL; struct llama_context * ctx_llama = NULL; @@ -30,6 +30,53 @@ void* internal_chars = nullptr; static struct common_params params; static struct llama_model* model; static struct omnivlm_context* ctx_omnivlm; +static std::unique_ptr g_oss = nullptr; + +static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past); +static void omnivlm_free(struct omnivlm_context * ctx_omnivlm); + +struct omni_streaming_sample { + struct common_sampler * ctx_sampling_; + std::string image_; + std::string ret_str_; + int32_t n_past_; + int32_t dec_cnt_; + + omni_streaming_sample() = delete; + omni_streaming_sample(const std::string& image) + :image_(image) { + n_past_ = 0; + dec_cnt_ = 0; + params.sparams.top_k = 1; + params.sparams.top_p = 1.0f; + ctx_sampling_ = common_sampler_init(model, params.sparams); + } + + int32_t sample() { + const llama_token id = common_sampler_sample(ctx_sampling_, ctx_omnivlm->ctx_llama, -1); + common_sampler_accept(ctx_sampling_, id, true); + if (llama_token_is_eog(llama_get_model(ctx_omnivlm->ctx_llama), id)) { + ret_str_ = ""; + } else { + ret_str_ = common_token_to_piece(ctx_omnivlm->ctx_llama, id); + } + eval_id(ctx_omnivlm->ctx_llama, id, &n_past_); + + ++dec_cnt_; + return id; + } + + ~omni_streaming_sample() { + common_sampler_free(ctx_sampling_); + if(ctx_omnivlm != nullptr) { + ctx_omnivlm->model = nullptr; + omnivlm_free(ctx_omnivlm); + free(ctx_omnivlm); + ctx_omnivlm = nullptr; + } + } +}; + static struct omni_image_embed * load_image(omnivlm_context * ctx_omnivlm, common_params * params, const std::string & fname) { @@ -286,3 +333,81 @@ void omnivlm_free() { } llama_free_model(model); } + + +struct omni_streaming_sample* omnivlm_inference_streaming(const char *prompt, const char *imag_path) { + if (g_oss) { + g_oss.reset(); + } + g_oss = std::make_unique(std::string(imag_path)); + + ctx_omnivlm = omnivlm_init_context(¶ms, model); + + params.prompt = prompt; + + if (params.omni_vlm_version == "vlm-81-ocr") { + params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n <|ocr_start|><|vision_start|><|image_pad|><|vision_end|><|ocr_end|><|im_end|>"; + } else if (params.omni_vlm_version == "vlm-81-instruct" || params.omni_vlm_version == "nano-vlm-instruct") { + params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n\n<|vision_start|><|image_pad|><|vision_end|>" + params.prompt + "<|im_end|>"; + } else { + LOG_ERR("%s : error: you set wrong vlm version info:'%s'.\n", __func__, params.omni_vlm_version.c_str()); + throw std::runtime_error("You set wrong vlm_version info strings."); + } + + return g_oss.get(); +} + +int32_t sample(omni_streaming_sample* oss) { + const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; + int32_t ret_id; + if(oss->n_past_ == 0) { + auto * image_embed = load_image(ctx_omnivlm, ¶ms, oss->image_); + if (!image_embed) { + LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, oss->image_.c_str()); + throw std::runtime_error("failed to load image " + oss->image_); + } + + size_t image_pos = params.prompt.find("<|image_pad|>"); + std::string system_prompt, user_prompt; + + system_prompt = params.prompt.substr(0, image_pos); + user_prompt = params.prompt.substr(image_pos + std::string("<|image_pad|>").length()); + if (params.verbose_prompt) { + auto tmp = ::common_tokenize(ctx_omnivlm->ctx_llama, system_prompt, true, true); + for (int i = 0; i < (int) tmp.size(); i++) { + LOG_ERR("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_omnivlm->ctx_llama, tmp[i]).c_str()); + } + } + if (params.verbose_prompt) { + auto tmp = ::common_tokenize(ctx_omnivlm->ctx_llama, user_prompt, true, true); + for (int i = 0; i < (int) tmp.size(); i++) { + LOG_ERR("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_omnivlm->ctx_llama, tmp[i]).c_str()); + } + } + + eval_string(ctx_omnivlm->ctx_llama, system_prompt.c_str(), params.n_batch, &(oss->n_past_), true); + omnivlm_eval_image_embed(ctx_omnivlm->ctx_llama, image_embed, params.n_batch, &(oss->n_past_)); + eval_string(ctx_omnivlm->ctx_llama, user_prompt.c_str(), params.n_batch, &(oss->n_past_), false); + + omnivlm_image_embed_free(image_embed); + + ret_id = oss->sample(); + if (oss->ret_str_ == "<|im_end|>" || oss->ret_str_ == "" ) { + ret_id = -1; + } + } else { + if(oss->dec_cnt_ == max_tgt_len) { + ret_id = -2; + } else { + ret_id = oss->sample(); + if (oss->ret_str_ == "<|im_end|>" || oss->ret_str_ == "" ) { + ret_id = -1; + } + } + } + return ret_id; +} + +const char* get_str(omni_streaming_sample* oss) { + return oss->ret_str_.c_str(); +} \ No newline at end of file diff --git a/examples/omni-vlm/omni-vlm-wrapper.h b/examples/omni-vlm/omni-vlm-wrapper.h index 22cc40533..3cc495a75 100644 --- a/examples/omni-vlm/omni-vlm-wrapper.h +++ b/examples/omni-vlm/omni-vlm-wrapper.h @@ -1,6 +1,6 @@ - #ifndef OMNIVLMWRAPPER_H #define OMNIVLMWRAPPER_H +#include #ifdef LLAMA_SHARED # if defined(_WIN32) && !defined(__MINGW32__) @@ -20,14 +20,22 @@ extern "C" { #endif +struct omni_streaming_sample; + OMNIVLM_API void omnivlm_init(const char* llm_model_path, const char* projector_model_path, const char* omni_vlm_version); OMNIVLM_API const char* omnivlm_inference(const char* prompt, const char* imag_path); +OMNIVLM_API struct omni_streaming_sample* omnivlm_inference_streaming(const char* prompt, const char* imag_path); + +OMNIVLM_API int32_t sample(struct omni_streaming_sample *); + +OMNIVLM_API const char* get_str(struct omni_streaming_sample *); + OMNIVLM_API void omnivlm_free(); #ifdef __cplusplus } #endif -#endif +#endif \ No newline at end of file