diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index faa60ec8b..075e0307f 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1932,11 +1932,9 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing( void * compute_buf_0, void * compute_buf_1, void * compute_buf_2, - void * compute_buf_3, size_t size_buf_0, size_t size_buf_1, size_t size_buf_2, - size_t size_buf_3, const int n_tokens, const int n_batch) { @@ -1966,16 +1964,14 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing( bool track_max_mem = true; int last_buf = -1; - size_t buf_offs[4] = { 0, 0, 0, 0 }; - size_t buf_size[4] = { size_buf_0, + size_t buf_offs[3] = { 0, 0, 0 }; + size_t buf_size[3] = { size_buf_0, size_buf_1, - size_buf_2, - size_buf_3 }; - void * buf_data[4] = { compute_buf_0, + size_buf_2 }; + void * buf_data[3] = { compute_buf_0, compute_buf_1, - compute_buf_2, - compute_buf_3 }; - size_t buf_maxs[4] = { 0, 0, 0, 0 }; + compute_buf_2 }; + size_t buf_maxs[3] = { 0, 0, 0 }; auto use_buf = [ctx0, &last_buf, &buf_offs, &buf_size, &buf_data, &buf_maxs] (int buf) { size_t last_offs = 0; @@ -2083,7 +2079,6 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing( clr_buf(0); clr_buf(1); clr_buf(2); - clr_buf(3); use_buf(-1); @@ -2112,22 +2107,22 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing( // example for 16 layers: // inp ~ implicit zeroth checkpoint == input - // L00 f 4b - // L01 f 4b + // L00 f 4b [ + // L01 f 4b 4th second forward pass // L02 f 4b - // L03 fc4b first checkpoint - // L04 f 3b - // L05 f 3b + // L03 fc4b ] first checkpoint + // L04 f 3b [ + // L05 f 3b 3rd second forward pass // L06 f 3b - // L07 fc3b second checkpoint - // L08 f 2b - // L09 f 2b + // L07 fc3b ] second checkpoint + // L08 f 2b [ + // L09 f 2b 2nd second forward pass // L10 f 2b - // L11 fc2b third checkpoint - // L12 f 1b - // L13 f 1b + // L11 fc2b ] third checkpoint + // L12 f 1b [ + // L13 f 1b 1st second forward pass // L14 f 1b - // L15 f 1b + // L15 f 1b ] // need to remember these for the backward pass std::vector t02L; t02L.resize(n_layer, NULL); @@ -2162,7 +2157,6 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing( struct ggml_tensor * cur = t01; - int chk_idx = 0; for (int il = 0; il < n_layer; ++il) { struct my_llama_layer & layer = model->layers[il]; @@ -2455,13 +2449,11 @@ struct ggml_tensor * forward_batch_wo_cache_flash_attn_train_grad_checkpointing( clr_buf(0); clr_buf(1); clr_buf(2); - clr_buf(3); if (track_max_mem) { printf("%s: max size compute buf0: %zu\n", __func__, buf_maxs[0]); printf("%s: max size compute buf1: %zu\n", __func__, buf_maxs[1]); printf("%s: max size compute buf2: %zu\n", __func__, buf_maxs[2]); - printf("%s: max size compute buf3: %zu\n", __func__, buf_maxs[3]); } // now that all grads are created, set the graph leafs and grads @@ -3434,7 +3426,6 @@ struct train_params get_default_train_params() { params.mem_compute0_gb = 8; params.mem_compute1_gb = 1; params.mem_compute2_gb = 2; - params.mem_compute3_gb = 1; return params; } @@ -3486,7 +3477,6 @@ void train_print_usage(int /*argc*/, char ** argv, const struct train_params * p fprintf(stderr, " --mem-compute0 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute0_gb); fprintf(stderr, " --mem-compute1 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute1_gb); fprintf(stderr, " --mem-compute2 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute2_gb); - fprintf(stderr, " --mem-compute3 N Memory to allocate for compute in gigabytes. (default %d)\n", params->mem_compute3_gb); fprintf(stderr, "\n"); } @@ -3724,12 +3714,6 @@ bool train_params_parse(int argc, char ** argv, struct train_params * params) { break; } params->mem_compute2_gb = std::stoi(argv[i]); - } else if (arg == "--mem-compute3") { - if (++i >= argc) { - invalid_param = true; - break; - } - params->mem_compute3_gb = std::stoi(argv[i]); } else if (arg == "-h" || arg == "--help") { train_print_usage(argc, argv, &default_params); exit(0); @@ -3892,11 +3876,9 @@ int main(int argc, char ** argv) { size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb); size_t size_buf_1 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute1_gb); size_t size_buf_2 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute2_gb); - size_t size_buf_3 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute3_gb); uint8_t * compute_buf_0 = new uint8_t[size_buf_0]; uint8_t * compute_buf_1 = new uint8_t[size_buf_1]; uint8_t * compute_buf_2 = new uint8_t[size_buf_2]; - uint8_t * compute_buf_3 = new uint8_t[size_buf_3]; GGML_ASSERT(n_tokens < (int) train_tokens.size()); std::vector train_samples; @@ -3924,9 +3906,9 @@ int main(int argc, char ** argv) { } struct ggml_init_params cparams = { - /*.mem_size =*/ compute_size, - /*.mem_buffer =*/ compute_addr, - /*.no_alloc =*/ false, + compute_size, // mem_size + compute_addr, // mem_buffer + false, // no_alloc }; struct ggml_context * ctx0 = ggml_init(cparams); @@ -3960,8 +3942,8 @@ int main(int argc, char ** argv) { &model, ctx0, gf, gb, &logits, tokens_input, target_probs, - compute_buf_0, compute_buf_1, compute_buf_2, compute_buf_3, - size_buf_0, size_buf_1, size_buf_2, size_buf_3, + compute_buf_0, compute_buf_1, compute_buf_2, + size_buf_0, size_buf_1, size_buf_2, n_tokens, n_batch); } else if (params.use_scratch) { loss = forward_batch_wo_cache_flash_attn_train( @@ -4082,9 +4064,9 @@ int main(int argc, char ** argv) { printf("---\n"); for (int i=0; i