From 770dc9da0d5724fc31dba1f6bf0ea93b201b137b Mon Sep 17 00:00:00 2001 From: Damian Stewart Date: Fri, 13 Oct 2023 17:53:17 +0200 Subject: [PATCH] add base64 in-prompt image support --- examples/llava/base64.hpp | 392 ++++++++++++++++++++++++++++++++++++++ examples/llava/llava.cpp | 84 +++++++- 2 files changed, 468 insertions(+), 8 deletions(-) create mode 100644 examples/llava/base64.hpp diff --git a/examples/llava/base64.hpp b/examples/llava/base64.hpp new file mode 100644 index 000000000..9a1923825 --- /dev/null +++ b/examples/llava/base64.hpp @@ -0,0 +1,392 @@ +/* +This is free and unencumbered software released into the public domain. + +Anyone is free to copy, modify, publish, use, compile, sell, or +distribute this software, either in source code form or as a compiled +binary, for any purpose, commercial or non-commercial, and by any +means. + +In jurisdictions that recognize copyright laws, the author or authors +of this software dedicate any and all copyright interest in the +software to the public domain. We make this dedication for the benefit +of the public at large and to the detriment of our heirs and +successors. We intend this dedication to be an overt act of +relinquishment in perpetuity of all present and future rights to this +software under copyright law. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, +EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF +MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. +IN NO EVENT SHALL THE AUTHORS BE LIABLE FOR ANY CLAIM, DAMAGES OR +OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, +ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR +OTHER DEALINGS IN THE SOFTWARE. + +For more information, please refer to +*/ + +#ifndef PUBLIC_DOMAIN_BASE64_HPP_ +#define PUBLIC_DOMAIN_BASE64_HPP_ + +#include +#include +#include +#include + +class base64_error : public std::runtime_error +{ +public: + using std::runtime_error::runtime_error; +}; + +class base64 +{ +public: + enum class alphabet + { + /** the alphabet is detected automatically */ + auto_, + /** the standard base64 alphabet is used */ + standard, + /** like `standard` except that the characters `+` and `/` are replaced by `-` and `_` respectively*/ + url_filename_safe + }; + + enum class decoding_behavior + { + /** if the input is not padded, the remaining bits are ignored */ + moderate, + /** if a padding character is encounter decoding is finished */ + loose + }; + + /** + Encodes all the elements from `in_begin` to `in_end` to `out`. + + @warning The source and destination cannot overlap. The destination must be able to hold at least + `required_encode_size(std::distance(in_begin, in_end))`, otherwise the behavior depends on the output iterator. + + @tparam Input_iterator the source; the returned elements are cast to `std::uint8_t` and should not be greater than + 8 bits + @tparam Output_iterator the destination; the elements written to it are from the type `char` + @param in_begin the beginning of the source + @param in_end the ending of the source + @param out the destination iterator + @param alphabet which alphabet should be used + @returns the iterator to the next element past the last element copied + @throws see `Input_iterator` and `Output_iterator` + */ + template + static Output_iterator encode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out, + alphabet alphabet = alphabet::standard) + { + constexpr auto pad = '='; + const char* alpha = alphabet == alphabet::url_filename_safe + ? "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_" + : "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + + while (in_begin != in_end) { + std::uint8_t i0 = 0, i1 = 0, i2 = 0; + + // first character + i0 = static_cast(*in_begin); + ++in_begin; + + *out = alpha[i0 >> 2 & 0x3f]; + ++out; + + // part of first character and second + if (in_begin != in_end) { + i1 = static_cast(*in_begin); + ++in_begin; + + *out = alpha[((i0 & 0x3) << 4) | (i1 >> 4 & 0x0f)]; + ++out; + } else { + *out = alpha[(i0 & 0x3) << 4]; + ++out; + + // last padding + *out = pad; + ++out; + + // last padding + *out = pad; + ++out; + + break; + } + + // part of second character and third + if (in_begin != in_end) { + i2 = static_cast(*in_begin); + ++in_begin; + + *out = alpha[((i1 & 0xf) << 2) | (i2 >> 6 & 0x03)]; + ++out; + } else { + *out = alpha[(i1 & 0xf) << 2]; + ++out; + + // last padding + *out = pad; + ++out; + + break; + } + + // rest of third + *out = alpha[i2 & 0x3f]; + ++out; + } + + return out; + } + /** + Encodes a string. + + @param str the string that should be encoded + @param alphabet which alphabet should be used + @returns the encoded base64 string + @throws see base64::encode() + */ + static std::string encode(const std::string& str, alphabet alphabet = alphabet::standard) + { + std::string result; + + result.reserve(required_encode_size(str.length()) + 1); + + encode(str.begin(), str.end(), std::back_inserter(result), alphabet); + + return result; + } + /** + Encodes a char array. + + @param buffer the char array + @param size the size of the array + @param alphabet which alphabet should be used + @returns the encoded string + */ + static std::string encode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::standard) + { + std::string result; + + result.reserve(required_encode_size(size) + 1); + + encode(buffer, buffer + size, std::back_inserter(result), alphabet); + + return result; + } + /** + Decodes all the elements from `in_begin` to `in_end` to `out`. `in_begin` may point to the same location as `out`, + in other words: inplace decoding is possible. + + @warning The destination must be able to hold at least `required_decode_size(std::distance(in_begin, in_end))`, + otherwise the behavior depends on the output iterator. + + @tparam Input_iterator the source; the returned elements are cast to `char` + @tparam Output_iterator the destination; the elements written to it are from the type `std::uint8_t` + @param in_begin the beginning of the source + @param in_end the ending of the source + @param out the destination iterator + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @returns the iterator to the next element past the last element copied + @throws base64_error depending on the set behavior + @throws see `Input_iterator` and `Output_iterator` + */ + template + static Output_iterator decode(Input_iterator in_begin, Input_iterator in_end, Output_iterator out, + alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + //constexpr auto pad = '='; + std::uint8_t last = 0; + auto bits = 0; + + while (in_begin != in_end) { + auto c = *in_begin; + ++in_begin; + + if (c == '=') { + break; + } + + auto part = _base64_value(alphabet, c); + + // enough bits for one byte + if (bits + 6 >= 8) { + *out = (last << (8 - bits)) | (part >> (bits - 2)); + ++out; + + bits -= 2; + } else { + bits += 6; + } + + last = part; + } + + // check padding + if (behavior != decoding_behavior::loose) { + while (in_begin != in_end) { + auto c = *in_begin; + ++in_begin; + + if (c != '=') { + throw base64_error("invalid base64 character."); + } + } + } + + return out; + } + /** + Decodes a string. + + @param str the base64 encoded string + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @returns the decoded string + @throws see base64::decode() + */ + static std::string decode(const std::string& str, alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + std::string result; + + result.reserve(max_decode_size(str.length())); + + decode(str.begin(), str.end(), std::back_inserter(result), alphabet, behavior); + + return result; + } + /** + Decodes a string. + + @param buffer the base64 encoded buffer + @param size the size of the buffer + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @returns the decoded string + @throws see base64::decode() + */ + static std::string decode(const char* buffer, std::size_t size, alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + std::string result; + + result.reserve(max_decode_size(size)); + + decode(buffer, buffer + size, std::back_inserter(result), alphabet, behavior); + + return result; + } + /** + Decodes a string inplace. + + @param[in,out] str the base64 encoded string + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @throws base64::decode_inplace() + */ + static void decode_inplace(std::string& str, alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + str.resize(decode(str.begin(), str.end(), str.begin(), alphabet, behavior) - str.begin()); + } + /** + Decodes a char array inplace. + + @param[in,out] str the string array + @param size the length of the array + @param alphabet which alphabet should be used + @param behavior the behavior when an error was detected + @returns the pointer to the next element past the last element decoded + @throws base64::decode_inplace() + */ + static char* decode_inplace(char* str, std::size_t size, alphabet alphabet = alphabet::auto_, + decoding_behavior behavior = decoding_behavior::moderate) + { + return decode(str, str + size, str, alphabet, behavior); + } + /** + Returns the required decoding size for a given size. The value is calculated with the following formula: + + $$ + \lceil \frac{size}{4} \rceil \cdot 3 + $$ + + @param size the size of the encoded input + @returns the size of the resulting decoded buffer; this the absolute maximum + */ + static std::size_t max_decode_size(std::size_t size) noexcept + { + return (size / 4 + (size % 4 ? 1 : 0)) * 3; + } + /** + Returns the required encoding size for a given size. The value is calculated with the following formula: + + $$ + \lceil \frac{size}{3} \rceil \cdot 4 + $$ + + @param size the size of the decoded input + @returns the size of the resulting encoded buffer + */ + static std::size_t required_encode_size(std::size_t size) noexcept + { + return (size / 3 + (size % 3 ? 1 : 0)) * 4; + } + +private: + static std::uint8_t _base64_value(alphabet& alphabet, char c) + { + if (c >= 'A' && c <= 'Z') { + return c - 'A'; + } else if (c >= 'a' && c <= 'z') { + return c - 'a' + 26; + } else if (c >= '0' && c <= '9') { + return c - '0' + 52; + } + + // comes down to alphabet + if (alphabet == alphabet::standard) { + if (c == '+') { + return 62; + } else if (c == '/') { + return 63; + } + } else if (alphabet == alphabet::url_filename_safe) { + if (c == '-') { + return 62; + } else if (c == '_') { + return 63; + } + } // auto detect + else { + if (c == '+') { + alphabet = alphabet::standard; + + return 62; + } else if (c == '/') { + alphabet = alphabet::standard; + + return 63; + } else if (c == '-') { + alphabet = alphabet::url_filename_safe; + + return 62; + } else if (c == '_') { + alphabet = alphabet::url_filename_safe; + + return 63; + } + } + + throw base64_error("invalid base64 character."); + } +}; + +#endif // !PUBLIC_DOMAIN_BASE64_HPP_ diff --git a/examples/llava/llava.cpp b/examples/llava/llava.cpp index 22e625236..bfa2f72a5 100644 --- a/examples/llava/llava.cpp +++ b/examples/llava/llava.cpp @@ -8,6 +8,8 @@ #include #include +#include "base64.hpp" + static void show_additional_info(int /*argc*/, char ** argv) { printf("\n example usage: %s -m --mmproj --image [--temp 0.1] [-p \"describe the image in detail.\"]\n", argv[0]); printf(" note: a lower temperature value like 0.1 is recommended for better quality.\n"); @@ -35,24 +37,90 @@ static bool encode_image_with_clip(clip_ctx * ctx_clip, int n_threads, const cli return true; } +static const char* IMG_BASE64_TAG_BEGIN = ""; + +static void find_image_tag_in_prompt(const std::string& prompt, size_t& begin_out, size_t& end_out) { + begin_out = prompt.find(IMG_BASE64_TAG_BEGIN); + end_out = prompt.find(IMG_BASE64_TAG_END, (begin_out == std::string::npos) ? 0UL : begin_out); +} + +static bool prompt_contains_image(const std::string& prompt) { + size_t begin, end; + find_image_tag_in_prompt(prompt, begin, end); + return (begin != std::string::npos); +} + +// replaces the base64 image tag in the prompt with `replacement` +static bool get_image_from_prompt(const std::string& prompt, clip_image_u8 * img) { + size_t img_base64_str_start, img_base64_str_end; + find_image_tag_in_prompt(prompt, img_base64_str_start, img_base64_str_end); + if (img_base64_str_start == std::string::npos || img_base64_str_end == std::string::npos) { + fprintf(stderr, "%s: invalid base64 image tag. must be %s%s\n", __func__, IMG_BASE64_TAG_BEGIN, IMG_BASE64_TAG_END); + return false; + } + + auto base64_bytes_start = img_base64_str_start + strlen(IMG_BASE64_TAG_BEGIN); + auto base64_bytes_count = img_base64_str_end - base64_bytes_start; + auto base64_str = prompt.substr(base64_bytes_start, base64_bytes_count ); + printf("base64_str: '%s'\n", base64_str.c_str()); + + auto required_bytes = base64::required_encode_size(base64_str.size()); + auto img_bytes = std::vector(required_bytes); + auto img_bytes_end = base64::decode(base64_str.begin(), base64_str.end(), img_bytes.begin()); + auto img_bytes_len = img_bytes_end - img_bytes.begin(); + + auto img_loaded_ok = clip_image_load_from_bytes(img_bytes.data(), img_bytes_len, img); + if (!img_loaded_ok) { + fprintf(stderr, "%s: could not load image from base64 string.\n", __func__); + return false; + } + + return true; +} + +static std::string remove_image_from_prompt(const std::string& prompt, const char * replacement = "") { + size_t begin, end; + find_image_tag_in_prompt(prompt, begin, end); + if (begin == std::string::npos || end == std::string::npos) { + return prompt; + } + auto pre = prompt.substr(0, begin); + auto post = prompt.substr(end+1); + return pre + replacement + post; +} + struct llava_context * llava_init(gpt_params * params) { const char * clip_path = params->mmproj.c_str(); const char * img_path = params->image.c_str(); - if (params->prompt.empty()) { - params->prompt = "describe the image in detail."; + auto prompt = params->prompt; + if (prompt.empty()) { + prompt = "describe the image in detail."; } - + auto ctx_clip = clip_model_load(clip_path, /*verbosity=*/ 1); // load and preprocess the image clip_image_u8 img; - if (!clip_image_load_from_file(img_path, &img)) { - fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path); - clip_free(ctx_clip); - return NULL; + if (prompt_contains_image(prompt)) { + if (img_path) { + printf("using base64 encoded image instead of command line image path\n"); + } + if (!get_image_from_prompt(prompt, &img)) { + fprintf(stderr, "%s: can't load image from prompt\n", __func__); + clip_free(ctx_clip); + return NULL; + } + prompt = remove_image_from_prompt(prompt); + } else { + if (!clip_image_load_from_file(img_path, &img)) { + fprintf(stderr, "%s: is %s really an image file?\n", __func__, img_path); + clip_free(ctx_clip); + return NULL; + } } float * image_embd = (float *)malloc(clip_embd_nbytes(ctx_clip)); @@ -169,7 +237,7 @@ int main(int argc, char ** argv) { show_additional_info(argc, argv); return 1; } - if (params.mmproj.empty() || params.image.empty()) { + if (params.mmproj.empty() || (params.image.empty() && !prompt_contains_image(params.prompt))) { gpt_print_usage(argc, argv, params); show_additional_info(argc, argv); return 1;