Merge pull request #10 from NexaAI/weili/master-release

add returned string (pure c const char* type) for omni-vlm inference api
This commit is contained in:
Zack Li 2024-11-05 19:41:03 -08:00 committed by GitHub
commit 5574bda471
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 23 additions and 9 deletions

View file

@ -24,6 +24,8 @@ struct omnivlm_context {
struct llama_model * model = NULL;
};
void* internal_chars = nullptr;
static struct gpt_params params;
static struct llama_model* model;
static struct omnivlm_context* ctx_omnivlm;
@ -128,7 +130,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
return ret.c_str();
}
static void process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_image_embed * image_embed, gpt_params * params, const std::string & prompt) {
static const char* process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_image_embed * image_embed, gpt_params * params, const std::string & prompt) {
int n_past = 0;
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
@ -172,11 +174,11 @@ static void process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_ima
std::string response = "";
for (int i = 0; i < max_tgt_len; i++) {
const char * tmp = sample(ctx_sampling, ctx_omnivlm->ctx_llama, &n_past);
response += tmp;
if (strcmp(tmp, "<|im_end|>") == 0) break;
if (strcmp(tmp, "</s>") == 0) break;
// if (strstr(tmp, "###")) break; // Yi-VL behavior
printf("%s", tmp);
// printf("%s", tmp);
response += tmp;
// if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
// if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6
// if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6
@ -186,6 +188,13 @@ static void process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_ima
llama_sampling_free(ctx_sampling);
printf("\n");
// const char* ret_char_ptr = (const char*)(malloc(sizeof(char)*response.size()));
if(internal_chars != nullptr) { free(internal_chars); }
internal_chars = malloc(sizeof(char)*(response.size()+1));
strncpy((char*)(internal_chars), response.c_str(), response.size());
((char*)(internal_chars))[response.size()] = '\0';
return (const char*)(internal_chars);
}
static void omnivlm_free(struct omnivlm_context * ctx_omnivlm) {
@ -225,7 +234,7 @@ void omnivlm_init(const char* llm_model_path, const char* projector_model_path)
ctx_omnivlm = omnivlm_init_context(&params, model);
}
void omnivlm_inference(const char *prompt, const char *imag_path) {
const char* omnivlm_inference(const char *prompt, const char *imag_path) {
std::string image = imag_path;
params.prompt = prompt;
auto * image_embed = load_image(ctx_omnivlm, &params, image);
@ -234,13 +243,16 @@ void omnivlm_inference(const char *prompt, const char *imag_path) {
throw std::runtime_error("failed to load image " + image);
}
// process the prompt
process_prompt(ctx_omnivlm, image_embed, &params, params.prompt);
const char* ret_chars = process_prompt(ctx_omnivlm, image_embed, &params, params.prompt);
// llama_perf_print(ctx_omnivlm->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
omnivlm_image_embed_free(image_embed);
return ret_chars;
}
void omnivlm_free() {
if(internal_chars != nullptr) { free(internal_chars); }
ctx_omnivlm->model = NULL;
omnivlm_free(ctx_omnivlm);
llama_free_model(model);

View file

@ -22,7 +22,7 @@ extern "C" {
OMNIVLM_API void omnivlm_init(const char* llm_model_path, const char* projector_model_path);
OMNIVLM_API void omnivlm_inference(const char* prompt, const char* imag_path);
OMNIVLM_API const char* omnivlm_inference(const char* prompt, const char* imag_path);
OMNIVLM_API void omnivlm_free();

View file

@ -73,7 +73,7 @@ def omnivlm_inference(prompt: omni_char_p, image_path: omni_char_p):
_lib.omnivlm_inference.argtypes = [omni_char_p, omni_char_p]
_lib.omnivlm_inference.restype = None
_lib.omnivlm_inference.restype = omni_char_p
def omnivlm_free():

View file

@ -20,7 +20,7 @@ class NexaOmniVlmInference:
def inference(self, prompt: str, image_path: str):
prompt = ctypes.c_char_p(prompt.encode("utf-8"))
image_path = ctypes.c_char_p(image_path.encode("utf-8"))
omni_vlm_cpp.omnivlm_inference(prompt, image_path)
return omni_vlm_cpp.omnivlm_inference(prompt, image_path)
def __del__(self):
omni_vlm_cpp.omnivlm_free()
@ -52,4 +52,6 @@ if __name__ == "__main__":
while not os.path.exists(image_path):
print("ERROR: can not find image in your input path, please check and input agian.")
image_path = input()
omni_vlm_obj.inference(prompt, image_path)
response = omni_vlm_obj.inference(prompt, image_path)
print("\tresponse:")
print(response.decode('utf-8'))