This commit is contained in:
Georgi Gerganov 2024-05-17 14:00:44 +03:00
parent 9c4fdcbec8
commit 6b2f496409
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 108 additions and 1 deletions

View file

@ -159,6 +159,8 @@ int main(int argc, char ** argv) {
std::vector<float> embeddings(n_prompts * n_embd, 0);
float * emb = embeddings.data();
auto t_start = ggml_time_us();
// break into batches
int p = 0; // number of prompts processed already
int s = 0; // number of prompts in current batch
@ -169,7 +171,8 @@ int main(int argc, char ** argv) {
const uint64_t n_toks = inp.size();
// encode if at capacity
if (batch.n_tokens + n_toks > n_batch) {
//if (batch.n_tokens + n_toks > n_batch) {
if (true) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
llama_batch_clear(batch);
@ -186,6 +189,10 @@ int main(int argc, char ** argv) {
float * out = emb + p * n_embd;
batch_decode(ctx, batch, out, s, n_embd);
auto t_end = ggml_time_us();
printf("time per embedding: %.3f ms\n", (t_end - t_start) / 1000.0 / n_prompts);
// print the first part of the embeddings or for a single prompt, the full embedding
fprintf(stdout, "\n");
for (int j = 0; j < n_prompts; j++) {