initialize opt ggml context if none was provided
This commit is contained in:
parent
4914f855c7
commit
d554a70f11
1 changed files with 31 additions and 13 deletions
44
ggml.c
44
ggml.c
|
@ -19606,13 +19606,31 @@ GGML_API void ggml_opt_init(
|
|||
opt->iter = 0;
|
||||
opt->nx = nx;
|
||||
opt->just_initialized = true;
|
||||
if (opt->ctx == NULL) {
|
||||
struct ggml_init_params ctx_opt_params;
|
||||
if (opt->params.type == GGML_OPT_ADAM) {
|
||||
ctx_opt_params.mem_size = GGML_MEM_ALIGN*2 + ggml_tensor_overhead()*2 + ggml_type_size(GGML_TYPE_F32)*nx*2;
|
||||
if (opt->params.past > 0) {
|
||||
ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
|
||||
}
|
||||
} else if (opt->params.type == GGML_OPT_LBFGS) {
|
||||
ctx_opt_params.mem_size = GGML_MEM_ALIGN*9 + ggml_tensor_overhead()*9 + ggml_type_size(GGML_TYPE_F32)*(nx*5 + opt->params.lbfgs.m*2 + nx*opt->params.lbfgs.m*2);
|
||||
if (opt->params.past > 0) {
|
||||
ctx_opt_params.mem_size += GGML_MEM_ALIGN + ggml_tensor_overhead() + ggml_type_size(GGML_TYPE_F32)*opt->params.past;
|
||||
}
|
||||
}
|
||||
ctx_opt_params.mem_buffer = NULL;
|
||||
ctx_opt_params.no_alloc = false;
|
||||
|
||||
opt->ctx = ggml_init(ctx_opt_params);
|
||||
}
|
||||
switch (opt->params.type) {
|
||||
case GGML_OPT_ADAM:
|
||||
{
|
||||
opt->adam.m = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
||||
opt->adam.v = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
||||
opt->adam.m = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
||||
opt->adam.v = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
||||
opt->adam.pf = params.past > 0
|
||||
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
||||
? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
|
||||
: NULL;
|
||||
ggml_set_zero(opt->adam.m);
|
||||
ggml_set_zero(opt->adam.v);
|
||||
|
@ -19622,18 +19640,18 @@ GGML_API void ggml_opt_init(
|
|||
} break;
|
||||
case GGML_OPT_LBFGS:
|
||||
{
|
||||
opt->lbfgs.x = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
||||
opt->lbfgs.xp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
||||
opt->lbfgs.g = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
||||
opt->lbfgs.gp = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
||||
opt->lbfgs.d = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, nx);
|
||||
opt->lbfgs.x = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
||||
opt->lbfgs.xp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
||||
opt->lbfgs.g = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
||||
opt->lbfgs.gp = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
||||
opt->lbfgs.d = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, nx);
|
||||
opt->lbfgs.pf = params.past > 0
|
||||
? ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.past)
|
||||
? ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.past)
|
||||
: NULL;
|
||||
opt->lbfgs.lmal = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
||||
opt->lbfgs.lmys = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, params.lbfgs.m);
|
||||
opt->lbfgs.lms = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
||||
opt->lbfgs.lmy = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
||||
opt->lbfgs.lmal = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
|
||||
opt->lbfgs.lmys = ggml_new_tensor_1d(opt->ctx, GGML_TYPE_F32, params.lbfgs.m);
|
||||
opt->lbfgs.lms = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
||||
opt->lbfgs.lmy = ggml_new_tensor_2d(opt->ctx, GGML_TYPE_F32, nx, params.lbfgs.m);
|
||||
ggml_set_zero(opt->lbfgs.x);
|
||||
ggml_set_zero(opt->lbfgs.xp);
|
||||
ggml_set_zero(opt->lbfgs.g);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue