wip llava python bindings compatibility

This commit is contained in:
Damian Stewart 2023-10-13 10:33:07 +02:00
parent 1e0e873c37
commit 0209d39526
5 changed files with 61 additions and 30 deletions

View file

@ -18,3 +18,6 @@ target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()
unset(TARGET)
llama_build_and_test_executable(test-llava.cpp)

View file

@ -682,25 +682,39 @@ clip_image_u8 * make_clip_image_u8() { return new clip_image_u8(); }
clip_image_f32 * make_clip_image_f32() { return new clip_image_f32(); }
bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
int nx, ny, nc;
auto data = stbi_load(fname, &nx, &ny, &nc, 3);
if (!data) {
fprintf(stderr, "%s: failed to load '%s'\n", __func__, fname);
return false;
}
static void build_clip_img_from_data(const stbi_uc * data, int nx, int ny, clip_image_u8 * img) {
img->nx = nx;
img->ny = ny;
img->size = nx * ny * 3;
img->data = new uint8_t[img->size]();
memcpy(img->data, data, img->size);
}
bool clip_image_load_from_bytes(const unsigned char * bytes, int bytes_length, clip_image_u8 * img) {
int nx, ny, nc;
auto data = stbi_load_from_memory(bytes, bytes_length, &nx, &ny, &nc, 3);
if (!data) {
fprintf(stderr, "%s: failed to decode image bytes\n", __func__);
return false;
}
build_clip_img_from_data(data, nx, ny, img);
stbi_image_free(data);
return true;
}
bool clip_image_load_from_file(const char * fname, clip_image_u8 * img) {
int nx, ny, nc;
auto data = stbi_load(fname, &nx, &ny, &nc, 3);
if (!data) {
fprintf(stderr, "%s: failed to load image '%s'\n", __func__, fname);
return false;
}
build_clip_img_from_data(data, nx, ny, img);
stbi_image_free(data);
return true;
}
// normalize: x = (x - mean) / std
// TODO: implement bicubic interpolation instead of linear.
bool clip_image_preprocess(const clip_ctx * ctx, const clip_image_u8 * img, clip_image_f32 * res, const bool pad2square) {

View file

@ -58,6 +58,7 @@ struct clip_image_f32_batch {
struct clip_image_u8 * make_clip_image_u8();
struct clip_image_f32 * make_clip_image_f32();
bool clip_image_load_from_file(const char * fname, struct clip_image_u8 * img);
bool clip_image_load_from_bytes(const unsigned char * bytes, int bytes_length, clip_image_u8 * img);
bool clip_image_preprocess(const struct clip_ctx * ctx, const struct clip_image_u8 * img, struct clip_image_f32 * res, const bool pad2square);
bool clip_image_encode(const struct clip_ctx * ctx, const int n_threads, struct clip_image_f32 * img, float * vec);

View file

@ -12,6 +12,28 @@ static void show_additional_info(int /*argc*/, char ** argv) {
printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n");
}
static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const clip_image_u8 * img, float * image_embd, int * n_img_embd, int * n_img_pos, float * t_img_enc_ms) {
clip_image_f32 img_res;
if (!clip_image_preprocess(ctx_clip, img, &img_res, /*pad2square =*/ true)) {
fprintf(stderr, "%s: unable to preprocess image\n", __func__);
return false;
}
*n_img_pos = clip_n_patches(ctx_clip);
*n_img_embd = clip_n_mmproj_embd(ctx_clip);
const int64_t t_img_enc_start_us = ggml_time_us();
if (!clip_image_encode(ctx_clip, n_threads, &img_res, image_embd)) {
fprintf(stderr, "Unable to encode image\n");
return false;
}
const int64_t t_img_enc_end_us = ggml_time_us();
*t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
return true;
}
int main(int argc, char ** argv) {
ggml_time_init();
@ -39,40 +61,27 @@ int main(int argc, char ** argv) {
// load and preprocess the image
clip_image_u8 img;
clip_image_f32 img_res;
if (!clip_image_load_from_file(img_path, &img)) {
fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path);
clip_free(ctx_clip);
return 1;
}
if (!clip_image_preprocess(ctx_clip, &img, &img_res, /*pad2square =*/ true)) {
fprintf(stderr, "%s: unable to preprocess %s\n", __func__, img_path);
clip_free(ctx_clip);
return 1;
}
int n_img_pos = clip_n_patches(ctx_clip);
int n_img_embd = clip_n_mmproj_embd(ctx_clip);
float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip));
if (!image_embd) {
fprintf(stderr, "Unable to allocate memory for image embeddings\n");
return 1;
}
const int64_t t_img_enc_start_us = ggml_time_us();
if (!clip_image_encode(ctx_clip, params.n_threads, &img_res, image_embd)) {
fprintf(stderr, "Unable to encode image\n");
int n_img_embd;
int n_img_pos;
float t_img_enc_ms;
if (!encode_image_with_clip(ctx_clip, params.n_threads, &img, image_embd, &n_img_embd, &n_img_pos, &t_img_enc_ms)) {
fprintf(stderr, "%s: cannot encode image, aborting\n", __func__);
clip_free(ctx_clip);
return 1;
}
const int64_t t_img_enc_end_us = ggml_time_us();
// we get the embeddings, free up the memory required for CLIP
clip_free(ctx_clip);
@ -140,8 +149,6 @@ int main(int argc, char ** argv) {
printf("\n");
{
const float t_img_enc_ms = (t_img_enc_end_us - t_img_enc_start_us) / 1000.0;
printf("\n%s: image encoded in %8.2f ms by CLIP (%8.2f ms per image patch)\n", __func__, t_img_enc_ms, t_img_enc_ms / n_img_pos);
}

View file

@ -0,0 +1,6 @@
#include <cstdio>
int main(int argc, char ** argv) {
printf("dummy llava test\n");
return 0;
}