defaulting to f32 kv, and 4 threads seem to produce better results

This commit is contained in:
Concedo 2023-03-25 11:11:40 +08:00
parent 506cd62638
commit 119392f6f2
3 changed files with 5 additions and 2 deletions

View file

@ -32,6 +32,7 @@ extern "C" {
const int threads; const int threads;
const int max_context_length; const int max_context_length;
const int batch_size; const int batch_size;
const bool f16_kv;
const char * model_filename; const char * model_filename;
const int n_parts_overwrite = -1; const int n_parts_overwrite = -1;
}; };
@ -75,7 +76,7 @@ extern "C" {
ctx_params.n_ctx = inputs.max_context_length; ctx_params.n_ctx = inputs.max_context_length;
ctx_params.n_parts = inputs.n_parts_overwrite; ctx_params.n_parts = inputs.n_parts_overwrite;
ctx_params.seed = -1; ctx_params.seed = -1;
ctx_params.f16_kv = true; ctx_params.f16_kv = inputs.f16_kv;
ctx_params.logits_all = false; ctx_params.logits_all = false;
ctx = llama_init_from_file(model.c_str(), ctx_params); ctx = llama_init_from_file(model.c_str(), ctx_params);

View file

@ -12,6 +12,7 @@ class load_model_inputs(ctypes.Structure):
_fields_ = [("threads", ctypes.c_int), _fields_ = [("threads", ctypes.c_int),
("max_context_length", ctypes.c_int), ("max_context_length", ctypes.c_int),
("batch_size", ctypes.c_int), ("batch_size", ctypes.c_int),
("f16_kv", ctypes.c_bool),
("model_filename", ctypes.c_char_p), ("model_filename", ctypes.c_char_p),
("n_parts_overwrite", ctypes.c_int)] ("n_parts_overwrite", ctypes.c_int)]
@ -43,8 +44,9 @@ def load_model(model_filename,batch_size=8,max_context_length=512,n_parts_overwr
inputs.model_filename = model_filename.encode("UTF-8") inputs.model_filename = model_filename.encode("UTF-8")
inputs.batch_size = batch_size inputs.batch_size = batch_size
inputs.max_context_length = max_context_length #initial value to use for ctx, can be overwritten inputs.max_context_length = max_context_length #initial value to use for ctx, can be overwritten
inputs.threads = os.cpu_count() inputs.threads = 4 #seems to outperform os.cpu_count(), it's memory bottlenecked
inputs.n_parts_overwrite = n_parts_overwrite inputs.n_parts_overwrite = n_parts_overwrite
inputs.f16_kv = False
ret = handle.load_model(inputs) ret = handle.load_model(inputs)
return ret return ret

Binary file not shown.