fix track_max_mem in forward_batch_wo_cache_flash_attn_train

This commit is contained in:
xaedes 2023-06-15 20:34:56 +02:00
parent 8a88e5855c
commit 5d124d0cb4
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1445,17 +1445,22 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
const int n_ff = get_n_ff(&hparams); const int n_ff = get_n_ff(&hparams);
const int rope_mode = 0; const int rope_mode = 0;
bool track_max_mem = true;
int last_buf = -1; int last_buf = -1;
size_t buf_offs[2] = { 0, 0 }; size_t buf_offs[2] = { 0, 0 };
size_t buf_size[2] = { size_buf_0, size_t buf_size[2] = { size_buf_0,
size_buf_1 }; size_buf_1 };
void * buf_data[2] = { compute_buf_0, void * buf_data[2] = { compute_buf_0,
compute_buf_1 }; compute_buf_1 };
auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data] (int buf) { size_t buf_maxs[2] = { 0, 0 };
auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs] (int buf) {
size_t last_offs = 0; size_t last_offs = 0;
last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, }); last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, });
if (last_buf >= 0) { if (last_buf >= 0) {
buf_offs[last_buf] = last_offs; buf_offs[last_buf] = last_offs;
buf_maxs[last_buf] = std::max(buf_maxs[last_buf], buf_offs[last_buf]);
} }
if (buf >= 0) { if (buf >= 0) {
size_t offs = buf_offs[buf]; size_t offs = buf_offs[buf];
@ -1466,8 +1471,6 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
last_buf = buf; last_buf = buf;
}; };
bool track_max_mem = false;
size_t buf_maxs[2] = { 0, 0 };
auto clr_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs, track_max_mem] (int buf) { auto clr_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs, track_max_mem] (int buf) {
if (buf < 0) return; if (buf < 0) return;
@ -1903,6 +1906,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
*logits = t35; *logits = t35;
clr_buf(0);
clr_buf(1);
if (track_max_mem) { if (track_max_mem) {
printf("%s: max size compute buf0: %zu\n", __func__, buf_maxs[0]); printf("%s: max size compute buf0: %zu\n", __func__, buf_maxs[0]);
printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]); printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]);