train-text-from-scratch: automatically allocate compute memory
This commit is contained in:
parent
f9b5d9b760
commit
c993246bfd
1 changed files with 94 additions and 83 deletions
|
@ -1081,13 +1081,6 @@ int main(int argc, char ** argv) {
|
|||
printf("%s: opt_size = %zu bytes (%.1f MB)\n", __func__, ggml_get_mem_size(opt->ctx), (float) ggml_get_mem_size(opt->ctx) / (1024.0f*1024.0f));
|
||||
printf("%s: opt iter %d\n", __func__, opt->iter);
|
||||
|
||||
// TODO: use std::vector<uint8_t> intead of "new"
|
||||
size_t compute_size = 1024ll*1024ll*1024ll*((size_t) params.mem_compute_gb);
|
||||
uint8_t * compute_addr = new uint8_t[compute_size];
|
||||
|
||||
size_t size_buf_0 = 1024ll*1024ll*1024ll*((size_t) params.mem_compute0_gb);
|
||||
uint8_t * compute_buf_0 = new uint8_t[size_buf_0];
|
||||
|
||||
int n_tokens = model.hparams.n_ctx;
|
||||
int n_vocab = model.hparams.n_vocab;
|
||||
int n_batch = params.common.n_batch;
|
||||
|
@ -1124,9 +1117,82 @@ int main(int argc, char ** argv) {
|
|||
ggml_allocr_alloc(alloc, target_probs);
|
||||
ggml_allocr_free(alloc);
|
||||
|
||||
if (params.use_alloc) {
|
||||
alloc = ggml_allocr_new(compute_buf_0, size_buf_0, tensor_alignment);
|
||||
// context for compute tensors without their data
|
||||
size_t estimated_compute_size_wo_data = (
|
||||
ggml_tensor_overhead()*GGML_MAX_NODES*2
|
||||
+ (GGML_OBJECT_SIZE+GGML_GRAPH_SIZE)*(
|
||||
params.common.use_checkpointing ? 3 : 2
|
||||
)
|
||||
);
|
||||
struct ggml_init_params ctx_compute_params = {
|
||||
estimated_compute_size_wo_data, // mem_size
|
||||
NULL, // mem_buffer
|
||||
true, // no_alloc
|
||||
};
|
||||
struct ggml_context * ctx_compute = NULL;
|
||||
|
||||
struct ggml_tensor * loss = NULL;
|
||||
struct ggml_tensor * logits = NULL;
|
||||
|
||||
struct ggml_cgraph * gf = NULL;
|
||||
struct ggml_cgraph * gb = NULL;
|
||||
struct ggml_cgraph * gb_tmp = NULL;
|
||||
|
||||
// measure required memory for compute tensors
|
||||
size_t best_compute_size = SIZE_MAX;
|
||||
enum ggml_cgraph_eval_order best_order = GGML_CGRAPH_EVAL_ORDER_COUNT;
|
||||
// find best evaluation order
|
||||
for (unsigned order = 0; order < (unsigned) GGML_CGRAPH_EVAL_ORDER_COUNT; ++order) {
|
||||
ctx_compute = ggml_init(ctx_compute_params);
|
||||
alloc = ggml_allocr_new_measure(tensor_alignment);
|
||||
gf = ggml_new_graph(ctx_compute);
|
||||
gf->order = (enum ggml_cgraph_eval_order) order;
|
||||
gb = ggml_new_graph(ctx_compute);
|
||||
gb_tmp = params.common.use_checkpointing
|
||||
? ggml_new_graph(ctx_compute)
|
||||
: NULL;
|
||||
loss = llama_build_train_graphs(
|
||||
&model, alloc, ctx_compute,
|
||||
gf, gb, gb_tmp,
|
||||
&logits, tokens_input, target_probs,
|
||||
n_tokens, n_batch,
|
||||
params.common.use_flash,
|
||||
params.common.use_checkpointing
|
||||
);
|
||||
size_t max_compute_size = ggml_allocr_max_size(alloc) + tensor_alignment;
|
||||
if (max_compute_size < best_compute_size) {
|
||||
best_compute_size = max_compute_size;
|
||||
best_order = gf->order;
|
||||
}
|
||||
ggml_allocr_free(alloc);
|
||||
ggml_free(ctx_compute);
|
||||
}
|
||||
size_t max_compute_size = best_compute_size;
|
||||
printf("%s: compute_size = %zu bytes (%.1f MB)\n", __func__, max_compute_size, (float) max_compute_size / (1024.0f*1024.0f));
|
||||
printf("%s: evaluation order = %s\n", __func__,
|
||||
(best_order == GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT) ? "LEFT_TO_RIGHT" :
|
||||
(best_order == GGML_CGRAPH_EVAL_ORDER_RIGHT_TO_LEFT) ? "RIGHT_TO_LEFT" :
|
||||
"invalid");
|
||||
|
||||
// allocate compute tensors
|
||||
mem_compute_data.resize(max_compute_size);
|
||||
ctx_compute = ggml_init(ctx_compute_params);
|
||||
alloc = ggml_allocr_new(mem_compute_data.data(), mem_compute_data.size(), tensor_alignment);
|
||||
gf = ggml_new_graph(ctx_compute);
|
||||
gf->order = best_order;
|
||||
gb = ggml_new_graph(ctx_compute);
|
||||
gb_tmp = params.common.use_checkpointing
|
||||
? ggml_new_graph(ctx_compute)
|
||||
: NULL;
|
||||
loss = llama_build_train_graphs(
|
||||
&model, alloc, ctx_compute,
|
||||
gf, gb, gb_tmp,
|
||||
&logits, tokens_input, target_probs,
|
||||
n_tokens, n_batch,
|
||||
params.common.use_flash,
|
||||
params.common.use_checkpointing
|
||||
);
|
||||
ggml_allocr_free(alloc);
|
||||
|
||||
std::vector<llama_token> train_tokens;
|
||||
std::vector<size_t> train_samples_begin;
|
||||
|
@ -1204,83 +1270,30 @@ int main(int argc, char ** argv) {
|
|||
opt_cb_data.last_time = ggml_time_ms();
|
||||
opt_cb_data.millis_per_iter = 0.0;
|
||||
|
||||
int64_t t0 = ggml_time_ms();
|
||||
// measure required memory for work buffer
|
||||
size_t max_work_size = ggml_graph_plan(gb, params.common.n_threads).work_size + GGML_OBJECT_SIZE;
|
||||
printf("%s: work_size = %zu bytes (%.1f MB)\n", __func__, max_work_size, (float) max_work_size / (1024.0f*1024.0f));
|
||||
|
||||
for (int ex = 0; ex < params.n_examples; ++ex) {
|
||||
|
||||
struct ggml_init_params cparams = {
|
||||
compute_size, // mem_size
|
||||
compute_addr, // mem_buffer
|
||||
// context for work buffer
|
||||
struct ggml_init_params ctx_work_params = {
|
||||
max_work_size, // mem_size
|
||||
NULL, // mem_buffer
|
||||
false, // no_alloc
|
||||
};
|
||||
struct ggml_context * ctx0 = ggml_init(cparams);
|
||||
struct ggml_context * ctx_work = ggml_init(ctx_work_params);
|
||||
|
||||
ggml_set_no_alloc(ctx0, (alloc != NULL));
|
||||
int64_t t0 = ggml_time_ms();
|
||||
|
||||
if (alloc) {
|
||||
ggml_allocr_reset(alloc);
|
||||
}
|
||||
ggml_opt_resume_g(ctx_work, opt, loss, gf, gb, &train_opt_callback, (void *) &opt_cb_data);
|
||||
|
||||
int n_past = 0;
|
||||
|
||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||
struct ggml_cgraph * gb = ggml_new_graph(ctx0);
|
||||
struct ggml_cgraph * gb_tmp = params.common.use_checkpointing
|
||||
? ggml_new_graph(ctx0)
|
||||
: NULL;
|
||||
|
||||
GGML_ASSERT(n_past == 0);
|
||||
|
||||
struct ggml_tensor * loss = NULL;
|
||||
struct ggml_tensor * logits = NULL;
|
||||
|
||||
loss = llama_build_train_graphs(
|
||||
&model, alloc, ctx0,
|
||||
gf, gb, gb_tmp,
|
||||
&logits, tokens_input, target_probs,
|
||||
n_tokens, n_batch,
|
||||
params.common.use_flash,
|
||||
params.common.use_checkpointing
|
||||
);
|
||||
|
||||
size_t used_mem_before_opt = ggml_used_mem(ctx0);
|
||||
|
||||
opt->params.adam.sched = learning_schedule(
|
||||
opt->iter,
|
||||
params.common.warmup,
|
||||
params.common.cos_decay_steps,
|
||||
params.common.adam_alpha,
|
||||
params.common.adam_min_alpha,
|
||||
params.common.cos_decay_min,
|
||||
params.common.cos_decay_restart,
|
||||
params.common.enable_restart);
|
||||
|
||||
printf("%s: opt->params.adam.sched %.5f\n", __func__, opt->params.adam.sched);
|
||||
|
||||
ggml_opt_resume_g(ctx0, opt, loss, gf, gb, &train_opt_callback, (void *) &opt_cb_data);
|
||||
|
||||
size_t used_mem_after_opt = ggml_used_mem(ctx0);
|
||||
|
||||
int n_iter = params.common.adam_n_iter;
|
||||
train->train_its = opt->iter;
|
||||
train->train_samples += n_batch * n_iter;
|
||||
train->train_tokens += n_batch * n_tokens * n_iter;
|
||||
|
||||
if (params.print_info_interval > 0 && ex % params.print_info_interval == 0) {
|
||||
printf("Example %d, opt iter %d\n", ex, opt->iter);
|
||||
printf("error_before_opt: %.6f\n", opt->loss_before);
|
||||
printf("error_after_opt: %.6f\n", opt->loss_after);
|
||||
printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt);
|
||||
printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
|
||||
}
|
||||
|
||||
ggml_free(ctx0);
|
||||
}
|
||||
ggml_free(ctx_work);
|
||||
ggml_free(ctx_compute);
|
||||
ggml_free(ctx_input);
|
||||
|
||||
int64_t t1 = ggml_time_ms();
|
||||
int64_t d = t1-t0;
|
||||
double dd = (double) d * 1e-3;
|
||||
printf("%s: total training time=%f seconds\n", __func__, dd);
|
||||
printf("%s: total training time: ", __func__);
|
||||
print_duration((double) (t1 - t0));
|
||||
printf("\n");
|
||||
|
||||
int new_iters = opt->iter - opt_cb_data.last_save_iter;
|
||||
if (new_iters > 0) {
|
||||
|
@ -1295,8 +1308,6 @@ int main(int argc, char ** argv) {
|
|||
ggml_allocr_free(alloc);
|
||||
}
|
||||
|
||||
delete[] compute_addr;
|
||||
delete[] compute_buf_0;
|
||||
ggml_free(opt->ctx);
|
||||
free_train_state(train);
|
||||
ggml_free(model.ctx);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue