refactor image load out of llava init
This commit is contained in:
parent
8224ca5775
commit
c6932085fe
3 changed files with 45 additions and 46 deletions
|
@ -5,6 +5,8 @@
|
||||||
#include "common.h"
|
#include "common.h"
|
||||||
#include "llama.h"
|
#include "llama.h"
|
||||||
|
|
||||||
|
#include "base64.hpp"
|
||||||
|
|
||||||
#include <cstdio>
|
#include <cstdio>
|
||||||
#include <cstdlib>
|
#include <cstdlib>
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
|
@ -15,7 +15,8 @@ static void show_additional_info(int /*argc*/, char ** argv) {
|
||||||
printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
|
printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_embd, int * n_img_pos, float * t_img_enc_ms) {
|
static bool encode_image_with_clip(llava_context * ctx_llava, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_embd, int * n_img_pos) {
|
||||||
|
auto ctx_clip = ctx_llava->ctx_clip;
|
||||||
clip_image_f32 img_res;
|
clip_image_f32 img_res;
|
||||||
if (!clip_image_preprocess(ctx_clip, img, &img_res, /*pad2square =*/ true)) {
|
if (!clip_image_preprocess(ctx_clip, img, &img_res, /*pad2square =*/ true)) {
|
||||||
fprintf(stderr, "%s: unable to preprocess image\n", __func__);
|
fprintf(stderr, "%s: unable to preprocess image\n", __func__);
|
||||||
|
@ -26,6 +27,14 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||||
*n_img_pos = clip_n_patches(ctx_clip);
|
*n_img_pos = clip_n_patches(ctx_clip);
|
||||||
*n_img_embd = clip_n_mmproj_embd(ctx_clip);
|
*n_img_embd = clip_n_mmproj_embd(ctx_clip);
|
||||||
|
|
||||||
|
// make sure that the correct mmproj was used, i.e., compare apples to apples
|
||||||
|
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llava->ctx_llama));
|
||||||
|
if (*n_img_embd != n_llama_embd) {
|
||||||
|
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, *n_img_embd, n_llama_embd);
|
||||||
|
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
const int64_t t_img_enc_start_us = ggml_time_us();
|
const int64_t t_img_enc_start_us = ggml_time_us();
|
||||||
if (!clip_image_encode(ctx_clip, n_threads, &img_res, image_embd)) {
|
if (!clip_image_encode(ctx_clip, n_threads, &img_res, image_embd)) {
|
||||||
fprintf(stderr, "Unable to encode image\n");
|
fprintf(stderr, "Unable to encode image\n");
|
||||||
|
@ -33,12 +42,18 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
const int64_t t_img_enc_end_us = ggml_time_us();
|
const int64_t t_img_enc_end_us = ggml_time_us();
|
||||||
*t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
|
float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
|
||||||
|
|
||||||
|
{
|
||||||
|
printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / *n_img_pos);
|
||||||
|
}
|
||||||
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
bool llava_build_img_embed(struct llava_context * ctx_llava, const clip_image_u8 * img) {
|
static bool llava_build_img_embed(struct llava_context * ctx_llava, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_image_pos_out) {
|
||||||
|
|
||||||
|
auto ctx_clip = ctx_llava->ctx_clip;
|
||||||
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
||||||
if (!image_embd) {
|
if (!image_embd) {
|
||||||
fprintf(stderr, "Unable to allocate memory for image embeddings\n");
|
fprintf(stderr, "Unable to allocate memory for image embeddings\n");
|
||||||
|
@ -46,24 +61,22 @@ bool llava_build_img_embed(struct llava_context * ctx_llava, const clip_image_u8
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int n_image_pos;
|
||||||
int n_img_embd;
|
int n_img_embd;
|
||||||
int n_img_pos;
|
if (!encode_image_with_clip(ctx_llava, n_threads, img, image_embd, &n_img_embd, &n_image_pos)) {
|
||||||
float t_img_enc_ms;
|
|
||||||
if (!encode_image_with_clip(ctx_clip, params->n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) {
|
|
||||||
fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
|
fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
|
||||||
free(image_embd);
|
free(image_embd);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
*image_embd_out = image_embd;
|
||||||
ctx_llava->image_embd = image_embd;
|
*n_image_pos_out = n_image_pos;
|
||||||
retur true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
struct llava_context * llava_init(gpt_params * params) {
|
struct llava_context * llava_init(gpt_params * params) {
|
||||||
|
|
||||||
const char * clip_path = params->mmproj.c_str();
|
const char * clip_path = params->mmproj.c_str();
|
||||||
const char * img_path = params->image.c_str();
|
|
||||||
|
|
||||||
auto prompt = params->prompt;
|
auto prompt = params->prompt;
|
||||||
if (prompt.empty()) {
|
if (prompt.empty()) {
|
||||||
|
@ -94,55 +107,36 @@ struct llava_context * llava_init(gpt_params * params) {
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
// make sure that the correct mmproj was used, i.e., compare apples to apples
|
|
||||||
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
|
|
||||||
if (n_img_embd != n_llama_embd) {
|
|
||||||
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_llama_embd);
|
|
||||||
|
|
||||||
llama_free(ctx_llama);
|
|
||||||
llama_free_model(model);
|
|
||||||
llama_backend_free();
|
|
||||||
free(image_embd);
|
|
||||||
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / n_img_pos);
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
|
auto ctx_llava = (struct llava_context *)malloc(sizeof(llava_context));
|
||||||
|
|
||||||
ctx_llava->ctx_llama = ctx_llama;
|
ctx_llava->ctx_llama = ctx_llama;
|
||||||
ctx_llava->ctx_clip = ctx_clip;
|
ctx_llava->ctx_clip = ctx_clip;
|
||||||
ctx_llava->model = model;
|
ctx_llava->model = model;
|
||||||
ctx_llava->image_embd = image_embd;
|
|
||||||
ctx_llava->n_img_pos = n_img_pos;
|
|
||||||
return ctx_llava;
|
return ctx_llava;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llava_free(struct llava_context * ctx_llava) {
|
void llava_free(struct llava_context * ctx_llava) {
|
||||||
if (ctx_llava->ctx_clip) {
|
if (ctx_llava->ctx_clip) {
|
||||||
clip_free(ctx_clip);
|
clip_free(ctx_llava->ctx_clip);
|
||||||
ctx_llava->ctx_clip = NULL;
|
ctx_llava->ctx_clip = NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
llama_free(ctx_llava->ctx_llama);
|
llama_free(ctx_llava->ctx_llama);
|
||||||
llama_free_model(ctx_llava->model);
|
llama_free_model(ctx_llava->model);
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
free(ctx_llava->image_embd);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void llava_process_prompt(struct llava_context * ctx_llava, gpt_params * params, const char * prompt) {
|
static void llava_process_prompt(struct llava_context * ctx_llava, float * image_embd, int n_img_pos, gpt_params * params, const char * prompt) {
|
||||||
int n_past = 0;
|
int n_past = 0;
|
||||||
|
|
||||||
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
|
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
|
||||||
|
|
||||||
|
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
|
||||||
// GG: are we sure that the should be a trailing whitespace at the end of this string?
|
// GG: are we sure that the should be a trailing whitespace at the end of this string?
|
||||||
eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past);
|
eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past);
|
||||||
eval_image_embd(ctx_llava->ctx_llama, ctx_llava->image_embd, ctx_llava->n_img_pos, params->n_batch, &n_past);
|
eval_image_embd(ctx_llava->ctx_llama, image_embd, n_img_pos, params->n_batch, &n_past);
|
||||||
eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past);
|
eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past);
|
||||||
eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past);
|
eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past);
|
||||||
|
|
||||||
|
@ -186,31 +180,34 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
// load and preprocess the image
|
// load and preprocess the image
|
||||||
clip_image_u8 img;
|
clip_image_u8 img;
|
||||||
|
auto prompt = params.prompt;
|
||||||
if (prompt_contains_image(prompt)) {
|
if (prompt_contains_image(prompt)) {
|
||||||
if (img_path) {
|
if (!params.image.empty()) {
|
||||||
printf("using base64 encoded image instead of command line image path\n");
|
printf("using base64 encoded image instead of command line image path\n");
|
||||||
}
|
}
|
||||||
if (!get_image_from_prompt(prompt, &img)) {
|
if (!get_image_from_prompt(prompt, &img)) {
|
||||||
fprintf(stderr, "%s: can't load image from prompt\n", __func__);
|
fprintf(stderr, "%s: can't load image from prompt\n", __func__);
|
||||||
clip_free(ctx_clip);
|
llava_free(ctx_llava);
|
||||||
return NULL;
|
return 1;
|
||||||
}
|
}
|
||||||
prompt = remove_image_from_prompt(prompt);
|
prompt = remove_image_from_prompt(prompt);
|
||||||
} else {
|
} else {
|
||||||
if (!clip_image_load_from_file(img_path, &img)) {
|
if (!clip_image_load_from_file(params.image.c_str(), &img)) {
|
||||||
fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path);
|
fprintf(stderr, "%s: is %s really an image file?\n", __func__, params.image.c_str());
|
||||||
clip_free(ctx_clip);
|
llava_free(ctx_llava);
|
||||||
return NULL;
|
return 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
llava_build_img_embed(ctx_llava, &img);
|
float * image_embd;
|
||||||
|
int n_image_pos;
|
||||||
|
llava_build_img_embed(ctx_llava, params.n_threads, &img, &image_embd, &n_image_pos);
|
||||||
|
|
||||||
// process the prompt
|
// process the prompt
|
||||||
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
|
llava_process_prompt(ctx_llava, image_embd, n_image_pos, ¶ms, params.prompt.c_str());
|
||||||
llava_process_prompt(ctx_llava, ¶ms, params.prompt.c_str());
|
|
||||||
|
|
||||||
llama_print_timings(ctx_llava->ctx_llama);
|
llama_print_timings(ctx_llava->ctx_llama);
|
||||||
|
|
||||||
|
free(image_embd);
|
||||||
llava_free(ctx_llava);
|
llava_free(ctx_llava);
|
||||||
return 0;
|
return 0;
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,14 +14,14 @@ struct llava_context {
|
||||||
struct llama_context * ctx_llama = NULL;
|
struct llama_context * ctx_llama = NULL;
|
||||||
struct llama_model * model = NULL;
|
struct llama_model * model = NULL;
|
||||||
|
|
||||||
int n_img_pos = 0;
|
// int n_img_pos = 0;
|
||||||
float * image_embd = NULL;
|
// float * image_embd = NULL;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llava_context * llava_init(gpt_params * params);
|
struct llava_context * llava_init(gpt_params * params);
|
||||||
void llava_free(struct llava_context * ctx_llava);
|
void llava_free(struct llava_context * ctx_llava);
|
||||||
|
|
||||||
void llava_process_prompt(struct llava_context * ctx_llava, gpt_params * params, const char * prompt);
|
//void llava_process_prompt(struct llava_context * ctx_llava, gpt_params * params, const char * prompt);
|
||||||
|
|
||||||
|
|
||||||
#ifdef __cplusplus
|
#ifdef __cplusplus
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue