ggml/examples: add backend support for numerical optimization (ggml/949)
* CUDA eval works * stochastic gradient descent op * Adam except decay * CUDA CROSS_ENTROPY_LOSS_BACK * CUDA mnist-fc training works * backend CLI arg * refactor gguf load * remove sched from opt_step_adam * implement l1 regularization (weight decay) * extra call to add optimizer * initialize gradients with ggml_graph_reset * gradient accumulation * increment iter per eval instead of epoch * adjust backend interfaces * fix ggml_graph_reset without backend * fix ggml graph export/import * fixup * rename * revert ggml_opt changes * more general CUDA repeat_back * update documentation, fix CNN * validation split * add clarifying comment * optimize PyTorch training * adjust buffer size, thread count * fix 0.0f validation split * Update examples/mnist/mnist-common.cpp Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * fix gradient accumulation * tensor flag for accumulators -> tensor hash set * Update include/ggml.h Co-authored-by: slaren <slarengh@gmail.com> * Update tests/test-backend-ops.cpp Co-authored-by: slaren <slarengh@gmail.com> * Update tests/test-backend-ops.cpp Co-authored-by: slaren <slarengh@gmail.com> * fix test prints * Update src/ggml-backend.c Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> * better CUDA support for noncontiguous out_prod * add comment --------- Co-authored-by: Georgi Gerganov <ggerganov@gmail.com> Co-authored-by: slaren <slarengh@gmail.com>
This commit is contained in:
parent
a6809c6a2e
commit
424c5d00a9
24 changed files with 883 additions and 129 deletions
424
ggml/src/ggml.c
424
ggml/src/ggml.c
|
@ -1,6 +1,7 @@
|
|||
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows
|
||||
#define _USE_MATH_DEFINES // For M_PI on MSVC
|
||||
|
||||
#include "ggml-backend.h"
|
||||
#include "ggml-impl.h"
|
||||
#include "ggml-cpu-impl.h"
|
||||
#include "ggml-quants.h"
|
||||
|
@ -2997,9 +2998,10 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
|||
|
||||
"CROSS_ENTROPY_LOSS",
|
||||
"CROSS_ENTROPY_LOSS_BACK",
|
||||
"OPT_STEP_ADAMW",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
|
||||
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
|
||||
|
||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||
"none",
|
||||
|
@ -3090,9 +3092,10 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
|||
|
||||
"cross_entropy_loss(x,y)",
|
||||
"cross_entropy_loss_back(x,y)",
|
||||
"adamw(x)",
|
||||
};
|
||||
|
||||
static_assert(GGML_OP_COUNT == 79, "GGML_OP_COUNT != 79");
|
||||
static_assert(GGML_OP_COUNT == 80, "GGML_OP_COUNT != 80");
|
||||
|
||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||
|
||||
|
@ -4094,7 +4097,11 @@ static void ggml_set_op_params_f32(struct ggml_tensor * tensor, uint32_t i, floa
|
|||
}
|
||||
|
||||
struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor) {
|
||||
memset(tensor->data, 0, ggml_nbytes(tensor));
|
||||
if (tensor->buffer) {
|
||||
ggml_backend_tensor_memset(tensor, 0, 0, ggml_nbytes(tensor));
|
||||
} else {
|
||||
memset(tensor->data, 0, ggml_nbytes(tensor));
|
||||
}
|
||||
return tensor;
|
||||
}
|
||||
|
||||
|
@ -8320,11 +8327,46 @@ struct ggml_tensor * ggml_cross_entropy_loss_back(
|
|||
return result;
|
||||
}
|
||||
|
||||
// opt_step_adamw
|
||||
|
||||
struct ggml_tensor * ggml_opt_step_adamw(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
float alpha,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float eps,
|
||||
float wd) {
|
||||
GGML_ASSERT(a->grad);
|
||||
GGML_ASSERT(alpha > 0.0f);
|
||||
GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f);
|
||||
GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f);
|
||||
GGML_ASSERT(eps >= 0.0f);
|
||||
GGML_ASSERT(wd >= 0.0f && wd <= 1.0f);
|
||||
|
||||
struct ggml_tensor * result = ggml_view_tensor(ctx, a);
|
||||
|
||||
result->op = GGML_OP_OPT_STEP_ADAMW;
|
||||
result->grad = NULL;
|
||||
result->src[0] = a;
|
||||
result->src[1] = a->grad;
|
||||
result->src[2] = ggml_dup_tensor(ctx, a->grad);
|
||||
result->src[3] = ggml_dup_tensor(ctx, a->grad);
|
||||
|
||||
const int64_t iter = 1;
|
||||
memcpy(&result->op_params[0], &iter, sizeof(int64_t));
|
||||
ggml_set_op_params_f32(result, 2, alpha);
|
||||
ggml_set_op_params_f32(result, 3, beta1);
|
||||
ggml_set_op_params_f32(result, 4, beta2);
|
||||
ggml_set_op_params_f32(result, 5, eps);
|
||||
ggml_set_op_params_f32(result, 6, wd);
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void ggml_set_param(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * tensor) {
|
||||
void ggml_set_param(struct ggml_context * ctx, struct ggml_tensor * tensor) {
|
||||
tensor->flags |= GGML_TENSOR_FLAG_PARAM;
|
||||
|
||||
GGML_ASSERT(tensor->grad == NULL);
|
||||
|
@ -8332,6 +8374,13 @@ void ggml_set_param(
|
|||
ggml_format_name(tensor->grad, "%s (grad)", tensor->name);
|
||||
}
|
||||
|
||||
void ggml_set_loss(struct ggml_tensor * tensor) {
|
||||
GGML_ASSERT(ggml_is_scalar(tensor));
|
||||
GGML_ASSERT(tensor->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(tensor->grad);
|
||||
tensor->flags |= GGML_TENSOR_FLAG_LOSS;
|
||||
}
|
||||
|
||||
// ggml_compute_forward_dup
|
||||
|
||||
static void ggml_compute_forward_dup_same_cont(
|
||||
|
@ -17406,7 +17455,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|||
const int64_t ir0 = dr*ith;
|
||||
const int64_t ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
float * d = (float *) opt0->data;
|
||||
const float d_by_nr = ((const float *) opt0->data)[0] / (float) nr;
|
||||
|
||||
for (int64_t i1 = ir0; i1 < ir1; i1++) {
|
||||
float * ds0 = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||
|
@ -17430,7 +17479,7 @@ static void ggml_compute_forward_cross_entropy_loss_back_f32(
|
|||
|
||||
// grad(src0) = (softmax(src0) - src1) * grad(cross_entropy_loss(src0, src1)) / nr
|
||||
ggml_vec_sub_f32(nc, ds0, ds0, s1);
|
||||
ggml_vec_scale_f32(nc, ds0, d[0] / (float) nr);
|
||||
ggml_vec_scale_f32(nc, ds0, d_by_nr);
|
||||
|
||||
#ifndef NDEBUG
|
||||
for (int i = 0; i < nc; ++i) {
|
||||
|
@ -17459,6 +17508,94 @@ static void ggml_compute_forward_cross_entropy_loss_back(
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_opt_step_adamw_f32(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
const struct ggml_tensor * src0_grad = dst->src[1];
|
||||
const struct ggml_tensor * src0_grad_m = dst->src[2];
|
||||
const struct ggml_tensor * src0_grad_v = dst->src[3];
|
||||
GGML_ASSERT(ggml_are_same_shape(src0, src0_grad));
|
||||
|
||||
const int ith = params->ith;
|
||||
const int nth = params->nth;
|
||||
|
||||
const int nr = ggml_nrows(src0);
|
||||
|
||||
GGML_TENSOR_UNARY_OP_LOCALS
|
||||
GGML_ASSERT(nb00 == sizeof(float));
|
||||
|
||||
// rows per thread
|
||||
const int dr = (nr + nth - 1)/nth;
|
||||
|
||||
// row range for this thread
|
||||
const int ir0 = dr*ith;
|
||||
const int ir1 = MIN(ir0 + dr, nr);
|
||||
|
||||
/* const float gnorm = 1.0f; */
|
||||
int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t));
|
||||
const float alpha = ggml_get_op_params_f32(dst, 2);
|
||||
const float beta1 = ggml_get_op_params_f32(dst, 3);
|
||||
const float beta2 = ggml_get_op_params_f32(dst, 4);
|
||||
const float eps = ggml_get_op_params_f32(dst, 5);
|
||||
const float wd = ggml_get_op_params_f32(dst, 6);
|
||||
|
||||
const float beta1h = alpha/(1.0f - powf(beta1, iter));
|
||||
const float beta2h = 1.0f/(1.0f - powf(beta2, iter));
|
||||
|
||||
for (int ir = ir0; ir < ir1; ++ir) {
|
||||
const int64_t i03 = ir/(ne02*ne01);
|
||||
const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
|
||||
const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
|
||||
|
||||
const size_t offset = i03*nb03 + i02*nb02 + i01*nb01;
|
||||
|
||||
float * w = (float *) ((char *) src0->data + offset); // weight
|
||||
const float * g = (const float *) ((const char *) src0_grad->data + offset); // grad
|
||||
float * m = (float *) ((char *) src0_grad_m->data + offset);
|
||||
float * v = (float *) ((char *) src0_grad_v->data + offset);
|
||||
|
||||
for (int i00 = 0; i00 < ne00; ++i00) {
|
||||
m[i00] = m[i00]*beta1 + g[i00]*(1.0f - beta1);
|
||||
v[i00] = v[i00]*beta2 + g[i00]*g[i00]*(1.0f - beta2);
|
||||
|
||||
const float mh = m[i00]*beta1h;
|
||||
const float vh = sqrtf(v[i00]*beta2h) + eps;
|
||||
|
||||
// The weight decay is applied independently of the Adam momenta m and v.
|
||||
// This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss.
|
||||
// See: https://arxiv.org/pdf/1711.05101v3.pdf
|
||||
w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh;
|
||||
}
|
||||
}
|
||||
|
||||
ggml_barrier(params->threadpool);
|
||||
if (ith != 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
iter++;
|
||||
memcpy(&dst->op_params[0], &iter, sizeof(int64_t));
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_opt_step_adamw(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst) {
|
||||
|
||||
const struct ggml_tensor * src0 = dst->src[0];
|
||||
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
ggml_compute_forward_opt_step_adamw_f32(params, dst);
|
||||
} break;
|
||||
default:
|
||||
{
|
||||
GGML_ABORT("fatal error");
|
||||
}
|
||||
}
|
||||
}
|
||||
/////////////////////////////////
|
||||
|
||||
static void ggml_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor) {
|
||||
|
@ -17804,6 +17941,11 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
ggml_compute_forward_cross_entropy_loss_back(params, tensor);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
{
|
||||
ggml_compute_forward_opt_step_adamw(params, tensor);
|
||||
}
|
||||
break;
|
||||
case GGML_OP_NONE:
|
||||
{
|
||||
// nop
|
||||
|
@ -17958,7 +18100,7 @@ void ggml_build_backward_gradient_checkpointing(
|
|||
struct ggml_tensor * * checkpoints,
|
||||
int n_checkpoints) {
|
||||
ggml_graph_cpy(gf, gb_tmp);
|
||||
ggml_build_backward_expand(ctx, gf, gb_tmp, true);
|
||||
ggml_build_backward_expand(ctx, gf, gb_tmp, false, true);
|
||||
|
||||
if (n_checkpoints <= 0) {
|
||||
ggml_graph_cpy(gb_tmp, gb);
|
||||
|
@ -17996,42 +18138,93 @@ void ggml_build_backward_gradient_checkpointing(
|
|||
ggml_hash_map_free(replacements);
|
||||
}
|
||||
|
||||
// functions to change gradients considering the case that input a might be initial gradient with zero value
|
||||
// utility functions to change gradients
|
||||
// if a is in acc_table, modify gradients in-place and mark result as gradient accumulator
|
||||
// else if a is in zero_table, replace a
|
||||
// else, just add/subtract/etc. the gradients
|
||||
|
||||
static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
|
||||
static struct ggml_tensor * ggml_add_or_set(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_hash_set * zero_table,
|
||||
struct ggml_hash_set * acc_table) {
|
||||
if (ggml_hash_contains(acc_table, a)) {
|
||||
struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true);
|
||||
const size_t insert_result = ggml_hash_insert(acc_table, ret);
|
||||
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
|
||||
return ret;
|
||||
}
|
||||
if (ggml_hash_contains(zero_table, a)) {
|
||||
return b;
|
||||
} else {
|
||||
return ggml_add_impl(ctx, a, b, false);
|
||||
}
|
||||
return ggml_add_impl(ctx, a, b, false);
|
||||
}
|
||||
|
||||
static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, struct ggml_hash_set * zero_table) {
|
||||
static struct ggml_tensor * ggml_acc_or_set(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
const size_t nb1,
|
||||
const size_t nb2,
|
||||
const size_t nb3,
|
||||
const size_t offset,
|
||||
struct ggml_hash_set * zero_table,
|
||||
struct ggml_hash_set * acc_table) {
|
||||
if (ggml_hash_contains(acc_table, a)) {
|
||||
struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true);
|
||||
const size_t insert_result = ggml_hash_insert(acc_table, ret);
|
||||
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
|
||||
return ret;
|
||||
}
|
||||
if (ggml_hash_contains(zero_table, a)) {
|
||||
struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f);
|
||||
struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN
|
||||
return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false);
|
||||
} else {
|
||||
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
|
||||
}
|
||||
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
|
||||
}
|
||||
|
||||
static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
|
||||
static struct ggml_tensor * ggml_add1_or_set(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_hash_set * zero_table,
|
||||
struct ggml_hash_set * acc_table) {
|
||||
if (ggml_hash_contains(acc_table, a)) {
|
||||
struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true);
|
||||
const size_t insert_result = ggml_hash_insert(acc_table, ret);
|
||||
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
|
||||
return ret;
|
||||
}
|
||||
if (ggml_hash_contains(zero_table, a)) {
|
||||
return ggml_repeat(ctx, b, a);
|
||||
} else {
|
||||
return ggml_add1_impl(ctx, a, b, false);
|
||||
}
|
||||
return ggml_add1_impl(ctx, a, b, false);
|
||||
}
|
||||
|
||||
static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, struct ggml_hash_set * zero_table) {
|
||||
static struct ggml_tensor * ggml_sub_or_set(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b,
|
||||
struct ggml_hash_set * zero_table,
|
||||
struct ggml_hash_set * acc_table) {
|
||||
if (ggml_hash_contains(acc_table, a)) {
|
||||
struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true);
|
||||
const size_t insert_result = ggml_hash_insert(acc_table, ret);
|
||||
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
|
||||
return ret;
|
||||
}
|
||||
if (ggml_hash_contains(zero_table, a)) {
|
||||
return ggml_neg(ctx, b);
|
||||
} else {
|
||||
return ggml_sub_impl(ctx, a, b, false);
|
||||
}
|
||||
return ggml_sub_impl(ctx, a, b, false);
|
||||
}
|
||||
|
||||
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table) {
|
||||
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table, struct ggml_hash_set * acc_table) {
|
||||
struct ggml_tensor * src0 = tensor->src[0];
|
||||
struct ggml_tensor * src1 = tensor->src[1];
|
||||
struct ggml_tensor * src2 = tensor->src[2];
|
||||
|
@ -18040,38 +18233,38 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
case GGML_OP_DUP:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
if (ggml_are_same_shape(src0, src1)) {
|
||||
src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
||||
src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
|
||||
} else {
|
||||
src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table);
|
||||
src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table);
|
||||
}
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ADD1:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
src1->grad = ggml_add_or_set(ctx,
|
||||
src1->grad,
|
||||
ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ACC:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
|
||||
|
@ -18093,16 +18286,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_reshape(ctx,
|
||||
ggml_cont(ctx, tensor_grad_view),
|
||||
src1->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SUB:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
||||
src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_MUL:
|
||||
|
@ -18112,14 +18305,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_mul(ctx, src1, tensor->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
src1->grad =
|
||||
ggml_add_or_set(ctx,
|
||||
src1->grad,
|
||||
ggml_mul(ctx, src0, tensor->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_DIV:
|
||||
|
@ -18129,7 +18322,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_div(ctx, tensor->grad, src1),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
src1->grad =
|
||||
|
@ -18138,7 +18331,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_mul(ctx,
|
||||
tensor->grad,
|
||||
ggml_div(ctx, tensor, src1)),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SQR:
|
||||
|
@ -18150,7 +18343,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_scale(ctx,
|
||||
ggml_mul(ctx, src0, tensor->grad),
|
||||
2.0f),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SQRT:
|
||||
|
@ -18164,7 +18357,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
tensor->grad,
|
||||
tensor),
|
||||
0.5f),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_LOG:
|
||||
|
@ -18176,7 +18369,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_div(ctx,
|
||||
tensor->grad,
|
||||
src0),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SIN:
|
||||
|
@ -18188,7 +18381,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_mul(ctx,
|
||||
tensor->grad,
|
||||
ggml_cos(ctx, src0)),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_COS:
|
||||
|
@ -18200,7 +18393,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_mul(ctx,
|
||||
tensor->grad,
|
||||
ggml_sin(ctx, src0)),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SUM:
|
||||
|
@ -18210,7 +18403,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_add1_or_set(ctx,
|
||||
src0->grad,
|
||||
tensor->grad,
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
|
@ -18222,7 +18415,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_repeat(ctx,
|
||||
tensor->grad,
|
||||
src0->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_MEAN:
|
||||
|
@ -18237,7 +18430,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_repeat_back(ctx, tensor->grad, src0->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
|
@ -18247,7 +18440,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_repeat(ctx, tensor->grad, src0->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CONCAT:
|
||||
|
@ -18272,7 +18465,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
|
@ -18320,7 +18513,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_add_or_set(ctx,
|
||||
src0->grad, // [n,m,q1,r1]
|
||||
s1_tg, // [n,m,q1,r1]
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
src1->grad =
|
||||
|
@ -18338,7 +18531,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0, // [n,m,q1,r1]
|
||||
ggml_transpose(ctx, // [p,m,qq,rr]
|
||||
tensor->grad)), // [m,p,qq,rr]
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
|
@ -18360,7 +18553,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_scale_impl(ctx, tensor->grad, s, false),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SET:
|
||||
|
@ -18389,7 +18582,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
tensor->grad,
|
||||
ggml_neg(ctx, tensor_grad_view),
|
||||
nb1, nb2, nb3, offset, false),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
|
||||
if (src1->grad) {
|
||||
|
@ -18399,7 +18592,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_reshape(ctx,
|
||||
ggml_cont(ctx, tensor_grad_view),
|
||||
src1->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CPY:
|
||||
|
@ -18410,7 +18603,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
// tensor = src0 * 1 + src1 * 0
|
||||
if (src0->grad) {
|
||||
// dsrc0 = dtensor * 1
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
// dsrc1 = dtensor * 0 -> noop
|
||||
|
@ -18422,7 +18615,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
if (src0->grad) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0->grad));
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor->grad));
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_RESHAPE:
|
||||
|
@ -18436,7 +18629,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
? tensor->grad
|
||||
: ggml_cont(ctx, tensor->grad),
|
||||
src0->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_VIEW:
|
||||
|
@ -18465,7 +18658,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
nb3 = (nb3 / n0) * ng;
|
||||
}
|
||||
|
||||
src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
|
||||
src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_PERMUTE:
|
||||
|
@ -18490,7 +18683,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
axes_backward[1],
|
||||
axes_backward[2],
|
||||
axes_backward[3]),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_TRANSPOSE:
|
||||
|
@ -18500,7 +18693,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0->grad =
|
||||
ggml_add_or_set(ctx, src0->grad,
|
||||
ggml_transpose(ctx, tensor->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
|
@ -18512,7 +18705,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
// last ggml_get_rows_back argument src0->grad is only
|
||||
// necessary to setup correct output shape
|
||||
ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
// noop
|
||||
|
@ -18536,7 +18729,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
/* ggml_diag_mask_inf_impl() shouldn't be here */
|
||||
/* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */
|
||||
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_DIAG_MASK_ZERO:
|
||||
|
@ -18547,7 +18740,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0->grad =
|
||||
ggml_add_or_set(ctx, src0->grad,
|
||||
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
|
@ -18557,7 +18750,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0->grad =
|
||||
ggml_add_or_set(ctx, src0->grad,
|
||||
ggml_soft_max_back(ctx, tensor->grad, tensor),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
|
||||
} break;
|
||||
|
@ -18598,7 +18791,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
attn_factor,
|
||||
beta_fast,
|
||||
beta_slow),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ROPE_BACK:
|
||||
|
@ -18634,7 +18827,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
beta_fast,
|
||||
beta_slow,
|
||||
false),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CLAMP:
|
||||
|
@ -18659,7 +18852,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src1->grad = ggml_add_or_set(ctx,
|
||||
src1->grad,
|
||||
ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_IM2COL_BACK:
|
||||
|
@ -18688,7 +18881,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_POOL_2D_BACK:
|
||||
|
@ -18753,7 +18946,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
grad_q,
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k);
|
||||
|
@ -18761,7 +18954,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src1->grad = ggml_add_or_set(ctx,
|
||||
src1->grad,
|
||||
grad_k,
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
if (src2->grad) {
|
||||
struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v);
|
||||
|
@ -18769,7 +18962,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src2->grad = ggml_add_or_set(ctx,
|
||||
src2->grad,
|
||||
grad_v,
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_FLASH_ATTN_BACK:
|
||||
|
@ -18795,7 +18988,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_mul(ctx,
|
||||
ggml_sgn(ctx, src0),
|
||||
tensor->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_UNARY_OP_SGN:
|
||||
|
@ -18807,7 +19000,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
case GGML_UNARY_OP_NEG:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_UNARY_OP_STEP:
|
||||
|
@ -18832,7 +19025,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
ggml_mul(ctx,
|
||||
ggml_step(ctx, src0),
|
||||
tensor->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_UNARY_OP_SIGMOID:
|
||||
|
@ -18854,7 +19047,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_silu_back(ctx, src0, tensor->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_UNARY_OP_EXP:
|
||||
|
@ -18863,7 +19056,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_mul(ctx, tensor, tensor->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
|
@ -18893,13 +19086,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0,
|
||||
src1,
|
||||
tensor->grad),
|
||||
zero_table);
|
||||
zero_table, acc_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // not supported
|
||||
}
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
{
|
||||
GGML_ABORT("fatal error"); // not supported
|
||||
}
|
||||
case GGML_OP_NONE:
|
||||
{
|
||||
// nop
|
||||
|
@ -18989,7 +19186,7 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor *
|
|||
ggml_build_forward_impl(cgraph, tensor, true);
|
||||
}
|
||||
|
||||
void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) {
|
||||
void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate, bool keep) {
|
||||
GGML_ASSERT(gf->n_nodes > 0);
|
||||
GGML_ASSERT(gf->grads);
|
||||
|
||||
|
@ -19005,21 +19202,35 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
|
|||
}
|
||||
}
|
||||
|
||||
// remember original gradients which start with zero values
|
||||
// keep tables of original gradients for replacement/accumulation logic
|
||||
struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size);
|
||||
struct ggml_hash_set acc_table = ggml_hash_set_new(gf->size);
|
||||
for (int i = 0; i < gf->n_nodes; i++) {
|
||||
if (gf->grads[i]) {
|
||||
ggml_hash_insert(&zero_table, gf->grads[i]);
|
||||
struct ggml_tensor * node = gf->nodes[i];
|
||||
|
||||
if (node->grad) {
|
||||
{
|
||||
const size_t insert_result = ggml_hash_insert(&zero_table, node->grad);
|
||||
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
|
||||
}
|
||||
|
||||
// only gradients of trainable parameters should be accumulated
|
||||
if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) {
|
||||
const size_t insert_result = ggml_hash_insert(&acc_table, node->grad);
|
||||
GGML_ASSERT(insert_result != GGML_HASHSET_FULL);
|
||||
GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = gf->n_nodes - 1; i >= 0; i--) {
|
||||
struct ggml_tensor * node = gf->nodes[i];
|
||||
|
||||
// inplace operations to add gradients are not created by ggml_compute_backward
|
||||
// inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation
|
||||
// use allocator to automatically make inplace operations
|
||||
if (node->grad) {
|
||||
ggml_compute_backward(ctx, node, &zero_table);
|
||||
ggml_compute_backward(ctx, node, &zero_table, &acc_table);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -19033,8 +19244,30 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
|
|||
}
|
||||
|
||||
ggml_hash_set_free(&zero_table);
|
||||
ggml_hash_set_free(&acc_table);
|
||||
}
|
||||
|
||||
void ggml_build_opt_adamw(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_cgraph * gf,
|
||||
struct ggml_cgraph * gb,
|
||||
float alpha,
|
||||
float beta1,
|
||||
float beta2,
|
||||
float eps,
|
||||
float wd) {
|
||||
for (int i = 0; i < gf->n_nodes; i++) {
|
||||
struct ggml_tensor * node = gf->nodes[i];
|
||||
|
||||
if (node->flags & GGML_TENSOR_FLAG_PARAM) {
|
||||
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
|
||||
struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, alpha, beta1, beta2, eps, wd);
|
||||
ggml_build_forward_expand(gb, opt_step);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
static void * incr_ptr_aligned(void ** p, size_t size, size_t align) {
|
||||
void * ptr = *p;
|
||||
ptr = (void *) GGML_PAD((uintptr_t) ptr, align);
|
||||
|
@ -19162,10 +19395,28 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) {
|
|||
GGML_ASSERT(cgraph->grads != NULL);
|
||||
|
||||
for (int i = 0; i < cgraph->n_nodes; i++) {
|
||||
struct ggml_tensor * grad = cgraph->grads[i];
|
||||
struct ggml_tensor * node = cgraph->nodes[i];
|
||||
|
||||
if (grad) {
|
||||
ggml_set_zero(grad);
|
||||
// initial gradients of loss should be 1, 0 otherwise
|
||||
if (node->grad) {
|
||||
if (node->flags & GGML_TENSOR_FLAG_LOSS) {
|
||||
GGML_ASSERT(node->grad->buffer);
|
||||
GGML_ASSERT(node->type == GGML_TYPE_F32);
|
||||
GGML_ASSERT(ggml_is_scalar(node));
|
||||
|
||||
const float onef = 1.0f;
|
||||
ggml_backend_tensor_set(node->grad, &onef, 0, ggml_nbytes(node->grad));
|
||||
} else {
|
||||
ggml_set_zero(node->grad);
|
||||
}
|
||||
}
|
||||
|
||||
GGML_ASSERT(node);
|
||||
if (node->op == GGML_OP_OPT_STEP_ADAMW) {
|
||||
// set iteration to 1 and clear momenta
|
||||
ggml_set_op_params_i32(node, 0, 1);
|
||||
ggml_set_zero(node->src[2]);
|
||||
ggml_set_zero(node->src[3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -19458,6 +19709,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
|||
} break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
case GGML_OP_OPT_STEP_ADAMW:
|
||||
{
|
||||
n_tasks = n_threads;
|
||||
} break;
|
||||
|
@ -21851,7 +22103,7 @@ enum ggml_opt_result ggml_opt_resume(
|
|||
ggml_build_forward_expand(gf, f);
|
||||
|
||||
struct ggml_cgraph * gb = ggml_graph_dup(ctx, gf);
|
||||
ggml_build_backward_expand(ctx, gf, gb, true);
|
||||
ggml_build_backward_expand(ctx, gf, gb, false, true);
|
||||
|
||||
return ggml_opt_resume_g(ctx, opt, f, gf, gb, NULL, NULL);
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue