From 7cf07df5e20c12624376656ce81c06b621cbb3a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E4=B8=BA?= Date: Mon, 11 Nov 2024 19:41:26 +0800 Subject: [PATCH] reset model in every inerence step to avoid nosense output. --- examples/omni-vlm/omni-vlm-cli.cpp | 10 +++++++--- examples/omni-vlm/omni-vlm-wrapper.cpp | 6 +++++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/examples/omni-vlm/omni-vlm-cli.cpp b/examples/omni-vlm/omni-vlm-cli.cpp index 167816360..40e65b339 100644 --- a/examples/omni-vlm/omni-vlm-cli.cpp +++ b/examples/omni-vlm/omni-vlm-cli.cpp @@ -12,6 +12,10 @@ #include #include #include +// #include +// +// using std::cout; +// using std::endl; static bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past) { int N = (int) tokens.size(); @@ -283,9 +287,9 @@ int main(int argc, char ** argv) { return 1; } - auto * ctx_omnivlm = omnivlm_init_context(¶ms, model); for (auto & image : params.image) { + auto * ctx_omnivlm = omnivlm_init_context(¶ms, model); auto * image_embed = load_image(ctx_omnivlm, ¶ms, image); if (!image_embed) { LOG_TEE("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str()); @@ -296,9 +300,9 @@ int main(int argc, char ** argv) { llama_print_timings(ctx_omnivlm->ctx_llama); omnivlm_image_embed_free(image_embed); + ctx_omnivlm->model = NULL; + omnivlm_free(ctx_omnivlm); } - ctx_omnivlm->model = NULL; - omnivlm_free(ctx_omnivlm); llama_free_model(model); diff --git a/examples/omni-vlm/omni-vlm-wrapper.cpp b/examples/omni-vlm/omni-vlm-wrapper.cpp index 9087fabb9..9a74e73d1 100644 --- a/examples/omni-vlm/omni-vlm-wrapper.cpp +++ b/examples/omni-vlm/omni-vlm-wrapper.cpp @@ -244,10 +244,11 @@ void omnivlm_init(const char* llm_model_path, const char* projector_model_path, fprintf(stderr, "%s: error: failed to init omnivlm model\n", __func__); throw std::runtime_error("Failed to init omnivlm model"); } - ctx_omnivlm = omnivlm_init_context(¶ms, model); } const char* omnivlm_inference(const char *prompt, const char *imag_path) { + ctx_omnivlm = omnivlm_init_context(¶ms, model); + std::string image = imag_path; params.prompt = prompt; @@ -270,6 +271,9 @@ const char* omnivlm_inference(const char *prompt, const char *imag_path) { // llama_perf_print(ctx_omnivlm->ctx_llama, LLAMA_PERF_TYPE_CONTEXT); omnivlm_image_embed_free(image_embed); + ctx_omnivlm->model = nullptr; + omnivlm_free(ctx_omnivlm); + ctx_omnivlm = nullptr; return ret_chars; }