diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 9244088dc..63f976f0d 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1337,6 +1337,82 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( return inpL; } +// expand the graph nodes without creating leafs. +struct ggml_tensor * expand(struct ggml_cgraph * g, struct ggml_tensor * t) { + // check if already visited + for (int i = 0; i < g->n_nodes; i++) { + if (g->nodes[i] == t) { + return t; + } + } + + for (int i = 0; i < g->n_leafs; i++) { + if (g->leafs[i] == t) { + return t; + } + } + + if (t->src0) { + expand(g, t->src0); + } + + if (t->src1) { + expand(g, t->src1); + } + + for (int i = 0; i < GGML_MAX_OPT; ++i) { + if (t->opt[i]) { + expand(g, t->opt[i]); + } + } + + GGML_ASSERT(g->n_nodes < GGML_MAX_NODES); + + if (strlen(t->name) == 0) { + snprintf(t->name, sizeof(t->name), "node_%d", g->n_nodes); + } + + g->nodes[g->n_nodes] = t; + g->grads[g->n_nodes] = t->grad; + g->n_nodes++; + return t; +} + +void graph_set_leafs_grads(struct ggml_cgraph * g) { + // moves leaf nodes to g->leafs. + // i.e. g->n_nodes might change. + int n_nodes = 0; + for (int i = 0; i < g->n_nodes; ++i) { + struct ggml_tensor * node = g->nodes[i]; + const bool is_leaf = node->op == GGML_OP_NONE && node->grad == NULL; + if (is_leaf) { + GGML_ASSERT(g->n_leafs < GGML_MAX_NODES); + + if (strlen(node->name) == 0) { + snprintf(node->name, sizeof(node->name), "leaf_%d", g->n_leafs); + } + + g->leafs[g->n_leafs] = node; + g->n_leafs++; + } else { + GGML_ASSERT(n_nodes < GGML_MAX_NODES); + + if (strlen(node->name) == 0) { + snprintf(node->name, sizeof(node->name), "node_%d", n_nodes); + } + + g->nodes[n_nodes] = node; + g->grads[n_nodes] = node->grad; + n_nodes++; + } + } + for (int i=n_nodes; i < g->n_nodes; ++i) { + g->nodes[n_nodes] = NULL; + g->grads[n_nodes] = NULL; + } + g->n_nodes = n_nodes; +} + struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( struct my_llama_model * model, struct ggml_context * ctx0, @@ -1375,11 +1451,6 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( const int n_ff = get_n_ff(&hparams); const int rope_mode = 0; - auto expand = [] (struct ggml_cgraph * g, struct ggml_tensor * t) -> struct ggml_tensor * { - ggml_build_forward_expand(g, t); - return t; - }; - int last_buf = -1; size_t buf_offs[2] = { 0, 0 }; size_t buf_size[2] = { size_buf_0, @@ -1423,6 +1494,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( } }; + auto view__q = [ctx0, n_embd, n_head, N, n_batch] (struct ggml_tensor * t) -> struct ggml_tensor * { int64_t ne0 = n_embd/n_head; int64_t ne1 = N; @@ -1472,28 +1544,21 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( use_buf(-1); - // need to create grads for model parameters, so that expand(..) correctly populates cgraph->leafs & cgraph->grads - // this wastes memory, because unnecessary grad for each op is automatically created: - // the automatically generated grad is unnecessary because we later manually set the grad (e.g. t35->grad = expand(gb, ...) ). - // this discards the automatically generated grad resulting in wasted memory. - // TODO: improve this, possibly by changing expand(..) to not use ggml_build_forward_expand. - // expand should correctly set cgraph->nodes. - // cgraph->leafs & cgraph->grads could be set in another pass after the last expand call. - model->tok_embeddings->grad = ggml_dup_tensor(ctx0, model->tok_embeddings->grad); - model->norm->grad = ggml_dup_tensor(ctx0, model->norm->grad); - model->output->grad = ggml_dup_tensor(ctx0, model->output->grad); + model->tok_embeddings->grad = NULL; + model->norm->grad = NULL; + model->output->grad = NULL; for (int il = 0; il < n_layer; ++il) { struct my_llama_layer & layer = model->layers[il]; - layer.attention_norm->grad = ggml_dup_tensor(ctx0, layer.attention_norm->grad); - layer.wq->grad = ggml_dup_tensor(ctx0, layer.wq->grad); - layer.wk->grad = ggml_dup_tensor(ctx0, layer.wk->grad); - layer.wv->grad = ggml_dup_tensor(ctx0, layer.wv->grad); - layer.wo->grad = ggml_dup_tensor(ctx0, layer.wo->grad); - layer.ffn_norm->grad = ggml_dup_tensor(ctx0, layer.ffn_norm->grad); - layer.w1->grad = ggml_dup_tensor(ctx0, layer.w1->grad); - layer.w2->grad = ggml_dup_tensor(ctx0, layer.w2->grad); - layer.w3->grad = ggml_dup_tensor(ctx0, layer.w3->grad); + layer.attention_norm->grad = NULL; + layer.wq->grad = NULL; + layer.wk->grad = NULL; + layer.wv->grad = NULL; + layer.wo->grad = NULL; + layer.ffn_norm->grad = NULL; + layer.w1->grad = NULL; + layer.w2->grad = NULL; + layer.w3->grad = NULL; } clr_buf(0); @@ -1717,10 +1782,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( *gb = *gf; // t36->grad gets set to one by optimizer, so we need the tensor. - GGML_ASSERT(t36->grad != NULL); // initialize it with 1.0f to make sure. - // use_buf(-1); - // t36->grad = expand(gb, ggml_new_f32(ctx0, 1.0f)); + use_buf(-1); + t36->grad = expand(gb, ggml_new_f32(ctx0, 1.0f)); use_buf(0); t35->grad = expand(gb, ggml_cross_entropy_loss_back(ctx0, t35, targets, t36->grad)); assert_shape_3d(t35->grad, n_vocab, N, n_batch); @@ -1839,7 +1903,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( use_buf(0); t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch); use_buf(-1); - model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab); + model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab); // clr_buf(1); // clr_buf(0); @@ -1850,6 +1914,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]); } + // now that all grads are created, set the graph leafs and grads + graph_set_leafs_grads(gf); + graph_set_leafs_grads(gb); + return t36; }