Merge pull request #10 from NexaAI/weili/master-release
add returned string (pure c const char* type) for omni-vlm inference api
This commit is contained in:
commit
5574bda471
4 changed files with 23 additions and 9 deletions
|
@ -24,6 +24,8 @@ struct omnivlm_context {
|
|||
struct llama_model * model = NULL;
|
||||
};
|
||||
|
||||
void* internal_chars = nullptr;
|
||||
|
||||
static struct gpt_params params;
|
||||
static struct llama_model* model;
|
||||
static struct omnivlm_context* ctx_omnivlm;
|
||||
|
@ -128,7 +130,7 @@ static const char * sample(struct llama_sampling_context * ctx_sampling,
|
|||
return ret.c_str();
|
||||
}
|
||||
|
||||
static void process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_image_embed * image_embed, gpt_params * params, const std::string & prompt) {
|
||||
static const char* process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_image_embed * image_embed, gpt_params * params, const std::string & prompt) {
|
||||
int n_past = 0;
|
||||
|
||||
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
|
||||
|
@ -172,11 +174,11 @@ static void process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_ima
|
|||
std::string response = "";
|
||||
for (int i = 0; i < max_tgt_len; i++) {
|
||||
const char * tmp = sample(ctx_sampling, ctx_omnivlm->ctx_llama, &n_past);
|
||||
response += tmp;
|
||||
if (strcmp(tmp, "<|im_end|>") == 0) break;
|
||||
if (strcmp(tmp, "</s>") == 0) break;
|
||||
// if (strstr(tmp, "###")) break; // Yi-VL behavior
|
||||
printf("%s", tmp);
|
||||
// printf("%s", tmp);
|
||||
response += tmp;
|
||||
// if (strstr(response.c_str(), "<|im_end|>")) break; // Yi-34B llava-1.6 - for some reason those decode not as the correct token (tokenizer works)
|
||||
// if (strstr(response.c_str(), "<|im_start|>")) break; // Yi-34B llava-1.6
|
||||
// if (strstr(response.c_str(), "USER:")) break; // mistral llava-1.6
|
||||
|
@ -186,6 +188,13 @@ static void process_prompt(struct omnivlm_context * ctx_omnivlm, struct omni_ima
|
|||
|
||||
llama_sampling_free(ctx_sampling);
|
||||
printf("\n");
|
||||
|
||||
// const char* ret_char_ptr = (const char*)(malloc(sizeof(char)*response.size()));
|
||||
if(internal_chars != nullptr) { free(internal_chars); }
|
||||
internal_chars = malloc(sizeof(char)*(response.size()+1));
|
||||
strncpy((char*)(internal_chars), response.c_str(), response.size());
|
||||
((char*)(internal_chars))[response.size()] = '\0';
|
||||
return (const char*)(internal_chars);
|
||||
}
|
||||
|
||||
static void omnivlm_free(struct omnivlm_context * ctx_omnivlm) {
|
||||
|
@ -225,7 +234,7 @@ void omnivlm_init(const char* llm_model_path, const char* projector_model_path)
|
|||
ctx_omnivlm = omnivlm_init_context(¶ms, model);
|
||||
}
|
||||
|
||||
void omnivlm_inference(const char *prompt, const char *imag_path) {
|
||||
const char* omnivlm_inference(const char *prompt, const char *imag_path) {
|
||||
std::string image = imag_path;
|
||||
params.prompt = prompt;
|
||||
auto * image_embed = load_image(ctx_omnivlm, ¶ms, image);
|
||||
|
@ -234,13 +243,16 @@ void omnivlm_inference(const char *prompt, const char *imag_path) {
|
|||
throw std::runtime_error("failed to load image " + image);
|
||||
}
|
||||
// process the prompt
|
||||
process_prompt(ctx_omnivlm, image_embed, ¶ms, params.prompt);
|
||||
const char* ret_chars = process_prompt(ctx_omnivlm, image_embed, ¶ms, params.prompt);
|
||||
|
||||
// llama_perf_print(ctx_omnivlm->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
|
||||
omnivlm_image_embed_free(image_embed);
|
||||
|
||||
return ret_chars;
|
||||
}
|
||||
|
||||
void omnivlm_free() {
|
||||
if(internal_chars != nullptr) { free(internal_chars); }
|
||||
ctx_omnivlm->model = NULL;
|
||||
omnivlm_free(ctx_omnivlm);
|
||||
llama_free_model(model);
|
||||
|
|
|
@ -22,7 +22,7 @@ extern "C" {
|
|||
|
||||
OMNIVLM_API void omnivlm_init(const char* llm_model_path, const char* projector_model_path);
|
||||
|
||||
OMNIVLM_API void omnivlm_inference(const char* prompt, const char* imag_path);
|
||||
OMNIVLM_API const char* omnivlm_inference(const char* prompt, const char* imag_path);
|
||||
|
||||
OMNIVLM_API void omnivlm_free();
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ def omnivlm_inference(prompt: omni_char_p, image_path: omni_char_p):
|
|||
|
||||
|
||||
_lib.omnivlm_inference.argtypes = [omni_char_p, omni_char_p]
|
||||
_lib.omnivlm_inference.restype = None
|
||||
_lib.omnivlm_inference.restype = omni_char_p
|
||||
|
||||
|
||||
def omnivlm_free():
|
||||
|
|
|
@ -20,7 +20,7 @@ class NexaOmniVlmInference:
|
|||
def inference(self, prompt: str, image_path: str):
|
||||
prompt = ctypes.c_char_p(prompt.encode("utf-8"))
|
||||
image_path = ctypes.c_char_p(image_path.encode("utf-8"))
|
||||
omni_vlm_cpp.omnivlm_inference(prompt, image_path)
|
||||
return omni_vlm_cpp.omnivlm_inference(prompt, image_path)
|
||||
|
||||
def __del__(self):
|
||||
omni_vlm_cpp.omnivlm_free()
|
||||
|
@ -52,4 +52,6 @@ if __name__ == "__main__":
|
|||
while not os.path.exists(image_path):
|
||||
print("ERROR: can not find image in your input path, please check and input agian.")
|
||||
image_path = input()
|
||||
omni_vlm_obj.inference(prompt, image_path)
|
||||
response = omni_vlm_obj.inference(prompt, image_path)
|
||||
print("\tresponse:")
|
||||
print(response.decode('utf-8'))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue