diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index f6e146b80..db7a52842 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -965,8 +965,8 @@ struct ggml_tensor * forward_batch_wo_cache( const int n_rot = hparams.n_rot; const int n_ff = get_n_ff(&hparams); - struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); - memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); + GGML_ASSERT(tokens_input->type == GGML_TYPE_I32); + struct ggml_tensor * tokens = ggml_reshape_1d(ctx0, tokens_input, N*n_batch); // inpL shape [n_embd,N*n_batch,1] struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); @@ -1168,7 +1168,7 @@ struct ggml_tensor * forward_batch_wo_cache( } // run the computation - ggml_build_forward_expand(gf, inpL); + // ggml_build_forward_expand(gf, inpL); return inpL; } @@ -1193,8 +1193,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( const int n_rot = hparams.n_rot; const int n_ff = get_n_ff(&hparams); - struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); - memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch); + + GGML_ASSERT(tokens_input->type == GGML_TYPE_I32); + struct ggml_tensor * tokens = ggml_reshape_1d(ctx0, tokens_input, N*n_batch); struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens); assert_shape_2d(inpL, n_embd, N*n_batch); @@ -1336,7 +1337,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn( } // run the computation - ggml_build_forward_expand(gf, inpL); + // ggml_build_forward_expand(gf, inpL); return inpL; } @@ -1563,8 +1564,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( use_buf(-1); - struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch); - memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch); + GGML_ASSERT(tokens_input->type == GGML_TYPE_I32); + struct ggml_tensor * t00 = ggml_reshape_1d(ctx0, tokens_input, N*n_batch); assert_shape_1d(t00, N*n_batch); use_buf(-1); @@ -2082,8 +2083,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing( use_buf(-1); - struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch); - memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch); + GGML_ASSERT(tokens_input->type == GGML_TYPE_I32); + struct ggml_tensor * t00 = ggml_reshape_1d(ctx0, tokens_input, N*n_batch); assert_shape_1d(t00, N*n_batch); use_buf(-1);