llama : enable mmap in quantize on Linux -> 31% faster

This commit is contained in:
Cebtenzzre 2023-09-09 23:04:53 -04:00
parent e6616cf0db
commit 32bc3f4fcf

View file

@ -5658,7 +5658,18 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
nthread = std::thread::hardware_concurrency();
}
std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname_inp, /*use_mmap*/ false));
// mmap consistently increases speed Linux, is inconsistent on macOS
// (possibly related to free memory), and has not been tested on Windows.
#ifdef __linux__
constexpr bool use_mmap = true;
#else
constexpr bool use_mmap = false;
#endif
std::unique_ptr<llama_model_loader> ml(new llama_model_loader(fname_inp, use_mmap));
if (ml->use_mmap) {
ml->mapping.reset(new llama_mmap(&ml->file, /* prefetch */ 0, ggml_is_numa()));
}
llama_model model;
llm_load_arch(*ml, model);
@ -5736,10 +5747,12 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
const std::string name = ggml_get_name(tensor);
if (read_data.size() < ggml_nbytes(tensor)) {
read_data.resize(ggml_nbytes(tensor));
if (!ml->use_mmap) {
if (read_data.size() < ggml_nbytes(tensor)) {
read_data.resize(ggml_nbytes(tensor));
}
tensor->data = read_data.data();
}
tensor->data = read_data.data();
ml->load_data_for(tensor);
LLAMA_LOG_INFO("[%4d/%4d] %36s - [%s], type = %6s, ",