initialize opt ggml context if none was provided

This commit is contained in:
xaedes 2023-09-01 15:41:57 +02:00
parent 4914f855c7
commit d554a70f11
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

44
ggml.c
View file

@ -19606,13 +19606,31 @@ GGML_API void ggml_opt_init(
opt->iter = 0; opt->iter = 0;
opt->nx = nx; opt->nx = nx;
opt->just_initialized = true; opt->just_initialized = true;
if (opt->ctx == NULL) {
struct ggml_init_params ctx_opt_params;
if (opt->params.type == GGML_OPT_ADAM) {
ctx_opt_params.mem_size = GGML_MEM_ALIGN*2 + ggml_tensor_overhead()*2 + ggml_type_size(GGML_TYPE_F32)*nx*2;
if (opt->params.past > 0) {
ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
}
} else if (opt->params.type == GGML_OPT_LBFGS) {
ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2);
if (opt->params.past > 0) {
ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
}
}
ctx_opt_params.mem_buffer = NULL;
ctx_opt_params.no_alloc = false;
opt->ctx = ggml_init(ctx_opt_params);
}
switch (opt->params.type) { switch (opt->params.type) {
case GGML_OPT_ADAM: case GGML_OPT_ADAM:
{ {
opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->adam.pf = params.past > 0 opt->adam.pf = params.past > 0
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past) ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
: NULL; : NULL;
ggml_set_zero(opt->adam.m); ggml_set_zero(opt->adam.m);
ggml_set_zero(opt->adam.v); ggml_set_zero(opt->adam.v);
@ -19622,18 +19640,18 @@ GGML_API void ggml_opt_init(
} break; } break;
case GGML_OPT_LBFGS: case GGML_OPT_LBFGS:
{ {
opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx); opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->lbfgs.pf = params.past > 0 opt->lbfgs.pf = params.past > 0
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past) ? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
: NULL; : NULL;
opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m); opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m); opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m); opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m); opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
ggml_set_zero(opt->lbfgs.x); ggml_set_zero(opt->lbfgs.x);
ggml_set_zero(opt->lbfgs.xp); ggml_set_zero(opt->lbfgs.xp);
ggml_set_zero(opt->lbfgs.g); ggml_set_zero(opt->lbfgs.g);