Add streaming for omnivlm (#39)

* omni vlm add streaming

* omni vlm add streaming
This commit is contained in:
T 2025-01-06 18:10:50 +08:00 committed by GitHub
parent 5962b506ba
commit 1487d32b46
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 136 additions and 3 deletions

View file

@ -15,10 +15,10 @@
#include <vector> #include <vector>
#include <string> #include <string>
#include <iostream> #include <iostream>
#include <memory>
#include "omni-vlm-wrapper.h" #include "omni-vlm-wrapper.h"
struct omnivlm_context { struct omnivlm_context {
struct clip_ctx * ctx_clip = NULL; struct clip_ctx * ctx_clip = NULL;
struct llama_context * ctx_llama = NULL; struct llama_context * ctx_llama = NULL;
@ -30,6 +30,53 @@ void* internal_chars = nullptr;
static struct common_params params; static struct common_params params;
static struct llama_model* model; static struct llama_model* model;
static struct omnivlm_context* ctx_omnivlm; static struct omnivlm_context* ctx_omnivlm;
static std::unique_ptr<struct omni_streaming_sample> g_oss = nullptr;
static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past);
static void omnivlm_free(struct omnivlm_context * ctx_omnivlm);
struct omni_streaming_sample {
struct common_sampler * ctx_sampling_;
std::string image_;
std::string ret_str_;
int32_t n_past_;
int32_t dec_cnt_;
omni_streaming_sample() = delete;
omni_streaming_sample(const std::string& image)
:image_(image) {
n_past_ = 0;
dec_cnt_ = 0;
params.sparams.top_k = 1;
params.sparams.top_p = 1.0f;
ctx_sampling_ = common_sampler_init(model, params.sparams);
}
int32_t sample() {
const llama_token id = common_sampler_sample(ctx_sampling_, ctx_omnivlm->ctx_llama, -1);
common_sampler_accept(ctx_sampling_, id, true);
if (llama_token_is_eog(llama_get_model(ctx_omnivlm->ctx_llama), id)) {
ret_str_ = "</s>";
} else {
ret_str_ = common_token_to_piece(ctx_omnivlm->ctx_llama, id);
}
eval_id(ctx_omnivlm->ctx_llama, id, &n_past_);
++dec_cnt_;
return id;
}
~omni_streaming_sample() {
common_sampler_free(ctx_sampling_);
if(ctx_omnivlm != nullptr) {
ctx_omnivlm->model = nullptr;
omnivlm_free(ctx_omnivlm);
free(ctx_omnivlm);
ctx_omnivlm = nullptr;
}
}
};
static struct omni_image_embed * load_image(omnivlm_context * ctx_omnivlm, common_params * params, const std::string & fname) { static struct omni_image_embed * load_image(omnivlm_context * ctx_omnivlm, common_params * params, const std::string & fname) {
@ -286,3 +333,81 @@ void omnivlm_free() {
} }
llama_free_model(model); llama_free_model(model);
} }
struct omni_streaming_sample* omnivlm_inference_streaming(const char *prompt, const char *imag_path) {
if (g_oss) {
g_oss.reset();
}
g_oss = std::make_unique<omni_streaming_sample>(std::string(imag_path));
ctx_omnivlm = omnivlm_init_context(&params, model);
params.prompt = prompt;
if (params.omni_vlm_version == "vlm-81-ocr") {
params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n <|ocr_start|><|vision_start|><|image_pad|><|vision_end|><|ocr_end|><|im_end|>";
} else if (params.omni_vlm_version == "vlm-81-instruct" || params.omni_vlm_version == "nano-vlm-instruct") {
params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n\n<|vision_start|><|image_pad|><|vision_end|>" + params.prompt + "<|im_end|>";
} else {
LOG_ERR("%s : error: you set wrong vlm version info:'%s'.\n", __func__, params.omni_vlm_version.c_str());
throw std::runtime_error("You set wrong vlm_version info strings.");
}
return g_oss.get();
}
int32_t sample(omni_streaming_sample* oss) {
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
int32_t ret_id;
if(oss->n_past_ == 0) {
auto * image_embed = load_image(ctx_omnivlm, &params, oss->image_);
if (!image_embed) {
LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, oss->image_.c_str());
throw std::runtime_error("failed to load image " + oss->image_);
}
size_t image_pos = params.prompt.find("<|image_pad|>");
std::string system_prompt, user_prompt;
system_prompt = params.prompt.substr(0, image_pos);
user_prompt = params.prompt.substr(image_pos + std::string("<|image_pad|>").length());
if (params.verbose_prompt) {
auto tmp = ::common_tokenize(ctx_omnivlm->ctx_llama, system_prompt, true, true);
for (int i = 0; i < (int) tmp.size(); i++) {
LOG_ERR("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_omnivlm->ctx_llama, tmp[i]).c_str());
}
}
if (params.verbose_prompt) {
auto tmp = ::common_tokenize(ctx_omnivlm->ctx_llama, user_prompt, true, true);
for (int i = 0; i < (int) tmp.size(); i++) {
LOG_ERR("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_omnivlm->ctx_llama, tmp[i]).c_str());
}
}
eval_string(ctx_omnivlm->ctx_llama, system_prompt.c_str(), params.n_batch, &(oss->n_past_), true);
omnivlm_eval_image_embed(ctx_omnivlm->ctx_llama, image_embed, params.n_batch, &(oss->n_past_));
eval_string(ctx_omnivlm->ctx_llama, user_prompt.c_str(), params.n_batch, &(oss->n_past_), false);
omnivlm_image_embed_free(image_embed);
ret_id = oss->sample();
if (oss->ret_str_ == "<|im_end|>" || oss->ret_str_ == "</s>" ) {
ret_id = -1;
}
} else {
if(oss->dec_cnt_ == max_tgt_len) {
ret_id = -2;
} else {
ret_id = oss->sample();
if (oss->ret_str_ == "<|im_end|>" || oss->ret_str_ == "</s>" ) {
ret_id = -1;
}
}
}
return ret_id;
}
const char* get_str(omni_streaming_sample* oss) {
return oss->ret_str_.c_str();
}

View file

@ -1,6 +1,6 @@
#ifndef OMNIVLMWRAPPER_H #ifndef OMNIVLMWRAPPER_H
#define OMNIVLMWRAPPER_H #define OMNIVLMWRAPPER_H
#include <stdint.h>
#ifdef LLAMA_SHARED #ifdef LLAMA_SHARED
# if defined(_WIN32) && !defined(__MINGW32__) # if defined(_WIN32) && !defined(__MINGW32__)
@ -20,14 +20,22 @@
extern "C" { extern "C" {
#endif #endif
struct omni_streaming_sample;
OMNIVLM_API void omnivlm_init(const char* llm_model_path, const char* projector_model_path, const char* omni_vlm_version); OMNIVLM_API void omnivlm_init(const char* llm_model_path, const char* projector_model_path, const char* omni_vlm_version);
OMNIVLM_API const char* omnivlm_inference(const char* prompt, const char* imag_path); OMNIVLM_API const char* omnivlm_inference(const char* prompt, const char* imag_path);
OMNIVLM_API struct omni_streaming_sample* omnivlm_inference_streaming(const char* prompt, const char* imag_path);
OMNIVLM_API int32_t sample(struct omni_streaming_sample *);
OMNIVLM_API const char* get_str(struct omni_streaming_sample *);
OMNIVLM_API void omnivlm_free(); OMNIVLM_API void omnivlm_free();
#ifdef __cplusplus #ifdef __cplusplus
} }
#endif #endif
#endif #endif