From 16c22471e88a9c8bd2049be890642ba496ee496f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9D=8E=E4=B8=BA?= Date: Fri, 8 Nov 2024 20:59:23 +0800 Subject: [PATCH] remove redundant omni-vlm-v2/ folder, all omni-vlm examples will be added to omni-vlm/ folder. --- common/common.cpp | 6 + common/common.h | 2 + examples/CMakeLists.txt | 1 - examples/omni-vlm/clip.cpp | 133 ++++----------------- examples/omni-vlm/clip.h | 3 + examples/omni-vlm/latex.png | Bin 0 -> 5931 bytes examples/omni-vlm/omni-vlm-cli.cpp | 28 +++-- examples/omni-vlm/omni-vlm-wrapper-cli.cpp | 21 +++- examples/omni-vlm/omni-vlm-wrapper.cpp | 43 +++++-- examples/omni-vlm/omni-vlm-wrapper.h | 2 +- examples/omni-vlm/omni_vlm_cpp.py | 6 +- examples/omni-vlm/omni_vlm_demo.py | 23 ++-- 12 files changed, 124 insertions(+), 144 deletions(-) create mode 100644 examples/omni-vlm/latex.png diff --git a/common/common.cpp b/common/common.cpp index 715adf946..e85c498c9 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1442,6 +1442,11 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa // End of Parse args for logging parameters #endif // LOG_DISABLE_LOGS + if (arg == "--omni-vlm-version") { + CHECK_ARG + params.omni_vlm_version = argv[i]; + return true; + } return false; } @@ -1688,6 +1693,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "layer range to apply the control vector(s) to, start and end inclusive" }); options.push_back({ "*", "-m, --model FNAME", "model path (default: models/$filename with filename from --hf-file\n" "or --model-url if set, otherwise %s)", DEFAULT_MODEL_PATH }); + options.push_back({ "*", " --omni-vlm-version VERSION_STRING", "omni vlm string version(one of 'vlm-81-ocr', 'vlm-81-instruct', 'nano-vlm-instruct')\n" "(default: 'vlm-81-ocr')"}); options.push_back({ "*", "-md, --model-draft FNAME", "draft model for speculative decoding (default: unused)" }); options.push_back({ "*", "-mu, --model-url MODEL_URL", "model download url (default: unused)" }); options.push_back({ "*", "-hfr, --hf-repo REPO", "Hugging Face model repository (default: unused)" }); diff --git a/common/common.h b/common/common.h index f603ba2be..73dab55ca 100644 --- a/common/common.h +++ b/common/common.h @@ -265,6 +265,8 @@ struct gpt_params { bool spm_infill = false; // suffix/prefix/middle pattern for infill std::string lora_outfile = "ggml-lora-merged-f16.gguf"; + + std::string omni_vlm_version = "vlm-81-ocr"; }; void gpt_params_parse_from_env(gpt_params & params); diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 784d6e9de..922c293ef 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -53,7 +53,6 @@ else() # add_subdirectory(speculative) # add_subdirectory(tokenize) add_subdirectory(omni-vlm) - add_subdirectory(omni-vlm-v2) add_subdirectory(nexa-omni-audio) add_subdirectory(qwen2-audio) endif() diff --git a/examples/omni-vlm/clip.cpp b/examples/omni-vlm/clip.cpp index 45764f9f3..e8a04ad3a 100644 --- a/examples/omni-vlm/clip.cpp +++ b/examples/omni-vlm/clip.cpp @@ -6,6 +6,7 @@ #include "ggml.h" #include "ggml-alloc.h" #include "ggml-backend.h" +#include "common.h" #ifdef GGML_USE_CUDA #include "ggml-cuda.h" @@ -167,7 +168,11 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_RESAMPLER, "resampler"}, }; - +enum omni_vlm_version_type { + VLM_81_OCR, + VLM_81_INSTRUCT, + NANO_VLM_INSTRUCT, +}; // // utilities to get data from a gguf file // @@ -294,115 +299,6 @@ static projector_type clip_projector_type_from_string(const std::string & name) return PROJECTOR_TYPE_UNKNOWN; } -#ifdef CLIP_DEBUG_FUNCTIONS -static void clip_image_write_image_to_ppm(const clip_image_u8& img, const std::string& filename) { - std::ofstream file(filename, std::ios::binary); - if (!file.is_open()) { - LOG_ERR("Failed to open file for writing: %s\n", filename.c_str()); - return; - } - - // PPM header: P6 format, width, height, and max color value - file << "P6\n" << img.nx << " " << img.ny << "\n255\n"; - - // Write pixel data - for (size_t i = 0; i < img.buf.size(); i += 3) { - // PPM expects binary data in RGB format, which matches our image buffer - file.write(reinterpret_cast(&img.buf[i]), 3); - } - - file.close(); -} - -static void clip_image_save_to_bmp(const clip_image_u8& img, const std::string& filename) { - std::ofstream file(filename, std::ios::binary); - if (!file.is_open()) { - LOG_ERR("Failed to open file for writing: %s\n", filename.c_str()); - return; - } - - int fileSize = 54 + 3 * img.nx * img.ny; // File header + info header + pixel data - int bytesPerPixel = 3; - int widthInBytes = img.nx * bytesPerPixel; - int paddingAmount = (4 - (widthInBytes % 4)) % 4; - int stride = widthInBytes + paddingAmount; - - // Bitmap file header - unsigned char fileHeader[14] = { - 'B','M', // Signature - 0,0,0,0, // Image file size in bytes - 0,0,0,0, // Reserved - 54,0,0,0 // Start of pixel array - }; - - // Total file size - fileSize = 54 + (stride * img.ny); - fileHeader[2] = (unsigned char)(fileSize); - fileHeader[3] = (unsigned char)(fileSize >> 8); - fileHeader[4] = (unsigned char)(fileSize >> 16); - fileHeader[5] = (unsigned char)(fileSize >> 24); - - // Bitmap information header (BITMAPINFOHEADER) - unsigned char infoHeader[40] = { - 40,0,0,0, // Size of this header (40 bytes) - 0,0,0,0, // Image width - 0,0,0,0, // Image height - 1,0, // Number of color planes - 24,0, // Bits per pixel - 0,0,0,0, // No compression - 0,0,0,0, // Image size (can be 0 for no compression) - 0,0,0,0, // X pixels per meter (not specified) - 0,0,0,0, // Y pixels per meter (not specified) - 0,0,0,0, // Total colors (color table not used) - 0,0,0,0 // Important colors (all are important) - }; - - // Width and height in the information header - infoHeader[4] = (unsigned char)(img.nx); - infoHeader[5] = (unsigned char)(img.nx >> 8); - infoHeader[6] = (unsigned char)(img.nx >> 16); - infoHeader[7] = (unsigned char)(img.nx >> 24); - infoHeader[8] = (unsigned char)(img.ny); - infoHeader[9] = (unsigned char)(img.ny >> 8); - infoHeader[10] = (unsigned char)(img.ny >> 16); - infoHeader[11] = (unsigned char)(img.ny >> 24); - - // Write file headers - file.write(reinterpret_cast(fileHeader), sizeof(fileHeader)); - file.write(reinterpret_cast(infoHeader), sizeof(infoHeader)); - - // Pixel data - std::vector padding(3, 0); // Max padding size to be added to each row - for (int y = img.ny - 1; y >= 0; --y) { // BMP files are stored bottom-to-top - for (int x = 0; x < img.nx; ++x) { - // Each pixel - size_t pixelIndex = (y * img.nx + x) * 3; - unsigned char pixel[3] = { - img.buf[pixelIndex + 2], // BMP stores pixels in BGR format - img.buf[pixelIndex + 1], - img.buf[pixelIndex] - }; - file.write(reinterpret_cast(pixel), 3); - } - // Write padding for the row - file.write(reinterpret_cast(padding.data()), paddingAmount); - } - - file.close(); -} - -// debug function to convert f32 to u8 -static void clip_image_convert_f32_to_u8(const clip_image_f32& src, clip_image_u8& dst) { - dst.nx = src.nx; - dst.ny = src.ny; - dst.buf.resize(3 * src.nx * src.ny); - for (size_t i = 0; i < src.buf.size(); ++i) { - dst.buf[i] = static_cast(std::min(std::max(int(src.buf[i] * 255.0f), 0), 255)); - } -} -#endif - - // // clip layers // @@ -564,6 +460,7 @@ struct clip_ctx { struct clip_vision_model vision_model; projector_type proj_type = PROJECTOR_TYPE_MLP; + omni_vlm_version_type omni_vlm_ver_type; float image_mean[3]; float image_std[3]; @@ -785,6 +682,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 embeddings = ggml_add(ctx0, ggml_mul(ctx0, embeddings, model.post_ln_w), model.post_ln_b); } + if(ctx->omni_vlm_ver_type == omni_vlm_version_type::VLM_81_OCR || ctx->omni_vlm_ver_type == omni_vlm_version_type::VLM_81_INSTRUCT) { + embeddings = ggml_reshape_3d(ctx0, embeddings, embeddings->ne[0]*9, embeddings->ne[1]/9, 1); + } + embeddings = ggml_mul_mat(ctx0, model.mm_0_w, embeddings); embeddings = ggml_add(ctx0, embeddings, model.mm_0_b); @@ -1308,6 +1209,18 @@ struct clip_ctx * clip_model_load(const char * fname, const int verbosity = 1) { return new_clip; } +void clip_set_omni_vlm_version(struct clip_ctx * ctx_clip, const struct gpt_params * params) { + if (params->omni_vlm_version == "vlm-81-ocr") { + ctx_clip->omni_vlm_ver_type = omni_vlm_version_type::VLM_81_OCR; + } else if (params->omni_vlm_version == "vlm-81-instruct") { + ctx_clip->omni_vlm_ver_type = omni_vlm_version_type::VLM_81_INSTRUCT; + } else if (params->omni_vlm_version == "nano-vlm-instruct") { + ctx_clip->omni_vlm_ver_type = omni_vlm_version_type::NANO_VLM_INSTRUCT; + } else { + throw std::runtime_error(std::string("error vlm version info: ") + params->omni_vlm_version); + } +} + void clip_add_load_image_size(struct clip_ctx * ctx_clip, struct clip_image_size * load_image_size) { ctx_clip->load_image_size = load_image_size; } diff --git a/examples/omni-vlm/clip.h b/examples/omni-vlm/clip.h index 78588bdf1..ae004b6df 100644 --- a/examples/omni-vlm/clip.h +++ b/examples/omni-vlm/clip.h @@ -42,6 +42,9 @@ struct clip_image_f32_batch { CLIP_API struct clip_ctx * clip_model_load (const char * fname, int verbosity); CLIP_API struct clip_ctx * clip_model_load_cpu(const char * fname, int verbosity); +struct gpt_params; +CLIP_API void clip_set_omni_vlm_version(struct clip_ctx * ctx_clip, const struct gpt_params * params); + CLIP_API void clip_free(struct clip_ctx * ctx); CLIP_API size_t clip_embd_nbytes(const struct clip_ctx * ctx); diff --git a/examples/omni-vlm/latex.png b/examples/omni-vlm/latex.png new file mode 100644 index 0000000000000000000000000000000000000000..b97318fc049a3fd4822604db547ac71520a104e3 GIT binary patch literal 5931 zcmZWtcQ{;q(^iAkMeij#yLvCNTJ+wcmnFnny$hoE8U$grXhC$*SKS&diyad*;mjJLjKC(AQNb2G9dAFffQU9;q1K>DD`O!M}HRSF6BI z!N9<@a8*{;*HBhw)%S)uyShU#FnBWj(i^qFUX(G4<75+6C8bJ(Z21C1jF0Dro?|r* zxrOvP9td zIEiXV?O~4IuP;oBcgZDfVz|}r~A?bDy zROdozq1g4<2WlnYzGF-GP3lG_(^3v=~P9Chb|O-p6L<3sGHLRmbUI$a&1NBfV!^`lC@AXhn3duM&<3*$Y%bE zm*{C1CSsSMFunyRJ)Gd3oKrE9dilTxi}~FQNDs^Sbx4Pp0>&!Zaq^lV&)|-0Q;4~S zvyKi1-<^$*ff?tDfpceJ-UzX^J~hfn)s>ZG~h1W4ruGR5DW5(74k^ zj@}T6htCt3Z;Dml#9gU*S7UQua~*9NM;H`n?*wyz00W_(e;^ogfiia{6yj^o8VGgw z@R13WXa6fh=Fa}(2C=jLmE!wUp50tWpH&&=4Pli43IT=K6#%TPta9E?&N7B7YJbD; zp5)n|`1*RvfItBO0l)xJAk5nZBrGj04H6OoiHHc^WeEBNdHC7~3VQf({HMtO)=`1@ zIC{H!`ntkASpU?ucYyi%%CocoY4p$apZA0Wy8gG5htJ<*-3jr>>S{~}F%Al}L_=v_-+ zh5w$+-{5}-{|3r|{!IQaCjLY7U*0>O6##Oef6hz+u;+D@je&7rLqkQ$I1qEsn$XJF zRfHZ0Rnsdym-J;S{f3-SW*#v>4zP&ZGvSbvJxahl1)>yBPGw84b&OdS`}u z2M_ly^F-3nh>M)Kgu(Pgc2;ysqkg!0O zDY`Nq8>SsH?$MSh#2I=1aqECPFoE?WGdg~|wj@w;ftbmXZI<}@bzc&Gxb)b0{cz@i zx4o`T>7b{?*!z3mPeM0?( zvjx1yRa0;6SncxGF-(NSx*o=f=^G2Kc7LQ;GKef-XIJA|c6CxHO5uRxMb;k{&_*J^ z1($3bS zjWXW2hC&qNUM#G}LAeIV?XfQplJRzAMn6bnIjgU!DP(fkBD5@-hpsNBgO> z4fnyi>D}YetmV_6|4cKjLp;K-cdjIv(x4y5uimziWMb)A+Hz0akCfeR(0C$G3O_Nm z?iGQ(!H3n?swH+~-tPF`&^^W6FLrtwG|A$Q=VKasRX@ZAft@W3#VSK38>GjLCk5wh z2Z7jTvpJhcxc-k{PT2F)hf0E%iXpY!r8k}A`__e?%BOl0^h~jD2M-P%zmWMtO&elo z*2x!+7Z|&JR&kwvbj|#Db7+(@c75Bg^eM=_1)Utnd-wg=;DO;g_(+}5lpIJ-X6HML z9k6EpmQK37Uadd(>|)lup+7V#x-*blkbaXbxrn^GjHlySb~Qwv<&g(F$U zoa@hJhddLdZ-y`fmR7%jjW=d!$OA(@@DikVz#DeMD87^@c^YbsjM;Nm2hx}xFeb*_ z9=h)xOSGeJUFCKUmhxsX&zo;Sq@Pn*I&sz!s?!5?qZPe0m>Jb@)Yi+o^MWjNuzCB3 zd^fNCIKhsqtIO@$?~yQrth)E8Nx23MU%K@)g;y;Teqi-;p*Kq-3hhpfr?$FomhBOf zy9W^=zjX$#Y)IREWm|$@+&-U32A@q6PRp$mZ^gGEM~zQBY_l`%*O!~GZMLRGX+UWvoj5qZ zv#0(bpO?c1>RpuPesyk}bYtaiwOIFkJ2>X!X*VtHoPB-q0Ziy=C`(H^=|ZxOZVbDq z77YKr0~D}sVltaBYaHdg`%8#QdT8KkrxlEhGWltV41-^a&pJ{?<3#$;USq6a((>%+ zt@b1L5*C&b$nFnIkExQcw#p?aq()GtY5dh%Lg){KnAou#0t%Py7Yu%tN5}J<#Lj01 zV74tiGZw_~WFI2rzWm~x;OcOBy^_5fW6rI4l?AJvMd{n>HJI7Pq7lTjYP=~Yui->% z;ZleFEO+&>1F`XHgL>iw4cLj~eU#2v`3*#{zE2z*@d`-`${*)Tka*z7>3x4y{eo{W zOF0~*a0oNS-&Z=YC*|V};Ae}%Q=7cOFV=RD2bydwoxQj?Cbco| zT<0GG(&B!edf+beW{)VZ%P=ptVda`kQJy&gRkB<+;~z^t5HFu7&NKFCa7Vl#-Am(e zpm3exgwMHGfl)K#v}I%GEwfCJu+3FDe~w2%T8U}bY_)y$m*d5?W(n$MyrrfT|n_~ud`D{EBCE*%b)ccSV9PI)TfQPktaiw%0vn^~T;Z1&hM}v<#Dr|Ut?o%;> z4D!ie{fUkvUE5Scq3*B_mn6~a&jsGqiFFU%%^kM&5T$1Z73G5-hCy53#8x7_|d=>UBcK7{AIt| z)0WOLlia;hRfHU`&>8;R4n>o1=Wq6k2?jL}5VyXKbT+7>v5MkBao6d z7fGGDTAq~9koCq3vqM0t=qLJ%I$ec2Zl?R1gJ!L}(L(7~Eg{8|=Sy*~9~42~@vM}# z{v51}jQE15%7;!9EQnEEWu}!rWZTkeJZ(u!RleYXUErv$9aHbRBA4{Hri+#nryZso zeiV+<&YE;Z8)(n}bk&uQI-FYc4!zD5zxbYLj-foh+yC?}8P;rlVKAT8vjR$RgFtvx zCQx#*oudwe%LH7V5kGLa*(dUlsn4}RO? zWMw@OOWD|{Enqw~b=M6w_vs6#34fP?%ff|_gN1KXnB=ILyqC}8G0!C`WTjJ(O=Zwi zSoBz#>qTSsr19xN%>KPWP8A|_`?{lcFAsS{(gMMCe5cI|{OBd}_Y1Eu$w7^`%8XoruF*EmCyGqj1*|Sa(6w9>Y=okn6|tbP zk62AeLNF~hq;PVy&F zAE|gyCaycV#D_v9c%nl%BT2MQ@u%~dR99DLp>N|%%H}>Ce3>s+_wSt1V~`wnZ!wh- zMY{~KzS=sxPp$MYT+g}&tXs4u8#vIAJJ`|X?b3#eir08#@KlgVxGKbKsqR*U#$=rN zzRq`o#jc)U4MFX+0~#sffC~i&Cp-6Vct`1Ldb!n}Hi3Zq%TGb%T!gAbHK2DmDOq#P zc**7w3)(?vG6`jijn_`q5Ou|2`GGUiHrROy9gpd4b==2Oo+-sBTih3vJu~A-lzq8u zC%E1T+xeGYRge9cQ646Fl*2AruO|_6I(}v64t*l7D3v#Y=4oryCZ8@B5IM0e|Afj9 z9#cpVZe^>|y~^5gnj^CX16%v&n@mrh$xJyz#Klr!wt=XV_#M$+1h#%KpeK|KkW{Ws zZ!YVgFRwXOt3O=CknI7uud&R;QRctN7lI=9IH3MjczLhE3%u)CSkx~w$$>ie897Mq zUG$6MPddbw5XyV8I-d`*R3WNS%h)mVKROe-d8Y&#x+-rapIN@1x9JC|$5g9uRE`t~ zykn@h30p|E=Jhr)Lm!Gm$%U)1en~%n#U&>x_IWJ41rNtJ^LP7quK=R&l|rGx-J_i( z>A!H+2rPJYStbT?dk|?7@w5+!Ecaaw&W~AQ2bh=ipl%YfHzk`&1}Q&9EoP$ej4uWb zGD5UP-hVGq3Un^#m)INaPxwhw^6{NPKOH7{_=vw5o5KvIK*+VN4tW&yh`1Wj;}*=t z%f@&m1JGLa7J1Z4!a%K>3Q=h{;t^yQLyiLvQTv@!p6lhf=psGlnGlEflTQz$oa_n8 zMWj{KiN7NCoP@#mjGWvO9)-1Ox-|rU29SO~t*QOV3w0!#43SR{HILx%mwj!(wk`&8 z=GdFy<+2>^on3#Hy`H=2^7(WtF=Lu&P+3j!fg9+vt|x;$iR0HjcM#nFpgdgYMrr}& z6n=BiBNV9sb!-F-7AEl?2RXPgY33sRVDj8D0w& z$KuONI(?yTi`4~S($s~*uTa7H3}Wd^YLjI{R%}CM`EIjiqs4+x=Z7am({$11q0C}z zYz<7&zxgU!R5=z@CE_wu08@_4#y@JZ(OHxBK690VPQJ= z#aDi$LwJ(zbHAaa_KuF@Sf9^oMMR$>9KXM$)90rS$O3XtvrYG#%UIO*otxO|mbwB; z+MQW`QEKwM_2AJT!sTP|pxRWd=s}wFLETd(;ojGs47g8}KN`URsj{w!s!AH3Ab*Us zU^jh2K!^@DkCcMhpmke(rpqs}bb3HA`&pg|qO)Xw0~1YqtuPsO3<1!wsB9t>OQTN+ zJ7J)Em~31V8%CAknvRtBd4QIWRh~8)ChI*-*4sPc;-ofau$|b;X_k4x{sMeZ%UXOy z$jhr&Y-FJ2vdJx$nV$mJ(JfCh9nu=~@YZ^QVCKo(n&}rVHVl3zeh!n2)jtH3Kwef* z1T$(acvM1aD0oynx#<-XEk?_yrR%gTRq7uH9q?$jJBaSm&NiP2X}j=5DfV}!f5?*+ zWUj2ieIpP)GbbOP-$uUd=K2!E@OAc&2(Jt*B>SKm?UN4Zv7?PGpXqmtrk`j54Zbtl7?ekZ_~bJh56mjuiLJc z%ZL;JnMJLRIRmN+KfZ-|uYgqF!|7*CR&~8a#-$OK!0LPsi_kJsi{XOCiOG;=;s-5uZFs@F zUFLJoblqwPp(8J@f5Sp4Ub;UdU4Bwqi?>8IZ{0JsOe(uYXJ*Zaxnh*GIcT&;i>d1- z>mgU)DRsAB+UPhmmQ}l3Eu}J~)dExRmQhkbxw%yGweq&j;mw*Pjzh)kpk(9sXxWWF P%TWzgU6p#U-Shte5t8c8 literal 0 HcmV?d00001 diff --git a/examples/omni-vlm/omni-vlm-cli.cpp b/examples/omni-vlm/omni-vlm-cli.cpp index 68e833182..5baa5adc7 100644 --- a/examples/omni-vlm/omni-vlm-cli.cpp +++ b/examples/omni-vlm/omni-vlm-cli.cpp @@ -212,8 +212,8 @@ static struct omnivlm_context * omnivlm_init_context(gpt_params * params, llama_ prompt = "describe the image in detail."; } - auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 10); - + auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 0); + clip_set_omni_vlm_version(ctx_clip, params); llama_context_params ctx_params = llama_context_params_from_gpt_params(*params); ctx_params.n_ctx = params->n_ctx < 2048 ? 2048 : params->n_ctx; // we need a longer context size to process image embeddings @@ -249,9 +249,6 @@ int main(int argc, char ** argv) { gpt_params params; - // if (!gpt_params_parse(argc, argv, params, LLAMA_EXAMPLE_LLAVA, print_usage)) { - // return 1; - // } if (!gpt_params_parse(argc, argv, params)) { print_usage(argc, argv, params); return 1; @@ -261,8 +258,21 @@ int main(int argc, char ** argv) { print_usage(argc, argv, {}); return 1; } + if (params.omni_vlm_version != "vlm-81-ocr" && params.prompt.empty()) { + LOG_TEE("%s : prompt is empty.\n", __func__); + print_usage(argc, argv, {}); + return 1; + } - params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\nDescribe this image for me\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>"; + if (params.omni_vlm_version == "vlm-81-ocr") { + params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n <|ocr_start|><|vision_start|><|image_pad|><|vision_end|><|ocr_end|><|im_end|>"; + } else if (params.omni_vlm_version == "vlm-81-instruct" || params.omni_vlm_version == "nano-vlm-instruct") { + params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" + params.prompt + "\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>"; + } else { + LOG_TEE("%s : error: you set wrong vlm version info:'%s'.\n", __func__, params.omni_vlm_version.c_str()); + print_usage(argc, argv, {}); + return 1; + } auto * model = omnivlm_init(¶ms); if (model == NULL) { @@ -270,8 +280,12 @@ int main(int argc, char ** argv) { return 1; } - auto * ctx_omnivlm = omnivlm_init_context(¶ms, model); + + // temporarily set to greedy decoding. + params.sparams.top_k = 1; + params.sparams.top_p = 1.0f; + for (auto & image : params.image) { auto * image_embed = load_image(ctx_omnivlm, ¶ms, image); if (!image_embed) { diff --git a/examples/omni-vlm/omni-vlm-wrapper-cli.cpp b/examples/omni-vlm/omni-vlm-wrapper-cli.cpp index 731b7791e..6a65b7643 100644 --- a/examples/omni-vlm/omni-vlm-wrapper-cli.cpp +++ b/examples/omni-vlm/omni-vlm-wrapper-cli.cpp @@ -1,15 +1,24 @@ // WARNING: this .cpp file is only for debugging. do not user directly. #include "omni-vlm-wrapper.h" +#include + + +using std::cout; +using std::endl; int main(int argc, char ** argv) { - const char* llm_model = ""; - const char* mmproj_model = ""; - const char* image_path = ""; + const char* llm_model = ""; + const char* mmproj_model = ""; + const char* image_path = ""; const char* prompt = ""; - omnivlm_init(llm_model, mmproj_model); - omnivlm_inference(prompt, image_path); - omnivlm_inference(prompt, image_path); + omnivlm_init(llm_model, mmproj_model, "vlm-81-ocr"); + + const char* res; + res = omnivlm_inference(prompt, image_path); + cout << "RES: " << res << endl; + res = omnivlm_inference(prompt, image_path); + cout << "RES: " << res << endl; omnivlm_free(); return 0; diff --git a/examples/omni-vlm/omni-vlm-wrapper.cpp b/examples/omni-vlm/omni-vlm-wrapper.cpp index 8c8cac786..c4819be4b 100644 --- a/examples/omni-vlm/omni-vlm-wrapper.cpp +++ b/examples/omni-vlm/omni-vlm-wrapper.cpp @@ -65,7 +65,8 @@ static struct omnivlm_context * omnivlm_init_context(gpt_params * params, llama_ prompt = "describe the image in detail."; } - auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 10); + auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 0); + clip_set_omni_vlm_version(ctx_clip, params); llama_context_params ctx_params = llama_context_params_from_gpt_params(*params); @@ -135,14 +136,14 @@ static const char* process_prompt(struct omnivlm_context * ctx_omnivlm, struct o const int max_tgt_len = params->n_predict < 0 ? 256 : params->n_predict; - std::string full_prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" \ - + prompt + "\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>"; - size_t image_pos = full_prompt.find("<|image_pad|>"); + // std::string full_prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" \ + // + prompt + "\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>"; + size_t image_pos = params->prompt.find("<|image_pad|>"); std::string system_prompt, user_prompt; // new templating mode: Provide the full prompt including system message and use as a placeholder for the image - system_prompt = full_prompt.substr(0, image_pos); - user_prompt = full_prompt.substr(image_pos + std::string("<|image_pad|>").length()); + system_prompt = params->prompt.substr(0, image_pos); + user_prompt = params->prompt.substr(image_pos + std::string("<|image_pad|>").length()); if (params->verbose_prompt) { auto tmp = ::llama_tokenize(ctx_omnivlm->ctx_llama, system_prompt, true, true); for (int i = 0; i < (int) tmp.size(); i++) { @@ -157,6 +158,9 @@ static const char* process_prompt(struct omnivlm_context * ctx_omnivlm, struct o } } + params->sparams.top_k = 1; + params->sparams.top_p = 1.0f; + eval_string(ctx_omnivlm->ctx_llama, system_prompt.c_str(), params->n_batch, &n_past, true); omnivlm_eval_image_embed(ctx_omnivlm->ctx_llama, image_embed, params->n_batch, &n_past); eval_string(ctx_omnivlm->ctx_llama, user_prompt.c_str(), params->n_batch, &n_past, false); @@ -217,8 +221,10 @@ static void print_usage(int argc, char ** argv, const gpt_params & params) { } // inference interface definition -void omnivlm_init(const char* llm_model_path, const char* projector_model_path) { - const char* argv = "hello-omni-vlm-wrapper-cli"; +void omnivlm_init(const char* llm_model_path, const char* projector_model_path, const char* omni_vlm_version) { + std::cout << "debug0 " << llm_model_path << std::endl; + std::cout << "debug1 " << omni_vlm_version << std::endl; + const char* argv = "omni-wrapper-py"; char* nc_argv = const_cast(argv); if (!gpt_params_parse(1, &nc_argv, params)) { print_usage(1, &nc_argv, {}); @@ -226,6 +232,17 @@ void omnivlm_init(const char* llm_model_path, const char* projector_model_path) } params.model = llm_model_path; params.mmproj = projector_model_path; + params.omni_vlm_version = omni_vlm_version; + + std::string omni_vlm_ver = params.omni_vlm_version; + std::cout << "\t\t DEBUG omni_ver" << std::endl; + std::cout << params.omni_vlm_version << std::endl; + if(omni_vlm_ver != "vlm-81-ocr" && omni_vlm_ver != "vlm-81-instruct" && omni_vlm_ver != "nano-vlm-instruct") { + fprintf(stderr, "%s: error: you set wrong omni_vlm_string: %s\n", __func__, omni_vlm_version); + fprintf(stderr, "%s: Valid omni_vlm_version set is ('vlm-81-ocr', 'vlm-81-instruct', 'nano-vlm-instruct')\n", __func__); + throw std::runtime_error("You set wrong vlm_version info strings."); + } + model = omnivlm_init(¶ms); if (model == nullptr) { fprintf(stderr, "%s: error: failed to init omnivlm model\n", __func__); @@ -237,6 +254,16 @@ void omnivlm_init(const char* llm_model_path, const char* projector_model_path) const char* omnivlm_inference(const char *prompt, const char *imag_path) { std::string image = imag_path; params.prompt = prompt; + + if (params.omni_vlm_version == "vlm-81-ocr") { + params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n <|ocr_start|><|vision_start|><|image_pad|><|vision_end|><|ocr_end|><|im_end|>"; + } else if (params.omni_vlm_version == "vlm-81-instruct" || params.omni_vlm_version == "nano-vlm-instruct") { + params.prompt = "<|im_start|>system\nYou are Nano-Omni-VLM, created by Nexa AI. You are a helpful assistant.<|im_end|>\n<|im_start|>user\n" + params.prompt + "\n<|vision_start|><|image_pad|><|vision_end|><|im_end|>"; + } else { + LOG_TEE("%s : error: you set wrong vlm version info:'%s'.\n", __func__, params.omni_vlm_version.c_str()); + throw std::runtime_error("You set wrong vlm_version info strings."); + } + auto * image_embed = load_image(ctx_omnivlm, ¶ms, image); if (!image_embed) { LOG_TEE("%s: failed to load image %s. Terminating\n\n", __func__, image.c_str()); diff --git a/examples/omni-vlm/omni-vlm-wrapper.h b/examples/omni-vlm/omni-vlm-wrapper.h index ff37a2550..22cc40533 100644 --- a/examples/omni-vlm/omni-vlm-wrapper.h +++ b/examples/omni-vlm/omni-vlm-wrapper.h @@ -20,7 +20,7 @@ extern "C" { #endif -OMNIVLM_API void omnivlm_init(const char* llm_model_path, const char* projector_model_path); +OMNIVLM_API void omnivlm_init(const char* llm_model_path, const char* projector_model_path, const char* omni_vlm_version); OMNIVLM_API const char* omnivlm_inference(const char* prompt, const char* imag_path); diff --git a/examples/omni-vlm/omni_vlm_cpp.py b/examples/omni-vlm/omni_vlm_cpp.py index e623ff53b..81edb6f1d 100644 --- a/examples/omni-vlm/omni_vlm_cpp.py +++ b/examples/omni-vlm/omni_vlm_cpp.py @@ -60,11 +60,11 @@ _lib = _load_shared_library(_lib_base_name, base_path) omni_char_p = ctypes.c_char_p -def omnivlm_init(llm_model_path: omni_char_p, mmproj_model_path: omni_char_p): - return _lib.omnivlm_init(llm_model_path, mmproj_model_path) +def omnivlm_init(llm_model_path: omni_char_p, mmproj_model_path: omni_char_p, vlm_version: omni_char_p): + return _lib.omnivlm_init(llm_model_path, mmproj_model_path, vlm_version) -_lib.omnivlm_init.argtypes = [omni_char_p, omni_char_p] +_lib.omnivlm_init.argtypes = [omni_char_p, omni_char_p, omni_char_p] _lib.omnivlm_init.restype = None diff --git a/examples/omni-vlm/omni_vlm_demo.py b/examples/omni-vlm/omni_vlm_demo.py index 0384631e0..fbed2758f 100644 --- a/examples/omni-vlm/omni_vlm_demo.py +++ b/examples/omni-vlm/omni_vlm_demo.py @@ -11,11 +11,12 @@ class NexaOmniVlmInference: A class used for vision language model inference. """ - def __init__(self, llm_model_path: str, mmproj_model_path: str): + def __init__(self, llm_model_path: str, mmproj_model_path: str, omni_vlm_version: str): self.llm_model = ctypes.c_char_p(llm_model_path.encode("utf-8")) self.mmproj_model = ctypes.c_char_p(mmproj_model_path.encode("utf-8")) + self.omni_vlm_version = ctypes.c_char_p(omni_vlm_version.encode("utf-8")) - omni_vlm_cpp.omnivlm_init(self.llm_model, self.mmproj_model) + omni_vlm_cpp.omnivlm_init(self.llm_model, self.mmproj_model, self.omni_vlm_version) def inference(self, prompt: str, image_path: str): prompt = ctypes.c_char_p(prompt.encode("utf-8")) @@ -34,19 +35,25 @@ if __name__ == "__main__": ) parser.add_argument("--model", type=str, help="Path to the llm model file") parser.add_argument("--mmproj", type=str, help="Path to the mmproj file") + parser.add_argument("--omni-vlm-version", type=str, help="omni-vlm-version info ('vlm-81-ocr', 'vlm-81-instruct', 'nano-vlm-instruct')") # parser.add_argument("--prompt", type=str, help="prompt string.") # parser.add_argument("--image-path", type=str, help="Path to the image.") args = parser.parse_args() - omni_vlm_obj = NexaOmniVlmInference(args.model, args.mmproj) + print("DEBUG") + print(args.omni_vlm_version) + omni_vlm_obj = NexaOmniVlmInference(args.model, args.mmproj, args.omni_vlm_version) # omni_vlm_obj.inference(args.prompt, args.image_path) while True: - print("Input your prompt:") - prompt = input() - if prompt == "": - print("ERROR: you input an empty prompt, try again.") - continue + if args.omni_vlm_version != "vlm-81-ocr": + print("Input your prompt:") + prompt = input() + if prompt == "": + print("ERROR: you input an empty prompt, try again.") + continue + else: + prompt = "" print("Input your image path:") image_path = input() while not os.path.exists(image_path):