cleanup and refactor *again*
This commit is contained in:
parent
e3261ffad3
commit
d64891b6cf
6 changed files with 155 additions and 70 deletions
|
@ -20,46 +20,47 @@ static void show_additional_info(int /*argc*/, char ** argv) {
|
|||
printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
|
||||
}
|
||||
|
||||
static bool load_image(llava_context * ctx_llava, gpt_params * params, float **image_embd, int * n_img_pos) {
|
||||
static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params) {
|
||||
|
||||
// load and preprocess the image
|
||||
clip_image_u8 * img = make_clip_image_u8();
|
||||
llava_image_embed * embed = NULL;
|
||||
auto prompt = params->prompt;
|
||||
if (prompt_contains_image(prompt)) {
|
||||
if (!params->image.empty()) {
|
||||
printf("using base64 encoded image instead of command line image path\n");
|
||||
}
|
||||
if (!clip_image_load_from_prompt(prompt, img)) {
|
||||
embed = llava_image_embed_make_with_prompt_base64(ctx_llava->ctx_clip, params->n_threads, prompt);
|
||||
if (!embed) {
|
||||
fprintf(stderr, "%s: can't load image from prompt\n", __func__);
|
||||
return false;
|
||||
return NULL;
|
||||
}
|
||||
params->prompt = remove_image_from_prompt(prompt);
|
||||
} else {
|
||||
if (!clip_image_load_from_file(params->image.c_str(), img)) {
|
||||
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, params->image.c_str());
|
||||
if (!embed) {
|
||||
fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str());
|
||||
return false;
|
||||
return NULL;
|
||||
}
|
||||
}
|
||||
bool image_embed_result = llava_build_img_embed(ctx_llava->ctx_llama, ctx_llava->ctx_clip, params->n_threads, img, image_embd, n_img_pos);
|
||||
if (!image_embed_result) {
|
||||
clip_image_u8_free(img);
|
||||
fprintf(stderr, "%s: coulnd't embed the image\n", __func__);
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
return embed;
|
||||
}
|
||||
|
||||
static void process_prompt(struct llava_context * ctx_llava, float * image_embd, int n_img_pos, gpt_params * params, const char * prompt) {
|
||||
static void process_prompt(struct llava_context * ctx_llava, struct llava_image_embed * image_embed, gpt_params * params, const char * prompt) {
|
||||
int n_past = 0;
|
||||
|
||||
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
|
||||
|
||||
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
|
||||
// GG: are we sure that the should be a trailing whitespace at the end of this string?
|
||||
printf("evaluating system prompt\n");
|
||||
eval_string(ctx_llava->ctx_llama, "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: ", params->n_batch, &n_past);
|
||||
llava_eval_image_embd(ctx_llava->ctx_llama, image_embd, n_img_pos, params->n_batch, &n_past);
|
||||
printf("evaluating image embed\n");
|
||||
llava_eval_image_embed(ctx_llava->ctx_llama, image_embed, params->n_batch, &n_past);
|
||||
printf("evaluating prompt\n");
|
||||
eval_string(ctx_llava->ctx_llama, prompt, params->n_batch, &n_past);
|
||||
eval_string(ctx_llava->ctx_llama, "\nASSISTANT:", params->n_batch, &n_past);
|
||||
printf("awaiting response\n");
|
||||
|
||||
// generate the response
|
||||
|
||||
|
@ -153,16 +154,14 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
float * image_embd;
|
||||
int n_image_pos;
|
||||
load_image(ctx_llava, ¶ms, &image_embd, &n_image_pos);
|
||||
auto image_embed = load_image(ctx_llava, ¶ms);
|
||||
|
||||
// process the prompt
|
||||
process_prompt(ctx_llava, image_embd, n_image_pos, ¶ms, params.prompt.c_str());
|
||||
process_prompt(ctx_llava, image_embed, ¶ms, params.prompt.c_str());
|
||||
|
||||
llama_print_timings(ctx_llava->ctx_llama);
|
||||
|
||||
free(image_embd);
|
||||
llava_image_embed_free(image_embed);
|
||||
llava_free(ctx_llava);
|
||||
return 0;
|
||||
}
|
||||
|
|
|
@ -678,7 +678,10 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) {
|
|||
return new_clip;
|
||||
}
|
||||
|
||||
clip_image_u8 * make_clip_image_u8() { return new clip_image_u8(); }
|
||||
clip_image_u8 * make_clip_image_u8() {
|
||||
auto img = new clip_image_u8();
|
||||
return img;
|
||||
}
|
||||
clip_image_f32 * make_clip_image_f32() { return new clip_image_f32(); }
|
||||
|
||||
void clip_image_u8_free(clip_image_u8 * img) { if (img->data) { delete[] img->data; } delete img; }
|
||||
|
@ -692,18 +695,6 @@ static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_
|
|||
memcpy(img->data, data, img->size);
|
||||
}
|
||||
|
||||
bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, clip_image_u8 * img) {
|
||||
int nx, ny, nc;
|
||||
auto data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3);
|
||||
if (!data) {
|
||||
fprintf(stderr, "%s: failed to decode image bytes\n", __func__);
|
||||
return false;
|
||||
}
|
||||
build_clip_img_from_data(data, nx, ny, img);
|
||||
stbi_image_free(data);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
|
||||
int nx, ny, nc;
|
||||
auto data = stbi_load(fname, &nx, &ny, &nc, 3);
|
||||
|
@ -716,6 +707,17 @@ bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img) {
|
||||
int nx, ny, nc;
|
||||
auto data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3);
|
||||
if (!data) {
|
||||
fprintf(stderr, "%s: failed to decode image bytes\n", __func__);
|
||||
return false;
|
||||
}
|
||||
build_clip_img_from_data(data, nx, ny, img);
|
||||
stbi_image_free(data);
|
||||
return true;
|
||||
}
|
||||
|
||||
// normalize: x = (x - mean) / std
|
||||
// TODO: implement bicubic interpolation instead of linear.
|
||||
|
@ -1065,16 +1067,16 @@ bool clip_model_quantize(const char * fname_inp, const char * fname_out, const i
|
|||
return true;
|
||||
}
|
||||
|
||||
int clip_n_mmproj_embd(struct clip_ctx * ctx) {
|
||||
int clip_n_mmproj_embd(const struct clip_ctx * ctx) {
|
||||
return ctx->vision_model.mm_2_b->ne[0];
|
||||
}
|
||||
|
||||
int clip_n_patches(struct clip_ctx * ctx) {
|
||||
int clip_n_patches(const struct clip_ctx * ctx) {
|
||||
auto & params = ctx->vision_model.hparams;
|
||||
|
||||
return (params.image_size / params.patch_size) * (params.image_size / params.patch_size);
|
||||
}
|
||||
|
||||
size_t clip_embd_nbytes(struct clip_ctx * ctx) {
|
||||
size_t clip_embd_nbytes(const struct clip_ctx * ctx) {
|
||||
return clip_n_patches(ctx) * clip_n_mmproj_embd(ctx) * sizeof(float);
|
||||
}
|
||||
|
|
|
@ -25,9 +25,9 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity);
|
|||
|
||||
void clip_free(struct clip_ctx * ctx);
|
||||
|
||||
size_t clip_embd_nbytes(struct clip_ctx * ctx);
|
||||
int clip_n_patches(struct clip_ctx * ctx);
|
||||
int clip_n_mmproj_embd(struct clip_ctx * ctx);
|
||||
size_t clip_embd_nbytes(const struct clip_ctx * ctx);
|
||||
int clip_n_patches(const struct clip_ctx * ctx);
|
||||
int clip_n_mmproj_embd(const struct clip_ctx * ctx);
|
||||
|
||||
// RGB uint8 image
|
||||
struct clip_image_u8 {
|
||||
|
@ -62,7 +62,7 @@ LLAMA_API void clip_image_u8_free(clip_image_u8 * img);
|
|||
LLAMA_API void clip_image_f32_free(clip_image_f32 * img);
|
||||
LLAMA_API bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
|
||||
/** interpret bytes as an image file with length bytes_length, and use the result to populate img */
|
||||
LLAMA_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, clip_image_u8 * img);
|
||||
LLAMA_API bool clip_image_load_from_bytes(const unsigned char * bytes, size_t bytes_length, struct clip_image_u8 * img);
|
||||
|
||||
bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square);
|
||||
bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec);
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
#include "common.h"
|
||||
#include "llama.h"
|
||||
#include "llava.h"
|
||||
|
||||
#include "base64.hpp"
|
||||
|
||||
|
@ -143,12 +144,12 @@ inline bool prompt_contains_image(const std::string& prompt) {
|
|||
}
|
||||
|
||||
// replaces the base64 image tag in the prompt with `replacement`
|
||||
inline bool clip_image_load_from_prompt(const std::string& prompt, clip_image_u8 * img) {
|
||||
inline llava_image_embed * llava_image_embed_make_with_prompt_base64(struct clip_ctx * ctx_clip, int n_threads, const std::string& prompt) {
|
||||
size_t img_base64_str_start, img_base64_str_end;
|
||||
find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
|
||||
if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
|
||||
fprintf(stderr, "%s: invalid base64 image tag. must be %s<base64 byte string>%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
|
||||
return false;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
|
||||
|
@ -157,16 +158,15 @@ inline bool clip_image_load_from_prompt(const std::string& prompt, clip_image_u8
|
|||
|
||||
auto required_bytes = base64::required_encode_size(base64_str.size());
|
||||
auto img_bytes = std::vector<unsigned char>(required_bytes);
|
||||
auto img_bytes_end = base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
|
||||
size_t img_bytes_len = img_bytes_end - img_bytes.begin();
|
||||
base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
|
||||
|
||||
auto img_loaded_ok = clip_image_load_from_bytes(img_bytes.data(), img_bytes_len, img);
|
||||
if (!img_loaded_ok) {
|
||||
auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, img_bytes.data(), img_bytes.size());
|
||||
if (!embed) {
|
||||
fprintf(stderr, "%s: could not load image from base64 string.\n", __func__);
|
||||
return false;
|
||||
return NULL;
|
||||
}
|
||||
|
||||
return true;
|
||||
return embed;
|
||||
}
|
||||
|
||||
inline std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
|
||||
|
|
111
llava/llava.cpp
111
llava/llava.cpp
|
@ -10,7 +10,7 @@
|
|||
|
||||
#include "base64.hpp"
|
||||
|
||||
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_image_embd, int * n_img_pos) {
|
||||
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_pos) {
|
||||
clip_image_f32 * img_res = make_clip_image_f32();
|
||||
if (!clip_image_preprocess(ctx_clip, img, img_res, /*pad2square =*/ true)) {
|
||||
fprintf(stderr, "%s: unable to preprocess image\n", __func__);
|
||||
|
@ -19,7 +19,6 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
|||
}
|
||||
|
||||
*n_img_pos = clip_n_patches(ctx_clip);
|
||||
*n_image_embd = clip_n_mmproj_embd(ctx_clip);
|
||||
|
||||
const int64_t t_img_enc_start_us = ggml_time_us();
|
||||
bool encoded = clip_image_encode(ctx_clip, n_threads, img_res, image_embd);
|
||||
|
@ -39,7 +38,18 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
|||
return true;
|
||||
}
|
||||
|
||||
bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
|
||||
bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip) {
|
||||
// make sure that the correct mmproj was used, i.e., compare apples to apples
|
||||
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||
auto n_image_embd = clip_n_mmproj_embd(ctx_clip);
|
||||
if (n_image_embd != n_llama_embd) {
|
||||
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd);
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool llava_image_embed_make_with_clip_img(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_img_pos_out) {
|
||||
|
||||
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
||||
if (!image_embd) {
|
||||
|
@ -49,20 +59,11 @@ bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip,
|
|||
}
|
||||
|
||||
int n_img_pos;
|
||||
int n_image_embd;
|
||||
if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_image_embd, &n_img_pos)) {
|
||||
if (!encode_image_with_clip(ctx_clip, n_threads, img, image_embd, &n_img_pos)) {
|
||||
fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
|
||||
free(image_embd);
|
||||
return false;
|
||||
}
|
||||
// make sure that the correct mmproj was used, i.e., compare apples to apples
|
||||
int n_llama_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||
if (n_image_embd != n_llama_embd) {
|
||||
printf("%s: embedding dim of the multimodal projector (%d) is not equal to that of LLaMA (%d). Make sure that you use the correct mmproj file.\n", __func__, n_image_embd, n_llama_embd);
|
||||
free(image_embd);
|
||||
return false;
|
||||
}
|
||||
|
||||
*image_embd_out = image_embd;
|
||||
*n_img_pos_out = n_img_pos;
|
||||
|
||||
|
@ -71,15 +72,15 @@ bool llava_build_img_embed(const llama_context * ctx_llama, clip_ctx * ctx_clip,
|
|||
|
||||
|
||||
|
||||
bool llava_eval_image_embd(llama_context * ctx_llama, float * image_embd, int n_image_pos, int n_batch, int * n_past) {
|
||||
bool llava_eval_image_embed(llama_context * ctx_llama, const struct llava_image_embed * image_embed, int n_batch, int * n_past) {
|
||||
int n_embd = llama_n_embd(llama_get_model(ctx_llama));
|
||||
|
||||
for (int i = 0; i < n_image_pos; i += n_batch) {
|
||||
int n_eval = n_image_pos - i;
|
||||
for (int i = 0; i < image_embed->n_image_pos; i += n_batch) {
|
||||
int n_eval = image_embed->n_image_pos - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch batch = {int32_t(n_eval), nullptr, (image_embd+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
||||
llama_batch batch = {int32_t(n_eval), nullptr, (image_embed->embed+i*n_embd), nullptr, nullptr, nullptr, *n_past, 1, 0, };
|
||||
if (llama_decode(ctx_llama, batch)) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
|
@ -88,3 +89,79 @@ bool llava_eval_image_embd(llama_context * ctx_llama, float * image_embd, int n_
|
|||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
LLAMA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length)
|
||||
{
|
||||
clip_image_u8 * img = make_clip_image_u8();
|
||||
if (!clip_image_load_from_bytes(image_bytes, image_bytes_length, img)) {
|
||||
clip_image_u8_free(img);
|
||||
fprintf(stderr, "%s: can't load image from bytes, is it a valid image?", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
float* image_embed = NULL;
|
||||
int n_image_pos = 0;
|
||||
bool image_embed_result = llava_image_embed_make_with_clip_img(ctx_clip, n_threads, img, &image_embed, &n_image_pos);
|
||||
if (!image_embed_result) {
|
||||
clip_image_u8_free(img);
|
||||
fprintf(stderr, "%s: coulnd't embed the image\n", __func__);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
clip_image_u8_free(img);
|
||||
auto result = (llava_image_embed*)malloc(sizeof(llava_image_embed));
|
||||
result->embed = image_embed;
|
||||
result->n_image_pos = n_image_pos;
|
||||
return result;
|
||||
}
|
||||
|
||||
static bool load_file_to_bytes(const char* path, unsigned char** bytesOut, long *sizeOut)
|
||||
{
|
||||
auto file = fopen(path, "rb");
|
||||
if (file == NULL) {
|
||||
fprintf(stderr, "%s: can't read file %s\n", __func__, path);
|
||||
return false;
|
||||
}
|
||||
|
||||
fseek(file, 0, SEEK_END);
|
||||
auto fileSize = ftell(file);
|
||||
fseek(file, 0, SEEK_SET);
|
||||
|
||||
auto buffer = (unsigned char *)malloc(fileSize); // Allocate memory to hold the file data
|
||||
if (buffer == NULL) {
|
||||
fprintf(stderr, "%s: failed to alloc %ld bytes for file %s\n", __func__, fileSize, path);
|
||||
perror("Memory allocation error");
|
||||
fclose(file);
|
||||
return false;
|
||||
}
|
||||
fread(buffer, 1, fileSize, file); // Read the file into the buffer
|
||||
fclose(file); // Close the file
|
||||
|
||||
*bytesOut = buffer;
|
||||
*sizeOut = fileSize;
|
||||
return true;
|
||||
|
||||
}
|
||||
|
||||
LLAMA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path)
|
||||
{
|
||||
unsigned char* image_bytes;
|
||||
long image_bytes_length;
|
||||
auto loaded = load_file_to_bytes(image_path, &image_bytes, &image_bytes_length);
|
||||
if (!loaded) {
|
||||
fprintf(stderr, "%s: failed to load %s\n", __func__, image_path);
|
||||
return NULL;
|
||||
}
|
||||
|
||||
auto embed = llava_image_embed_make_with_bytes(ctx_clip, n_threads, image_bytes, image_bytes_length);
|
||||
free(image_bytes);
|
||||
|
||||
return embed;
|
||||
}
|
||||
|
||||
|
||||
LLAMA_API void llava_image_embed_free(struct llava_image_embed * embed) {
|
||||
free(embed->embed);
|
||||
free(embed);
|
||||
}
|
||||
|
|
|
@ -10,13 +10,20 @@ struct clip_ctx;
|
|||
extern "C" {
|
||||
#endif
|
||||
|
||||
/** using ctx_clip, build a llava image embedding from the passed-in image `img` (see clip.h for methods to load img).
|
||||
* result is returned as image_embd_out, size n_image_pos_out */
|
||||
LLAMA_API bool llava_build_img_embed(const struct llama_context * ctx_llama, struct clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float ** image_embd_out, int * n_image_pos_out);
|
||||
struct llava_image_embed {
|
||||
float * embed;
|
||||
int n_image_pos;
|
||||
};
|
||||
|
||||
/** write the image represented by image_embd (size n_image_pos) into the llama context with batch size n_batch,
|
||||
LLAMA_API bool llava_validate_embed_size(const llama_context * ctx_llama, const clip_ctx * ctx_clip);
|
||||
|
||||
LLAMA_API struct llava_image_embed * llava_image_embed_make_with_bytes(struct clip_ctx * ctx_clip, int n_threads, const unsigned char * image_bytes, int image_bytes_length);
|
||||
LLAMA_API struct llava_image_embed * llava_image_embed_make_with_filename(struct clip_ctx * ctx_clip, int n_threads, const char * image_path);
|
||||
LLAMA_API void llava_image_embed_free(struct llava_image_embed * embed);
|
||||
|
||||
/** write the image represented by embed into the llama context with batch size n_batch,
|
||||
* starting at context pos n_past. on completion, n_past points to the next position in the context after the image embed. */
|
||||
LLAMA_API bool llava_eval_image_embd(struct llama_context * ctx_llama, float * image_embd, int n_image_pos, int n_batch, int * n_past);
|
||||
LLAMA_API bool llava_eval_image_embed(struct llama_context * ctx_llama, const struct llava_image_embed * embed, int n_batch, int * n_past);
|
||||
|
||||
|
||||
#ifdef __cplusplus
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue