fix track_max_mem in forward_batch_wo_cache_flash_attn_train
This commit is contained in:
parent
8a88e5855c
commit
5d124d0cb4
1 changed files with 9 additions and 3 deletions
|
@ -1445,17 +1445,22 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
const int n_ff = get_n_ff(&hparams);
|
const int n_ff = get_n_ff(&hparams);
|
||||||
const int rope_mode = 0;
|
const int rope_mode = 0;
|
||||||
|
|
||||||
|
bool track_max_mem = true;
|
||||||
|
|
||||||
int last_buf = -1;
|
int last_buf = -1;
|
||||||
size_t buf_offs[2] = { 0, 0 };
|
size_t buf_offs[2] = { 0, 0 };
|
||||||
size_t buf_size[2] = { size_buf_0,
|
size_t buf_size[2] = { size_buf_0,
|
||||||
size_buf_1 };
|
size_buf_1 };
|
||||||
void * buf_data[2] = { compute_buf_0,
|
void * buf_data[2] = { compute_buf_0,
|
||||||
compute_buf_1 };
|
compute_buf_1 };
|
||||||
auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data] (int buf) {
|
size_t buf_maxs[2] = { 0, 0 };
|
||||||
|
|
||||||
|
auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs] (int buf) {
|
||||||
size_t last_offs = 0;
|
size_t last_offs = 0;
|
||||||
last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, });
|
last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, });
|
||||||
if (last_buf >= 0) {
|
if (last_buf >= 0) {
|
||||||
buf_offs[last_buf] = last_offs;
|
buf_offs[last_buf] = last_offs;
|
||||||
|
buf_maxs[last_buf] = std::max(buf_maxs[last_buf], buf_offs[last_buf]);
|
||||||
}
|
}
|
||||||
if (buf >= 0) {
|
if (buf >= 0) {
|
||||||
size_t offs = buf_offs[buf];
|
size_t offs = buf_offs[buf];
|
||||||
|
@ -1466,8 +1471,6 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
last_buf = buf;
|
last_buf = buf;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool track_max_mem = false;
|
|
||||||
size_t buf_maxs[2] = { 0, 0 };
|
|
||||||
|
|
||||||
auto clr_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs, track_max_mem] (int buf) {
|
auto clr_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs, track_max_mem] (int buf) {
|
||||||
if (buf < 0) return;
|
if (buf < 0) return;
|
||||||
|
@ -1903,6 +1906,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train(
|
||||||
|
|
||||||
*logits = t35;
|
*logits = t35;
|
||||||
|
|
||||||
|
clr_buf(0);
|
||||||
|
clr_buf(1);
|
||||||
|
|
||||||
if (track_max_mem) {
|
if (track_max_mem) {
|
||||||
printf("%s: max size compute buf0: %zu\n", __func__, buf_maxs[0]);
|
printf("%s: max size compute buf0: %zu\n", __func__, buf_maxs[0]);
|
||||||
printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]);
|
printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue