train-text-from-scratch: automatically allocate compute memory

This commit is contained in:
xaedes 2023-09-17 17:52:22 +02:00
parent f9b5d9b760
commit c993246bfd
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1081,13 +1081,6 @@ int main(int argc, char ** argv) {
printf("%s: opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f));
printf("%s: opt iter %d\n", __func__, opt->iter);
// TODO: use std::vector<uint8_t> intead of "new"
size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb);
uint8_t * compute_addr = new uint8_t[compute_size];
size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb);
uint8_t * compute_buf_0 = new uint8_t[size_buf_0];
int n_tokens = model.hparams.n_ctx;
int n_vocab = model.hparams.n_vocab;
int n_batch = params.common.n_batch;
@ -1124,9 +1117,82 @@ int main(int argc, char ** argv) {
ggml_allocr_alloc(alloc, target_probs);
ggml_allocr_free(alloc);
if (params.use_alloc) {
alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment);
// context for compute tensors without their data
size_t estimated_compute_size_wo_data = (
ggml_tensor_overhead()*GGML_MAX_NODES*2
+ (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
params.common.use_checkpointing ? 3 : 2
)
);
struct ggml_init_params ctx_compute_params = {
estimated_compute_size_wo_data, // mem_size
NULL, // mem_buffer
true, // no_alloc
};
struct ggml_context * ctx_compute = NULL;
struct ggml_tensor * loss = NULL;
struct ggml_tensor * logits = NULL;
struct ggml_cgraph * gf = NULL;
struct ggml_cgraph * gb = NULL;
struct ggml_cgraph * gb_tmp = NULL;
// measure required memory for compute tensors
size_t best_compute_size = SIZE_MAX;
enum ggml_cgraph_eval_order best_order = GGML_CGRAPH_EVAL_ORDER_COUNT;
// find best evaluation order
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new_measure(tensor_alignment);
gf = ggml_new_graph(ctx_compute);
gf->order = (enum ggml_cgraph_eval_order) order;
gb = ggml_new_graph(ctx_compute);
gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute)
: NULL;
loss = llama_build_train_graphs(
&model, alloc, ctx_compute,
gf, gb, gb_tmp,
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.common.use_flash,
params.common.use_checkpointing
);
size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
if (max_compute_size < best_compute_size) {
best_compute_size = max_compute_size;
best_order = gf->order;
}
ggml_allocr_free(alloc);
ggml_free(ctx_compute);
}
size_t max_compute_size = best_compute_size;
printf("%s: compute_size = %zu bytes (%.1f MB)\n", __func__, max_compute_size, (float) max_compute_size / (1024.0f*1024.0f));
printf("%s: evaluation order = %s\n", __func__,
(best_order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? "LEFT_TO_RIGHT" :
(best_order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? "RIGHT_TO_LEFT" :
"invalid");
// allocate compute tensors
mem_compute_data.resize(max_compute_size);
ctx_compute = ggml_init(ctx_compute_params);
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
gf = ggml_new_graph(ctx_compute);
gf->order = best_order;
gb = ggml_new_graph(ctx_compute);
gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx_compute)
: NULL;
loss = llama_build_train_graphs(
&model, alloc, ctx_compute,
gf, gb, gb_tmp,
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.common.use_flash,
params.common.use_checkpointing
);
ggml_allocr_free(alloc);
std::vector<llama_token> train_tokens;
std::vector<size_t> train_samples_begin;
@ -1204,83 +1270,30 @@ int main(int argc, char ** argv) {
opt_cb_data.last_time = ggml_time_ms();
opt_cb_data.millis_per_iter = 0.0;
// measure required memory for work buffer
size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads).work_size + GGML_OBJECT_SIZE;
printf("%s: work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f));
// context for work buffer
struct ggml_init_params ctx_work_params = {
max_work_size, // mem_size
NULL, // mem_buffer
false, // no_alloc
};
struct ggml_context * ctx_work = ggml_init(ctx_work_params);
int64_t t0 = ggml_time_ms();
for (int ex = 0; ex < params.n_examples; ++ex) {
ggml_opt_resume_g(ctx_work, opt, loss, gf, gb, &train_opt_callback, (void *) &opt_cb_data);
struct ggml_init_params cparams = {
compute_size, // mem_size
compute_addr, // mem_buffer
false, // no_alloc
};
struct ggml_context * ctx0 = ggml_init(cparams);
ggml_set_no_alloc(ctx0, (alloc != NULL));
if (alloc) {
ggml_allocr_reset(alloc);
}
int n_past = 0;
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
struct ggml_cgraph * gb = ggml_new_graph(ctx0);
struct ggml_cgraph * gb_tmp = params.common.use_checkpointing
? ggml_new_graph(ctx0)
: NULL;
GGML_ASSERT(n_past == 0);
struct ggml_tensor * loss = NULL;
struct ggml_tensor * logits = NULL;
loss = llama_build_train_graphs(
&model, alloc, ctx0,
gf, gb, gb_tmp,
&logits, tokens_input, target_probs,
n_tokens, n_batch,
params.common.use_flash,
params.common.use_checkpointing
);
size_t used_mem_before_opt = ggml_used_mem(ctx0);
opt->params.adam.sched = learning_schedule(
opt->iter,
params.common.warmup,
params.common.cos_decay_steps,
params.common.adam_alpha,
params.common.adam_min_alpha,
params.common.cos_decay_min,
params.common.cos_decay_restart,
params.common.enable_restart);
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
ggml_opt_resume_g(ctx0, opt, loss, gf, gb, &train_opt_callback, (void *) &opt_cb_data);
size_t used_mem_after_opt = ggml_used_mem(ctx0);
int n_iter = params.common.adam_n_iter;
train->train_its = opt->iter;
train->train_samples += n_batch * n_iter;
train->train_tokens += n_batch * n_tokens * n_iter;
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
printf("Example %d, opt iter %d\n", ex, opt->iter);
printf("error_before_opt: %.6f\n", opt->loss_before);
printf("error_after_opt: %.6f\n", opt->loss_after);
printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt);
printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
}
ggml_free(ctx0);
}
ggml_free(ctx_work);
ggml_free(ctx_compute);
ggml_free(ctx_input);
int64_t t1 = ggml_time_ms();
int64_t d = t1-t0;
double dd = (double) d * 1e-3;
printf("%s: total training time=%f seconds\n", __func__, dd);
printf("%s: total training time: ", __func__);
print_duration((double) (t1 - t0));
printf("\n");
int new_iters = opt->iter - opt_cb_data.last_save_iter;
if (new_iters > 0) {
@ -1295,8 +1308,6 @@ int main(int argc, char ** argv) {
ggml_allocr_free(alloc);
}
delete[] compute_addr;
delete[] compute_buf_0;
ggml_free(opt->ctx);
free_train_state(train);
ggml_free(model.ctx);