avoid keeping in memory ALL of the gradients

The problem here stems from ggml_graph_reset. This function is called in the optimization function, before each graph computation, to reset the gradients to zero. This required a unique memory slot for each gradient: allocating memory from a previosly freed memory location might lead to non-zero input gradients.

During ggml_compute_backward the gradients are build stepwise by adding or substracting new values, starting from a OP_NONE tensor which needs to contain zero-values. This requires the graph reset.

To avoid this I now remember in ggml_build_backward_expand the original OP_NONE gradient tensors in a hash table, which is passed to ggml_compute_backward. There instead of using add (or sub or similar) I test whether the existing gradient to be changed is a zero-valued-tensor by looking up its existence in the hash table. When it is such a zero-tensor it will not be modified, but replaced by the value to be added, otherwise the regular add (not inplace, allocator will take care of this) will be used. This way none of those zero-tensor values will be necessary in the final backward graph and more importantly they won't need a unique memory slot, just to make them zero.
This commit is contained in:
xaedes 2023-08-18 16:01:43 +02:00
parent a252111b45
commit f358204a5f
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1
2 changed files with 187 additions and 126 deletions

View file

@ -1242,14 +1242,6 @@ struct ggml_tensor * llama_build_lora_finetune_graphs(
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, one)); ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, layer.w3, one));
} }
// gradient tensors (will be set to zero by ggml_graph_reset)
for (int i = 0; i < gf->n_nodes; ++i) {
if (!gf->grads[i]) continue;
if (gf->grads[i]->data == NULL && !ggml_is_view(gf->grads[i])) {
ggml_allocr_alloc(alloc, gf->grads[i]);
}
ggml_build_forward_expand(gb, ggml_scale_inplace(ctx, gf->grads[i], one));
}
for (int i = 0; i < checkpoints.size(); ++i) { for (int i = 0; i < checkpoints.size(); ++i) {
if (checkpoints[i]->data == NULL && !ggml_is_view(checkpoints[i])) { if (checkpoints[i]->data == NULL && !ggml_is_view(checkpoints[i])) {
ggml_allocr_alloc(alloc, checkpoints[i]); ggml_allocr_alloc(alloc, checkpoints[i]);

305
ggml.c
View file

@ -15009,7 +15009,89 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace) { static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
static size_t hash(void * p) {
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
}
static bool hash_insert(void * hash_table[], void * p) {
size_t h = hash(p);
// linear probing
size_t i = h;
while (hash_table[i] != NULL && hash_table[i] != p) {
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
if (i == h) {
// hash table is full
GGML_ASSERT(false);
}
}
if (hash_table[i] == p) {
return true;
}
// insert
hash_table[i] = p;
return false;
}
static bool hash_contains(void * hash_table[], void * p) {
size_t h = hash(p);
// linear probing
size_t i = h;
while (hash_table[i] != NULL && hash_table[i] != p) {
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
if (i == h) {
// hash table is full
return false;
}
}
if (hash_table[i] == p) {
return true;
}
return false;
}
// functions to change gradients considering the case that input a might be initial gradient with zero value
static struct ggml_tensor * ggml_add_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
if (hash_contains(zero_table, a)) {
return b;
} else {
return ggml_add_impl(ctx, a, b, false);
}
}
static struct ggml_tensor * ggml_acc_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, size_t nb1, size_t nb2, size_t nb3, size_t offset, void * zero_table[]) {
if (hash_contains(zero_table, a)) {
return b;
} else {
return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false);
}
}
static struct ggml_tensor * ggml_add1_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
if (hash_contains(zero_table, a)) {
return ggml_repeat(ctx, b, a);
} else {
return ggml_add1_impl(ctx, a, b, false);
}
}
static struct ggml_tensor * ggml_sub_or_set(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b, void * zero_table[]) {
if (hash_contains(zero_table, a)) {
return ggml_neg(ctx, b);
} else {
return ggml_sub_impl(ctx, a, b, false);
}
}
static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, bool inplace, void * zero_table[]) {
struct ggml_tensor * src0 = tensor->src[0]; struct ggml_tensor * src0 = tensor->src[0];
struct ggml_tensor * src1 = tensor->src[1]; struct ggml_tensor * src1 = tensor->src[1];
@ -15017,34 +15099,34 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_DUP: case GGML_OP_DUP:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
} }
} break; } break;
case GGML_OP_ADD: case GGML_OP_ADD:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
} }
if (src1->grad) { if (src1->grad) {
src1->grad = ggml_add_impl(ctx, src1->grad, tensor->grad, inplace); src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table);
} }
} break; } break;
case GGML_OP_ADD1: case GGML_OP_ADD1:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
} }
if (src1->grad) { if (src1->grad) {
src1->grad = ggml_add_impl(ctx, src1->grad = ggml_add_or_set(ctx,
src1->grad, src1->grad,
ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean
inplace); zero_table);
} }
} break; } break;
case GGML_OP_ACC: case GGML_OP_ACC:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
} }
if (src1->grad) { if (src1->grad) {
const size_t nb1 = ((int32_t *) tensor->op_params)[0]; const size_t nb1 = ((int32_t *) tensor->op_params)[0];
@ -15061,117 +15143,117 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
nb1, nb2, nb3, offset); nb1, nb2, nb3, offset);
src1->grad = src1->grad =
ggml_add_impl(ctx, ggml_add_or_set(ctx,
src1->grad, src1->grad,
ggml_reshape(ctx, ggml_reshape(ctx,
ggml_cont(ctx, tensor_grad_view), ggml_cont(ctx, tensor_grad_view),
src1->grad), src1->grad),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_SUB: case GGML_OP_SUB:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
} }
if (src1->grad) { if (src1->grad) {
src1->grad = ggml_sub_impl(ctx, src1->grad, tensor->grad, inplace); src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table);
} }
} break; } break;
case GGML_OP_MUL: case GGML_OP_MUL:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add_impl(ctx, ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_mul(ctx, src1, tensor->grad), ggml_mul(ctx, src1, tensor->grad),
inplace); zero_table);
} }
if (src1->grad) { if (src1->grad) {
src1->grad = src1->grad =
ggml_add_impl(ctx, ggml_add_or_set(ctx,
src1->grad, src1->grad,
ggml_mul(ctx, src0, tensor->grad), ggml_mul(ctx, src0, tensor->grad),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_DIV: case GGML_OP_DIV:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add_impl(ctx, ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_div(ctx, tensor->grad, src1), ggml_div(ctx, tensor->grad, src1),
inplace); zero_table);
} }
if (src1->grad) { if (src1->grad) {
src1->grad = src1->grad =
ggml_sub_impl(ctx, ggml_sub_or_set(ctx,
src1->grad, src1->grad,
ggml_mul(ctx, ggml_mul(ctx,
tensor->grad, tensor->grad,
ggml_div(ctx, tensor, src1)), ggml_div(ctx, tensor, src1)),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_SQR: case GGML_OP_SQR:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add_impl(ctx, ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_scale(ctx, ggml_scale(ctx,
ggml_mul(ctx, src0, tensor->grad), ggml_mul(ctx, src0, tensor->grad),
ggml_new_f32(ctx, 2.0f)), ggml_new_f32(ctx, 2.0f)),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_SQRT: case GGML_OP_SQRT:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add_impl(ctx, ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_scale(ctx, ggml_scale(ctx,
ggml_div(ctx, ggml_div(ctx,
tensor->grad, tensor->grad,
tensor), tensor),
ggml_new_f32(ctx, 0.5f)), ggml_new_f32(ctx, 0.5f)),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_LOG: case GGML_OP_LOG:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add_impl(ctx, ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_div(ctx, ggml_div(ctx,
tensor->grad, tensor->grad,
src0), src0),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_SUM: case GGML_OP_SUM:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add1_impl(ctx, ggml_add1_or_set(ctx,
src0->grad, src0->grad,
tensor->grad, tensor->grad,
inplace); zero_table);
} }
} break; } break;
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add_impl(ctx, ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_repeat(ctx, ggml_repeat(ctx,
tensor->grad, tensor->grad,
src0->grad), src0->grad),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_MEAN: case GGML_OP_MEAN:
@ -15183,20 +15265,20 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{ {
// necessary for llama // necessary for llama
if (src0->grad) { if (src0->grad) {
src0->grad = ggml_add_impl(ctx, src0->grad = ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_repeat_back(ctx, tensor->grad, src0->grad), ggml_repeat_back(ctx, tensor->grad, src0->grad),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_REPEAT_BACK: case GGML_OP_REPEAT_BACK:
{ {
if (src0->grad) { if (src0->grad) {
// TODO: test this // TODO: test this
src0->grad = ggml_add_impl(ctx, src0->grad = ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_repeat(ctx, tensor->grad, src0->grad), ggml_repeat(ctx, tensor->grad, src0->grad),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_SILU_BACK: case GGML_OP_SILU_BACK:
@ -15214,10 +15296,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
float eps; float eps;
memcpy(&eps, tensor->op_params, sizeof(float)); memcpy(&eps, tensor->op_params, sizeof(float));
src0->grad = ggml_add_impl(ctx, src0->grad = ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_rms_norm_back(ctx, src0, tensor->grad, eps), ggml_rms_norm_back(ctx, src0, tensor->grad, eps),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_RMS_NORM_BACK: case GGML_OP_RMS_NORM_BACK:
@ -15244,16 +15326,16 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama // necessary for llama
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add_impl(ctx, ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_out_prod(ctx, // [n,m] ggml_out_prod(ctx, // [n,m]
src1, // [n,p] src1, // [n,p]
tensor->grad), // [m,p] tensor->grad), // [m,p]
inplace); zero_table);
} }
if (src1->grad) { if (src1->grad) {
src1->grad = src1->grad =
ggml_add_impl(ctx, ggml_add_or_set(ctx,
src1->grad, src1->grad,
// ggml_mul_mat(ctx, // [n,p] // ggml_mul_mat(ctx, // [n,p]
// ggml_cont(ctx, // [m,n] // ggml_cont(ctx, // [m,n]
@ -15267,7 +15349,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
src0, // [n,m] src0, // [n,m]
ggml_transpose(ctx, // [p,m] ggml_transpose(ctx, // [p,m]
tensor->grad)), // [m,p] tensor->grad)), // [m,p]
inplace); zero_table);
} }
} break; } break;
case GGML_OP_OUT_PROD: case GGML_OP_OUT_PROD:
@ -15279,17 +15361,17 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama // necessary for llama
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add_impl(ctx, ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_scale_impl(ctx, tensor->grad, src1, false), ggml_scale_impl(ctx, tensor->grad, src1, false),
inplace); zero_table);
} }
if (src1->grad) { if (src1->grad) {
src1->grad = src1->grad =
ggml_add_impl(ctx, ggml_add_or_set(ctx,
src1->grad, src1->grad,
ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)), ggml_sum(ctx, ggml_mul_impl(ctx, tensor->grad, src0, false)),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_SET: case GGML_OP_SET:
@ -15316,23 +15398,23 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} }
if (src0->grad) { if (src0->grad) {
src0->grad = ggml_add_impl(ctx, src0->grad = ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_acc_impl(ctx, ggml_acc_impl(ctx,
tensor->grad, tensor->grad,
ggml_neg(ctx, tensor_grad_view), ggml_neg(ctx, tensor_grad_view),
nb1, nb2, nb3, offset, false), nb1, nb2, nb3, offset, false),
inplace); zero_table);
} }
if (src1->grad) { if (src1->grad) {
src1->grad = src1->grad =
ggml_add_impl(ctx, ggml_add_or_set(ctx,
src1->grad, src1->grad,
ggml_reshape(ctx, ggml_reshape(ctx,
ggml_cont(ctx, tensor_grad_view), ggml_cont(ctx, tensor_grad_view),
src1->grad), src1->grad),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_CPY: case GGML_OP_CPY:
@ -15343,7 +15425,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// tensor = src0 * 1 + src1 * 0 // tensor = src0 * 1 + src1 * 0
if (src0->grad) { if (src0->grad) {
// dsrc0 = dtensor * 1 // dsrc0 = dtensor * 1
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
} }
if (src1->grad) { if (src1->grad) {
// dsrc1 = dtensor * 0 -> noop // dsrc1 = dtensor * 0 -> noop
@ -15355,7 +15437,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
if (src0->grad) { if (src0->grad) {
GGML_ASSERT(ggml_is_contiguous(src0->grad)); GGML_ASSERT(ggml_is_contiguous(src0->grad));
GGML_ASSERT(ggml_is_contiguous(tensor->grad)); GGML_ASSERT(ggml_is_contiguous(tensor->grad));
src0->grad = ggml_add_impl(ctx, src0->grad, tensor->grad, inplace); src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table);
} }
} break; } break;
case GGML_OP_RESHAPE: case GGML_OP_RESHAPE:
@ -15363,9 +15445,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama // necessary for llama
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add_impl(ctx, src0->grad, ggml_add_or_set(ctx, src0->grad,
ggml_reshape(ctx, tensor->grad, src0->grad), ggml_reshape(ctx,
inplace); ggml_is_contiguous(tensor->grad)
? tensor->grad
: ggml_cont(ctx, tensor->grad),
src0->grad),
zero_table);
} }
} break; } break;
case GGML_OP_VIEW: case GGML_OP_VIEW:
@ -15394,7 +15480,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
nb3 = (nb3 / n0) * ng; nb3 = (nb3 / n0) * ng;
} }
src0->grad = ggml_acc_impl(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, inplace); src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table);
} }
} break; } break;
case GGML_OP_PERMUTE: case GGML_OP_PERMUTE:
@ -15412,14 +15498,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
axes_backward[axis2] = 2; axes_backward[axis2] = 2;
axes_backward[axis3] = 3; axes_backward[axis3] = 3;
src0->grad = src0->grad =
ggml_add_impl(ctx, src0->grad, ggml_add_or_set(ctx, src0->grad,
ggml_permute(ctx, ggml_permute(ctx,
tensor->grad, tensor->grad,
axes_backward[0], axes_backward[0],
axes_backward[1], axes_backward[1],
axes_backward[2], axes_backward[2],
axes_backward[3]), axes_backward[3]),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_TRANSPOSE: case GGML_OP_TRANSPOSE:
@ -15427,9 +15513,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama // necessary for llama
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add_impl(ctx, src0->grad, ggml_add_or_set(ctx, src0->grad,
ggml_transpose(ctx, tensor->grad), ggml_transpose(ctx, tensor->grad),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_GET_ROWS: case GGML_OP_GET_ROWS:
@ -15437,9 +15523,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama (only for tokenizer) // necessary for llama (only for tokenizer)
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add_impl(ctx, src0->grad, ggml_add_or_set(ctx, src0->grad,
ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad),
inplace); zero_table);
} }
if (src1->grad) { if (src1->grad) {
// noop // noop
@ -15459,9 +15545,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
if (src0->grad) { if (src0->grad) {
const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_past = ((int32_t *) tensor->op_params)[0];
src0->grad = src0->grad =
ggml_add_impl(ctx, src0->grad, ggml_add_or_set(ctx, src0->grad,
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_DIAG_MASK_ZERO: case GGML_OP_DIAG_MASK_ZERO:
@ -15470,9 +15556,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
if (src0->grad) { if (src0->grad) {
const int n_past = ((int32_t *) tensor->op_params)[0]; const int n_past = ((int32_t *) tensor->op_params)[0];
src0->grad = src0->grad =
ggml_add_impl(ctx, src0->grad, ggml_add_or_set(ctx, src0->grad,
ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
@ -15480,9 +15566,9 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
// necessary for llama // necessary for llama
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add_impl(ctx, src0->grad, ggml_add_or_set(ctx, src0->grad,
ggml_soft_max_back(ctx, tensor->grad, tensor), ggml_soft_max_back(ctx, tensor->grad, tensor),
inplace); zero_table);
} }
} break; } break;
@ -15498,7 +15584,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
const int n_dims = ((int32_t *) tensor->op_params)[1]; const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2]; const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3]; const int n_ctx = ((int32_t *) tensor->op_params)[3];
src0->grad = ggml_add_impl(ctx, src0->grad = ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_rope_back(ctx, ggml_rope_back(ctx,
tensor->grad, tensor->grad,
@ -15506,7 +15592,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
n_dims, n_dims,
mode, mode,
n_ctx), n_ctx),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_ROPE_BACK: case GGML_OP_ROPE_BACK:
@ -15516,7 +15602,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
const int n_dims = ((int32_t *) tensor->op_params)[1]; const int n_dims = ((int32_t *) tensor->op_params)[1];
const int mode = ((int32_t *) tensor->op_params)[2]; const int mode = ((int32_t *) tensor->op_params)[2];
const int n_ctx = ((int32_t *) tensor->op_params)[3]; const int n_ctx = ((int32_t *) tensor->op_params)[3];
src0->grad = ggml_add_impl(ctx, src0->grad = ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_rope(ctx, ggml_rope(ctx,
tensor->grad, tensor->grad,
@ -15524,7 +15610,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
n_dims, n_dims,
mode, mode,
n_ctx), n_ctx),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_ALIBI: case GGML_OP_ALIBI:
@ -15607,10 +15693,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} break; } break;
} }
src0->grad = ggml_add_impl(ctx, src0->grad = ggml_add_or_set(ctx,
src0->grad, src0->grad,
grad_q, grad_q,
inplace); zero_table);
} }
if (src1->grad) { if (src1->grad) {
@ -15653,10 +15739,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} break; } break;
} }
src1->grad = ggml_add_impl(ctx, src1->grad = ggml_add_or_set(ctx,
src1->grad, src1->grad,
grad_k, grad_k,
inplace); zero_table);
} }
struct ggml_tensor * opt0 = tensor->src[2]; struct ggml_tensor * opt0 = tensor->src[2];
@ -15702,10 +15788,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} break; } break;
} }
opt0->grad = ggml_add_impl(ctx, opt0->grad = ggml_add_or_set(ctx,
opt0->grad, opt0->grad,
grad_v, grad_v,
inplace); zero_table);
} }
} break; } break;
case GGML_OP_FLASH_FF: case GGML_OP_FLASH_FF:
@ -15725,12 +15811,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = src0->grad =
ggml_add_impl(ctx, ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_mul(ctx, ggml_mul(ctx,
ggml_sgn(ctx, src0), ggml_sgn(ctx, src0),
tensor->grad), tensor->grad),
inplace); zero_table);
} }
} break; } break;
case GGML_UNARY_OP_SGN: case GGML_UNARY_OP_SGN:
@ -15742,7 +15828,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_UNARY_OP_NEG: case GGML_UNARY_OP_NEG:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = ggml_sub_impl(ctx, src0->grad, tensor->grad, inplace); src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table);
} }
} break; } break;
case GGML_UNARY_OP_STEP: case GGML_UNARY_OP_STEP:
@ -15762,12 +15848,12 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_UNARY_OP_RELU: case GGML_UNARY_OP_RELU:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = ggml_add_impl(ctx, src0->grad = ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_mul(ctx, ggml_mul(ctx,
ggml_step(ctx, src0), ggml_step(ctx, src0),
tensor->grad), tensor->grad),
inplace); zero_table);
} }
} break; } break;
case GGML_UNARY_OP_GELU: case GGML_UNARY_OP_GELU:
@ -15782,10 +15868,10 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{ {
// necessary for llama // necessary for llama
if (src0->grad) { if (src0->grad) {
src0->grad = ggml_add_impl(ctx, src0->grad = ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_silu_back(ctx, src0, tensor->grad), ggml_silu_back(ctx, src0, tensor->grad),
inplace); zero_table);
} }
} break; } break;
default: default:
@ -15803,13 +15889,13 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS:
{ {
if (src0->grad) { if (src0->grad) {
src0->grad = ggml_add_impl(ctx, src0->grad = ggml_add_or_set(ctx,
src0->grad, src0->grad,
ggml_cross_entropy_loss_back(ctx, ggml_cross_entropy_loss_back(ctx,
src0, src0,
src1, src1,
tensor->grad), tensor->grad),
inplace); zero_table);
} }
} break; } break;
case GGML_OP_CROSS_ENTROPY_LOSS_BACK: case GGML_OP_CROSS_ENTROPY_LOSS_BACK:
@ -15827,34 +15913,6 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
} }
} }
static_assert(GGML_GRAPH_HASHTABLE_SIZE > GGML_MAX_NODES * 2, "GGML_GRAPH_HT_SIZE is too small");
static size_t hash(void * p) {
return (size_t)p % GGML_GRAPH_HASHTABLE_SIZE;
}
static bool hash_insert(void * hash_table[], void * p) {
size_t h = hash(p);
// linear probing
size_t i = h;
while (hash_table[i] != NULL && hash_table[i] != p) {
i = (i + 1) % GGML_GRAPH_HASHTABLE_SIZE;
if (i == h) {
// hash table is full
GGML_ASSERT(false);
}
}
if (hash_table[i] == p) {
return true;
}
// insert
hash_table[i] = p;
return false;
}
static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) {
if (node->grad == NULL) { if (node->grad == NULL) {
// this usually happens when we generate intermediate nodes from constants in the backward pass // this usually happens when we generate intermediate nodes from constants in the backward pass
@ -15955,12 +16013,21 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
} }
} }
// remember original gradients which start with zero values
void ** zero_table = malloc(sizeof(void *) * GGML_GRAPH_HASHTABLE_SIZE);
memset(zero_table, 0, sizeof(void*) * GGML_GRAPH_HASHTABLE_SIZE);
for (int i = 0; i < gf->n_nodes; i++) {
if (gf->grads[i]) {
hash_insert(zero_table, gf->grads[i]);
}
}
for (int i = gf->n_nodes - 1; i >= 0; i--) { for (int i = gf->n_nodes - 1; i >= 0; i--) {
struct ggml_tensor * node = gf->nodes[i]; struct ggml_tensor * node = gf->nodes[i];
// because we detached the grad nodes from the original graph, we can afford inplace operations // because we detached the grad nodes from the original graph, we can afford inplace operations
if (node->grad) { if (node->grad) {
ggml_compute_backward(ctx, node, keep); ggml_compute_backward(ctx, node, keep, zero_table);
} }
} }
@ -15972,6 +16039,8 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph *
ggml_build_forward_expand(gb, node->grad); ggml_build_forward_expand(gb, node->grad);
} }
} }
free(zero_table);
} }
struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) { struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
@ -17574,7 +17643,7 @@ static enum ggml_opt_result ggml_opt_adam(
} }
// compute the function value // compute the function value
ggml_graph_reset (gf); // ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads); struct ggml_cplan cplan = ggml_graph_plan(gb, params.n_threads);
@ -17668,7 +17737,7 @@ static enum ggml_opt_result ggml_opt_adam(
callback(callback_data, &sched); callback(callback_data, &sched);
} }
ggml_graph_reset (gf); // ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute(gb, &cplan); ggml_graph_compute(gb, &cplan);
@ -17806,7 +17875,7 @@ static enum ggml_opt_result linesearch_backtracking(
{ {
ggml_opt_set_params(np, ps, x); ggml_opt_set_params(np, ps, x);
ggml_graph_reset (gf); //ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute(gb, cplan); ggml_graph_compute(gb, cplan);
@ -17938,7 +18007,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
{ {
ggml_opt_set_params(np, ps, x); ggml_opt_set_params(np, ps, x);
ggml_graph_reset (gf); //ggml_graph_reset (gf);
ggml_set_f32 (f->grad, 1.0f); ggml_set_f32 (f->grad, 1.0f);
ggml_graph_compute(gb, &cplan); ggml_graph_compute(gb, &cplan);