Update llama_model_load() from master branch
This commit is contained in:
parent
5195fed013
commit
1c545e51ed
1 changed files with 6 additions and 2 deletions
|
@ -128,7 +128,8 @@ struct llama_context
|
|||
|
||||
/* Original code by @ggerganov */
|
||||
// load the model's weights from a file
|
||||
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx) {
|
||||
bool llama_model_load(const std::string & fname, llama_model & model, gpt_vocab & vocab, int n_ctx, ggml_type memory_type = GGML_TYPE_F32) {
|
||||
|
||||
fprintf(stderr, "%s: loading model from '%s' - please wait ...\n", __func__, fname.c_str());
|
||||
|
||||
std::vector<char> f_buf(1024*1024);
|
||||
|
@ -1071,9 +1072,12 @@ llama_context* llama_init_from_params(const gpt_params& params) {
|
|||
llama_model model{};
|
||||
gpt_vocab vocab{};
|
||||
|
||||
const ggml_type memory_type = params.memory_f16 ? GGML_TYPE_F16 : GGML_TYPE_F32;
|
||||
|
||||
// Compute time taken to load model
|
||||
const int64_t t_start = ggml_time_us();
|
||||
bool ret = llama_model_load(params.model, model, vocab, 1024);
|
||||
|
||||
bool ret = llama_model_load(params.model, model, vocab, params.n_ctx, memory_type);
|
||||
const int64_t t_end = ggml_time_us();
|
||||
if(!ret)
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue