From ce937bc431f7ac88f5e5b0bab2475bcc673369ca Mon Sep 17 00:00:00 2001 From: xaedes Date: Sun, 2 Jul 2023 21:36:56 +0200 Subject: [PATCH] replace memcpy with reshape operation so that the graph is not cut at the input this makes it possible to store other values into the input tensor and then simply recompute the graph without rebuilding it --- .../train-text-from-scratch.cpp | 21 ++++++++++--------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index f6e146b80..db7a52842 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -965,8 +965,8 @@ struct ggml_tensor * forward_batch_wo_cache( const int n_rot = hparams.n_rot; const int n_ff = get_n_ff(&hparams); - struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); - memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); + GGML_ASSERT(tokens_input->type == GGML_TYPE_I32); + struct ggml_tensor * tokens = ggml_reshape_1d(ctx0, tokens_input, N*n_batch); // inpL shape [n_embd,N*n_batch,1] struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); @@ -1168,7 +1168,7 @@ struct ggml_tensor * forward_batch_wo_cache( } // run the computation - ggml_build_forward_expand(gf, inpL); + // ggml_build_forward_expand(gf, inpL); return inpL; } @@ -1193,8 +1193,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( const int n_rot = hparams.n_rot; const int n_ff = get_n_ff(&hparams); - struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); - memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); + + GGML_ASSERT(tokens_input->type == GGML_TYPE_I32); + struct ggml_tensor * tokens = ggml_reshape_1d(ctx0, tokens_input, N*n_batch); struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); assert_shape_2d(inpL, n_embd, N*n_batch); @@ -1336,7 +1337,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( } // run the computation - ggml_build_forward_expand(gf, inpL); + // ggml_build_forward_expand(gf, inpL); return inpL; } @@ -1563,8 +1564,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( use_buf(-1); - struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch); - memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch); + GGML_ASSERT(tokens_input->type == GGML_TYPE_I32); + struct ggml_tensor * t00 = ggml_reshape_1d(ctx0, tokens_input, N*n_batch); assert_shape_1d(t00, N*n_batch); use_buf(-1); @@ -2082,8 +2083,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing( use_buf(-1); - struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch); - memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch); + GGML_ASSERT(tokens_input->type == GGML_TYPE_I32); + struct ggml_tensor * t00 = ggml_reshape_1d(ctx0, tokens_input, N*n_batch); assert_shape_1d(t00, N*n_batch); use_buf(-1);