wip refactor image loading
This commit is contained in:
parent
770dc9da0d
commit
8224ca5775
2 changed files with 91 additions and 83 deletions
|
@ -143,3 +143,55 @@ inline const char * sample(struct llama_context * ctx_llama, gpt_params & params
|
||||||
eval_id(ctx_llama, id, n_past);
|
eval_id(ctx_llama, id, n_past);
|
||||||
return ret.c_str();
|
return ret.c_str();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static const char* IMG_BASE64_TAG_BEGIN = "<img src=\"data:image/jpeg;base64,";
|
||||||
|
static const char* IMG_BASE64_TAG_END = "\">";
|
||||||
|
|
||||||
|
static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) {
|
||||||
|
begin_out = prompt.find(IMG_BASE64_TAG_BEGIN);
|
||||||
|
end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out);
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool prompt_contains_image(const std::string& prompt) {
|
||||||
|
size_t begin, end;
|
||||||
|
find_image_tag_in_prompt(prompt, begin, end);
|
||||||
|
return (begin != std::string::npos);
|
||||||
|
}
|
||||||
|
|
||||||
|
// replaces the base64 image tag in the prompt with `replacement`
|
||||||
|
static bool get_image_from_prompt(const std::string& prompt, clip_image_u8 * img) {
|
||||||
|
size_t img_base64_str_start, img_base64_str_end;
|
||||||
|
find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
|
||||||
|
if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
|
||||||
|
fprintf(stderr, "%s: invalid base64 image tag. must be %s<base64 byte string>%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
|
||||||
|
auto base64_bytes_count = img_base64_str_end - base64_bytes_start;
|
||||||
|
auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count );
|
||||||
|
|
||||||
|
auto required_bytes = base64::required_encode_size(base64_str.size());
|
||||||
|
auto img_bytes = std::vector<unsigned char>(required_bytes);
|
||||||
|
auto img_bytes_end = base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
|
||||||
|
auto img_bytes_len = img_bytes_end - img_bytes.begin();
|
||||||
|
|
||||||
|
auto img_loaded_ok = clip_image_load_from_bytes(img_bytes.data(), img_bytes_len, img);
|
||||||
|
if (!img_loaded_ok) {
|
||||||
|
fprintf(stderr, "%s: could not load image from base64 string.\n", __func__);
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
|
||||||
|
size_t begin, end;
|
||||||
|
find_image_tag_in_prompt(prompt, begin, end);
|
||||||
|
if (begin == std::string::npos || end == std::string::npos) {
|
||||||
|
return prompt;
|
||||||
|
}
|
||||||
|
auto pre = prompt.substr(0, begin);
|
||||||
|
auto post = prompt.substr(end+1);
|
||||||
|
return pre + replacement + post;
|
||||||
|
}
|
||||||
|
|
|
@ -37,58 +37,28 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static const char* IMG_BASE64_TAG_BEGIN = "<img src=\"data:image/jpeg;base64,";
|
bool llava_build_img_embed(struct llava_context * ctx_llava, const clip_image_u8 * img) {
|
||||||
static const char* IMG_BASE64_TAG_END = "\">";
|
|
||||||
|
|
||||||
static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) {
|
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
||||||
begin_out = prompt.find(IMG_BASE64_TAG_BEGIN);
|
if (!image_embd) {
|
||||||
end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out);
|
fprintf(stderr, "Unable to allocate memory for image embeddings\n");
|
||||||
}
|
free(image_embd);
|
||||||
|
|
||||||
static bool prompt_contains_image(const std::string& prompt) {
|
|
||||||
size_t begin, end;
|
|
||||||
find_image_tag_in_prompt(prompt, begin, end);
|
|
||||||
return (begin != std::string::npos);
|
|
||||||
}
|
|
||||||
|
|
||||||
// replaces the base64 image tag in the prompt with `replacement`
|
|
||||||
static bool get_image_from_prompt(const std::string& prompt, clip_image_u8 * img) {
|
|
||||||
size_t img_base64_str_start, img_base64_str_end;
|
|
||||||
find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end);
|
|
||||||
if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) {
|
|
||||||
fprintf(stderr, "%s: invalid base64 image tag. must be %s<base64 byte string>%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END);
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN);
|
int n_img_embd;
|
||||||
auto base64_bytes_count = img_base64_str_end - base64_bytes_start;
|
int n_img_pos;
|
||||||
auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count );
|
float t_img_enc_ms;
|
||||||
printf("base64_str: '%s'\n", base64_str.c_str());
|
if (!encode_image_with_clip(ctx_clip, params->n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) {
|
||||||
|
fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
|
||||||
auto required_bytes = base64::required_encode_size(base64_str.size());
|
free(image_embd);
|
||||||
auto img_bytes = std::vector<unsigned char>(required_bytes);
|
|
||||||
auto img_bytes_end = base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin());
|
|
||||||
auto img_bytes_len = img_bytes_end - img_bytes.begin();
|
|
||||||
|
|
||||||
auto img_loaded_ok = clip_image_load_from_bytes(img_bytes.data(), img_bytes_len, img);
|
|
||||||
if (!img_loaded_ok) {
|
|
||||||
fprintf(stderr, "%s: could not load image from base64 string.\n", __func__);
|
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
return true;
|
ctx_llava->image_embd = image_embd;
|
||||||
|
retur true;
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") {
|
|
||||||
size_t begin, end;
|
|
||||||
find_image_tag_in_prompt(prompt, begin, end);
|
|
||||||
if (begin == std::string::npos || end == std::string::npos) {
|
|
||||||
return prompt;
|
|
||||||
}
|
|
||||||
auto pre = prompt.substr(0, begin);
|
|
||||||
auto post = prompt.substr(end+1);
|
|
||||||
return pre + replacement + post;
|
|
||||||
}
|
|
||||||
|
|
||||||
struct llava_context * llava_init(gpt_params * params) {
|
struct llava_context * llava_init(gpt_params * params) {
|
||||||
|
|
||||||
|
@ -102,46 +72,6 @@ struct llava_context * llava_init(gpt_params * params) {
|
||||||
|
|
||||||
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1);
|
||||||
|
|
||||||
// load and preprocess the image
|
|
||||||
clip_image_u8 img;
|
|
||||||
|
|
||||||
if (prompt_contains_image(prompt)) {
|
|
||||||
if (img_path) {
|
|
||||||
printf("using base64 encoded image instead of command line image path\n");
|
|
||||||
}
|
|
||||||
if (!get_image_from_prompt(prompt, &img)) {
|
|
||||||
fprintf(stderr, "%s: can't load image from prompt\n", __func__);
|
|
||||||
clip_free(ctx_clip);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
prompt = remove_image_from_prompt(prompt);
|
|
||||||
} else {
|
|
||||||
if (!clip_image_load_from_file(img_path, &img)) {
|
|
||||||
fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path);
|
|
||||||
clip_free(ctx_clip);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
|
|
||||||
if (!image_embd) {
|
|
||||||
fprintf(stderr, "Unable to allocate memory for image embeddings\n");
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
int n_img_embd;
|
|
||||||
int n_img_pos;
|
|
||||||
float t_img_enc_ms;
|
|
||||||
if (!encode_image_with_clip(ctx_clip, params->n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) {
|
|
||||||
fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
|
|
||||||
clip_free(ctx_clip);
|
|
||||||
return NULL;
|
|
||||||
}
|
|
||||||
|
|
||||||
// we get the embeddings, free up the memory required for CLIP
|
|
||||||
clip_free(ctx_clip);
|
|
||||||
ctx_clip = NULL;
|
|
||||||
|
|
||||||
llama_backend_init(params->numa);
|
llama_backend_init(params->numa);
|
||||||
|
|
||||||
llama_model_params model_params = llama_model_default_params();
|
llama_model_params model_params = llama_model_default_params();
|
||||||
|
@ -194,6 +124,11 @@ struct llava_context * llava_init(gpt_params * params) {
|
||||||
}
|
}
|
||||||
|
|
||||||
void llava_free(struct llava_context * ctx_llava) {
|
void llava_free(struct llava_context * ctx_llava) {
|
||||||
|
if (ctx_llava->ctx_clip) {
|
||||||
|
clip_free(ctx_clip);
|
||||||
|
ctx_llava->ctx_clip = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
llama_free(ctx_llava->ctx_llama);
|
llama_free(ctx_llava->ctx_llama);
|
||||||
llama_free_model(ctx_llava->model);
|
llama_free_model(ctx_llava->model);
|
||||||
llama_backend_free();
|
llama_backend_free();
|
||||||
|
@ -249,6 +184,27 @@ int main(int argc, char ** argv) {
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// load and preprocess the image
|
||||||
|
clip_image_u8 img;
|
||||||
|
if (prompt_contains_image(prompt)) {
|
||||||
|
if (img_path) {
|
||||||
|
printf("using base64 encoded image instead of command line image path\n");
|
||||||
|
}
|
||||||
|
if (!get_image_from_prompt(prompt, &img)) {
|
||||||
|
fprintf(stderr, "%s: can't load image from prompt\n", __func__);
|
||||||
|
clip_free(ctx_clip);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
prompt = remove_image_from_prompt(prompt);
|
||||||
|
} else {
|
||||||
|
if (!clip_image_load_from_file(img_path, &img)) {
|
||||||
|
fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path);
|
||||||
|
clip_free(ctx_clip);
|
||||||
|
return NULL;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
llava_build_img_embed(ctx_llava, &img);
|
||||||
|
|
||||||
// process the prompt
|
// process the prompt
|
||||||
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
|
// llava chat format is "<system_prompt>USER: <image_embeddings>\n<textual_prompt>\nASSISTANT:"
|
||||||
llava_process_prompt(ctx_llava, ¶ms, params.prompt.c_str());
|
llava_process_prompt(ctx_llava, ¶ms, params.prompt.c_str());
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue