fixed MPT ooms

This commit is contained in:
Concedo 2023-06-03 18:37:13 +08:00
parent 8bd9a3a48b
commit c1b293d31a

View file

@ -120,21 +120,20 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
ctx_size += n_layer * (4 * n_embd * n_embd * ggml_type_sizef(wtype)); // mlp_mlp_up_weight ctx_size += n_layer * (4 * n_embd * n_embd * ggml_type_sizef(wtype)); // mlp_mlp_up_weight
ctx_size += n_layer * (n_embd * n_embd * 4 * ggml_type_sizef(wtype)); // mlp_mlp_down_weight ctx_size += n_layer * (n_embd * n_embd * 4 * ggml_type_sizef(wtype)); // mlp_mlp_down_weight
ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_k ctx_size += (n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16)); // memory_k
ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_v ctx_size += (n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16)); // memory_v
ctx_size += (1 + 6 * n_layer) * 512; // object overhead ctx_size += (6 + 6 * n_layer) * 512; // object overhead
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size / (1024.0 * 1024.0)); printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size / (1024.0 * 1024.0));
} }
// create the ggml context // create the ggml context
{ {
struct ggml_init_params params = { struct ggml_init_params params;
.mem_size = ctx_size, params.mem_size = ctx_size;
.mem_buffer = NULL, params.mem_buffer = NULL;
.no_alloc = false, params.no_alloc = false;
};
model.ctx = ggml_init(params); model.ctx = ggml_init(params);
if (!model.ctx) { if (!model.ctx) {
@ -307,17 +306,16 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
// use 2 scratch buffers // use 2 scratch buffers
// TODO: very hacky solution - reimplement in a more elegant way // TODO: very hacky solution - reimplement in a more elegant way
static size_t scr0_size = 256u*1024*1024; static size_t scr0_size = (n_ctx>2048?1024u:512u)*1024*1024;
static void * scr0 = malloc(scr0_size); static void * scr0 = malloc(scr0_size);
static size_t scr1_size = 256u*1024*1024; static size_t scr1_size = (n_ctx>2048?1024u:512u)*1024*1024;
static void * scr1 = malloc(scr1_size); static void * scr1 = malloc(scr1_size);
if (mem_per_token > 0 && mem_per_token * N > buf_size) { if (mem_per_token > 0 && mem_per_token * N > buf_size) {
const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead
// printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__, // printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__,
// buf_size, buf_size_new); // buf_size, buf_size_new);
// reallocate // reallocate
buf_size = buf_size_new; buf_size = buf_size_new;
buf = realloc(buf, buf_size); buf = realloc(buf, buf_size);