add multimodal input - alfa
This commit is contained in:
parent
de35b47908
commit
9f72b44635
3 changed files with 2442 additions and 2249 deletions
File diff suppressed because it is too large
Load diff
|
@ -194,6 +194,7 @@
|
|||
|
||||
import { llama } from '/completion.js';
|
||||
import { SchemaConverter } from '/json-schema-to-grammar.mjs';
|
||||
let selected_image = false;
|
||||
|
||||
const session = signal({
|
||||
prompt: "This is a conversation between User and Llama, a friendly chatbot. Llama is helpful, kind, honest, good at writing, and never fails to answer any requests immediately and with precision.",
|
||||
|
@ -222,6 +223,7 @@
|
|||
grammar: '',
|
||||
n_probs: 0, // no completion_probabilities,
|
||||
slot_id: -1,
|
||||
image_data: [],
|
||||
cache_prompt: true
|
||||
})
|
||||
|
||||
|
@ -424,7 +426,7 @@
|
|||
|
||||
transcriptUpdate([...session.value.transcript, ["{{user}}", msg]])
|
||||
|
||||
const prompt = template(session.value.template, {
|
||||
let prompt = template(session.value.template, {
|
||||
message: msg,
|
||||
history: session.value.transcript.flatMap(
|
||||
([name, data]) =>
|
||||
|
@ -439,7 +441,9 @@
|
|||
)
|
||||
).join("\n"),
|
||||
});
|
||||
|
||||
if(selected_image) {
|
||||
prompt = `A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER: [img-10]${msg}\nASSISTANT:`;
|
||||
}
|
||||
await runLlama(prompt, {
|
||||
...params.value,
|
||||
stop: ["</s>", template("{{char}}:"), template("{{user}}:")],
|
||||
|
@ -472,6 +476,24 @@
|
|||
transcriptUpdate([]);
|
||||
}
|
||||
|
||||
const uploadImage = (e) => {
|
||||
e.preventDefault();
|
||||
document.getElementById("fileInput").click();
|
||||
document.getElementById("fileInput").addEventListener("change", function (event) {
|
||||
const selectedFile = event.target.files[0];
|
||||
if (selectedFile) {
|
||||
const reader = new FileReader();
|
||||
reader.onload = function () {
|
||||
const image_data = reader.result;
|
||||
params.value = {...params.value, image_data: [
|
||||
{ data: image_data.replace('data:image/png;base64,', ''), id: 10 }] }
|
||||
};
|
||||
selected_image = true;
|
||||
reader.readAsDataURL(selectedFile);
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
function MessageInput() {
|
||||
const message = useSignal("")
|
||||
|
||||
|
@ -502,6 +524,7 @@
|
|||
</div>
|
||||
<div class="right">
|
||||
<button type="submit" disabled=${generating.value}>Send</button>
|
||||
<button onclick=${uploadImage} style="margin-left: 10px;margin-right: 10px;">Upload Image</button>
|
||||
<button onclick=${stop} disabled=${!generating.value}>Stop</button>
|
||||
<button onclick=${reset}>Reset</button>
|
||||
</div>
|
||||
|
@ -957,7 +980,8 @@
|
|||
</head>
|
||||
|
||||
<body>
|
||||
<div id="container"></div>
|
||||
<div id="container">
|
||||
<input type="file" id="fileInput" accept="image/*" style="display: none;"></div>
|
||||
<div id="portal"></div>
|
||||
</body>
|
||||
|
||||
|
|
|
@ -74,7 +74,6 @@ static const std::string base64_chars =
|
|||
"abcdefghijklmnopqrstuvwxyz"
|
||||
"0123456789+/";
|
||||
|
||||
|
||||
static inline bool is_base64(uint8_t c) {
|
||||
return (isalnum(c) || (c == '+') || (c == '/'));
|
||||
}
|
||||
|
@ -273,6 +272,17 @@ static json probs_vector_to_json(const llama_context *ctx, const std::vector<com
|
|||
return out;
|
||||
}
|
||||
|
||||
#ifdef SERVER_MULTIMODAL_SUPPORT
|
||||
struct slot_image {
|
||||
clip_image_u8 img_data;
|
||||
bool request_encode_image = false;
|
||||
float* image_embedding = nullptr;
|
||||
int image_tokens = 0;
|
||||
int id;
|
||||
std::string prefix_prompt = ""; // before of this image
|
||||
};
|
||||
#endif
|
||||
|
||||
struct llama_client_slot
|
||||
{
|
||||
int id;
|
||||
|
@ -317,10 +327,7 @@ struct llama_client_slot
|
|||
llama_grammar *grammar = nullptr;
|
||||
|
||||
#ifdef SERVER_MULTIMODAL_SUPPORT
|
||||
clip_image_u8 img_data;
|
||||
bool request_encode_image = false;
|
||||
float* image_embedding = nullptr;
|
||||
int image_tokens = 0;
|
||||
std::vector<slot_image> images;
|
||||
#endif
|
||||
|
||||
void reset() {
|
||||
|
@ -806,6 +813,91 @@ struct llama_server_context
|
|||
});
|
||||
return has_next_token; // continue
|
||||
}
|
||||
#ifdef SERVER_MULTIMODAL_SUPPORT
|
||||
bool processImages(llama_client_slot &slot) {
|
||||
for(slot_image &img : slot.images) {
|
||||
if(!img.request_encode_image) {
|
||||
continue;
|
||||
}
|
||||
clip_image_f32 img_res;
|
||||
if (!clip_image_preprocess(clp_ctx, &img.img_data, &img_res, /*pad2square =*/ true)) {
|
||||
LOG_TEE("Error processing the given image");
|
||||
clip_free(clp_ctx);
|
||||
return false;
|
||||
}
|
||||
img.image_tokens = clip_n_patches(clp_ctx);
|
||||
img.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx));
|
||||
if (!img.image_embedding) {
|
||||
LOG_TEE("Unable to allocate memory for image embeddings\n");
|
||||
clip_free(clp_ctx);
|
||||
return false;
|
||||
}
|
||||
LOG_TEE("slot %i - encoding image %i\n", slot.id, img.id);
|
||||
if (!clip_image_encode(clp_ctx, params.n_threads, &img_res, img.image_embedding)) {
|
||||
LOG_TEE("Unable to encode image\n");
|
||||
return false;
|
||||
}
|
||||
img.request_encode_image = false;
|
||||
}
|
||||
return slot.images.size() > 0;
|
||||
}
|
||||
|
||||
// for multiple images processing
|
||||
bool ingestImages(llama_client_slot &slot, int n_batch) {
|
||||
int image_idx = 0;
|
||||
while(image_idx < slot.images.size()) {
|
||||
slot_image img = slot.images[image_idx];
|
||||
|
||||
// process prefix prompt
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
if (llama_decode(ctx, batch_view)) {
|
||||
LOG_TEE("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// process image with llm
|
||||
for (int i = 0; i < img.image_tokens; i += n_batch) {
|
||||
int n_eval = img.image_tokens - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch batch = {int32_t(n_eval), nullptr, (img.image_embedding + i * n_embd), nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_TEE("%s : failed to eval image\n", __func__);
|
||||
return false;
|
||||
}
|
||||
slot.n_past += n_eval;
|
||||
}
|
||||
image_idx++;
|
||||
|
||||
// append prefix of next image
|
||||
batch.n_tokens = 0;
|
||||
std::vector<llama_token> append_tokens = tokenize(
|
||||
image_idx >= slot.images.size() ? slot.params.input_suffix : // no more images, then process suffix prompt
|
||||
slot.images[image_idx].prefix_prompt, true); // has next image
|
||||
for (int i = 0; i < append_tokens.size(); ++i) {
|
||||
batch.token [batch.n_tokens] = append_tokens[i];
|
||||
batch.pos [batch.n_tokens] = slot.n_past;
|
||||
batch.seq_id[batch.n_tokens] = slot.id;
|
||||
batch.logits[batch.n_tokens] = false;
|
||||
slot.n_past += 1;
|
||||
batch.n_tokens += 1;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
#endif
|
||||
|
||||
bool updateSlots() {
|
||||
// update the system prompt wait until all slots are idle state
|
||||
|
@ -901,31 +993,6 @@ struct llama_server_context
|
|||
slot.state = PROCESSING;
|
||||
slot.command = NONE;
|
||||
std::vector<llama_token> prompt_tokens;
|
||||
#ifdef SERVER_MULTIMODAL_SUPPORT
|
||||
bool ingest_image = false;
|
||||
if(slot.request_encode_image) {
|
||||
ingest_image = true;
|
||||
clip_image_f32 img_res;
|
||||
if (!clip_image_preprocess(clp_ctx, &slot.img_data, &img_res, /*pad2square =*/ true)) {
|
||||
LOG_TEE("Error processing the given image");
|
||||
clip_free(clp_ctx);
|
||||
return false;
|
||||
}
|
||||
slot.image_tokens = clip_n_patches(clp_ctx);
|
||||
slot.image_embedding = (float *)malloc(clip_embd_nbytes(clp_ctx));
|
||||
if (!slot.image_embedding) {
|
||||
LOG_TEE("Unable to allocate memory for image embeddings\n");
|
||||
clip_free(clp_ctx);
|
||||
return false;
|
||||
}
|
||||
LOG_TEE("slot %i - encoding image\n", slot.id);
|
||||
if (!clip_image_encode(clp_ctx, params.n_threads, &img_res, slot.image_embedding)) {
|
||||
LOG_TEE("Unable to encode image\n");
|
||||
return false;
|
||||
}
|
||||
slot.request_encode_image = false;
|
||||
}
|
||||
#endif
|
||||
slot.t_start_process_prompt = ggml_time_us();
|
||||
slot.t_start_genereration = 0;
|
||||
|
||||
|
@ -1003,59 +1070,21 @@ struct llama_server_context
|
|||
});
|
||||
|
||||
#ifdef SERVER_MULTIMODAL_SUPPORT
|
||||
std::vector<llama_token> preffix_tokens = ingest_image ? tokenize(slot.params.input_prefix, true) : prompt_tokens;
|
||||
for (; slot.n_past < preffix_tokens.size(); ++slot.n_past) {
|
||||
batch.token [batch.n_tokens] = preffix_tokens[slot.n_past];
|
||||
bool ingest_images = processImages(slot); // has images?
|
||||
|
||||
// process the prefix of first image
|
||||
std::vector<llama_token> prefix_tokens = ingest_images ? tokenize(slot.images[0].prefix_prompt, true) : prompt_tokens;
|
||||
for (; slot.n_past < prefix_tokens.size(); ++slot.n_past) {
|
||||
batch.token [batch.n_tokens] = prefix_tokens[slot.n_past];
|
||||
batch.pos [batch.n_tokens] = slot.n_past + num_tokens_system;
|
||||
batch.seq_id[batch.n_tokens] = slot.id;
|
||||
batch.logits[batch.n_tokens] = false;
|
||||
batch.n_tokens += 1;
|
||||
}
|
||||
|
||||
if(ingest_image) {
|
||||
// process preffix prompt
|
||||
for (int32_t i = 0; i < (int32_t) batch.n_tokens; i += n_batch) {
|
||||
const int32_t n_tokens = std::min(n_batch, (int32_t) (batch.n_tokens - i));
|
||||
llama_batch batch_view = {
|
||||
n_tokens,
|
||||
batch.token + i,
|
||||
nullptr,
|
||||
batch.pos + i,
|
||||
batch.seq_id + i,
|
||||
batch.logits + i,
|
||||
0, 0, 0, // unused
|
||||
};
|
||||
if (llama_decode(ctx, batch_view)) {
|
||||
LOG_TEE("%s : failed to eval\n", __func__);
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// process image
|
||||
for (int i = 0; i < slot.image_tokens; i += n_batch) {
|
||||
int n_eval = slot.image_tokens - i;
|
||||
if (n_eval > n_batch) {
|
||||
n_eval = n_batch;
|
||||
}
|
||||
llama_batch batch = {int32_t(n_eval), nullptr, (slot.image_embedding + i * n_embd), nullptr, nullptr, nullptr, slot.n_past, 1, 0, };
|
||||
if (llama_decode(ctx, batch)) {
|
||||
LOG_TEE("%s : failed to eval image\n", __func__);
|
||||
return false;
|
||||
}
|
||||
slot.n_past += n_eval;
|
||||
}
|
||||
|
||||
// process suffix prompt
|
||||
batch.n_tokens = 0;
|
||||
std::vector<llama_token> suffix_tokens = tokenize(slot.params.input_suffix, true);
|
||||
for (int i = 0; i < suffix_tokens.size(); ++i) {
|
||||
batch.token [batch.n_tokens] = suffix_tokens[i];
|
||||
batch.pos [batch.n_tokens] = slot.n_past;
|
||||
batch.seq_id[batch.n_tokens] = slot.id;
|
||||
batch.logits[batch.n_tokens] = false;
|
||||
slot.n_past += 1;
|
||||
batch.n_tokens += 1;
|
||||
}
|
||||
if(ingest_images && !ingestImages(slot, n_batch)) {
|
||||
LOG_TEE("failed processing images\n");
|
||||
return false;
|
||||
}
|
||||
#else
|
||||
for (; slot.n_past < prompt_tokens.size(); ++slot.n_past) {
|
||||
|
@ -1111,7 +1140,6 @@ struct llama_server_context
|
|||
}
|
||||
|
||||
for (auto & slot : slots) {
|
||||
|
||||
if (slot.i_batch < (int) i || slot.i_batch >= (int) (i + n_tokens)) {
|
||||
continue;
|
||||
}
|
||||
|
@ -1728,34 +1756,72 @@ static void parse_options_completion(const json &body, llama_client_slot* slot,
|
|||
return;
|
||||
}
|
||||
|
||||
std::string data_b64 = json_value(body, "image_data", std::string(""));
|
||||
if(!data_b64.empty()) {
|
||||
if(!slot->prompt.is_array()) {
|
||||
std::string prompt = slot->prompt.get<std::string>();
|
||||
int pos = prompt.find("[(img)]");
|
||||
if(pos == std::string::npos) {
|
||||
LOG_TEE("Missing image position in prompt\n");
|
||||
const auto &images_data = body.find("image_data");
|
||||
if (images_data != body.end() && images_data->is_array())
|
||||
{
|
||||
slot->images.clear();
|
||||
for (const auto &img : *images_data)
|
||||
{
|
||||
slot_image img_sl;
|
||||
std::string data_b64 = img["data"].get<std::string>();
|
||||
img_sl.id = img.count("id") != 0 ? img["id"].get<int>() : slot->images.size();
|
||||
int width, height, channels;
|
||||
std::vector<uint8_t> image_buffer = base64_decode(data_b64);
|
||||
data_b64.clear();
|
||||
auto data = stbi_load_from_memory(image_buffer.data(), image_buffer.size(), &width, &height, &channels, 3);
|
||||
if(!data) {
|
||||
LOG_TEE("slot %i - failed to load image\n", slot->id);
|
||||
return;
|
||||
} else {
|
||||
// reuse infill prompt input
|
||||
slot->params.input_prefix = prompt.substr(0, pos);
|
||||
slot->params.input_suffix = prompt.substr(pos + 7); // ignore [(img)]
|
||||
slot->params.cache_prompt = false; // multimodal doesn't support cache prompt
|
||||
}
|
||||
LOG_TEE("slot %i - RGB image %i loaded (%i x %i)\n", slot->id, img_sl.id, width, height);
|
||||
img_sl.img_data.nx = width;
|
||||
img_sl.img_data.ny = height;
|
||||
img_sl.img_data.size = width * height * 3;
|
||||
img_sl.img_data.data = new uint8_t[width * height * 3]();
|
||||
memcpy(img_sl.img_data.data, data, width * height * 3);
|
||||
stbi_image_free(data);
|
||||
img_sl.request_encode_image = true;
|
||||
slot->images.push_back(img_sl);
|
||||
}
|
||||
// process prompt
|
||||
// example: system prompt <img-102> user <img-103> describe <img-134> -> [{id: 102, prefix: 'system prompt '}, {id: 103, prefix: ' user '}, {id: 134, prefix: ' describe '}]}
|
||||
if(slot->images.size() > 0 && !slot->prompt.is_array()) {
|
||||
std::string prompt = slot->prompt.get<std::string>();
|
||||
size_t pos = 0, begin_prefix = 0;
|
||||
std::string pattern = "[img-";
|
||||
while ((pos = prompt.find(pattern, pos)) != std::string::npos) {
|
||||
size_t end_prefix = pos;
|
||||
pos += pattern.length();
|
||||
size_t end_pos = prompt.find("]", pos);
|
||||
if (end_pos != std::string::npos) {
|
||||
std::string image_id = prompt.substr(pos, end_pos - pos);
|
||||
try {
|
||||
int img_id = std::stoi(image_id);
|
||||
bool found = false;
|
||||
for(slot_image &img : slot->images) {
|
||||
if(img.id == img_id) {
|
||||
found = true;
|
||||
img.prefix_prompt = prompt.substr(begin_prefix, end_prefix - begin_prefix);
|
||||
begin_prefix = end_pos + 1;
|
||||
break;
|
||||
}
|
||||
}
|
||||
if(!found) {
|
||||
LOG_TEE("ERROR: Image with id %i not found.\n", img_id);
|
||||
slot->images.clear();
|
||||
return;
|
||||
}
|
||||
} catch (const std::invalid_argument& e) {
|
||||
LOG_TEE("Invalid image number id in prompt\n");
|
||||
slot->images.clear();
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
slot->prompt = "";
|
||||
slot->params.input_suffix = prompt.substr(begin_prefix);
|
||||
slot->params.cache_prompt = false; // multimodal doesn't support cache prompt
|
||||
}
|
||||
int width, height, channels;
|
||||
std::vector<uint8_t> image_buffer = base64_decode(data_b64);
|
||||
data_b64.clear();
|
||||
// decode base64
|
||||
auto data = stbi_load_from_memory(image_buffer.data(), image_buffer.size(), &width, &height, &channels, 3);
|
||||
slot->img_data.nx = width;
|
||||
slot->img_data.ny = height;
|
||||
slot->img_data.size = width * height * 3;
|
||||
slot->img_data.data = new uint8_t[slot->img_data.size]();
|
||||
memcpy(slot->img_data.data, data, slot->img_data.size);
|
||||
stbi_image_free(data);
|
||||
LOG_TEE("slot %i - RGB image loaded (%i x %i)\n", slot->id, width, height);
|
||||
slot->request_encode_image = true;
|
||||
}
|
||||
#endif
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue