fixed add end condition for generating

This commit is contained in:
ningshanwutuobang 2023-06-07 23:44:54 +08:00
parent ba1f617d7d
commit 6ed4893391
2 changed files with 22 additions and 14 deletions

View file

@ -33,16 +33,20 @@ class MyModel:
s = libc.sampling(self.model) s = libc.sampling(self.model)
return s return s
model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"]) if __name__ == "__main__":
# print(model) model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"])
model.eval_string("""user: what is the color of the flag of UN?""") # print(model)
# model.eval_token(100) model.eval_string("""user: what is the color of the flag of UN?""")
x = np.random.random((10, 5120))# , dtype=np.float32) # model.eval_token(100)
model.eval_float(x) x = np.random.random((10, 5120))# , dtype=np.float32)
model.eval_string("""assistant:""") model.eval_float(x)
# print(x[0,0], x[0,1],x[1,0]) model.eval_string("""assistant:""")
# model.eval_float(x) # print(x[0,0], x[0,1],x[1,0])
# print(libc) # model.eval_float(x)
# print(libc)
for i in range(50): for i in range(500):
print(model.sampling().decode(), end="", flush=True) tmp = model.sampling().decode()
if tmp == "":
break
print(tmp, end="", flush=True)

View file

@ -1,6 +1,7 @@
#include "embd_input.h" #include "embd_input.h"
#include <stdlib.h> #include <stdlib.h>
#include <random> #include <random>
#include <string.h>
int main(int argc, char** argv) { int main(int argc, char** argv) {
@ -20,9 +21,12 @@ int main(int argc, char** argv) {
eval_string(mymodel, "assistant:"); eval_string(mymodel, "assistant:");
// printf("eval float end\n"); // printf("eval float end\n");
eval_string(mymodel, mymodel->params.prompt.c_str()); eval_string(mymodel, mymodel->params.prompt.c_str());
for (int i=0;i < 50; i++) { const char* tmp;
for (int i=0;i < 500; i++) {
// int id = sampling_id(mymodel); // int id = sampling_id(mymodel);
printf("%s", sampling(mymodel)); // llama_token_to_str(mymodel->ctx, id)); tmp = sampling(mymodel);
if (strlen(tmp) == 0) break;
printf("%s", tmp); // llama_token_to_str(mymodel->ctx, id));
fflush(stdout); fflush(stdout);
// eval_id(mymodel, id); // eval_id(mymodel, id);
} }