reset model in every inerence step to avoid nosense output.

This commit is contained in:
李为 2024-11-11 19:41:26 +08:00
parent d04e354f2f
commit 7cf07df5e2
2 changed files with 12 additions and 4 deletions

View file

@ -12,6 +12,10 @@
#include <cstdlib>
#include <cstring>
#include <vector>
// #include <iostream>
//
// using std::cout;
// using std::endl;
static bool eval_tokens(struct llama_context * ctx_llama, std::vector<llama_token> tokens, int n_batch, int * n_past) {
int N = (int) tokens.size();
@ -283,9 +287,9 @@ int main(int argc, char ** argv) {
return 1;
}
auto * ctx_omnivlm = omnivlm_init_context(&params, model);
for (auto & image : params.image) {
auto * ctx_omnivlm = omnivlm_init_context(&params, model);
auto * image_embed = load_image(ctx_omnivlm, &params, image);
if (!image_embed) {
LOG_TEE("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str());
@ -296,9 +300,9 @@ int main(int argc, char ** argv) {
llama_print_timings(ctx_omnivlm->ctx_llama);
omnivlm_image_embed_free(image_embed);
}
ctx_omnivlm->model = NULL;
omnivlm_free(ctx_omnivlm);
}
llama_free_model(model);

View file

@ -244,10 +244,11 @@ void omnivlm_init(const char* llm_model_path, const char* projector_model_path,
fprintf(stderr, "%s: error: failed to init omnivlm model\n", __func__);
throw std::runtime_error("Failed to init omnivlm model");
}
ctx_omnivlm = omnivlm_init_context(&params, model);
}
const char* omnivlm_inference(const char *prompt, const char *imag_path) {
ctx_omnivlm = omnivlm_init_context(&params, model);
std::string image = imag_path;
params.prompt = prompt;
@ -270,6 +271,9 @@ const char* omnivlm_inference(const char *prompt, const char *imag_path) {
// llama_perf_print(ctx_omnivlm->ctx_llama, LLAMA_PERF_TYPE_CONTEXT);
omnivlm_image_embed_free(image_embed);
ctx_omnivlm->model = nullptr;
omnivlm_free(ctx_omnivlm);
ctx_omnivlm = nullptr;
return ret_chars;
}