From 5d124d0cb4ebf834aa136aade847092777078c35 Mon Sep 17 00:00:00 2001 From: xaedes Date: Thu, 15 Jun 2023 20:34:56 +0200 Subject: [PATCH] fix track_max_mem in forward_batch_wo_cache_flash_attn_train --- .../train-text-from-scratch.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 54dc2beed..828a2a9b7 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1445,17 +1445,22 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( const int n_ff = get_n_ff(&hparams); const int rope_mode = 0; + bool track_max_mem = true; + int last_buf = -1; size_t buf_offs[2] = { 0, 0 }; size_t buf_size[2] = { size_buf_0, size_buf_1 }; void * buf_data[2] = { compute_buf_0, compute_buf_1 }; - auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data] (int buf) { + size_t buf_maxs[2] = { 0, 0 }; + + auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs] (int buf) { size_t last_offs = 0; last_offs = ggml_set_scratch(ctx0, { 0, 0, nullptr, }); if (last_buf >= 0) { buf_offs[last_buf] = last_offs; + buf_maxs[last_buf] = std::max(buf_maxs[last_buf], buf_offs[last_buf]); } if (buf >= 0) { size_t offs = buf_offs[buf]; @@ -1466,8 +1471,6 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( last_buf = buf; }; - bool track_max_mem = false; - size_t buf_maxs[2] = { 0, 0 }; auto clr_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs, track_max_mem] (int buf) { if (buf < 0) return; @@ -1903,6 +1906,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train( *logits = t35; + clr_buf(0); + clr_buf(1); + if (track_max_mem) { printf("%s: max size compute buf0: %zu\n", __func__, buf_maxs[0]); printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]);