Add streaming for omnivlm (#39)
* omni vlm add streaming * omni vlm add streaming
This commit is contained in:
parent
5962b506ba
commit
1487d32b46
2 changed files with 136 additions and 3 deletions
|
@ -15,10 +15,10 @@
|
|||
#include <vector>
|
||||
#include <string>
|
||||
#include <iostream>
|
||||
#include <memory>
|
||||
|
||||
#include "omni-vlm-wrapper.h"
|
||||
|
||||
|
||||
struct omnivlm_context {
|
||||
struct clip_ctx * ctx_clip = NULL;
|
||||
struct llama_context * ctx_llama = NULL;
|
||||
|
@ -30,6 +30,53 @@ void* internal_chars = nullptr;
|
|||
static struct common_params params;
|
||||
static struct llama_model* model;
|
||||
static struct omnivlm_context* ctx_omnivlm;
|
||||
static std::unique_ptr<struct omni_streaming_sample> g_oss = nullptr;
|
||||
|
||||
static bool eval_id(struct llama_context * ctx_llama, int id, int * n_past);
|
||||
static void omnivlm_free(struct omnivlm_context * ctx_omnivlm);
|
||||
|
||||
struct omni_streaming_sample {
|
||||
struct common_sampler * ctx_sampling_;
|
||||
std::string image_;
|
||||
std::string ret_str_;
|
||||
int32_t n_past_;
|
||||
int32_t dec_cnt_;
|
||||
|
||||
omni_streaming_sample() = delete;
|
||||
omni_streaming_sample(const std::string& image)
|
||||
:image_(image) {
|
||||
n_past_ = 0;
|
||||
dec_cnt_ = 0;
|
||||
params.sparams.top_k = 1;
|
||||
params.sparams.top_p = 1.0f;
|
||||
ctx_sampling_ = common_sampler_init(model, params.sparams);
|
||||
}
|
||||
|
||||
int32_t sample() {
|
||||
const llama_token id = common_sampler_sample(ctx_sampling_, ctx_omnivlm->ctx_llama, -1);
|
||||
common_sampler_accept(ctx_sampling_, id, true);
|
||||
if (llama_token_is_eog(llama_get_model(ctx_omnivlm->ctx_llama), id)) {
|
||||
ret_str_ = "</s>";
|
||||
} else {
|
||||
ret_str_ = common_token_to_piece(ctx_omnivlm->ctx_llama, id);
|
||||
}
|
||||
eval_id(ctx_omnivlm->ctx_llama, id, &n_past_);
|
||||
|
||||
++dec_cnt_;
|
||||
return id;
|
||||
}
|
||||
|
||||
~omni_streaming_sample() {
|
||||
common_sampler_free(ctx_sampling_);
|
||||
if(ctx_omnivlm != nullptr) {
|
||||
ctx_omnivlm->model = nullptr;
|
||||
omnivlm_free(ctx_omnivlm);
|
||||
free(ctx_omnivlm);
|
||||
ctx_omnivlm = nullptr;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
static struct omni_image_embed * load_image(omnivlm_context * ctx_omnivlm, common_params * params, const std::string & fname) {
|
||||
|
||||
|
@ -286,3 +333,81 @@ void omnivlm_free() {
|
|||
}
|
||||
llama_free_model(model);
|
||||
}
|
||||
|
||||
|
||||
struct omni_streaming_sample* omnivlm_inference_streaming(const char *prompt, const char *imag_path) {
|
||||
if (g_oss) {
|
||||
g_oss.reset();
|
||||
}
|
||||
g_oss = std::make_unique<omni_streaming_sample>(std::string(imag_path));
|
||||
|
||||
ctx_omnivlm = omnivlm_init_context(¶ms, model);
|
||||
|
||||
params.prompt = prompt;
|
||||
|
||||
if (params.omni_vlm_version == "vlm-81-ocr") {
|
||||
params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n <|ocr_start|><|vision_start|><|image_pad|><|vision_end|><|ocr_end|><|im_end|>";
|
||||
} else if (params.omni_vlm_version == "vlm-81-instruct" || params.omni_vlm_version == "nano-vlm-instruct") {
|
||||
params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n\n<|vision_start|><|image_pad|><|vision_end|>" + params.prompt + "<|im_end|>";
|
||||
} else {
|
||||
LOG_ERR("%s : error: you set wrong vlm version info:'%s'.\n", __func__, params.omni_vlm_version.c_str());
|
||||
throw std::runtime_error("You set wrong vlm_version info strings.");
|
||||
}
|
||||
|
||||
return g_oss.get();
|
||||
}
|
||||
|
||||
int32_t sample(omni_streaming_sample* oss) {
|
||||
const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict;
|
||||
int32_t ret_id;
|
||||
if(oss->n_past_ == 0) {
|
||||
auto * image_embed = load_image(ctx_omnivlm, ¶ms, oss->image_);
|
||||
if (!image_embed) {
|
||||
LOG_ERR("%s: failed to load image %s. Terminating\n\n", __func__, oss->image_.c_str());
|
||||
throw std::runtime_error("failed to load image " + oss->image_);
|
||||
}
|
||||
|
||||
size_t image_pos = params.prompt.find("<|image_pad|>");
|
||||
std::string system_prompt, user_prompt;
|
||||
|
||||
system_prompt = params.prompt.substr(0, image_pos);
|
||||
user_prompt = params.prompt.substr(image_pos + std::string("<|image_pad|>").length());
|
||||
if (params.verbose_prompt) {
|
||||
auto tmp = ::common_tokenize(ctx_omnivlm->ctx_llama, system_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
LOG_ERR("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_omnivlm->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
if (params.verbose_prompt) {
|
||||
auto tmp = ::common_tokenize(ctx_omnivlm->ctx_llama, user_prompt, true, true);
|
||||
for (int i = 0; i < (int) tmp.size(); i++) {
|
||||
LOG_ERR("%6d -> '%s'\n", tmp[i], common_token_to_piece(ctx_omnivlm->ctx_llama, tmp[i]).c_str());
|
||||
}
|
||||
}
|
||||
|
||||
eval_string(ctx_omnivlm->ctx_llama, system_prompt.c_str(), params.n_batch, &(oss->n_past_), true);
|
||||
omnivlm_eval_image_embed(ctx_omnivlm->ctx_llama, image_embed, params.n_batch, &(oss->n_past_));
|
||||
eval_string(ctx_omnivlm->ctx_llama, user_prompt.c_str(), params.n_batch, &(oss->n_past_), false);
|
||||
|
||||
omnivlm_image_embed_free(image_embed);
|
||||
|
||||
ret_id = oss->sample();
|
||||
if (oss->ret_str_ == "<|im_end|>" || oss->ret_str_ == "</s>" ) {
|
||||
ret_id = -1;
|
||||
}
|
||||
} else {
|
||||
if(oss->dec_cnt_ == max_tgt_len) {
|
||||
ret_id = -2;
|
||||
} else {
|
||||
ret_id = oss->sample();
|
||||
if (oss->ret_str_ == "<|im_end|>" || oss->ret_str_ == "</s>" ) {
|
||||
ret_id = -1;
|
||||
}
|
||||
}
|
||||
}
|
||||
return ret_id;
|
||||
}
|
||||
|
||||
const char* get_str(omni_streaming_sample* oss) {
|
||||
return oss->ret_str_.c_str();
|
||||
}
|
|
@ -1,6 +1,6 @@
|
|||
|
||||
#ifndef OMNIVLMWRAPPER_H
|
||||
#define OMNIVLMWRAPPER_H
|
||||
#include <stdint.h>
|
||||
|
||||
#ifdef LLAMA_SHARED
|
||||
# if defined(_WIN32) && !defined(__MINGW32__)
|
||||
|
@ -20,14 +20,22 @@
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
struct omni_streaming_sample;
|
||||
|
||||
OMNIVLM_API void omnivlm_init(const char* llm_model_path, const char* projector_model_path, const char* omni_vlm_version);
|
||||
|
||||
OMNIVLM_API const char* omnivlm_inference(const char* prompt, const char* imag_path);
|
||||
|
||||
OMNIVLM_API struct omni_streaming_sample* omnivlm_inference_streaming(const char* prompt, const char* imag_path);
|
||||
|
||||
OMNIVLM_API int32_t sample(struct omni_streaming_sample *);
|
||||
|
||||
OMNIVLM_API const char* get_str(struct omni_streaming_sample *);
|
||||
|
||||
OMNIVLM_API void omnivlm_free();
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
#endif
|
Loading…
Add table
Add a link
Reference in a new issue