This commit is contained in:
Damian Stewart 2023-10-14 12:39:00 +02:00
parent c6932085fe
commit 0889117573

View file

@ -128,7 +128,33 @@ void llava_free(struct llava_context * ctx_llava) {
llama_backend_free(); llama_backend_free();
} }
static void llava_process_prompt(struct llava_context * ctx_llava, float * image_embd, int n_img_pos, gpt_params * params, const char * prompt) {
static bool load_image(llava_context * ctx_llava, gpt_params * params, float **image_embd, int * n_image_pos) {
// load and preprocess the image
clip_image_u8 img;
auto prompt = params->prompt;
if (prompt_contains_image(prompt)) {
if (!params->image.empty()) {
printf("using base64 encoded image instead of command line image path\n");
}
if (!get_image_from_prompt(prompt, &img)) {
fprintf(stderr, "%s: can't load image from prompt\n", __func__);
return false;
}
prompt = remove_image_from_prompt(prompt);
} else {
if (!clip_image_load_from_file(params->image.c_str(), &img)) {
fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str());
return false;
}
}
llava_build_img_embed(ctx_llava, params->n_threads, &img, image_embd, n_image_pos);
return true;
}
static void process_prompt(struct llava_context * ctx_llava, float * image_embd, int n_img_pos, gpt_params * params, const char * prompt) {
int n_past = 0; int n_past = 0;
const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict;
@ -156,7 +182,6 @@ static void llava_process_prompt(struct llava_context * ctx_llava, float * image
} }
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
ggml_time_init(); ggml_time_init();
@ -178,32 +203,12 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
// load and preprocess the image
clip_image_u8 img;
auto prompt = params.prompt;
if (prompt_contains_image(prompt)) {
if (!params.image.empty()) {
printf("using base64 encoded image instead of command line image path\n");
}
if (!get_image_from_prompt(prompt, &img)) {
fprintf(stderr, "%s: can't load image from prompt\n", __func__);
llava_free(ctx_llava);
return 1;
}
prompt = remove_image_from_prompt(prompt);
} else {
if (!clip_image_load_from_file(params.image.c_str(), &img)) {
fprintf(stderr, "%s: is %s really an image file?\n", __func__, params.image.c_str());
llava_free(ctx_llava);
return 1;
}
}
float * image_embd; float * image_embd;
int n_image_pos; int n_image_pos;
llava_build_img_embed(ctx_llava, params.n_threads, &img, &image_embd, &n_image_pos); load_image(ctx_llava, &params, &image_embd, &n_image_pos);
// process the prompt // process the prompt
llava_process_prompt(ctx_llava, image_embd, n_image_pos, &params, params.prompt.c_str()); process_prompt(ctx_llava, image_embd, n_image_pos, &params, params.prompt.c_str());
llama_print_timings(ctx_llava->ctx_llama); llama_print_timings(ctx_llava->ctx_llama);