Add mmap support for model files

This commit is contained in:
Slaren 2023-03-29 02:03:43 +02:00 committed by Justine Tunney
parent 3bcc129ba8
commit c03ae8dca1
3 changed files with 74 additions and 22 deletions

View file

@ -12,6 +12,13 @@
#include <cassert>
#include <cstring>
// headers for POSIX mmap
#if defined (__unix__) || defined (__APPLE__)
# include <sys/mman.h>
# include <fcntl.h>
# include <unistd.h>
#endif
#define LLAMA_USE_SCRATCH
#define LLAMA_MAX_SCRATCH_BUFFERS 16
@ -246,6 +253,7 @@ static bool kv_cache_init(
struct ggml_init_params params;
params.mem_size = cache.buf.size();
params.mem_buffer = cache.buf.data();
params.no_alloc = false;
cache.ctx = ggml_init(params);
@ -288,6 +296,26 @@ struct llama_context_params llama_context_default_params() {
// model loading
//
void * mmap_file(const char* fname) {
#if defined(MAP_FAILED)
// POSIX mmap
int fd = open(fname, O_RDONLY);
size_t len = lseek(fd, 0, SEEK_END);
void * mm_addr = mmap(NULL, len, PROT_READ, MAP_SHARED, fd, 0);
if (mm_addr == MAP_FAILED) {
perror("mmap failed");
mm_addr = NULL;
}
close(fd);
return mm_addr;
#else
// TODO: windows support
(void)(fname); // suppress warnings
return NULL;
#endif
}
static bool llama_model_load(
const std::string & fname,
llama_context & lctx,
@ -303,6 +331,7 @@ static bool llama_model_load(
lctx.t_start_us = t_start_us;
// TODO: this could probably be smaller when using mmap
std::vector<char> f_buf(1024*1024);
auto & model = lctx.model;
@ -449,39 +478,49 @@ static bool llama_model_load(
}
}
bool use_mmap = (n_parts == 1);
// try to memory map the model file
void* mm_addr = NULL;
if (use_mmap) {
mm_addr = mmap_file(fname.c_str());
if (mm_addr == NULL) {
use_mmap = false;
}
}
auto & ctx = model.ctx;
size_t ctx_size = 0;
{
const auto & hparams = model.hparams;
const int n_embd = hparams.n_embd;
const int n_layer = hparams.n_layer;
const int n_ctx = hparams.n_ctx;
const int n_vocab = hparams.n_vocab;
ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // tok_embeddings
if (!use_mmap) {
ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // tok_embeddings
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm
ctx_size += n_embd*ggml_type_sizef(GGML_TYPE_F32); // norm
ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // output
ctx_size += n_embd*n_vocab*ggml_type_sizef(vtype); // output
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // attention_norm
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wq
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wk
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wv
ctx_size += n_layer*(n_embd*n_embd*ggml_type_sizef(wtype)); // wo
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm
ctx_size += n_layer*(n_embd*ggml_type_sizef(GGML_TYPE_F32)); // ffn_norm
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_k
ctx_size += n_ctx*n_layer*n_embd*ggml_type_sizef(memory_type); // memory_v
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w1
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w2
ctx_size += n_layer*(n_ff*n_embd*ggml_type_sizef(wtype)); // w3
}
ctx_size += (5 + 10*n_layer)*256; // object overhead
@ -514,6 +553,7 @@ static bool llama_model_load(
struct ggml_init_params params = {
/*.mem_size =*/ lctx.model.buf.size(),
/*.mem_buffer =*/ lctx.model.buf.data(),
/*.no_alloc =*/ use_mmap,
};
model.ctx = ggml_init(params);
@ -595,7 +635,7 @@ static bool llama_model_load(
fname_part += "." + std::to_string(i);
}
fprintf(stderr, "%s: loading model part %d/%d from '%s'\n", __func__, i+1, n_parts, fname_part.c_str());
fprintf(stderr, "%s: loading model part %d/%d from '%s'%s\n", __func__, i+1, n_parts, fname_part.c_str(), use_mmap ? " (memory mapped)" : "");
fin = std::ifstream(fname_part, std::ios::binary);
fin.rdbuf()->pubsetbuf(f_buf.data(), f_buf.size());
@ -736,7 +776,14 @@ static bool llama_model_load(
}
if (part_id == 0) {
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
if (mm_addr) {
off_t offset = fin.tellg();
tensor->data = (char *) mm_addr + offset;
fin.seekg(ggml_nbytes(tensor), std::ios::cur);
}
else {
fin.read(reinterpret_cast<char *>(tensor->data), ggml_nbytes(tensor));
}
} else {
fin.seekg(ggml_nbytes(tensor), std::ios::cur);
}
@ -849,6 +896,7 @@ static bool llama_eval_internal(
struct ggml_init_params params = {
/*.mem_size =*/ buf_compute.size(),
/*.mem_buffer =*/ buf_compute.data(),
/*.no_alloc =*/ false,
};
struct ggml_context * ctx0 = ggml_init(params);