llama : support batched embeddings (#5466)
* batched embedding: pool outputs by sequence id. updated embedding example * bring back non-causal attention * embd : minor improvements * llama : minor --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
This commit is contained in:
parent
ad014bba97
commit
03bf161eb6
6 changed files with 163 additions and 54 deletions
|
@ -7,6 +7,51 @@
|
|||
#pragma warning(disable: 4244 4267) // possible loss of data
|
||||
#endif
|
||||
|
||||
static std::vector<std::string> split_lines(const std::string & s) {
|
||||
std::string line;
|
||||
std::vector<std::string> lines;
|
||||
std::stringstream ss(s);
|
||||
while (std::getline(ss, line)) {
|
||||
lines.push_back(line);
|
||||
}
|
||||
return lines;
|
||||
}
|
||||
|
||||
static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
|
||||
for (size_t i = 0; i < tokens.size(); i++) {
|
||||
llama_batch_add(batch, tokens[i], i, { seq_id }, false);
|
||||
}
|
||||
}
|
||||
|
||||
static void normalize(float * vec, float * out, int n) {
|
||||
float norm = 0;
|
||||
for (int i = 0; i < n; i++) {
|
||||
norm += vec[i] * vec[i];
|
||||
}
|
||||
norm = sqrt(norm);
|
||||
for (int i = 0; i < n; i++) {
|
||||
out[i] = vec[i] / norm;
|
||||
}
|
||||
}
|
||||
|
||||
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd) {
|
||||
// clear previous kv_cache values (irrelevant for embeddings)
|
||||
llama_kv_cache_clear(ctx);
|
||||
|
||||
// run model
|
||||
fprintf(stderr, "%s: n_tokens = %d, n_seq = %d\n", __func__, batch.n_tokens, n_seq);
|
||||
if (llama_decode(ctx, batch) < 0) {
|
||||
fprintf(stderr, "%s : failed to decode\n", __func__);
|
||||
}
|
||||
|
||||
// normalize on copy
|
||||
for (int k = 0; k < n_seq; k++) {
|
||||
float * emb = llama_get_embeddings_ith(ctx, k);
|
||||
float * out = output + k * n_embd;
|
||||
normalize(emb, out, n_embd);
|
||||
}
|
||||
}
|
||||
|
||||
int main(int argc, char ** argv) {
|
||||
gpt_params params;
|
||||
|
||||
|
@ -55,59 +100,84 @@ int main(int argc, char ** argv) {
|
|||
fprintf(stderr, "%s\n", get_system_info(params).c_str());
|
||||
}
|
||||
|
||||
int n_past = 0;
|
||||
// split the prompt into lines
|
||||
std::vector<std::string> prompts = split_lines(params.prompt);
|
||||
|
||||
// tokenize the prompt
|
||||
auto embd_inp = ::llama_tokenize(ctx, params.prompt, true);
|
||||
// max batch size
|
||||
const uint64_t n_batch = params.n_batch;
|
||||
GGML_ASSERT(params.n_batch == params.n_ctx);
|
||||
|
||||
// tokenize the prompts and trim
|
||||
std::vector<std::vector<int32_t>> inputs;
|
||||
for (const auto & prompt : prompts) {
|
||||
auto inp = ::llama_tokenize(ctx, prompt, true);
|
||||
if (inp.size() > n_batch) {
|
||||
inp.resize(n_batch);
|
||||
}
|
||||
inputs.push_back(inp);
|
||||
}
|
||||
|
||||
// tokenization stats
|
||||
if (params.verbose_prompt) {
|
||||
fprintf(stderr, "\n");
|
||||
fprintf(stderr, "%s: prompt: '%s'\n", __func__, params.prompt.c_str());
|
||||
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, embd_inp.size());
|
||||
for (int i = 0; i < (int) embd_inp.size(); i++) {
|
||||
fprintf(stderr, "%6d -> '%s'\n", embd_inp[i], llama_token_to_piece(ctx, embd_inp[i]).c_str());
|
||||
for (int i = 0; i < (int) inputs.size(); i++) {
|
||||
fprintf(stderr, "%s: prompt %d: '%s'\n", __func__, i, prompts[i].c_str());
|
||||
fprintf(stderr, "%s: number of tokens in prompt = %zu\n", __func__, inputs[i].size());
|
||||
for (int j = 0; j < (int) inputs[i].size(); j++) {
|
||||
fprintf(stderr, "%6d -> '%s'\n", inputs[i][j], llama_token_to_piece(ctx, inputs[i][j]).c_str());
|
||||
}
|
||||
fprintf(stderr, "\n\n");
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
if (embd_inp.size() > (size_t)n_ctx) {
|
||||
fprintf(stderr, "%s: error: prompt is longer than the context window (%zu tokens, n_ctx = %d)\n",
|
||||
__func__, embd_inp.size(), n_ctx);
|
||||
return 1;
|
||||
}
|
||||
|
||||
while (!embd_inp.empty()) {
|
||||
int n_tokens = std::min(params.n_batch, (int) embd_inp.size());
|
||||
if (llama_decode(ctx, llama_batch_get_one(embd_inp.data(), n_tokens, n_past, 0))) {
|
||||
fprintf(stderr, "%s : failed to eval\n", __func__);
|
||||
return 1;
|
||||
}
|
||||
n_past += n_tokens;
|
||||
embd_inp.erase(embd_inp.begin(), embd_inp.begin() + n_tokens);
|
||||
}
|
||||
// initialize batch
|
||||
const int n_prompts = prompts.size();
|
||||
struct llama_batch batch = llama_batch_init(n_batch, 0, n_prompts);
|
||||
|
||||
// allocate output
|
||||
const int n_embd = llama_n_embd(model);
|
||||
auto * embeddings = llama_get_embeddings(ctx);
|
||||
std::vector<float> embeddings(n_prompts * n_embd, 0);
|
||||
float * emb = embeddings.data();
|
||||
|
||||
// l2-normalize embeddings
|
||||
float norm = 0;
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
norm += embeddings[i] * embeddings[i];
|
||||
}
|
||||
norm = sqrt(norm);
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
embeddings[i] /= norm;
|
||||
// break into batches
|
||||
int p = 0; // number of prompts processed already
|
||||
int s = 0; // number of prompts in current batch
|
||||
for (int k = 0; k < n_prompts; k++) {
|
||||
// clamp to n_batch tokens
|
||||
auto & inp = inputs[k];
|
||||
const uint64_t n_toks = inp.size();
|
||||
|
||||
// encode if at capacity
|
||||
if (batch.n_tokens + n_toks > n_batch) {
|
||||
float * out = emb + p * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd);
|
||||
llama_batch_clear(batch);
|
||||
p += s;
|
||||
s = 0;
|
||||
}
|
||||
|
||||
// add to batch
|
||||
batch_add_seq(batch, inp, s);
|
||||
s += 1;
|
||||
}
|
||||
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
printf("%f ", embeddings[i]);
|
||||
}
|
||||
printf("\n");
|
||||
// final batch
|
||||
float * out = emb + p * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd);
|
||||
|
||||
// print first 3 embeddings
|
||||
for (int j = 0; j < std::min(3, n_prompts); j++) {
|
||||
fprintf(stderr, "embedding %d: ", j);
|
||||
for (int i = 0; i < n_embd; i++) {
|
||||
fprintf(stderr, "%f ", emb[j * n_embd + i]);
|
||||
}
|
||||
fprintf(stderr, "\n\n");
|
||||
}
|
||||
fprintf(stderr, "\n");
|
||||
|
||||
// clean up
|
||||
llama_print_timings(ctx);
|
||||
llama_free(ctx);
|
||||
llama_free_model(model);
|
||||
|
||||
llama_backend_free();
|
||||
|
||||
return 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue