replace memcpy with reshape operation so that the graph is not cut at the input

this makes it possible to store other values into the input tensor and then simply recompute the graph without rebuilding it
This commit is contained in:
xaedes 2023-07-02 21:36:56 +02:00
parent c6a18e15c1
commit ce937bc431
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -965,8 +965,8 @@ struct ggml_tensor * forward_batch_wo_cache(
const int n_rot = hparams.n_rot;
const int n_ff = get_n_ff(&hparams);
struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch);
memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch);
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
struct ggml_tensor * tokens = ggml_reshape_1d(ctx0, tokens_input, N*n_batch);
// inpL shape [n_embd,N*n_batch,1]
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
@ -1168,7 +1168,7 @@ struct ggml_tensor * forward_batch_wo_cache(
}
// run the computation
ggml_build_forward_expand(gf, inpL);
// ggml_build_forward_expand(gf, inpL);
return inpL;
}
@ -1193,8 +1193,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
const int n_rot = hparams.n_rot;
const int n_ff = get_n_ff(&hparams);
struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch);
memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch);
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
struct ggml_tensor * tokens = ggml_reshape_1d(ctx0, tokens_input, N*n_batch);
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
assert_shape_2d(inpL, n_embd, N*n_batch);
@ -1336,7 +1337,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
}
// run the computation
ggml_build_forward_expand(gf, inpL);
// ggml_build_forward_expand(gf, inpL);
return inpL;
}
@ -1563,8 +1564,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
use_buf(-1);
struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch);
memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch);
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
struct ggml_tensor * t00 = ggml_reshape_1d(ctx0, tokens_input, N*n_batch); assert_shape_1d(t00, N*n_batch);
use_buf(-1);
@ -2082,8 +2083,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
use_buf(-1);
struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch);
memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch);
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
struct ggml_tensor * t00 = ggml_reshape_1d(ctx0, tokens_input, N*n_batch); assert_shape_1d(t00, N*n_batch);
use_buf(-1);