fixed MPT ooms
This commit is contained in:
parent
8bd9a3a48b
commit
c1b293d31a
1 changed files with 9 additions and 11 deletions
|
@ -120,21 +120,20 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo
|
||||||
ctx_size += n_layer * (4 * n_embd * n_embd * ggml_type_sizef(wtype)); // mlp_mlp_up_weight
|
ctx_size += n_layer * (4 * n_embd * n_embd * ggml_type_sizef(wtype)); // mlp_mlp_up_weight
|
||||||
ctx_size += n_layer * (n_embd * n_embd * 4 * ggml_type_sizef(wtype)); // mlp_mlp_down_weight
|
ctx_size += n_layer * (n_embd * n_embd * 4 * ggml_type_sizef(wtype)); // mlp_mlp_down_weight
|
||||||
|
|
||||||
ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_k
|
ctx_size += (n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16)); // memory_k
|
||||||
ctx_size += n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16); // memory_v
|
ctx_size += (n_ctx * n_layer * n_embd * ggml_type_sizef(GGML_TYPE_F16)); // memory_v
|
||||||
|
|
||||||
ctx_size += (1 + 6 * n_layer) * 512; // object overhead
|
ctx_size += (6 + 6 * n_layer) * 512; // object overhead
|
||||||
|
|
||||||
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size / (1024.0 * 1024.0));
|
printf("%s: ggml ctx size = %6.2f MB\n", __func__, ctx_size / (1024.0 * 1024.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
// create the ggml context
|
// create the ggml context
|
||||||
{
|
{
|
||||||
struct ggml_init_params params = {
|
struct ggml_init_params params;
|
||||||
.mem_size = ctx_size,
|
params.mem_size = ctx_size;
|
||||||
.mem_buffer = NULL,
|
params.mem_buffer = NULL;
|
||||||
.no_alloc = false,
|
params.no_alloc = false;
|
||||||
};
|
|
||||||
|
|
||||||
model.ctx = ggml_init(params);
|
model.ctx = ggml_init(params);
|
||||||
if (!model.ctx) {
|
if (!model.ctx) {
|
||||||
|
@ -307,17 +306,16 @@ bool mpt_eval(const mpt_model & model, const int n_threads, const int n_past,
|
||||||
|
|
||||||
// use 2 scratch buffers
|
// use 2 scratch buffers
|
||||||
// TODO: very hacky solution - reimplement in a more elegant way
|
// TODO: very hacky solution - reimplement in a more elegant way
|
||||||
static size_t scr0_size = 256u*1024*1024;
|
static size_t scr0_size = (n_ctx>2048?1024u:512u)*1024*1024;
|
||||||
static void * scr0 = malloc(scr0_size);
|
static void * scr0 = malloc(scr0_size);
|
||||||
|
|
||||||
static size_t scr1_size = 256u*1024*1024;
|
static size_t scr1_size = (n_ctx>2048?1024u:512u)*1024*1024;
|
||||||
static void * scr1 = malloc(scr1_size);
|
static void * scr1 = malloc(scr1_size);
|
||||||
|
|
||||||
if (mem_per_token > 0 && mem_per_token * N > buf_size) {
|
if (mem_per_token > 0 && mem_per_token * N > buf_size) {
|
||||||
const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead
|
const size_t buf_size_new = 1.1 * (mem_per_token * N); // add 10% to account for ggml object overhead
|
||||||
// printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__,
|
// printf("\n%s: reallocating buffer from %zu to %zu bytes\n", __func__,
|
||||||
// buf_size, buf_size_new);
|
// buf_size, buf_size_new);
|
||||||
|
|
||||||
// reallocate
|
// reallocate
|
||||||
buf_size = buf_size_new;
|
buf_size = buf_size_new;
|
||||||
buf = realloc(buf, buf_size);
|
buf = realloc(buf, buf_size);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue