diff --git a/examples/llava/llava-cli.cpp b/examples/llava/llava-cli.cpp index 5c61c67fd..603b552ca 100644 --- a/examples/llava/llava-cli.cpp +++ b/examples/llava/llava-cli.cpp @@ -7,6 +7,13 @@ #include "llava.h" #include "llava-utils.h" +struct llava_context { + struct clip_ctx * ctx_clip = NULL; + struct llama_context * ctx_llama = NULL; + struct llama_model * model = NULL; +}; + + static void show_additional_info(int /*argc*/, char ** argv) { printf("\n example usage: %s -m --mmproj --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); @@ -25,7 +32,7 @@ static bool load_image(llava_context * ctx_llava, gpt_params * params, float **i fprintf(stderr, "%s: can't load image from prompt\n", __func__); return false; } - prompt = remove_image_from_prompt(prompt); + params->prompt = remove_image_from_prompt(prompt); } else { if (!clip_image_load_from_file(params->image.c_str(), &img)) { fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str()); @@ -49,8 +56,7 @@ static void process_prompt(struct llava_context * ctx_llava, float * image_embd, // llava chat format is "USER: \n\nASSISTANT:" // GG: are we sure that the should be a trailing whitespace at the end of this string? eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past); - printf("embedding image, n_img_pos is %d\n", n_img_pos); - eval_image_embd(ctx_llava->ctx_llama, image_embd, n_img_pos, params->n_batch, &n_past); + llava_eval_image_embd(ctx_llava->ctx_llama, image_embd, n_img_pos, params->n_batch, &n_past); eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past); eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past); @@ -70,6 +76,61 @@ static void process_prompt(struct llava_context * ctx_llava, float * image_embd, } + +static struct llava_context * llava_init(gpt_params * params) { + + const char * clip_path = params->mmproj.c_str(); + + auto prompt = params->prompt; + if (prompt.empty()) { + prompt = "describe the image in detail."; + } + + auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); + + llama_backend_init(params->numa); + + llama_model_params model_params = llama_model_default_params(); + llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); + if (model == NULL) { + fprintf(stderr , "%s: error: unable to load model\n" , __func__); + return NULL; + } + + llama_context_params ctx_params = llama_context_default_params(); + + ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings + ctx_params.n_threads = params->n_threads; + ctx_params.n_threads_batch = params->n_threads_batch == -1 ? params->n_threads : params->n_threads_batch; + + llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); + + if (ctx_llama == NULL) { + fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); + return NULL; + } + + auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context)); + + ctx_llava->ctx_llama = ctx_llama; + ctx_llava->ctx_clip = ctx_clip; + ctx_llava->model = model; + return ctx_llava; +} + + +static void llava_free(struct llava_context * ctx_llava) { + if (ctx_llava->ctx_clip) { + clip_free(ctx_llava->ctx_clip); + ctx_llava->ctx_clip = NULL; + } + + llama_free(ctx_llava->ctx_llama); + llama_free_model(ctx_llava->model); + llama_backend_free(); +} + + int main(int argc, char ** argv) { ggml_time_init(); diff --git a/llava/llava-utils.h b/llava/llava-utils.h index 3b4fa96cc..53beefd26 100644 --- a/llava/llava-utils.h +++ b/llava/llava-utils.h @@ -11,24 +11,6 @@ #include #include -inline bool eval_image_embd(llama_context * ctx_llama, float * embd, int N, int n_batch, int * n_past) { - int n_embd = llama_n_embd(llama_get_model(ctx_llama)); - - for (int i = 0; i < N; i += n_batch) { - int n_eval = N - i; - if (n_eval > n_batch) { - n_eval = n_batch; - } - llama_batch batch = {int32_t(n_eval), nullptr, (embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, }; - if (llama_decode(ctx_llama, batch)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return false; - } - *n_past += n_eval; - } - return true; -} - inline bool eval_tokens(struct llama_context * ctx_llama, std::vector tokens, int n_batch, int * n_past) { int N = (int) tokens.size(); for (int i = 0; i < N; i += n_batch) { @@ -37,7 +19,7 @@ inline bool eval_tokens(struct llama_context * ctx_llama, std::vectormmproj.c_str(); +bool llava_eval_image_embd(llama_context * ctx_llama, float * image_embd, int n_image_pos, int n_batch, int * n_past) { + int n_embd = llama_n_embd(llama_get_model(ctx_llama)); - auto prompt = params->prompt; - if (prompt.empty()) { - prompt = "describe the image in detail."; + for (int i = 0; i < n_image_pos; i += n_batch) { + int n_eval = n_image_pos - i; + if (n_eval > n_batch) { + n_eval = n_batch; + } + llama_batch batch = {int32_t(n_eval), nullptr, (image_embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, }; + if (llama_decode(ctx_llama, batch)) { + fprintf(stderr, "%s : failed to eval\n", __func__); + return false; + } + *n_past += n_eval; } - - auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); - - llama_backend_init(params->numa); - - llama_model_params model_params = llama_model_default_params(); - llama_model * model = llama_load_model_from_file(params->model.c_str(), model_params); - if (model == NULL) { - fprintf(stderr , "%s: error: unable to load model\n" , __func__); - return NULL; - } - - llama_context_params ctx_params = llama_context_default_params(); - - ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings - ctx_params.n_threads = params->n_threads; - ctx_params.n_threads_batch = params->n_threads_batch == -1 ? params->n_threads : params->n_threads_batch; - - llama_context * ctx_llama = llama_new_context_with_model(model, ctx_params); - - if (ctx_llama == NULL) { - fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); - return NULL; - } - - auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context)); - - ctx_llava->ctx_llama = ctx_llama; - ctx_llava->ctx_clip = ctx_clip; - ctx_llava->model = model; - return ctx_llava; + return true; } - - -void llava_free(struct llava_context * ctx_llava) { - if (ctx_llava->ctx_clip) { - clip_free(ctx_llava->ctx_clip); - ctx_llava->ctx_clip = NULL; - } - - llama_free(ctx_llava->ctx_llama); - llama_free_model(ctx_llava->model); - llama_backend_free(); -} - diff --git a/llava/llava.h b/llava/llava.h index ba67103a4..de3875e03 100644 --- a/llava/llava.h +++ b/llava/llava.h @@ -10,18 +10,14 @@ struct clip_ctx; extern "C" { #endif -struct llava_context { - struct clip_ctx * ctx_clip = NULL; - struct llama_context * ctx_llama = NULL; - struct llama_model * model = NULL; -}; - -struct llava_context * llava_init(gpt_params * params); -void llava_free(struct llava_context * ctx_llava); - -/** build a llava image embedding from the passed-in clip image `img`. result is returned as image_embd_out, size n_image_pos_out */ +/** using ctx_clip, build a llava image embedding from the passed-in image `img` (see clip.h for methods to load img). + * result is returned as image_embd_out, size n_image_pos_out */ LLAMA_API bool llava_build_img_embed(const struct llama_context * ctx_llama, struct clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_image_pos_out); +/** write the image represented by image_embd (size n_image_pos) into the llama context with batch size n_batch, + * starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */ +LLAMA_API bool llava_eval_image_embd(struct llama_context * ctx_llama, float * image_embd, int n_image_pos, int n_batch, int * n_past); + #ifdef __cplusplus }