fixed MPT ooms
This commit is contained in:
parent
8bd9a3a48b
commit
c1b293d31a
1 changed files with 9 additions and 11 deletions
|
@ -120,21 +120,20 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
|
|||
ctx_size += n_layer * (4 * n_embd * n_embd * ggml_type_sizef(wtype)); // mlp_mlp_up_weight
|
||||
ctx_size += n_layer * (n_embd * n_embd * 4 * ggml_type_sizef(wtype)); // mlp_mlp_down_weight
|
||||
|
||||
ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_k
|
||||
ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_v
|
||||
ctx_size += (n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16)); // memory_k
|
||||
ctx_size += (n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16)); // memory_v
|
||||
|
||||
ctx_size += (1 + 6 * n_layer) * 512; // object overhead
|
||||
ctx_size += (6 + 6 * n_layer) * 512; // object overhead
|
||||
|
||||
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size / (1024.0 * 1024.0));
|
||||
}
|
||||
|
||||
// create the ggml context
|
||||
{
|
||||
struct ggml_init_params params = {
|
||||
.mem_size = ctx_size,
|
||||
.mem_buffer = NULL,
|
||||
.no_alloc = false,
|
||||
};
|
||||
struct ggml_init_params params;
|
||||
params.mem_size = ctx_size;
|
||||
params.mem_buffer = NULL;
|
||||
params.no_alloc = false;
|
||||
|
||||
model.ctx = ggml_init(params);
|
||||
if (!model.ctx) {
|
||||
|
@ -307,17 +306,16 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
|
|||
|
||||
// use 2 scratch buffers
|
||||
// TODO: very hacky solution - reimplement in a more elegant way
|
||||
static size_t scr0_size = 256u*1024*1024;
|
||||
static size_t scr0_size = (n_ctx>2048?1024u:512u)*1024*1024;
|
||||
static void * scr0 = malloc(scr0_size);
|
||||
|
||||
static size_t scr1_size = 256u*1024*1024;
|
||||
static size_t scr1_size = (n_ctx>2048?1024u:512u)*1024*1024;
|
||||
static void * scr1 = malloc(scr1_size);
|
||||
|
||||
if (mem_per_token > 0 && mem_per_token * N > buf_size) {
|
||||
const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead
|
||||
// printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__,
|
||||
// buf_size, buf_size_new);
|
||||
|
||||
// reallocate
|
||||
buf_size = buf_size_new;
|
||||
buf = realloc(buf, buf_size);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue