wip
This commit is contained in:
parent
9c4fdcbec8
commit
6b2f496409
2 changed files with 108 additions and 1 deletions
|
@ -159,6 +159,8 @@ int main(int argc, char ** argv) {
|
|||
std::vector<float> embeddings(n_prompts * n_embd, 0);
|
||||
float * emb = embeddings.data();
|
||||
|
||||
auto t_start = ggml_time_us();
|
||||
|
||||
// break into batches
|
||||
int p = 0; // number of prompts processed already
|
||||
int s = 0; // number of prompts in current batch
|
||||
|
@ -169,7 +171,8 @@ int main(int argc, char ** argv) {
|
|||
const uint64_t n_toks = inp.size();
|
||||
|
||||
// encode if at capacity
|
||||
if (batch.n_tokens + n_toks > n_batch) {
|
||||
//if (batch.n_tokens + n_toks > n_batch) {
|
||||
if (true) {
|
||||
float * out = emb + p * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd);
|
||||
llama_batch_clear(batch);
|
||||
|
@ -186,6 +189,10 @@ int main(int argc, char ** argv) {
|
|||
float * out = emb + p * n_embd;
|
||||
batch_decode(ctx, batch, out, s, n_embd);
|
||||
|
||||
auto t_end = ggml_time_us();
|
||||
|
||||
printf("time per embedding: %.3f ms\n", (t_end - t_start) / 1000.0 / n_prompts);
|
||||
|
||||
// print the first part of the embeddings or for a single prompt, the full embedding
|
||||
fprintf(stdout, "\n");
|
||||
for (int j = 0; j < n_prompts; j++) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue