From 3c10d9f3de3e79bdba18f3745e3cc56b0aa45e67 Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Fri, 13 Oct 2023 12:21:44 +0200 Subject: [PATCH] add external llava API --- examples/llava/llava.cpp | 121 ++++++++++++++++++++++++--------------- examples/llava/llava.h | 31 ++++++++++ 2 files changed, 106 insertions(+), 46 deletions(-) create mode 100644 examples/llava/llava.h diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index c55d4f165..22e625236 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -2,6 +2,7 @@ #include "llava-utils.h" #include "common.h" #include "llama.h" +#include "llava.h" #include #include @@ -34,27 +35,13 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli return true; } -int main(int argc, char ** argv) { - ggml_time_init(); +struct llava_context * llava_init(gpt_params * params) { - gpt_params params; + const char * clip_path = params->mmproj.c_str(); + const char * img_path = params->image.c_str(); - if (!gpt_params_parse(argc, argv, params)) { - show_additional_info(argc, argv); - return 1; - } - - if (params.mmproj.empty() || params.image.empty()) { - gpt_print_usage(argc, argv, params); - show_additional_info(argc, argv); - return 1; - } - - const char * clip_path = params.mmproj.c_str(); - const char * img_path = params.image.c_str(); - - if (params.prompt.empty()) { - params.prompt = "describe the image in detail."; + if (params->prompt.empty()) { + params->prompt = "describe the image in detail."; } auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); @@ -65,47 +52,48 @@ int main(int argc, char ** argv) { if (!clip_image_load_from_file(img_path, &img)) { fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path); clip_free(ctx_clip); - return 1; + return NULL; } float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)); if (!image_embd) { fprintf(stderr, "Unable to allocate memory for image embeddings\n"); - return 1; + return NULL; } int n_img_embd; int n_img_pos; float t_img_enc_ms; - if (!encode_image_with_clip(ctx_clip, params.n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) { + if (!encode_image_with_clip(ctx_clip, params->n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) { fprintf(stderr, "%s: cannot encode image, aborting\n", __func__); clip_free(ctx_clip); - return 1; + return NULL; } // we get the embeddings, free up the memory required for CLIP clip_free(ctx_clip); + ctx_clip = NULL; - llama_backend_init(params.numa); + llama_backend_init(params->numa); llama_model_params model_params = llama_model_default_params(); - llama_model * model = llama_load_model_from_file(params.model.c_str(), model_params); + llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); if (model == NULL) { fprintf(stderr , "%s: error: unable to load model\n" , __func__); - return 1; + return NULL; } llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = params.n_ctx < 2048 ? 2048 : params.n_ctx; // we need a longer context size to process image embeddings - ctx_params.n_threads = params.n_threads; - ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; + ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings + ctx_params.n_threads = params->n_threads; + ctx_params.n_threads_batch = params->n_threads_batch == -1 ? params->n_threads : params->n_threads_batch; llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); if (ctx_llama == NULL) { fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); - return 1; + return NULL; } // make sure that the correct mmproj was used, i.e., compare apples to apples @@ -118,28 +106,49 @@ int main(int argc, char ** argv) { llama_backend_free(); free(image_embd); - return 1; + return NULL; } - // process the prompt - // llava chat format is "USER: \n\nASSISTANT:" + { + printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / n_img_pos); + } + + auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context)); + + ctx_llava->ctx_llama = ctx_llama; + ctx_llava->ctx_clip = ctx_clip; + ctx_llava->model = model; + ctx_llava->image_embd = image_embd; + ctx_llava->n_img_pos = n_img_pos; + return ctx_llava; + +} + +void llava_free(struct llava_context * ctx_llava) { + llama_free(ctx_llava->ctx_llama); + llama_free_model(ctx_llava->model); + llama_backend_free(); + free(ctx_llava->image_embd); +} + +void llava_process_prompt(struct llava_context * ctx_llava, gpt_params * params, const char * prompt) { int n_past = 0; - const int max_tgt_len = params.n_predict < 0 ? 256 : params.n_predict; + const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; // GG: are we sure that the should be a trailing whitespace at the end of this string? - eval_string(ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params.n_batch, &n_past); - eval_image_embd(ctx_llama, image_embd, n_img_pos, params.n_batch, &n_past); - eval_string(ctx_llama, params.prompt.c_str(), params.n_batch, &n_past); - eval_string(ctx_llama, "\nASSISTANT:", params.n_batch, &n_past); + eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past); + eval_image_embd(ctx_llava->ctx_llama, ctx_llava->image_embd, ctx_llava->n_img_pos, params->n_batch, &n_past); + eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past); + eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past); // generate the response printf("\n"); for (int i = 0; i < max_tgt_len; i++) { - const char * tmp = sample(ctx_llama, params, &n_past); + const char * tmp = sample(ctx_llava->ctx_llama, *params, &n_past); if (strcmp(tmp, "") == 0) break; printf("%s", tmp); @@ -148,16 +157,36 @@ int main(int argc, char ** argv) { printf("\n"); - { - printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / n_img_pos); +} + + +int main(int argc, char ** argv) { + ggml_time_init(); + + gpt_params params; + + if (!gpt_params_parse(argc, argv, params)) { + show_additional_info(argc, argv); + return 1; + } + if (params.mmproj.empty() || params.image.empty()) { + gpt_print_usage(argc, argv, params); + show_additional_info(argc, argv); + return 1; } - llama_print_timings(ctx_llama); + auto ctx_llava = llava_init(¶ms); + if (ctx_llava == NULL) { + fprintf(stderr, "%s: error: failed to init llava\n", __func__); + return 1; + } - llama_free(ctx_llama); - llama_free_model(model); - llama_backend_free(); - free(image_embd); + // process the prompt + // llava chat format is "USER: \n\nASSISTANT:" + llava_process_prompt(ctx_llava, ¶ms, params.prompt.c_str()); + llama_print_timings(ctx_llava->ctx_llama); + + llava_free(ctx_llava); return 0; } diff --git a/examples/llava/llava.h b/examples/llava/llava.h new file mode 100644 index 000000000..4f229a08c --- /dev/null +++ b/examples/llava/llava.h @@ -0,0 +1,31 @@ +#ifndef LLAVA_H +#define LLAVA_H + +#include "ggml.h" + +struct clip_ctx; + +#ifdef __cplusplus +extern "C" { +#endif + +struct llava_context { + struct clip_ctx * ctx_clip = NULL; + struct llama_context * ctx_llama = NULL; + struct llama_model * model = NULL; + + int n_img_pos = 0; + float * image_embd = NULL; +}; + +struct llava_context * llava_init(gpt_params * params); +void llava_free(struct llava_context * ctx_llava); + +void llava_process_prompt(struct llava_context * ctx_llava, gpt_params * params, const char * prompt); + + +#ifdef __cplusplus +} +#endif + +#endif