This commit is contained in:
Damian Stewart 2023-10-14 17:20:21 +02:00
parent b9f533b997
commit f21af512cd
4 changed files with 16 additions and 14 deletions

View file

@ -13,7 +13,7 @@ static void show_additional_info(int /*argc*/, char ** argv) {
printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n"); printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
} }
static bool load_image(llava_context * ctx_llava, gpt_params * params, float **image_embd, int * n_image_pos) { static bool load_image(llava_context * ctx_llava, gpt_params * params, float **image_embd, int * n_img_pos) {
// load and preprocess the image // load and preprocess the image
clip_image_u8 img; clip_image_u8 img;
auto prompt = params->prompt; auto prompt = params->prompt;
@ -32,7 +32,7 @@ static bool load_image(llava_context * ctx_llava, gpt_params * params, float **i
return false; return false;
} }
} }
bool image_embed_result = llava_build_img_embed(ctx_llava->ctx_llama, ctx_llava->ctx_clip, params->n_threads, &img, image_embd, n_image_pos); bool image_embed_result = llava_build_img_embed(ctx_llava->ctx_llama, ctx_llava->ctx_clip, params->n_threads, &img, image_embd, n_img_pos);
if (!image_embed_result) { if (!image_embed_result) {
fprintf(stderr, "%s: coulnd't embed the image\n", __func__); fprintf(stderr, "%s: coulnd't embed the image\n", __func__);
return false; return false;
@ -49,6 +49,7 @@ static void process_prompt(struct llava_context * ctx_llava, float * image_embd,
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:" // llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
// GG: are we sure that the should be a trailing whitespace at the end of this string? // GG: are we sure that the should be a trailing whitespace at the end of this string?
eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past); eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past);
printf("embedding image, n_img_pos is %d\n", n_img_pos);
eval_image_embd(ctx_llava->ctx_llama, image_embd, n_img_pos, params->n_batch, &n_past); eval_image_embd(ctx_llava->ctx_llama, image_embd, n_img_pos, params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past); eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past);
eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past); eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past);

View file

@ -2,6 +2,7 @@
#define CLIP_H #define CLIP_H
#include "ggml.h" #include "ggml.h"
#include "llama.h"
struct clip_ctx; struct clip_ctx;
@ -57,8 +58,8 @@ struct clip_image_f32_batch {
struct clip_image_u8 * make_clip_image_u8(); struct clip_image_u8 * make_clip_image_u8();
struct clip_image_f32 * make_clip_image_f32(); struct clip_image_f32 * make_clip_image_f32();
bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img); LLAMA_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, clip_image_u8 * img); LLAMA_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, clip_image_u8 * img);
bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square); bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square);
bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec); bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec);

View file

@ -10,7 +10,7 @@
#include "base64.hpp" #include "base64.hpp"
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_embd, int * n_img_pos) { static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_image_embd, int * n_img_pos) {
clip_image_f32 img_res; clip_image_f32 img_res;
if (!clip_image_preprocess(ctx_clip, img, &img_res, /*pad2square =*/ true)) { if (!clip_image_preprocess(ctx_clip, img, &img_res, /*pad2square =*/ true)) {
fprintf(stderr, "%s: unable to preprocess image\n", __func__); fprintf(stderr, "%s: unable to preprocess image\n", __func__);
@ -19,7 +19,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
} }
*n_img_pos = clip_n_patches(ctx_clip); *n_img_pos = clip_n_patches(ctx_clip);
*n_img_embd = clip_n_mmproj_embd(ctx_clip); *n_image_embd = clip_n_mmproj_embd(ctx_clip);
const int64_t t_img_enc_start_us = ggml_time_us(); const int64_t t_img_enc_start_us = ggml_time_us();
if (!clip_image_encode(ctx_clip, n_threads, &img_res, image_embd)) { if (!clip_image_encode(ctx_clip, n_threads, &img_res, image_embd)) {
@ -37,7 +37,7 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
return true; return true;
} }
bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_image_pos_out) { bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)); float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
if (!image_embd) { if (!image_embd) {
@ -46,23 +46,23 @@ bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip,
return false; return false;
} }
int n_image_pos; int n_img_pos;
int n_img_embd; int n_image_embd;
if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_img_embd, &n_image_pos)) { if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_image_embd, &n_img_pos)) {
fprintf(stderr, "%s: cannot encode image, aborting\n", __func__); fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
free(image_embd); free(image_embd);
return false; return false;
} }
// make sure that the correct mmproj was used, i.e., compare apples to apples // make sure that the correct mmproj was used, i.e., compare apples to apples
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama)); int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
if (n_img_embd != n_llama_embd) { if (n_image_embd != n_llama_embd) {
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_img_embd, n_llama_embd); printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd);
free(image_embd); free(image_embd);
return false; return false;
} }
*image_embd_out = image_embd; *image_embd_out = image_embd;
*n_image_pos_out = n_image_pos; *n_img_pos_out = n_img_pos;
return true; return true;
} }

View file

@ -20,7 +20,7 @@ struct llava_context * llava_init(gpt_params * params);
void llava_free(struct llava_context * ctx_llava); void llava_free(struct llava_context * ctx_llava);
/** build a llava image embedding from the passed-in clip image `img`. result is returned as image_embd_out, size n_image_pos_out */ /** build a llava image embedding from the passed-in clip image `img`. result is returned as image_embd_out, size n_image_pos_out */
bool llava_build_img_embed(const struct llama_context * ctx_llama, struct clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_image_pos_out); LLAMA_API bool llava_build_img_embed(const struct llama_context * ctx_llama, struct clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_image_pos_out);
#ifdef __cplusplus #ifdef __cplusplus