add and use function ggml_build_backward_expand to avoid stack overflows with large maximum number of nodes
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
This commit is contained in:
parent
e05e4414ac
commit
ed4319e1a7
3 changed files with 18 additions and 11 deletions
|
@ -3957,12 +3957,14 @@ int main(int argc, char ** argv) {
|
||||||
logits = forward_batch_wo_cache_flash_attn(&model, ctx0, gf, tokens_input, n_tokens, n_batch);
|
logits = forward_batch_wo_cache_flash_attn(&model, ctx0, gf, tokens_input, n_tokens, n_batch);
|
||||||
loss = cross_entropy_loss(ctx0, logits, target_probs);
|
loss = cross_entropy_loss(ctx0, logits, target_probs);
|
||||||
ggml_build_forward_expand(gf, loss);
|
ggml_build_forward_expand(gf, loss);
|
||||||
*gb = ggml_build_backward(ctx0, gf, true);
|
*gb = *gf;
|
||||||
|
ggml_build_backward_expand(ctx0, gf, gb, true);
|
||||||
} else {
|
} else {
|
||||||
logits = forward_batch_wo_cache(&model, ctx0, gf, tokens_input, n_tokens, n_batch);
|
logits = forward_batch_wo_cache(&model, ctx0, gf, tokens_input, n_tokens, n_batch);
|
||||||
loss = cross_entropy_loss(ctx0, logits, target_probs);
|
loss = cross_entropy_loss(ctx0, logits, target_probs);
|
||||||
ggml_build_forward_expand(gf, loss);
|
ggml_build_forward_expand(gf, loss);
|
||||||
*gb = ggml_build_backward(ctx0, gf, true);
|
*gb = *gf;
|
||||||
|
ggml_build_backward_expand(ctx0, gf, gb, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
|
ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
|
||||||
|
@ -4070,13 +4072,15 @@ int main(int argc, char ** argv) {
|
||||||
};
|
};
|
||||||
struct ggml_context * ctx0 = ggml_init(cparams);
|
struct ggml_context * ctx0 = ggml_init(cparams);
|
||||||
|
|
||||||
ggml_cgraph gf = {};
|
struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0));
|
||||||
|
memset(gfbuf->data, 0, ggml_nbytes(gfbuf));
|
||||||
|
struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data;
|
||||||
|
|
||||||
int n_past = 0;
|
int n_past = 0;
|
||||||
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past);
|
struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, gf, tokens_input, sample_ctx, n_past);
|
||||||
|
|
||||||
ggml_build_forward_expand(&gf, logits);
|
ggml_build_forward_expand(gf, logits);
|
||||||
ggml_graph_compute_helper(work_buffer, &gf, params.n_threads);
|
ggml_graph_compute_helper(work_buffer, gf, params.n_threads);
|
||||||
|
|
||||||
//struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
|
//struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx);
|
||||||
//struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
|
//struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx);
|
||||||
|
|
10
ggml.c
10
ggml.c
|
@ -15787,9 +15787,7 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) {
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
|
void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep) {
|
||||||
struct ggml_cgraph result = *gf;
|
|
||||||
|
|
||||||
GGML_ASSERT(gf->n_nodes > 0);
|
GGML_ASSERT(gf->n_nodes > 0);
|
||||||
|
|
||||||
// if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
|
// if we are keeping the gradient graph, we have to detach the gradient nodes from the original graph
|
||||||
|
@ -15818,10 +15816,14 @@ struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cg
|
||||||
|
|
||||||
if (node->is_param) {
|
if (node->is_param) {
|
||||||
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
|
GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node);
|
||||||
ggml_build_forward_expand(&result, node->grad);
|
ggml_build_forward_expand(gb, node->grad);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep) {
|
||||||
|
struct ggml_cgraph result = *gf;
|
||||||
|
ggml_build_backward_expand(ctx, gf, &result, keep);
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
3
ggml.h
3
ggml.h
|
@ -1403,7 +1403,8 @@ extern "C" {
|
||||||
struct ggml_tensor * tensor);
|
struct ggml_tensor * tensor);
|
||||||
|
|
||||||
|
|
||||||
GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor);
|
||||||
|
GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool keep);
|
||||||
|
|
||||||
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
|
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
|
||||||
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
|
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue