diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 069e460c1..80cf2e255 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -1088,15 +1088,46 @@ int main(int argc, char ** argv) { size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb); uint8_t * compute_buf_0 = new uint8_t[size_buf_0]; - ggml_allocr * alloc = NULL; - if (params.use_alloc) { - alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment); - } - int n_tokens = model.hparams.n_ctx; int n_vocab = model.hparams.n_vocab; int n_batch = params.common.n_batch; + std::vector mem_input_data; + std::vector mem_compute_data; + + ggml_allocr * alloc = NULL; + + // context for input tensors without their data + struct ggml_init_params ctx_input_params = { + ggml_tensor_overhead() * 2, // mem_size + NULL, // mem_buffer + true, // no_alloc + }; + struct ggml_context * ctx_input = ggml_init(ctx_input_params); + + // the input tensors + struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx_input, GGML_TYPE_I32, n_tokens, n_batch); + struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx_input, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); + + // measure required memory for input tensors + alloc = ggml_allocr_new_measure(tensor_alignment); + ggml_allocr_alloc(alloc, tokens_input); + ggml_allocr_alloc(alloc, target_probs); + size_t max_input_size = ggml_allocr_max_size(alloc) + tensor_alignment; + ggml_allocr_free(alloc); + printf("%s: input_size = %zu bytes (%.1f MB)\n", __func__, max_input_size, (float) max_input_size / (1024.0f*1024.0f)); + + // allocate input tensors + mem_input_data.resize(max_input_size); + alloc = ggml_allocr_new(mem_input_data.data(), mem_input_data.size(), tensor_alignment); + ggml_allocr_alloc(alloc, tokens_input); + ggml_allocr_alloc(alloc, target_probs); + ggml_allocr_free(alloc); + + if (params.use_alloc) { + alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment); + } + std::vector train_tokens; std::vector train_samples_begin; std::vector train_samples_size; @@ -1167,8 +1198,8 @@ int main(int argc, char ** argv) { opt_cb_data.shuffled_samples_begin = train_shuffled_samples_begin.data(); opt_cb_data.shuffled_samples_size = train_shuffled_samples_size.data(); opt_cb_data.samples_count = train_samples_size.size(); - opt_cb_data.tokens_input = NULL; - opt_cb_data.target_probs = NULL; + opt_cb_data.tokens_input = tokens_input; + opt_cb_data.target_probs = target_probs; opt_cb_data.first_iter = opt->iter; opt_cb_data.last_time = ggml_time_ms(); opt_cb_data.millis_per_iter = 0.0; @@ -1184,21 +1215,12 @@ int main(int argc, char ** argv) { }; struct ggml_context * ctx0 = ggml_init(cparams); - ggml_set_no_alloc(ctx0, false); - - // don't use alloc for input tensors, so we can safely fill them with data - struct ggml_tensor * tokens_input = ggml_new_tensor_2d(ctx0, GGML_TYPE_I32, n_tokens, n_batch); - struct ggml_tensor * target_probs = ggml_new_tensor_3d(ctx0, GGML_TYPE_F32, n_vocab, n_tokens, n_batch); - ggml_set_no_alloc(ctx0, (alloc != NULL)); if (alloc) { ggml_allocr_reset(alloc); } - opt_cb_data.tokens_input = tokens_input; - opt_cb_data.target_probs = target_probs; - int n_past = 0; struct ggml_cgraph * gf = ggml_new_graph(ctx0);