replace memcpy with reshape operation so that the graph is not cut at the input
this makes it possible to store other values into the input tensor and then simply recompute the graph without rebuilding it
This commit is contained in:
parent
c6a18e15c1
commit
ce937bc431
1 changed files with 11 additions and 10 deletions
|
@ -965,8 +965,8 @@ struct ggml_tensor * forward_batch_wo_cache(
|
|||
const int n_rot = hparams.n_rot;
|
||||
const int n_ff = get_n_ff(&hparams);
|
||||
|
||||
struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch);
|
||||
memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch);
|
||||
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
|
||||
struct ggml_tensor * tokens = ggml_reshape_1d(ctx0, tokens_input, N*n_batch);
|
||||
|
||||
// inpL shape [n_embd,N*n_batch,1]
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
|
||||
|
@ -1168,7 +1168,7 @@ struct ggml_tensor * forward_batch_wo_cache(
|
|||
}
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(gf, inpL);
|
||||
// ggml_build_forward_expand(gf, inpL);
|
||||
|
||||
return inpL;
|
||||
}
|
||||
|
@ -1193,8 +1193,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
|
|||
const int n_rot = hparams.n_rot;
|
||||
const int n_ff = get_n_ff(&hparams);
|
||||
|
||||
struct ggml_tensor * tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch);
|
||||
memcpy(tokens->data, tokens_input->data, ggml_element_size(tokens)*N*n_batch);
|
||||
|
||||
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
|
||||
struct ggml_tensor * tokens = ggml_reshape_1d(ctx0, tokens_input, N*n_batch);
|
||||
|
||||
struct ggml_tensor * inpL = ggml_get_rows(ctx0, model->tok_embeddings, tokens);
|
||||
assert_shape_2d(inpL, n_embd, N*n_batch);
|
||||
|
@ -1336,7 +1337,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn(
|
|||
}
|
||||
|
||||
// run the computation
|
||||
ggml_build_forward_expand(gf, inpL);
|
||||
// ggml_build_forward_expand(gf, inpL);
|
||||
|
||||
return inpL;
|
||||
}
|
||||
|
@ -1563,8 +1564,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
|
||||
use_buf(-1);
|
||||
|
||||
struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch);
|
||||
memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch);
|
||||
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
|
||||
struct ggml_tensor * t00 = ggml_reshape_1d(ctx0, tokens_input, N*n_batch); assert_shape_1d(t00, N*n_batch);
|
||||
|
||||
use_buf(-1);
|
||||
|
||||
|
@ -2082,8 +2083,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing(
|
|||
|
||||
use_buf(-1);
|
||||
|
||||
struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch);
|
||||
memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch);
|
||||
GGML_ASSERT(tokens_input->type == GGML_TYPE_I32);
|
||||
struct ggml_tensor * t00 = ggml_reshape_1d(ctx0, tokens_input, N*n_batch); assert_shape_1d(t00, N*n_batch);
|
||||
|
||||
use_buf(-1);
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue