From 6ed4893391ec3135e59b8e7c726fd4847fbd1a18 Mon Sep 17 00:00:00 2001 From: ningshanwutuobang Date: Wed, 7 Jun 2023 23:44:54 +0800 Subject: [PATCH] fixed add end condition for generating --- examples/embd_input/embd_input.py | 28 ++++++++++++++----------- examples/embd_input/embd_input_test.cpp | 8 +++++-- 2 files changed, 22 insertions(+), 14 deletions(-) diff --git a/examples/embd_input/embd_input.py b/examples/embd_input/embd_input.py index d4831d46a..ebce1bb45 100644 --- a/examples/embd_input/embd_input.py +++ b/examples/embd_input/embd_input.py @@ -33,16 +33,20 @@ class MyModel: s = libc.sampling(self.model) return s -model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"]) -# print(model) -model.eval_string("""user: what is the color of the flag of UN?""") -# model.eval_token(100) -x = np.random.random((10, 5120))# , dtype=np.float32) -model.eval_float(x) -model.eval_string("""assistant:""") -# print(x[0,0], x[0,1],x[1,0]) -# model.eval_float(x) -# print(libc) +if __name__ == "__main__": + model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"]) + # print(model) + model.eval_string("""user: what is the color of the flag of UN?""") + # model.eval_token(100) + x = np.random.random((10, 5120))# , dtype=np.float32) + model.eval_float(x) + model.eval_string("""assistant:""") + # print(x[0,0], x[0,1],x[1,0]) + # model.eval_float(x) + # print(libc) -for i in range(50): - print(model.sampling().decode(), end="", flush=True) + for i in range(500): + tmp = model.sampling().decode() + if tmp == "": + break + print(tmp, end="", flush=True) diff --git a/examples/embd_input/embd_input_test.cpp b/examples/embd_input/embd_input_test.cpp index 94287e37f..d83febeb2 100644 --- a/examples/embd_input/embd_input_test.cpp +++ b/examples/embd_input/embd_input_test.cpp @@ -1,6 +1,7 @@ #include "embd_input.h" #include #include +#include int main(int argc, char** argv) { @@ -20,9 +21,12 @@ int main(int argc, char** argv) { eval_string(mymodel, "assistant:"); // printf("eval float end\n"); eval_string(mymodel, mymodel->params.prompt.c_str()); - for (int i=0;i < 50; i++) { + const char* tmp; + for (int i=0;i < 500; i++) { // int id = sampling_id(mymodel); - printf("%s", sampling(mymodel)); // llama_token_to_str(mymodel->ctx, id)); + tmp = sampling(mymodel); + if (strlen(tmp) == 0) break; + printf("%s", tmp); // llama_token_to_str(mymodel->ctx, id)); fflush(stdout); // eval_id(mymodel, id); }