fixed sampling

This commit is contained in:
ningshanwutuobang 2023-06-06 22:29:34 +08:00
parent a91487093b
commit 9c6117cd8d
4 changed files with 16 additions and 15 deletions

1
.gitignore vendored
View file

@ -34,6 +34,7 @@ models/*
/benchmark-matmult /benchmark-matmult
/vdot /vdot
/Pipfile /Pipfile
/embd_input_test
build-info.h build-info.h
arm_neon.h arm_neon.h

View file

@ -16,7 +16,7 @@ class MyModel:
c_str = [c_char_p(i.encode()) for i in args] c_str = [c_char_p(i.encode()) for i in args]
args_c = (c_char_p * argc)(*c_str) args_c = (c_char_p * argc)(*c_str)
self.model = c_void_p(libc.create_mymodel(argc, args_c)) self.model = c_void_p(libc.create_mymodel(argc, args_c))
print("self.model", self.model) # print("self.model", self.model)
def eval_float(self, x): def eval_float(self, x):
libc.eval_float(self.model, x.astype(np.float32).ctypes.data_as(POINTER(c_float)), x.shape[0]) libc.eval_float(self.model, x.astype(np.float32).ctypes.data_as(POINTER(c_float)), x.shape[0])
@ -31,17 +31,16 @@ class MyModel:
s = libc.sampling(self.model) s = libc.sampling(self.model)
return s return s
model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin", "-c", "2048"])
model = MyModel(["main", "--model", "../llama.cpp/models/ggml-vic13b-q4_1.bin"]) # print(model)
print(model) model.eval_string("""user: what is the color of the flag of UN?""")
model.eval_string("""There is a better way to deal with the formula, """)
# model.eval_token(100) # model.eval_token(100)
x = np.random.random((10,5120))# , dtype=np.float32) x = np.random.random((10, 5120))# , dtype=np.float32)
# print(x[0,0], x[0,1],x[1,0])
model.eval_float(x) model.eval_float(x)
print(libc) model.eval_string("""assistant:""")
# print(x[0,0], x[0,1],x[1,0])
for i in range(100): # model.eval_float(x)
print(model.sampling().decode(), end="") # print(libc)
for i in range(50):
print(model.sampling().decode(), end="", flush=True)

View file

@ -266,6 +266,7 @@ const char* sampling(struct MyModel* mymodel) {
llama_context* ctx = mymodel->ctx; llama_context* ctx = mymodel->ctx;
int id = sampling_id(mymodel); int id = sampling_id(mymodel);
std::string ret = llama_token_to_str(ctx, id); std::string ret = llama_token_to_str(ctx, id);
eval_id(mymodel, id);
return ret.c_str(); return ret.c_str();
} }

View file

@ -21,10 +21,10 @@ int main(int argc, char** argv) {
// printf("eval float end\n"); // printf("eval float end\n");
eval_string(mymodel, mymodel->params.prompt.c_str()); eval_string(mymodel, mymodel->params.prompt.c_str());
for (int i=0;i < 50; i++) { for (int i=0;i < 50; i++) {
int id = sampling_id(mymodel); // int id = sampling_id(mymodel);
printf("%s", llama_token_to_str(mymodel->ctx, id)); printf("%s", sampling(mymodel)); // llama_token_to_str(mymodel->ctx, id));
fflush(stdout); fflush(stdout);
eval_id(mymodel, id); // eval_id(mymodel, id);
} }
printf("\n"); printf("\n");
return 0; return 0;