initialize opt ggml context if none was provided

This commit is contained in:
xaedes 2023-09-01 15:41:57 +02:00
parent 4914f855c7
commit d554a70f11
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

44
ggml.c
View file

@ -19606,13 +19606,31 @@ GGML_API void ggml_opt_init(
opt->iter = 0;
opt->nx = nx;
opt->just_initialized = true;
if (opt->ctx == NULL) {
struct ggml_init_params ctx_opt_params;
if (opt->params.type == GGML_OPT_ADAM) {
ctx_opt_params.mem_size = GGML_MEM_ALIGN*2 + ggml_tensor_overhead()*2 + ggml_type_size(GGML_TYPE_F32)*nx*2;
if (opt->params.past > 0) {
ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
}
} else if (opt->params.type == GGML_OPT_LBFGS) {
ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2);
if (opt->params.past > 0) {
ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
}
}
ctx_opt_params.mem_buffer = NULL;
ctx_opt_params.no_alloc = false;
opt->ctx = ggml_init(ctx_opt_params);
}
switch (opt->params.type) {
case GGML_OPT_ADAM:
{
opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->adam.pf = params.past > 0
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
: NULL;
ggml_set_zero(opt->adam.m);
ggml_set_zero(opt->adam.v);
@ -19622,18 +19640,18 @@ GGML_API void ggml_opt_init(
} break;
case GGML_OPT_LBFGS:
{
opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
opt->lbfgs.pf = params.past > 0
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
: NULL;
opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
ggml_set_zero(opt->lbfgs.x);
ggml_set_zero(opt->lbfgs.xp);
ggml_set_zero(opt->lbfgs.g);