add multimodal input - alfa

This commit is contained in:
FSSRepo 2023-10-13 23:36:32 -04:00
parent de35b47908
commit 9f72b44635
3 changed files with 2442 additions and 2249 deletions

File diff suppressed because it is too large Load diff

View file

@ -194,6 +194,7 @@
import { llama } from '/completion.js'; import { llama } from '/completion.js';
import { SchemaConverter } from '/json-schema-to-grammar.mjs'; import { SchemaConverter } from '/json-schema-to-grammar.mjs';
let selected_image = false;
const session = signal({ const session = signal({
prompt: "This is a conversation between User and Llama, a friendly chatbot. Llama is helpful, kind, honest, good at writing, and never fails to answer any requests immediately and with precision.", prompt: "This is a conversation between User and Llama, a friendly chatbot. Llama is helpful, kind, honest, good at writing, and never fails to answer any requests immediately and with precision.",
@ -222,6 +223,7 @@
grammar: '', grammar: '',
n_probs: 0, // no completion_probabilities, n_probs: 0, // no completion_probabilities,
slot_id: -1, slot_id: -1,
image_data: [],
cache_prompt: true cache_prompt: true
}) })
@ -424,7 +426,7 @@
transcriptUpdate([...session.value.transcript, ["{{user}}", msg]]) transcriptUpdate([...session.value.transcript, ["{{user}}", msg]])
const prompt = template(session.value.template, { let prompt = template(session.value.template, {
message: msg, message: msg,
history: session.value.transcript.flatMap( history: session.value.transcript.flatMap(
([name, data]) => ([name, data]) =>
@ -439,7 +441,9 @@
) )
).join("\n"), ).join("\n"),
}); });
if(selected_image) {
prompt = `A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: [img-10]${msg}\nASSISTANT:`;
}
await runLlama(prompt, { await runLlama(prompt, {
...params.value, ...params.value,
stop: ["</s>", template("{{char}}:"), template("{{user}}:")], stop: ["</s>", template("{{char}}:"), template("{{user}}:")],
@ -472,6 +476,24 @@
transcriptUpdate([]); transcriptUpdate([]);
} }
const uploadImage = (e) => {
e.preventDefault();
document.getElementById("fileInput").click();
document.getElementById("fileInput").addEventListener("change", function (event) {
const selectedFile = event.target.files[0];
if (selectedFile) {
const reader = new FileReader();
reader.onload = function () {
const image_data = reader.result;
params.value = {...params.value, image_data: [
{ data: image_data.replace('data:image/png;base64,', ''), id: 10 }] }
};
selected_image = true;
reader.readAsDataURL(selectedFile);
}
});
}
function MessageInput() { function MessageInput() {
const message = useSignal("") const message = useSignal("")
@ -502,6 +524,7 @@
</div> </div>
<div class="right"> <div class="right">
<button type="submit" disabled=${generating.value}>Send</button> <button type="submit" disabled=${generating.value}>Send</button>
<button onclick=${uploadImage} style="margin-left: 10px;margin-right: 10px;">Upload Image</button>
<button onclick=${stop} disabled=${!generating.value}>Stop</button> <button onclick=${stop} disabled=${!generating.value}>Stop</button>
<button onclick=${reset}>Reset</button> <button onclick=${reset}>Reset</button>
</div> </div>
@ -957,7 +980,8 @@
</head> </head>
<body> <body>
<div id="container"></div> <div id="container">
<input type="file" id="fileInput" accept="image/*" style="display: none;"></div>
<div id="portal"></div> <div id="portal"></div>
</body> </body>

View file

@ -74,7 +74,6 @@ static const std::string base64_chars =
"abcdefghijklmnopqrstuvwxyz" "abcdefghijklmnopqrstuvwxyz"
"0123456789+/"; "0123456789+/";
static inline bool is_base64(uint8_t c) { static inline bool is_base64(uint8_t c) {
return (isalnum(c) || (c == '+') || (c == '/')); return (isalnum(c) || (c == '+') || (c == '/'));
} }
@ -273,6 +272,17 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
return out; return out;
} }
#ifdef SERVER_MULTIMODAL_SUPPORT
struct slot_image {
clip_image_u8 img_data;
bool request_encode_image = false;
float* image_embedding = nullptr;
int image_tokens = 0;
int id;
std::string prefix_prompt = ""; // before of this image
};
#endif
struct llama_client_slot struct llama_client_slot
{ {
int id; int id;
@ -317,10 +327,7 @@ struct llama_client_slot
llama_grammar *grammar = nullptr; llama_grammar *grammar = nullptr;
#ifdef SERVER_MULTIMODAL_SUPPORT #ifdef SERVER_MULTIMODAL_SUPPORT
clip_image_u8 img_data; std::vector<slot_image> images;
bool request_encode_image = false;
float* image_embedding = nullptr;
int image_tokens = 0;
#endif #endif
void reset() { void reset() {
@ -806,6 +813,91 @@ struct llama_server_context
}); });
return has_next_token; // continue return has_next_token; // continue
} }
#ifdef SERVER_MULTIMODAL_SUPPORT
bool processImages(llama_client_slot &slot) {
for(slot_image &img : slot.images) {
if(!img.request_encode_image) {
continue;
}
clip_image_f32 img_res;
if (!clip_image_preprocess(clp_ctx, &img.img_data, &img_res, /*pad2square =*/ true)) {
LOG_TEE("Error processing the given image");
clip_free(clp_ctx);
return false;
}
img.image_tokens = clip_n_patches(clp_ctx);
img.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx));
if (!img.image_embedding) {
LOG_TEE("Unable to allocate memory for image embeddings\n");
clip_free(clp_ctx);
return false;
}
LOG_TEE("slot %i - encoding image %i\n", slot.id, img.id);
if (!clip_image_encode(clp_ctx, params.n_threads, &img_res, img.image_embedding)) {
LOG_TEE("Unable to encode image\n");
return false;
}
img.request_encode_image = false;
}
return slot.images.size() > 0;
}
// for multiple images processing
bool ingestImages(llama_client_slot &slot, int n_batch) {
int image_idx = 0;
while(image_idx < slot.images.size()) {
slot_image img = slot.images[image_idx];
// process prefix prompt
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
if (llama_decode(ctx, batch_view)) {
LOG_TEE("%s : failed to eval\n", __func__);
return false;
}
}
// process image with llm
for (int i = 0; i < img.image_tokens; i += n_batch) {
int n_eval = img.image_tokens - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
if (llama_decode(ctx, batch)) {
LOG_TEE("%s : failed to eval image\n", __func__);
return false;
}
slot.n_past += n_eval;
}
image_idx++;
// append prefix of next image
batch.n_tokens = 0;
std::vector<llama_token> append_tokens = tokenize(
image_idx >= slot.images.size() ? slot.params.input_suffix : // no more images, then process suffix prompt
slot.images[image_idx].prefix_prompt, true); // has next image
for (int i = 0; i < append_tokens.size(); ++i) {
batch.token [batch.n_tokens] = append_tokens[i];
batch.pos [batch.n_tokens] = slot.n_past;
batch.seq_id[batch.n_tokens] = slot.id;
batch.logits[batch.n_tokens] = false;
slot.n_past += 1;
batch.n_tokens += 1;
}
}
return true;
}
#endif
bool updateSlots() { bool updateSlots() {
// update the system prompt wait until all slots are idle state // update the system prompt wait until all slots are idle state
@ -901,31 +993,6 @@ struct llama_server_context
slot.state = PROCESSING; slot.state = PROCESSING;
slot.command = NONE; slot.command = NONE;
std::vector<llama_token> prompt_tokens; std::vector<llama_token> prompt_tokens;
#ifdef SERVER_MULTIMODAL_SUPPORT
bool ingest_image = false;
if(slot.request_encode_image) {
ingest_image = true;
clip_image_f32 img_res;
if (!clip_image_preprocess(clp_ctx, &slot.img_data, &img_res, /*pad2square =*/ true)) {
LOG_TEE("Error processing the given image");
clip_free(clp_ctx);
return false;
}
slot.image_tokens = clip_n_patches(clp_ctx);
slot.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx));
if (!slot.image_embedding) {
LOG_TEE("Unable to allocate memory for image embeddings\n");
clip_free(clp_ctx);
return false;
}
LOG_TEE("slot %i - encoding image\n", slot.id);
if (!clip_image_encode(clp_ctx, params.n_threads, &img_res, slot.image_embedding)) {
LOG_TEE("Unable to encode image\n");
return false;
}
slot.request_encode_image = false;
}
#endif
slot.t_start_process_prompt = ggml_time_us(); slot.t_start_process_prompt = ggml_time_us();
slot.t_start_genereration = 0; slot.t_start_genereration = 0;
@ -1003,59 +1070,21 @@ struct llama_server_context
}); });
#ifdef SERVER_MULTIMODAL_SUPPORT #ifdef SERVER_MULTIMODAL_SUPPORT
std::vector<llama_token> preffix_tokens = ingest_image ? tokenize(slot.params.input_prefix, true) : prompt_tokens; bool ingest_images = processImages(slot); // has images?
for (; slot.n_past < preffix_tokens.size(); ++slot.n_past) {
batch.token [batch.n_tokens] = preffix_tokens[slot.n_past]; // process the prefix of first image
std::vector<llama_token> prefix_tokens = ingest_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens;
for (; slot.n_past < prefix_tokens.size(); ++slot.n_past) {
batch.token [batch.n_tokens] = prefix_tokens[slot.n_past];
batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system; batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system;
batch.seq_id[batch.n_tokens] = slot.id; batch.seq_id[batch.n_tokens] = slot.id;
batch.logits[batch.n_tokens] = false; batch.logits[batch.n_tokens] = false;
batch.n_tokens += 1; batch.n_tokens += 1;
} }
if(ingest_image) { if(ingest_images && !ingestImages(slot, n_batch)) {
// process preffix prompt LOG_TEE("failed processing images\n");
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) { return false;
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
llama_batch batch_view = {
n_tokens,
batch.token + i,
nullptr,
batch.pos + i,
batch.seq_id + i,
batch.logits + i,
0, 0, 0, // unused
};
if (llama_decode(ctx, batch_view)) {
LOG_TEE("%s : failed to eval\n", __func__);
return false;
}
}
// process image
for (int i = 0; i < slot.image_tokens; i += n_batch) {
int n_eval = slot.image_tokens - i;
if (n_eval > n_batch) {
n_eval = n_batch;
}
llama_batch batch = {int32_t(n_eval), nullptr, (slot.image_embedding + i * n_embd), nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
if (llama_decode(ctx, batch)) {
LOG_TEE("%s : failed to eval image\n", __func__);
return false;
}
slot.n_past += n_eval;
}
// process suffix prompt
batch.n_tokens = 0;
std::vector<llama_token> suffix_tokens = tokenize(slot.params.input_suffix, true);
for (int i = 0; i < suffix_tokens.size(); ++i) {
batch.token [batch.n_tokens] = suffix_tokens[i];
batch.pos [batch.n_tokens] = slot.n_past;
batch.seq_id[batch.n_tokens] = slot.id;
batch.logits[batch.n_tokens] = false;
slot.n_past += 1;
batch.n_tokens += 1;
}
} }
#else #else
for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) { for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) {
@ -1111,7 +1140,6 @@ struct llama_server_context
} }
for (auto & slot : slots) { for (auto & slot : slots) {
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) { if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
continue; continue;
} }
@ -1728,34 +1756,72 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
return; return;
} }
std::string data_b64 = json_value(body, "image_data", std::string("")); const auto &images_data = body.find("image_data");
if(!data_b64.empty()) { if (images_data != body.end() && images_data->is_array())
if(!slot->prompt.is_array()) { {
std::string prompt = slot->prompt.get<std::string>(); slot->images.clear();
int pos = prompt.find("[(img)]"); for (const auto &img : *images_data)
if(pos == std::string::npos) { {
LOG_TEE("Missing image position in prompt\n"); slot_image img_sl;
std::string data_b64 = img["data"].get<std::string>();
img_sl.id = img.count("id") != 0 ? img["id"].get<int>() : slot->images.size();
int width, height, channels;
std::vector<uint8_t> image_buffer = base64_decode(data_b64);
data_b64.clear();
auto data = stbi_load_from_memory(image_buffer.data(), image_buffer.size(), &width, &height, &channels, 3);
if(!data) {
LOG_TEE("slot %i - failed to load image\n", slot->id);
return; return;
} else {
// reuse infill prompt input
slot->params.input_prefix = prompt.substr(0, pos);
slot->params.input_suffix = prompt.substr(pos + 7); // ignore [(img)]
slot->params.cache_prompt = false; // multimodal doesn't support cache prompt
} }
LOG_TEE("slot %i - RGB image %i loaded (%i x %i)\n", slot->id, img_sl.id, width, height);
img_sl.img_data.nx = width;
img_sl.img_data.ny = height;
img_sl.img_data.size = width * height * 3;
img_sl.img_data.data = new uint8_t[width * height * 3]();
memcpy(img_sl.img_data.data, data, width * height * 3);
stbi_image_free(data);
img_sl.request_encode_image = true;
slot->images.push_back(img_sl);
}
// process prompt
// example: system prompt <img-102> user <img-103> describe <img-134> -> [{id: 102, prefix: 'system prompt '}, {id: 103, prefix: ' user '}, {id: 134, prefix: ' describe '}]}
if(slot->images.size() > 0 && !slot->prompt.is_array()) {
std::string prompt = slot->prompt.get<std::string>();
size_t pos = 0, begin_prefix = 0;
std::string pattern = "[img-";
while ((pos = prompt.find(pattern, pos)) != std::string::npos) {
size_t end_prefix = pos;
pos += pattern.length();
size_t end_pos = prompt.find("]", pos);
if (end_pos != std::string::npos) {
std::string image_id = prompt.substr(pos, end_pos - pos);
try {
int img_id = std::stoi(image_id);
bool found = false;
for(slot_image &img : slot->images) {
if(img.id == img_id) {
found = true;
img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix);
begin_prefix = end_pos + 1;
break;
}
}
if(!found) {
LOG_TEE("ERROR: Image with id %i not found.\n", img_id);
slot->images.clear();
return;
}
} catch (const std::invalid_argument& e) {
LOG_TEE("Invalid image number id in prompt\n");
slot->images.clear();
return;
}
}
}
slot->prompt = "";
slot->params.input_suffix = prompt.substr(begin_prefix);
slot->params.cache_prompt = false; // multimodal doesn't support cache prompt
} }
int width, height, channels;
std::vector<uint8_t> image_buffer = base64_decode(data_b64);
data_b64.clear();
// decode base64
auto data = stbi_load_from_memory(image_buffer.data(), image_buffer.size(), &width, &height, &channels, 3);
slot->img_data.nx = width;
slot->img_data.ny = height;
slot->img_data.size = width * height * 3;
slot->img_data.data = new uint8_t[slot->img_data.size]();
memcpy(slot->img_data.data, data, slot->img_data.size);
stbi_image_free(data);
LOG_TEE("slot %i - RGB image loaded (%i x %i)\n", slot->id, width, height);
slot->request_encode_image = true;
} }
#endif #endif
} }