diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 15f60513f..a4b41e7fb 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -4391,6 +4391,12 @@ int main(int argc, char ** argv) { uint8_t * compute_buf_1 = new uint8_t[size_buf_1]; uint8_t * compute_buf_2 = new uint8_t[size_buf_2]; + ggml_allocr * alloc = NULL; + if (params.use_alloc) { + static const size_t tensor_alignment = 32; + alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment); + } + GGML_ASSERT(n_tokens < (int) train_tokens.size()); std::vector train_samples; train_samples.push_back(0); @@ -4437,33 +4443,48 @@ int main(int argc, char ** argv) { }; struct ggml_context * ctx0 = ggml_init(cparams); + ggml_set_no_alloc(ctx0, false); + + // don't use alloc for input tensors, so we can safely fill them with data struct ggml_tensor * after_opt_best_samples = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); //struct ggml_tensor * after_opt_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); struct ggml_tensor * target_logits = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); + ggml_set_no_alloc(ctx0, (alloc != NULL)); + + if (alloc) { + ggml_allocr_reset(alloc); + } + opt_cb_data.tokens_input = tokens_input; opt_cb_data.target_logits = target_logits; opt_cb_data.target_probs = target_probs; int n_past = 0; - struct ggml_tensor * gfbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); - struct ggml_tensor * gbbuf = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sizeof(struct ggml_cgraph) / ggml_type_size(GGML_TYPE_I32) + (sizeof(struct ggml_cgraph) % ggml_type_size(GGML_TYPE_I32) ? 1 : 0)); - - memset(gfbuf->data, 0, ggml_nbytes(gfbuf)); - memset(gbbuf->data, 0, ggml_nbytes(gbbuf)); - - struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data; - struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data; + struct ggml_cgraph * gf = ggml_new_graph(ctx0); + struct ggml_cgraph * gb = ggml_new_graph(ctx0); + struct ggml_cgraph * gb_tmp = (params.use_unified || params.use_alloc) + ? ggml_new_graph(ctx0) + : NULL; GGML_ASSERT(n_past == 0); struct ggml_tensor * loss = NULL; struct ggml_tensor * logits = NULL; - if (params.use_checkpointing) { + if (params.use_alloc || params.use_unified) { + loss = llama_build_train_graphs( + &model, alloc, ctx0, + gf, gb, gb_tmp, + &logits, tokens_input, target_probs, + n_tokens, n_batch, + params.use_flash, + params.use_checkpointing + ); + } else if (params.use_checkpointing) { loss = forward_batch_wo_cache_flash_attn_train_grad_checkpointing( &model, ctx0, gf, gb, @@ -4641,6 +4662,10 @@ int main(int argc, char ** argv) { } } + if (alloc) { + ggml_allocr_free(alloc); + } + delete[] compute_addr; delete[] compute_buf_0; delete[] compute_buf_1;