print used memory before and after optimization

This commit is contained in:
xaedes 2023-05-19 18:40:20 +02:00
parent da86a1d736
commit e19ead6e3f
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

View file

@ -1637,6 +1637,8 @@ int main(int argc, char ** argv) {
ggml_build_forward_expand(&gf, e);
ggml_graph_compute(ctx0, &gf);
size_t used_mem_before_opt = ggml_used_mem(ctx0);
float error_before_opt = ggml_get_f32_1d(e, 0);
struct ggml_opt_params opt_params_adam = ggml_opt_default_params(GGML_OPT_ADAM);
@ -1645,6 +1647,7 @@ int main(int argc, char ** argv) {
opt_params_adam.print_backward_graph = false;
opt_params_adam.n_threads = gf.n_threads;
opt_params_adam.adam.n_iter = 16;
opt_params_adam.adam.alpha = 1e-4;
opt_params_lbfgs.print_forward_graph = false;
opt_params_lbfgs.print_backward_graph = false;
@ -1658,6 +1661,8 @@ int main(int argc, char ** argv) {
ggml_opt(ctx0, opt_params_lbfgs, e);
}
size_t used_mem_after_opt = ggml_used_mem(ctx0);
model.train_its += use_adam ? opt_params_adam.adam.n_iter : opt_params_lbfgs.lbfgs.n_iter;
model.train_samples += n_batch;
@ -1666,6 +1671,9 @@ int main(int argc, char ** argv) {
float error_after_opt = ggml_get_f32_1d(e, 0);
printf("used_mem_before_opt: %zu bytes\n", used_mem_before_opt);
printf("used_mem_after_opt: %zu bytes\n", used_mem_after_opt);
if (ex % 1 == 0) {
printf("Example %d\n", ex);
printf("error_before_opt: %.6f\n", error_before_opt);