Working! Thanks to @nullhook

This commit is contained in:
strikingLoo 2023-03-21 18:32:51 -07:00
parent d2b1d3a439
commit 76dde26844

View file

@ -759,13 +759,10 @@ bool llama_eval(
// capture input sentence embedding // capture input sentence embedding
ggml_build_forward_expand(&gf, inpL); ggml_build_forward_expand(&gf, inpL);
ggml_graph_compute (ctx0, &gf); ggml_graph_compute (ctx0, &gf);
printf("Compute went ok\n");
std::vector<float> embedding_representation; std::vector<float> embedding_representation;
embedding_representation.resize(n_embd); embedding_representation.resize(n_embd);
memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 1)), sizeof(float) * n_embd); memcpy(embedding_representation.data(), (float *) ggml_get_data(inpL) + (n_embd * (N - 1)), sizeof(float) * n_embd);
printf("About to display\n");
display_embedding(embedding_representation); display_embedding(embedding_representation);
printf("About to free\n");
ggml_free(ctx0); ggml_free(ctx0);
return true; return true;
} }
@ -943,13 +940,14 @@ int main(int argc, char ** argv) {
} }
if (params.embedding){ if (params.embedding){
printf("got right before second call.\n"); embd = embd_inp;
const int64_t t_start_us = ggml_time_us(); //HERE if (embd.size() > 0) {
if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, true)) { const int64_t t_start_us = ggml_time_us();
fprintf(stderr, "Failed to predict\n"); if (!llama_eval(model, params.n_threads, n_past, embd, logits, mem_per_token, true)) {
return 1; fprintf(stderr, "Failed to predict\n");
return 1;
}
} }
//ggml_free(model.ctx);
if (params.use_color) { if (params.use_color) {
printf(ANSI_COLOR_RESET); printf(ANSI_COLOR_RESET);