remove unnecessary scratch buffer 0
buf 0 is persistent memory, so we can just disable scratch for this by using buf -1
This commit is contained in:
parent
efd7314d27
commit
59544f0cdf
1 changed files with 64 additions and 77 deletions
|
@ -1347,10 +1347,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
struct ggml_tensor * targets,
|
||||
void * compute_buf_0,
|
||||
void * compute_buf_1,
|
||||
void * compute_buf_2,
|
||||
size_t size_buf_0,
|
||||
size_t size_buf_1,
|
||||
size_t size_buf_2,
|
||||
const int n_tokens,
|
||||
const int n_batch) {
|
||||
|
||||
|
@ -1383,13 +1381,11 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
};
|
||||
|
||||
int last_buf = -1;
|
||||
size_t buf_offs[3] = { 0, 0, 0 };
|
||||
size_t buf_size[3] = { size_buf_0,
|
||||
size_buf_1,
|
||||
size_buf_2 };
|
||||
void * buf_data[3] = { compute_buf_0,
|
||||
compute_buf_1,
|
||||
compute_buf_2 };
|
||||
size_t buf_offs[2] = { 0, 0 };
|
||||
size_t buf_size[2] = { size_buf_0,
|
||||
size_buf_1 };
|
||||
void * buf_data[2] = { compute_buf_0,
|
||||
compute_buf_1 };
|
||||
auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data] (int buf) {
|
||||
size_t last_offs = 0;
|
||||
last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, });
|
||||
|
@ -1406,7 +1402,7 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
};
|
||||
|
||||
bool track_max_mem = false;
|
||||
size_t buf_maxs[3] = { 0, 0, 0 };
|
||||
size_t buf_maxs[2] = { 0, 0 };
|
||||
|
||||
auto clr_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs, track_max_mem] (int buf) {
|
||||
if (buf < 0) return;
|
||||
|
@ -1500,15 +1496,15 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
layer.w3->grad = ggml_dup_tensor(ctx0, layer.w3->grad);
|
||||
}
|
||||
|
||||
clr_buf(0);
|
||||
clr_buf(1);
|
||||
clr_buf(2);
|
||||
|
||||
use_buf(-1);
|
||||
|
||||
struct ggml_tensor * t00 = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, N*n_batch); assert_shape_1d(t00, N*n_batch);
|
||||
memcpy(t00->data, tokens_input->data, ggml_element_size(t00)*N*n_batch);
|
||||
|
||||
use_buf(0);
|
||||
use_buf(-1);
|
||||
|
||||
struct ggml_tensor * t01 = expand(gf, ggml_get_rows(ctx0, model->tok_embeddings, t00)); assert_shape_2d(t01, n_embd, N*n_batch);
|
||||
|
||||
|
@ -1546,39 +1542,39 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
struct ggml_tensor * cur = t01;
|
||||
|
||||
for (int il = 0; il < n_layer; ++il) {
|
||||
clr_buf(1);
|
||||
clr_buf(0);
|
||||
struct my_llama_layer & layer = model->layers[il];
|
||||
// tensors with values necessary for backward pass are in persistent buf(0)
|
||||
// other tensors with buf(1) are only temporary needed, and their memory reused after layer is completed.
|
||||
use_buf(0); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch);
|
||||
use_buf(1); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t06 = expand(gf, ggml_reshape_4d (ctx0, t05, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(0); struct ggml_tensor * t07 = expand(gf, ggml_rope_inplace (ctx0, t06, n_past, n_rot, rope_mode)); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(0); struct ggml_tensor * t08 = expand(gf, ggml_mul_mat (ctx0, layer.wk, t04)); assert_shape_2d(t08, n_embd, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t09 = expand(gf, ggml_reshape_4d (ctx0, t08, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(0); struct ggml_tensor * t10 = expand(gf, ggml_rope_inplace (ctx0, t09, n_past, n_rot, rope_mode)); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(0); struct ggml_tensor * t11 = expand(gf, ggml_mul_mat (ctx0, t04, layer.wv)); assert_shape_2d(t11, N*n_batch, n_embd);
|
||||
use_buf(0); struct ggml_tensor * t12 = expand(gf, ggml_reshape_4d (ctx0, t11, N, n_batch, n_embd/n_head, n_head)); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
|
||||
use_buf(0); struct ggml_tensor * t13 = expand(gf, ggml_permute (ctx0, t07, 0, 2, 1, 3)); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch);
|
||||
use_buf(0); struct ggml_tensor * t14 = expand(gf, ggml_permute (ctx0, t10, 0, 2, 1, 3)); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch);
|
||||
use_buf(0); struct ggml_tensor * t15 = expand(gf, ggml_permute (ctx0, t12, 0, 3, 1, 2)); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
|
||||
use_buf(0); struct ggml_tensor * t16 = expand(gf, ggml_flash_attn (ctx0, t13, t14, t15, true)); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
|
||||
use_buf(1); struct ggml_tensor * t17 = expand(gf, ggml_permute (ctx0, t16, 0, 2, 1, 3)); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(0); struct ggml_tensor * t18 = expand(gf, ggml_cont (ctx0, t17)); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(0); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
|
||||
use_buf(1); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch);
|
||||
use_buf(1); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t26 = expand(gf, ggml_mul_mat (ctx0, layer.w1, t24)); assert_shape_2d(t26, n_ff, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t27 = expand(gf, ggml_silu (ctx0, t26)); assert_shape_2d(t27, n_ff, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t28 = expand(gf, ggml_mul (ctx0, t27, t25)); assert_shape_2d(t28, n_ff, N*n_batch);
|
||||
use_buf(1); struct ggml_tensor * t29 = expand(gf, ggml_mul_mat (ctx0, layer.w2, t28)); assert_shape_2d(t29, n_embd, N*n_batch);
|
||||
use_buf(0); struct ggml_tensor * t30 = expand(gf, ggml_add (ctx0, t21, t29)); assert_shape_2d(t30, n_embd, N*n_batch);
|
||||
// tensors with values necessary for backward pass are in persistent buf(-1)
|
||||
// other tensors with buf(0) and buf(1) are only temporary needed, and their memory reused after layer is completed.
|
||||
use_buf(-1); struct ggml_tensor * t02 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t02, n_embd, N*n_batch);
|
||||
use_buf( 0); struct ggml_tensor * t03 = expand(gf, ggml_repeat (ctx0, layer.attention_norm, t02)); assert_shape_2d(t03, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t04 = expand(gf, ggml_mul (ctx0, t02, t03)); assert_shape_2d(t04, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t05 = expand(gf, ggml_mul_mat (ctx0, layer.wq, t04)); assert_shape_2d(t05, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t06 = expand(gf, ggml_reshape_4d (ctx0, t05, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t06, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t07 = expand(gf, ggml_rope_inplace (ctx0, t06, n_past, n_rot, rope_mode)); assert_shape_4d(t07, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t08 = expand(gf, ggml_mul_mat (ctx0, layer.wk, t04)); assert_shape_2d(t08, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t09 = expand(gf, ggml_reshape_4d (ctx0, t08, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t09, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t10 = expand(gf, ggml_rope_inplace (ctx0, t09, n_past, n_rot, rope_mode)); assert_shape_4d(t10, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t11 = expand(gf, ggml_mul_mat (ctx0, t04, layer.wv)); assert_shape_2d(t11, N*n_batch, n_embd);
|
||||
use_buf(-1); struct ggml_tensor * t12 = expand(gf, ggml_reshape_4d (ctx0, t11, N, n_batch, n_embd/n_head, n_head)); assert_shape_4d(t12, N, n_batch, n_embd/n_head, n_head);
|
||||
use_buf(-1); struct ggml_tensor * t13 = expand(gf, ggml_permute (ctx0, t07, 0, 2, 1, 3)); assert_shape_4d(t13, n_embd/n_head, N, n_head, n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t14 = expand(gf, ggml_permute (ctx0, t10, 0, 2, 1, 3)); assert_shape_4d(t14, n_embd/n_head, N, n_head, n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t15 = expand(gf, ggml_permute (ctx0, t12, 0, 3, 1, 2)); assert_shape_4d(t15, N, n_embd/n_head, n_head, n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t16 = expand(gf, ggml_flash_attn (ctx0, t13, t14, t15, true)); assert_shape_4d(t16, n_embd/n_head, N, n_head, n_batch);
|
||||
use_buf( 0); struct ggml_tensor * t17 = expand(gf, ggml_permute (ctx0, t16, 0, 2, 1, 3)); assert_shape_4d(t17, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t18 = expand(gf, ggml_cont (ctx0, t17)); assert_shape_4d(t18, n_embd/n_head, n_head, N, n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t19 = expand(gf, ggml_reshape_2d (ctx0, t18, n_embd, N*n_batch)); assert_shape_2d(t19, n_embd, N*n_batch);
|
||||
use_buf( 0); struct ggml_tensor * t20 = expand(gf, ggml_mul_mat (ctx0, layer.wo, t19)); assert_shape_2d(t20, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t21 = expand(gf, ggml_add (ctx0, t20, cur)); assert_shape_2d(t21, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t22 = expand(gf, ggml_rms_norm (ctx0, t21)); assert_shape_2d(t22, n_embd, N*n_batch);
|
||||
use_buf( 0); struct ggml_tensor * t23 = expand(gf, ggml_repeat (ctx0, layer.ffn_norm, t22)); assert_shape_2d(t23, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t24 = expand(gf, ggml_mul (ctx0, t23, t22)); assert_shape_2d(t24, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t25 = expand(gf, ggml_mul_mat (ctx0, layer.w3, t24)); assert_shape_2d(t25, n_ff, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t26 = expand(gf, ggml_mul_mat (ctx0, layer.w1, t24)); assert_shape_2d(t26, n_ff, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t27 = expand(gf, ggml_silu (ctx0, t26)); assert_shape_2d(t27, n_ff, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t28 = expand(gf, ggml_mul (ctx0, t27, t25)); assert_shape_2d(t28, n_ff, N*n_batch);
|
||||
use_buf( 0); struct ggml_tensor * t29 = expand(gf, ggml_mul_mat (ctx0, layer.w2, t28)); assert_shape_2d(t29, n_embd, N*n_batch);
|
||||
use_buf(-1); struct ggml_tensor * t30 = expand(gf, ggml_add (ctx0, t21, t29)); assert_shape_2d(t30, n_embd, N*n_batch);
|
||||
t02L[il] = t02;
|
||||
t03L[il] = t03;
|
||||
t04L[il] = t04;
|
||||
|
@ -1611,8 +1607,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
|
||||
cur = t30;
|
||||
}
|
||||
clr_buf(1);
|
||||
use_buf(1);
|
||||
clr_buf(0);
|
||||
use_buf(0);
|
||||
struct ggml_tensor * t31 = expand(gf, ggml_rms_norm (ctx0, cur)); assert_shape_2d(t31, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t32 = expand(gf, ggml_repeat (ctx0, model->norm, t31)); assert_shape_2d(t32, n_embd, N*n_batch);
|
||||
struct ggml_tensor * t33 = expand(gf, ggml_mul (ctx0, t32, t31)); assert_shape_2d(t33, n_embd, N*n_batch);
|
||||
|
@ -1720,13 +1716,13 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
|
||||
*gb = *gf;
|
||||
|
||||
// t36->grad gets set to one by optimizer, so we need to create the tensor.
|
||||
// initialize it with 1.0f to make sure.
|
||||
// t36->grad gets set to one by optimizer, so we need the tensor.
|
||||
GGML_ASSERT(t36->grad != NULL);
|
||||
// initialize it with 1.0f to make sure.
|
||||
// use_buf(-1);
|
||||
// t36->grad = expand(gb, ggml_new_f32(ctx0, 1.0f));
|
||||
|
||||
use_buf(1);
|
||||
use_buf(0);
|
||||
t35->grad = expand(gb, ggml_cross_entropy_loss_back(ctx0, t35, targets, t36->grad)); assert_shape_3d(t35->grad, n_vocab, N, n_batch);
|
||||
t34->grad = expand(gb, ggml_reshape_2d (ctx0, t35->grad, n_vocab, N*n_batch)); assert_shape_2d(t34->grad, n_vocab, N*n_batch);
|
||||
t33->grad = expand(gb, ggml_out_prod (ctx0, model->output, ggml_transpose(ctx0, t34->grad))); assert_shape_2d(t33->grad, n_embd, N*n_batch);
|
||||
|
@ -1737,8 +1733,8 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
model->norm->grad = expand(gb, add_or_set(model->norm->grad, ggml_repeat_back(ctx0, t32->grad, model->norm))); assert_shape_1d(model->norm->grad, n_embd);
|
||||
model->output->grad = expand(gb, add_or_set(model->output->grad, ggml_out_prod(ctx0, t33, t34->grad))); assert_shape_2d(model->output->grad, n_embd, n_vocab);
|
||||
|
||||
clr_buf(2);
|
||||
use_buf(2);
|
||||
clr_buf(1);
|
||||
use_buf(1);
|
||||
t31->grad = expand(gb, ggml_mul(ctx0, t33->grad, t32)); assert_shape_2d(t31->grad, n_embd, N*n_batch);
|
||||
|
||||
struct ggml_tensor * back_layer_inp = t31;
|
||||
|
@ -1778,13 +1774,13 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
struct ggml_tensor * t29 = t29L[il];
|
||||
struct ggml_tensor * t30 = t30L[il];
|
||||
|
||||
clr_buf(1);
|
||||
use_buf(1);
|
||||
clr_buf(0);
|
||||
use_buf(0);
|
||||
t30->grad = expand(gb, ggml_rms_norm_back(ctx0, t30, back_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
if (grad_layer_inp) {
|
||||
t30->grad = expand(gb, ggml_add(ctx0, t30->grad, grad_layer_inp->grad)); assert_shape_2d(t30->grad, n_embd, N*n_batch);
|
||||
}
|
||||
clr_buf(2);
|
||||
clr_buf(1);
|
||||
t29->grad = t30->grad; assert_shape_2d(t29->grad, n_embd, N*n_batch);
|
||||
t28->grad = expand(gb, ggml_out_prod(ctx0, layer.w2, ggml_transpose(ctx0, t29->grad))); assert_shape_2d(t28->grad, n_ff, N*n_batch);
|
||||
t27->grad = expand(gb, ggml_mul(ctx0, t28->grad, t25)); assert_shape_2d(t27->grad, n_ff, N*n_batch);
|
||||
|
@ -1795,10 +1791,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
ggml_out_prod(ctx0, layer.w3, ggml_transpose(ctx0, t25->grad)))); assert_shape_2d(t24->grad, n_embd, N*n_batch);
|
||||
t23->grad = expand(gb, ggml_mul(ctx0, t24->grad, t22)); assert_shape_2d(t23->grad, n_embd, N*n_batch);
|
||||
t22->grad = expand(gb, ggml_mul(ctx0, t24->grad, ggml_repeat(ctx0, layer.ffn_norm, t24->grad))); assert_shape_2d(t22->grad, n_embd, N*n_batch);
|
||||
use_buf(2);
|
||||
use_buf(1);
|
||||
t21->grad = expand(gb, ggml_add(ctx0, t30->grad, ggml_rms_norm_back(ctx0, t21, t22->grad))); assert_shape_2d(t21->grad, n_embd, N*n_batch);
|
||||
grad_layer_inp = t21;
|
||||
use_buf(1);
|
||||
use_buf(0);
|
||||
t20->grad = t21->grad; assert_shape_2d(t20->grad, n_embd, N*n_batch);
|
||||
t19->grad = expand(gb, ggml_out_prod(ctx0, layer.wo, ggml_transpose(ctx0, t20->grad))); assert_shape_2d(t19->grad, n_embd, N*n_batch);
|
||||
t18->grad = expand(gb, ggml_reshape_4d(ctx0, t19->grad, n_embd/n_head, n_head, N, n_batch)); assert_shape_4d(t18->grad, n_embd/n_head, n_head, N, n_batch);
|
||||
|
@ -1822,9 +1818,10 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
ggml_out_prod(ctx0, layer.wk, ggml_transpose(ctx0, t08->grad))),
|
||||
ggml_out_prod(ctx0, layer.wq, ggml_transpose(ctx0, t05->grad)))); assert_shape_2d(t04->grad, n_embd, N*n_batch);
|
||||
t03->grad = expand(gb, ggml_mul(ctx0, t04->grad, t02)); assert_shape_2d(t04->grad, n_embd, N*n_batch);
|
||||
use_buf(2);
|
||||
use_buf(1);
|
||||
t02->grad = expand(gb, ggml_mul(ctx0, t04->grad, ggml_repeat(ctx0, layer.attention_norm, t02))); assert_shape_2d(t02->grad, n_embd, N*n_batch);
|
||||
back_layer_inp = t02;
|
||||
// use_buf(0);
|
||||
|
||||
use_buf(-1);
|
||||
layer.attention_norm->grad = expand(gb, add_or_set(layer.attention_norm->grad, ggml_repeat_back(ctx0, t03->grad, layer.attention_norm))); assert_shape_1d(layer.attention_norm->grad, n_embd);
|
||||
|
@ -1836,19 +1833,21 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
|||
layer.w1->grad = expand(gb, add_or_set(layer.w1->grad, ggml_out_prod(ctx0, t24, t26->grad))); assert_shape_2d(layer.w1->grad, n_embd, n_ff);
|
||||
layer.w2->grad = expand(gb, add_or_set(layer.w2->grad, ggml_out_prod(ctx0, t28, t29->grad))); assert_shape_2d(layer.w2->grad, n_ff, n_embd);
|
||||
layer.w3->grad = expand(gb, add_or_set(layer.w3->grad, ggml_out_prod(ctx0, t24, t25->grad))); assert_shape_2d(layer.w3->grad, n_embd, n_ff);
|
||||
// use_buf(0);
|
||||
}
|
||||
clr_buf(1);
|
||||
use_buf(1);
|
||||
clr_buf(0);
|
||||
use_buf(0);
|
||||
t01->grad = expand(gb, ggml_add_inplace(ctx0, grad_layer_inp->grad, ggml_rms_norm_back(ctx0, t01, back_layer_inp->grad))); assert_shape_2d(t01->grad, n_embd, N*n_batch);
|
||||
use_buf(-1);
|
||||
model->tok_embeddings->grad = expand(gb, ggml_get_rows_back(ctx0, t01->grad, t00, model->tok_embeddings)); assert_shape_2d(model->tok_embeddings->grad, n_embd, n_vocab);
|
||||
// clr_buf(1);
|
||||
// clr_buf(0);
|
||||
|
||||
*logits = t35;
|
||||
|
||||
if (track_max_mem) {
|
||||
printf("%s: max size compute buf0: %zu\n", __func__, buf_maxs[0]);
|
||||
printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]);
|
||||
printf("%s: max size compute buf2: %zu\n", __func__, buf_maxs[2]);
|
||||
}
|
||||
|
||||
return t36;
|
||||
|
@ -2649,7 +2648,6 @@ struct train_params {
|
|||
int mem_compute_gb;
|
||||
int mem_compute0_gb;
|
||||
int mem_compute1_gb;
|
||||
int mem_compute2_gb;
|
||||
};
|
||||
|
||||
struct train_params get_default_train_params() {
|
||||
|
@ -2694,10 +2692,9 @@ struct train_params get_default_train_params() {
|
|||
params.adam_decay = 1e-3;
|
||||
|
||||
params.mem_model_gb = 2;
|
||||
params.mem_compute_gb = 8;
|
||||
params.mem_compute0_gb = 24;
|
||||
params.mem_compute1_gb = 8;
|
||||
params.mem_compute2_gb = 8;
|
||||
params.mem_compute_gb = 24;
|
||||
params.mem_compute0_gb = 8;
|
||||
params.mem_compute1_gb = 2;
|
||||
|
||||
return params;
|
||||
}
|
||||
|
@ -2744,7 +2741,6 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p
|
|||
fprintf(stderr, " --mem-compute N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute_gb);
|
||||
fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb);
|
||||
fprintf(stderr, " --mem-compute1 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute1_gb);
|
||||
fprintf(stderr, " --mem-compute2 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute2_gb);
|
||||
fprintf(stderr, "\n");
|
||||
}
|
||||
|
||||
|
@ -2954,12 +2950,6 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) {
|
|||
break;
|
||||
}
|
||||
params->mem_compute1_gb = std::stoi(argv[i]);
|
||||
} else if (arg == "--mem-compute2") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params->mem_compute2_gb = std::stoi(argv[i]);
|
||||
} else if (arg == "-h" || arg == "--help") {
|
||||
train_print_usage(argc, argv, &default_params);
|
||||
exit(0);
|
||||
|
@ -3117,10 +3107,8 @@ int main(int argc, char ** argv) {
|
|||
|
||||
size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb);
|
||||
size_t size_buf_1 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute1_gb);
|
||||
size_t size_buf_2 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute2_gb);
|
||||
uint8_t * compute_buf_0 = new uint8_t[size_buf_0];
|
||||
uint8_t * compute_buf_1 = new uint8_t[size_buf_1];
|
||||
uint8_t * compute_buf_2 = new uint8_t[size_buf_2];
|
||||
|
||||
GGML_ASSERT(n_tokens < (int) train_tokens.size());
|
||||
std::vector<int> train_samples;
|
||||
|
@ -3182,8 +3170,8 @@ int main(int argc, char ** argv) {
|
|||
&model, ctx0,
|
||||
gf, gb,
|
||||
&logits, tokens_input, target_probs,
|
||||
compute_buf_0, compute_buf_1, compute_buf_2,
|
||||
size_buf_0, size_buf_1, size_buf_2,
|
||||
compute_buf_0, compute_buf_1,
|
||||
size_buf_0, size_buf_1,
|
||||
n_tokens, n_batch);
|
||||
} else if (params.use_flash) {
|
||||
logits = forward_batch_wo_cache_flash_attn(&model, ctx0, gf, tokens_input, n_tokens, n_batch);
|
||||
|
@ -3335,7 +3323,6 @@ int main(int argc, char ** argv) {
|
|||
delete[] compute_addr;
|
||||
delete[] compute_buf_0;
|
||||
delete[] compute_buf_1;
|
||||
delete[] compute_buf_2;
|
||||
ggml_free(model.ctx);
|
||||
|
||||
return 0;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue