avoid keeping in memory ALL of the gradients
The problem here stems from ggml_graph_reset. This function is called in the optimization function, before each graph computation, to reset the gradients to zero. This required a unique memory slot for each gradient: allocating memory from a previosly freed memory location might lead to non-zero input gradients. During ggml_compute_backward the gradients are build stepwise by adding or substracting new values, starting from a OP_NONE tensor which needs to contain zero-values. This requires the graph reset. To avoid this I now remember in ggml_build_backward_expand the original OP_NONE gradient tensors in a hash table, which is passed to ggml_compute_backward. There instead of using add (or sub or similar) I test whether the existing gradient to be changed is a zero-valued-tensor by looking up its existence in the hash table. When it is such a zero-tensor it will not be modified, but replaced by the value to be added, otherwise the regular add (not inplace, allocator will take care of this) will be used. This way none of those zero-tensor values will be necessary in the final backward graph and more importantly they won't need a unique memory slot, just to make them zero.
This commit is contained in:
parent
a252111b45
commit
f358204a5f
2 changed files with 187 additions and 126 deletions
|
@ -1242,14 +1242,6 @@ struct ggml_tensor * llama_build_lora_finetune_graphs(
|
|||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, one));
|
||||
}
|
||||
|
||||
// gradient tensors (will be set to zero by ggml_graph_reset)
|
||||
for (int i = 0; i < gf->n_nodes; ++i) {
|
||||
if (!gf->grads[i]) continue;
|
||||
if (gf->grads[i]->data == NULL && !ggml_is_view(gf->grads[i])) {
|
||||
ggml_allocr_alloc(alloc, gf->grads[i]);
|
||||
}
|
||||
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, gf->grads[i], one));
|
||||
}
|
||||
for (int i = 0; i < checkpoints.size(); ++i) {
|
||||
if (checkpoints[i]->data == NULL && !ggml_is_view(checkpoints[i])) {
|
||||
ggml_allocr_alloc(alloc, checkpoints[i]);
|
||||
|
|
305
ggml.c
305
ggml.c
|
@ -15009,7 +15009,89 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
|||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) {
|
||||
static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
|
||||
|
||||
static size_t hash(void * p) {
|
||||
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
||||
}
|
||||
|
||||
static bool hash_insert(void * hash_table[], void * p) {
|
||||
size_t h = hash(p);
|
||||
|
||||
// linear probing
|
||||
size_t i = h;
|
||||
while (hash_table[i] != NULL && hash_table[i] != p) {
|
||||
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
||||
if (i == h) {
|
||||
// hash table is full
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
if (hash_table[i] == p) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// insert
|
||||
hash_table[i] = p;
|
||||
return false;
|
||||
}
|
||||
|
||||
static bool hash_contains(void * hash_table[], void * p) {
|
||||
size_t h = hash(p);
|
||||
|
||||
// linear probing
|
||||
size_t i = h;
|
||||
while (hash_table[i] != NULL && hash_table[i] != p) {
|
||||
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
||||
if (i == h) {
|
||||
// hash table is full
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
if (hash_table[i] == p) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
|
||||
// functions to change gradients considering the case that input a might be initial gradient with zero value
|
||||
|
||||
static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
|
||||
if (hash_contains(zero_table, a)) {
|
||||
return b;
|
||||
} else {
|
||||
return ggml_add_impl(ctx, a, b, false);
|
||||
}
|
||||
}
|
||||
|
||||
static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) {
|
||||
if (hash_contains(zero_table, a)) {
|
||||
return b;
|
||||
} else {
|
||||
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
|
||||
}
|
||||
}
|
||||
|
||||
static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
|
||||
if (hash_contains(zero_table, a)) {
|
||||
return ggml_repeat(ctx, b, a);
|
||||
} else {
|
||||
return ggml_add1_impl(ctx, a, b, false);
|
||||
}
|
||||
}
|
||||
|
||||
static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
|
||||
if (hash_contains(zero_table, a)) {
|
||||
return ggml_neg(ctx, b);
|
||||
} else {
|
||||
return ggml_sub_impl(ctx, a, b, false);
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace, void * zero_table[]) {
|
||||
struct ggml_tensor * src0 = tensor->src[0];
|
||||
struct ggml_tensor * src1 = tensor->src[1];
|
||||
|
||||
|
@ -15017,34 +15099,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
case GGML_OP_DUP:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ADD:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace);
|
||||
src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ADD1:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
src1->grad = ggml_add_impl(ctx,
|
||||
src1->grad = ggml_add_or_set(ctx,
|
||||
src1->grad,
|
||||
ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ACC:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
const size_t nb1 = ((int32_t *) tensor->op_params)[0];
|
||||
|
@ -15061,117 +15143,117 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
nb1, nb2, nb3, offset);
|
||||
|
||||
src1->grad =
|
||||
ggml_add_impl(ctx,
|
||||
ggml_add_or_set(ctx,
|
||||
src1->grad,
|
||||
ggml_reshape(ctx,
|
||||
ggml_cont(ctx, tensor_grad_view),
|
||||
src1->grad),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SUB:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace);
|
||||
src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_MUL:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx,
|
||||
ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_mul(ctx, src1, tensor->grad),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
src1->grad =
|
||||
ggml_add_impl(ctx,
|
||||
ggml_add_or_set(ctx,
|
||||
src1->grad,
|
||||
ggml_mul(ctx, src0, tensor->grad),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_DIV:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx,
|
||||
ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_div(ctx, tensor->grad, src1),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
src1->grad =
|
||||
ggml_sub_impl(ctx,
|
||||
ggml_sub_or_set(ctx,
|
||||
src1->grad,
|
||||
ggml_mul(ctx,
|
||||
tensor->grad,
|
||||
ggml_div(ctx, tensor, src1)),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SQR:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx,
|
||||
ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_scale(ctx,
|
||||
ggml_mul(ctx, src0, tensor->grad),
|
||||
ggml_new_f32(ctx, 2.0f)),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SQRT:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx,
|
||||
ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_scale(ctx,
|
||||
ggml_div(ctx,
|
||||
tensor->grad,
|
||||
tensor),
|
||||
ggml_new_f32(ctx, 0.5f)),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_LOG:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx,
|
||||
ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_div(ctx,
|
||||
tensor->grad,
|
||||
src0),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SUM:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add1_impl(ctx,
|
||||
ggml_add1_or_set(ctx,
|
||||
src0->grad,
|
||||
tensor->grad,
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SUM_ROWS:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx,
|
||||
ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_repeat(ctx,
|
||||
tensor->grad,
|
||||
src0->grad),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_MEAN:
|
||||
|
@ -15183,20 +15265,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
{
|
||||
// necessary for llama
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_impl(ctx,
|
||||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_repeat_back(ctx, tensor->grad, src0->grad),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_REPEAT_BACK:
|
||||
{
|
||||
if (src0->grad) {
|
||||
// TODO: test this
|
||||
src0->grad = ggml_add_impl(ctx,
|
||||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_repeat(ctx, tensor->grad, src0->grad),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SILU_BACK:
|
||||
|
@ -15214,10 +15296,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
float eps;
|
||||
memcpy(&eps, tensor->op_params, sizeof(float));
|
||||
|
||||
src0->grad = ggml_add_impl(ctx,
|
||||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
|
@ -15244,16 +15326,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
// necessary for llama
|
||||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx,
|
||||
ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_out_prod(ctx, // [n,m]
|
||||
src1, // [n,p]
|
||||
tensor->grad), // [m,p]
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
src1->grad =
|
||||
ggml_add_impl(ctx,
|
||||
ggml_add_or_set(ctx,
|
||||
src1->grad,
|
||||
// ggml_mul_mat(ctx, // [n,p]
|
||||
// ggml_cont(ctx, // [m,n]
|
||||
|
@ -15267,7 +15349,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
src0, // [n,m]
|
||||
ggml_transpose(ctx, // [p,m]
|
||||
tensor->grad)), // [m,p]
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_OUT_PROD:
|
||||
|
@ -15279,17 +15361,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
// necessary for llama
|
||||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx,
|
||||
ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_scale_impl(ctx, tensor->grad, src1, false),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
src1->grad =
|
||||
ggml_add_impl(ctx,
|
||||
ggml_add_or_set(ctx,
|
||||
src1->grad,
|
||||
ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SET:
|
||||
|
@ -15316,23 +15398,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
}
|
||||
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_impl(ctx,
|
||||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_acc_impl(ctx,
|
||||
tensor->grad,
|
||||
ggml_neg(ctx, tensor_grad_view),
|
||||
nb1, nb2, nb3, offset, false),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
|
||||
if (src1->grad) {
|
||||
src1->grad =
|
||||
ggml_add_impl(ctx,
|
||||
ggml_add_or_set(ctx,
|
||||
src1->grad,
|
||||
ggml_reshape(ctx,
|
||||
ggml_cont(ctx, tensor_grad_view),
|
||||
src1->grad),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CPY:
|
||||
|
@ -15343,7 +15425,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
// tensor = src0 * 1 + src1 * 0
|
||||
if (src0->grad) {
|
||||
// dsrc0 = dtensor * 1
|
||||
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
// dsrc1 = dtensor * 0 -> noop
|
||||
|
@ -15355,7 +15437,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
if (src0->grad) {
|
||||
GGML_ASSERT(ggml_is_contiguous(src0->grad));
|
||||
GGML_ASSERT(ggml_is_contiguous(tensor->grad));
|
||||
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace);
|
||||
src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_RESHAPE:
|
||||
|
@ -15363,9 +15445,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
// necessary for llama
|
||||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx, src0->grad,
|
||||
ggml_reshape(ctx, tensor->grad, src0->grad),
|
||||
inplace);
|
||||
ggml_add_or_set(ctx, src0->grad,
|
||||
ggml_reshape(ctx,
|
||||
ggml_is_contiguous(tensor->grad)
|
||||
? tensor->grad
|
||||
: ggml_cont(ctx, tensor->grad),
|
||||
src0->grad),
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_VIEW:
|
||||
|
@ -15394,7 +15480,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
nb3 = (nb3 / n0) * ng;
|
||||
}
|
||||
|
||||
src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace);
|
||||
src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_PERMUTE:
|
||||
|
@ -15412,14 +15498,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
axes_backward[axis2] = 2;
|
||||
axes_backward[axis3] = 3;
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx, src0->grad,
|
||||
ggml_add_or_set(ctx, src0->grad,
|
||||
ggml_permute(ctx,
|
||||
tensor->grad,
|
||||
axes_backward[0],
|
||||
axes_backward[1],
|
||||
axes_backward[2],
|
||||
axes_backward[3]),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_TRANSPOSE:
|
||||
|
@ -15427,9 +15513,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
// necessary for llama
|
||||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx, src0->grad,
|
||||
ggml_add_or_set(ctx, src0->grad,
|
||||
ggml_transpose(ctx, tensor->grad),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_GET_ROWS:
|
||||
|
@ -15437,9 +15523,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
// necessary for llama (only for tokenizer)
|
||||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx, src0->grad,
|
||||
ggml_add_or_set(ctx, src0->grad,
|
||||
ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
if (src1->grad) {
|
||||
// noop
|
||||
|
@ -15459,9 +15545,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
if (src0->grad) {
|
||||
const int n_past = ((int32_t *) tensor->op_params)[0];
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx, src0->grad,
|
||||
ggml_add_or_set(ctx, src0->grad,
|
||||
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_DIAG_MASK_ZERO:
|
||||
|
@ -15470,9 +15556,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
if (src0->grad) {
|
||||
const int n_past = ((int32_t *) tensor->op_params)[0];
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx, src0->grad,
|
||||
ggml_add_or_set(ctx, src0->grad,
|
||||
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SOFT_MAX:
|
||||
|
@ -15480,9 +15566,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
// necessary for llama
|
||||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx, src0->grad,
|
||||
ggml_add_or_set(ctx, src0->grad,
|
||||
ggml_soft_max_back(ctx, tensor->grad, tensor),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
|
||||
} break;
|
||||
|
@ -15498,7 +15584,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||
src0->grad = ggml_add_impl(ctx,
|
||||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_rope_back(ctx,
|
||||
tensor->grad,
|
||||
|
@ -15506,7 +15592,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
n_dims,
|
||||
mode,
|
||||
n_ctx),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ROPE_BACK:
|
||||
|
@ -15516,7 +15602,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
const int n_dims = ((int32_t *) tensor->op_params)[1];
|
||||
const int mode = ((int32_t *) tensor->op_params)[2];
|
||||
const int n_ctx = ((int32_t *) tensor->op_params)[3];
|
||||
src0->grad = ggml_add_impl(ctx,
|
||||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_rope(ctx,
|
||||
tensor->grad,
|
||||
|
@ -15524,7 +15610,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
n_dims,
|
||||
mode,
|
||||
n_ctx),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_ALIBI:
|
||||
|
@ -15607,10 +15693,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
} break;
|
||||
}
|
||||
|
||||
src0->grad = ggml_add_impl(ctx,
|
||||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
grad_q,
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
|
||||
if (src1->grad) {
|
||||
|
@ -15653,10 +15739,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
} break;
|
||||
}
|
||||
|
||||
src1->grad = ggml_add_impl(ctx,
|
||||
src1->grad = ggml_add_or_set(ctx,
|
||||
src1->grad,
|
||||
grad_k,
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
|
||||
struct ggml_tensor * opt0 = tensor->src[2];
|
||||
|
@ -15702,10 +15788,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
} break;
|
||||
}
|
||||
|
||||
opt0->grad = ggml_add_impl(ctx,
|
||||
opt0->grad = ggml_add_or_set(ctx,
|
||||
opt0->grad,
|
||||
grad_v,
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_FLASH_FF:
|
||||
|
@ -15725,12 +15811,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
{
|
||||
if (src0->grad) {
|
||||
src0->grad =
|
||||
ggml_add_impl(ctx,
|
||||
ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_mul(ctx,
|
||||
ggml_sgn(ctx, src0),
|
||||
tensor->grad),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_UNARY_OP_SGN:
|
||||
|
@ -15742,7 +15828,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
case GGML_UNARY_OP_NEG:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace);
|
||||
src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_UNARY_OP_STEP:
|
||||
|
@ -15762,12 +15848,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
case GGML_UNARY_OP_RELU:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_impl(ctx,
|
||||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_mul(ctx,
|
||||
ggml_step(ctx, src0),
|
||||
tensor->grad),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_UNARY_OP_GELU:
|
||||
|
@ -15782,10 +15868,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
{
|
||||
// necessary for llama
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_impl(ctx,
|
||||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_silu_back(ctx, src0, tensor->grad),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
default:
|
||||
|
@ -15803,13 +15889,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
case GGML_OP_CROSS_ENTROPY_LOSS:
|
||||
{
|
||||
if (src0->grad) {
|
||||
src0->grad = ggml_add_impl(ctx,
|
||||
src0->grad = ggml_add_or_set(ctx,
|
||||
src0->grad,
|
||||
ggml_cross_entropy_loss_back(ctx,
|
||||
src0,
|
||||
src1,
|
||||
tensor->grad),
|
||||
inplace);
|
||||
zero_table);
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
|
||||
|
@ -15827,34 +15913,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
|||
}
|
||||
}
|
||||
|
||||
static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
|
||||
|
||||
static size_t hash(void * p) {
|
||||
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
|
||||
}
|
||||
|
||||
static bool hash_insert(void * hash_table[], void * p) {
|
||||
size_t h = hash(p);
|
||||
|
||||
// linear probing
|
||||
size_t i = h;
|
||||
while (hash_table[i] != NULL && hash_table[i] != p) {
|
||||
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
|
||||
if (i == h) {
|
||||
// hash table is full
|
||||
GGML_ASSERT(false);
|
||||
}
|
||||
}
|
||||
|
||||
if (hash_table[i] == p) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// insert
|
||||
hash_table[i] = p;
|
||||
return false;
|
||||
}
|
||||
|
||||
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
|
||||
if (node->grad == NULL) {
|
||||
// this usually happens when we generate intermediate nodes from constants in the backward pass
|
||||
|
@ -15955,12 +16013,21 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
|
|||
}
|
||||
}
|
||||
|
||||
// remember original gradients which start with zero values
|
||||
void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE);
|
||||
memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE);
|
||||
for (int i = 0; i < gf->n_nodes; i++) {
|
||||
if (gf->grads[i]) {
|
||||
hash_insert(zero_table, gf->grads[i]);
|
||||
}
|
||||
}
|
||||
|
||||
for (int i = gf->n_nodes - 1; i >= 0; i--) {
|
||||
struct ggml_tensor * node = gf->nodes[i];
|
||||
|
||||
// because we detached the grad nodes from the original graph, we can afford inplace operations
|
||||
if (node->grad) {
|
||||
ggml_compute_backward(ctx, node, keep);
|
||||
ggml_compute_backward(ctx, node, keep, zero_table);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -15972,6 +16039,8 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
|
|||
ggml_build_forward_expand(gb, node->grad);
|
||||
}
|
||||
}
|
||||
|
||||
free(zero_table);
|
||||
}
|
||||
|
||||
struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
|
||||
|
@ -17574,7 +17643,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
}
|
||||
|
||||
// compute the function value
|
||||
ggml_graph_reset (gf);
|
||||
// ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
|
||||
struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
|
||||
|
@ -17668,7 +17737,7 @@ static enum ggml_opt_result ggml_opt_adam(
|
|||
callback(callback_data, &sched);
|
||||
}
|
||||
|
||||
ggml_graph_reset (gf);
|
||||
// ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
|
||||
ggml_graph_compute(gb, &cplan);
|
||||
|
@ -17806,7 +17875,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|||
{
|
||||
ggml_opt_set_params(np, ps, x);
|
||||
|
||||
ggml_graph_reset (gf);
|
||||
//ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
|
||||
ggml_graph_compute(gb, cplan);
|
||||
|
@ -17938,7 +18007,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|||
{
|
||||
ggml_opt_set_params(np, ps, x);
|
||||
|
||||
ggml_graph_reset (gf);
|
||||
//ggml_graph_reset (gf);
|
||||
ggml_set_f32 (f->grad, 1.0f);
|
||||
|
||||
ggml_graph_compute(gb, &cplan);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue