ggml: new optimization interface (ggml/988)

This commit is contained in:
Johannes Gäßler 2024-11-16 22:17:59 +02:00 committed by Georgi Gerganov
parent 5c9a8b22b1
commit 8a43e940ab
15 changed files with 2663 additions and 1633 deletions

View file

@ -12216,11 +12216,16 @@ static void ggml_compute_forward_opt_step_adamw_f32(
const struct ggml_compute_params * params,
struct ggml_tensor * dst) {
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src0_grad = dst->src[1];
const struct ggml_tensor * src0_grad_m = dst->src[2];
const struct ggml_tensor * src0_grad_v = dst->src[3];
const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src0_grad = dst->src[1];
const struct ggml_tensor * src0_grad_m = dst->src[2];
const struct ggml_tensor * src0_grad_v = dst->src[3];
const struct ggml_tensor * adamw_params = dst->src[4];
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m));
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v));
GGML_ASSERT(ggml_nelements(adamw_params) == 7);
const int ith = params->ith;
const int nth = params->nth;
@ -12237,16 +12242,14 @@ static void ggml_compute_forward_opt_step_adamw_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
/* const float gnorm = 1.0f; */
int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
const float alpha = ggml_get_op_params_f32(dst, 2);
const float beta1 = ggml_get_op_params_f32(dst, 3);
const float beta2 = ggml_get_op_params_f32(dst, 4);
const float eps = ggml_get_op_params_f32(dst, 5);
const float wd = ggml_get_op_params_f32(dst, 6);
const float beta1h = alpha/(1.0f - powf(beta1, iter));
const float beta2h = 1.0f/(1.0f - powf(beta2, iter));
const float * adamw_params_ptr = ggml_get_data_f32(adamw_params);
const float alpha = adamw_params_ptr[0];
const float beta1 = adamw_params_ptr[1];
const float beta2 = adamw_params_ptr[2];
const float eps = adamw_params_ptr[3];
const float wd = adamw_params_ptr[4];
const float beta1h = adamw_params_ptr[5];
const float beta2h = adamw_params_ptr[6];
for (int ir = ir0; ir < ir1; ++ir) {
const int64_t i03 = ir/(ne02*ne01);
@ -12270,17 +12273,9 @@ static void ggml_compute_forward_opt_step_adamw_f32(
// The weight decay is applied independently of the Adam momenta m and v.
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
// See: https://arxiv.org/pdf/1711.05101v3.pdf
w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh;
w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh;
}
}
ggml_barrier(params->threadpool);
if (ith != 0) {
return;
}
iter++;
memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
}
static void ggml_compute_forward_opt_step_adamw(