From 578a38234c409a27f35990ceb2b5f8d3e1e7e6d0 Mon Sep 17 00:00:00 2001 From: Yann Follet Date: Thu, 23 May 2024 01:42:10 +0000 Subject: [PATCH] fix json generation, use " not ' --- examples/embedding/embedding.cpp | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index bbe33377c..82b567327 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -204,11 +204,16 @@ int main(int argc, char ** argv) { if (n_prompts > 1) { fprintf(stdout, "\n"); printf("cosine similarity matrix:\n\n"); + for (int i = 0; i < n_prompts; i++) { + fprintf(stdout, "%6.6s ", prompts[i].c_str()); + } + fprintf(stdout, "\n"); for (int i = 0; i < n_prompts; i++) { for (int j = 0; j < n_prompts; j++) { float sim = llama_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); fprintf(stdout, "%6.2f ", sim); } + fprintf(stdout, "%1.10s",prompts[i].c_str()); fprintf(stdout, "\n"); } } @@ -217,9 +222,9 @@ int main(int argc, char ** argv) { if (params.embd_out=="json" || params.embd_out=="json+" || params.embd_out=="array") { const bool notArray = params.embd_out!="array"; - fprintf(stdout, notArray?"{\n 'object': 'list',\n 'data': [\n":"["); + fprintf(stdout, notArray?"{\n \"object\": \"list\",\n \"data\": [\n":"["); for (int j = 0;;) { // at least one iteration (one prompt) - if (notArray) fprintf(stdout, " {\n 'object': 'embedding',\n 'index': %d,\n 'embedding': ",j); + if (notArray) fprintf(stdout, " {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j); fprintf(stdout, "["); for (int i = 0;;) { // at least one iteration (n_embd > 0) fprintf(stdout, params.embd_normalize==0?"%1.0f":"%1.7f", emb[j * n_embd + i]); @@ -233,7 +238,7 @@ int main(int argc, char ** argv) { fprintf(stdout, notArray?"\n ]":"]\n"); if (params.embd_out=="json+" && n_prompts > 1) { - fprintf(stdout, ",\n cosineSimilarity: [\n"); + fprintf(stdout, ",\n \"cosineSimilarity\": [\n"); for (int i = 0;;) { // at least two iteration (n_prompts > 1) fprintf(stdout, " ["); for (int j = 0;;) { // at least two iteration (n_prompts > 1)