Update llava-cli.cpp

Change functions to use image as a vector datatype
This commit is contained in:
cpumaxx 2024-04-09 21:07:37 -07:00 committed by GitHub
parent 124e259dc6
commit 8c0a5b0e1e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194

View file

@ -116,7 +116,7 @@ static void show_additional_info(int /*argc*/, char ** argv) {
fprintf(stderr, " note: a lower temperature value like 0.1 is recommended for better quality.\n");
}
static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params) {
static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_params * params, std::string * image) {
// load and preprocess the image
llava_image_embed * embed = NULL;
@ -132,9 +132,9 @@ static struct llava_image_embed * load_image(llava_context * ctx_llava, gpt_para
}
params->prompt = remove_image_from_prompt(prompt);
} else {
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, params->image.c_str());
embed = llava_image_embed_make_with_filename(ctx_llava->ctx_clip, params->n_threads, image.c_str());
if (!embed) {
fprintf(stderr, "%s: is %s really an image file?\n", __func__, params->image.c_str());
fprintf(stderr, "%s: is %s really an image file?\n", __func__, image.c_str());
return NULL;
}
}
@ -281,24 +281,13 @@ int main(int argc, char ** argv) {
return 1;
}
std::stringstream ss(params.image);
std::vector<std::string> imagestack;
while( ss.good() )
{
std::string substr;
getline( ss, substr, ',' );
imagestack.push_back( substr );
}
for (auto & image : imagestack) {
for (auto & image : image) {
auto ctx_llava = llava_init_context(&params, model);
params.image=image;
auto image_embed = load_image(ctx_llava, &params);
auto image_embed = load_image(ctx_llava, &params, &image);
if (!image_embed) {
std::cerr << "error: failed to load image " << params.image << ". Terminating\n\n";
std::cerr << "error: failed to load image " << image << ". Terminating\n\n";
return 1;
}