From b11ac01f6b3985a8f41d9a99db076982a61bfec0 Mon Sep 17 00:00:00 2001 From: mqy Date: Mon, 3 Jul 2023 16:00:47 +0800 Subject: [PATCH] rewrite: no longer consider backward compitability; plan and make_plan --- examples/baby-llama/baby-llama.cpp | 41 +++- examples/benchmark/benchmark-matmult.cpp | 46 +++- .../train-text-from-scratch.cpp | 41 +++- ggml.c | 227 ++++++++++-------- ggml.h | 52 ++-- llama.cpp | 68 +++++- tests/test-grad0.c | 66 ++++- tests/test-opt.c | 28 ++- 8 files changed, 404 insertions(+), 165 deletions(-) diff --git a/examples/baby-llama/baby-llama.cpp b/examples/baby-llama/baby-llama.cpp index 212f54d32..f147c23a2 100644 --- a/examples/baby-llama/baby-llama.cpp +++ b/examples/baby-llama/baby-llama.cpp @@ -1586,7 +1586,6 @@ int main(int argc, char ** argv) { int n_past = 0; ggml_cgraph gf = {}; - gf.n_threads = 1; get_example_targets_batch(ctx0, 64*ex+0, tokens_input, targets); @@ -1595,7 +1594,18 @@ int main(int argc, char ** argv) { struct ggml_tensor * e = square_error_loss(ctx0, targets, logits); ggml_build_forward_expand(&gf, e); - ggml_graph_compute(ctx0, &gf); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf, /*n_threads*/ 1); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf); + if (plan.work_data) { + free(plan.work_data); + } + } float error_before_opt = ggml_get_f32_1d(e, 0); @@ -1611,7 +1621,18 @@ int main(int argc, char ** argv) { ggml_opt(ctx0, opt_params_lbfgs, e); // ggml_build_forward_expand(&gf, e); - ggml_graph_compute(ctx0, &gf); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf, /*n_threads*/ 1); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf); + if (plan.work_data) { + free(plan.work_data); + } + } float error_after_opt = ggml_get_f32_1d(e, 0); @@ -1659,13 +1680,23 @@ int main(int argc, char ** argv) { struct ggml_context * ctx0 = ggml_init(params); ggml_cgraph gf = {}; - gf.n_threads = 1; int n_past = 0; struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past); ggml_build_forward_expand(&gf, logits); - ggml_graph_compute(ctx0, &gf); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf, /*n_threads*/ 1); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf); + if (plan.work_data) { + free(plan.work_data); + } + } struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx); struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx); diff --git a/examples/benchmark/benchmark-matmult.cpp b/examples/benchmark/benchmark-matmult.cpp index 39d15caeb..e4f361e13 100644 --- a/examples/benchmark/benchmark-matmult.cpp +++ b/examples/benchmark/benchmark-matmult.cpp @@ -159,13 +159,22 @@ int main(int argc, char ** argv) { // printf("Creating compute graph\n"); struct ggml_cgraph gf = ggml_build_forward(m11xm2); - gf.n_threads=benchmark_params.n_threads; - printf("cgraph->n_threads=%i\n",gf.n_threads); + printf("n_threads=%i\n", benchmark_params.n_threads); TENSOR_DUMP(m11); TENSOR_DUMP(m2); - ggml_graph_compute(ctx, &gf); + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf, benchmark_params.n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf); + if (plan.work_data) { + free(plan.work_data); + } + } TENSOR_DUMP(gf.nodes[0]); @@ -187,7 +196,6 @@ int main(int argc, char ** argv) { // printf("Creating compute graph\n"); struct ggml_cgraph gf31 = ggml_build_forward(q31); - gf31.n_threads=benchmark_params.n_threads; // Set up a second graph computation to make sure we override the CPU cache lines // printf("Creating new tensor q12 & Running quantize\n"); @@ -199,8 +207,7 @@ int main(int argc, char ** argv) { //printf("Creating compute graph\n"); struct ggml_cgraph gf32 = ggml_build_forward(q32); - gf32.n_threads=benchmark_params.n_threads; - printf("cgraph->n_threads=%i\n",gf31.n_threads); + printf("n_threads=%i\n", benchmark_params.n_threads); const int dimx = sizex; const int dimy = sizey; @@ -221,14 +228,25 @@ int main(int argc, char ** argv) { long long int start = ggml_time_us(); //printf("Running ggml_graph_compute\n"); - ggml_graph_compute(ctx, &gf31); + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf31, benchmark_params.n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf31); + if (plan.work_data) { + free(plan.work_data); + } + } + long long int stop = ggml_time_us(); long long int usec = stop-start; double gflops = (double)(flops_per_matrix)/usec/1000.0; gflops_sum += gflops; printf("%9i;%8i;%6i;%6i;%6i;%15lli;%18lli;%10.2f\n", i, - gf31.n_threads, + benchmark_params.n_threads, sizex, sizey, sizez, flops_per_matrix, usec,gflops); @@ -253,7 +271,17 @@ int main(int argc, char ** argv) { } // Running a different graph computation to make sure we override the CPU cache lines - ggml_graph_compute(ctx, &gf32); + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf32, benchmark_params.n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf32); + if (plan.work_data) { + free(plan.work_data); + } + } } printf("\n"); printf("Average%78.2f\n",gflops_sum/((double)benchmark_params.n_iterations)); diff --git a/examples/train-text-from-scratch/train-text-from-scratch.cpp b/examples/train-text-from-scratch/train-text-from-scratch.cpp index 7f7bf3b6f..83da31531 100644 --- a/examples/train-text-from-scratch/train-text-from-scratch.cpp +++ b/examples/train-text-from-scratch/train-text-from-scratch.cpp @@ -3215,9 +3215,6 @@ int main(int argc, char ** argv) { struct ggml_cgraph * gf = (struct ggml_cgraph *) gfbuf->data; struct ggml_cgraph * gb = (struct ggml_cgraph *) gbbuf->data; - // ggml_cgraph gf = {}; - gf->n_threads = params.n_threads; - gb->n_threads = params.n_threads; get_example_targets_batch(lctx, train_samples.data(), train_samples.size(), train_tokens.data(), train_tokens.size(), ex, tokens_input, target_logits, target_probs); @@ -3246,7 +3243,17 @@ int main(int argc, char ** argv) { *gb = ggml_build_backward(ctx0, gf, true); } - ggml_graph_compute(ctx0, gf); + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gf, params.n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, gf); + if (plan.work_data) { + free(plan.work_data); + } + } size_t used_mem_before_opt = ggml_used_mem(ctx0); @@ -3270,7 +3277,17 @@ int main(int argc, char ** argv) { model.train_samples += n_batch; model.train_tokens += n_batch * n_tokens; - ggml_graph_compute(ctx0, gf); + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gf, params.n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, gf); + if (plan.work_data) { + free(plan.work_data); + } + } float error_after_opt = ggml_get_f32_1d(loss, 0); @@ -3352,13 +3369,23 @@ int main(int argc, char ** argv) { struct ggml_context * ctx0 = ggml_init(cparams); ggml_cgraph gf = {}; - gf.n_threads = params.n_threads; int n_past = 0; struct ggml_tensor * logits = forward(&model, &kv_self, ctx0, &gf, tokens_input, sample_ctx, n_past); ggml_build_forward_expand(&gf, logits); - ggml_graph_compute(ctx0, &gf); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf, params.n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf); + if (plan.work_data) { + free(plan.work_data); + } + } //struct ggml_tensor * best_samples = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, sample_ctx); //struct ggml_tensor * probs = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_vocab, sample_ctx); diff --git a/ggml.c b/ggml.c index 003506600..f019774e3 100644 --- a/ggml.c +++ b/ggml.c @@ -4583,14 +4583,13 @@ struct ggml_tensor * ggml_new_tensor_impl( /*.src0 =*/ NULL, /*.src1 =*/ NULL, /*.opt =*/ { NULL }, - /*.n_tasks =*/ 0, /*.perf_runs =*/ 0, /*.perf_cycles =*/ 0, /*.perf_time_us =*/ 0, /*.data =*/ (data == NULL && !ctx->no_alloc) ? (void *)(result + 1) : data, /*.name =*/ { 0 }, /*.extra =*/ NULL, - /*.pad =*/ { 0 }, + /*.padding =*/ { 0 }, }; // TODO: this should not be needed as long as we don't rely on aligned SIMD loads @@ -15772,7 +15771,6 @@ struct ggml_cgraph ggml_build_forward(struct ggml_tensor * tensor) { struct ggml_cgraph result = { /*.n_nodes =*/ 0, /*.n_leafs =*/ 0, - /*.n_threads =*/ GGML_DEFAULT_N_THREADS, /*.nodes =*/ { NULL }, /*.grads =*/ { NULL }, /*.leafs =*/ { NULL }, @@ -15944,7 +15942,7 @@ void clear_numa_thread_affinity(void) {} struct ggml_compute_state_shared { struct ggml_cgraph * cgraph; - struct ggml_cgraph_context * cgraph_ctx; + struct ggml_graph_compute_plan * cgraph_ctx; int64_t perf_node_start_cycles; int64_t perf_node_start_time_us; @@ -15974,7 +15972,9 @@ static void ggml_graph_compute_perf_stats_node(struct ggml_tensor * node, const static thread_ret_t ggml_graph_compute_thread(void * data) { struct ggml_compute_state * state = (struct ggml_compute_state *) data; struct ggml_cgraph * cgraph = state->shared->cgraph; - struct ggml_cgraph_context * ctx = state->shared->cgraph_ctx; + + struct ggml_graph_compute_plan * ctx = state->shared->cgraph_ctx; + const int *n_tasks_arr = ctx->n_tasks; const int n_threads = state->shared->n_threads; set_numa_thread_affinity(state->ith, n_threads); @@ -15997,7 +15997,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /* FINALIZE */ struct ggml_tensor * node = state->shared->cgraph->nodes[node_n]; if (GGML_OP_HAS_FINALIZE[node->op]) { - params.nth = node->n_tasks; + params.nth = n_tasks_arr[node_n]; ggml_compute_forward(¶ms, node); ggml_graph_compute_perf_stats_node(node, state->shared); } @@ -16008,11 +16008,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { GGML_PRINT_DEBUG_5("%s: %d/%d\n", __func__, node_n, cgraph->n_nodes); struct ggml_tensor * node = cgraph->nodes[node_n]; + const int n_tasks = n_tasks_arr[node_n]; state->shared->perf_node_start_cycles = ggml_perf_cycles(); state->shared->perf_node_start_time_us = ggml_perf_time_us(); - params.nth = node->n_tasks; + params.nth = n_tasks; /* INIT */ if (GGML_OP_HAS_INIT[node->op]) { @@ -16020,7 +16021,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { ggml_compute_forward(¶ms, node); } - if (node->n_tasks == 1) { + if (n_tasks == 1) { // TODO: maybe push node_n to the atomic but if other threads see n_tasks is 1, // they do something more efficient than spinning (?) params.type = GGML_TASK_COMPUTE; @@ -16052,16 +16053,17 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /* COMPUTE */ struct ggml_tensor * node = cgraph->nodes[node_n]; + const int n_tasks = n_tasks_arr[node_n]; struct ggml_compute_params params = { /*.type =*/ GGML_TASK_COMPUTE, /*.ith =*/ state->ith, - /*.nth =*/ node->n_tasks, + /*.nth =*/ n_tasks, /*.wsize =*/ ctx->work_size, /*.wdata =*/ ctx->work_data, }; - if (state->ith < node->n_tasks) { + if (state->ith < n_tasks) { ggml_compute_forward(¶ms, node); } } @@ -16070,15 +16072,14 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } // Prepare for graph computing. -// Will set: node->n_tasks, ctx->{work_size, planned} -void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgraph * cgraph) { - GGML_ASSERT(ctx); - // This function is actually reentrant, but duplicate calls is unnecessary. - GGML_ASSERT(ctx->work_size == 0); - GGML_ASSERT(ctx->work_data == NULL); - GGML_ASSERT(!ctx->planned); +struct ggml_graph_compute_plan ggml_graph_compute_make_plan(struct ggml_cgraph * cgraph, int n_threads) { + if (n_threads <= 0) { + n_threads = GGML_DEFAULT_N_THREADS; + } - int n_threads = cgraph->n_threads; + struct ggml_graph_compute_plan ctx; + memset(&ctx, 0, sizeof(struct ggml_graph_compute_plan)); + int * n_tasks = ctx.n_tasks; size_t work_size = 0; // initialize tasks + work buffer @@ -16091,11 +16092,11 @@ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgrap case GGML_OP_CPY: case GGML_OP_DUP: { - node->n_tasks = n_threads; + n_tasks[i] = n_threads; size_t cur = 0; if (ggml_is_quantized(node->type)) { - cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_threads; + cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks[i]; } work_size = MAX(work_size, cur); @@ -16103,24 +16104,24 @@ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgrap case GGML_OP_ADD: case GGML_OP_ADD1: { - node->n_tasks = n_threads; + n_tasks[i] = n_threads; size_t cur = 0; if (ggml_is_quantized(node->src0->type)) { - cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_threads; + cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_tasks[i]; } work_size = MAX(work_size, cur); } break; case GGML_OP_ACC: { - node->n_tasks = n_threads; + n_tasks[i] = n_threads; size_t cur = 0; if (ggml_is_quantized(node->src0->type)) { - cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_threads; + cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_tasks[i]; } work_size = MAX(work_size, cur); @@ -16144,7 +16145,7 @@ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgrap case GGML_OP_ELU: case GGML_OP_RELU: { - node->n_tasks = 1; + n_tasks[i] = 1; } break; case GGML_OP_MUL: case GGML_OP_GELU: @@ -16155,32 +16156,32 @@ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgrap case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: { - node->n_tasks = n_threads; + n_tasks[i] = n_threads; } break; case GGML_OP_MUL_MAT: case GGML_OP_OUT_PROD: { - node->n_tasks = n_threads; + n_tasks[i] = n_threads; // TODO: use different scheduling for different matrix sizes //const int nr0 = ggml_nrows(node->src0); //const int nr1 = ggml_nrows(node->src1); - //node->n_tasks = MIN(n_threads, MAX(1, nr0/128)); - //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, node->n_tasks); + //n_tasks[i] = MIN(n_threads, MAX(1, nr0/128)); + //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, n_tasks[i]); size_t cur = 0; const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type; #if defined(GGML_USE_CUBLAS) if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) { - node->n_tasks = 1; // TODO: this actually is doing nothing + n_tasks[i] = 1; // TODO: this actually is doing nothing // the threads are still spinning } else #elif defined(GGML_USE_CLBLAST) if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) { - node->n_tasks = 1; // TODO: this actually is doing nothing + n_tasks[i] = 1; // TODO: this actually is doing nothing // the threads are still spinning cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node); } @@ -16188,7 +16189,7 @@ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgrap #endif #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { - node->n_tasks = 1; // TODO: this actually is doing nothing + n_tasks[i] = 1; // TODO: this actually is doing nothing // the threads are still spinning if (node->src0->type != GGML_TYPE_F32) { // here we need memory just for single 2D matrix from src0 @@ -16206,7 +16207,7 @@ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgrap } break; case GGML_OP_SCALE: { - node->n_tasks = 1; + n_tasks[i] = 1; } break; case GGML_OP_SET: case GGML_OP_CONT: @@ -16219,7 +16220,7 @@ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgrap case GGML_OP_DIAG: case GGML_OP_DIAG_MASK_ZERO: { - node->n_tasks = 1; + n_tasks[i] = 1; } break; case GGML_OP_DIAG_MASK_INF: case GGML_OP_SOFT_MAX: @@ -16227,19 +16228,19 @@ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgrap case GGML_OP_ROPE: case GGML_OP_ROPE_BACK: { - node->n_tasks = n_threads; + n_tasks[i] = n_threads; } break; case GGML_OP_ALIBI: { - node->n_tasks = 1; //TODO + n_tasks[i] = 1; //TODO } break; case GGML_OP_CLAMP: { - node->n_tasks = 1; //TODO + n_tasks[i] = 1; //TODO } break; case GGML_OP_CONV_1D: { - node->n_tasks = n_threads; + n_tasks[i] = n_threads; GGML_ASSERT(node->src0->ne[3] == 1); GGML_ASSERT(node->src1->ne[2] == 1); @@ -16268,7 +16269,7 @@ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgrap } break; case GGML_OP_CONV_2D: { - node->n_tasks = n_threads; + n_tasks[i] = n_threads; GGML_ASSERT(node->src1->ne[3] == 1); @@ -16303,45 +16304,45 @@ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgrap } break; case GGML_OP_FLASH_ATTN: { - node->n_tasks = n_threads; + n_tasks[i] = n_threads; size_t cur = 0; const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL); if (node->src1->type == GGML_TYPE_F32) { - cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2 + cur = sizeof(float)*ne11*n_tasks[i]; // TODO: this can become (n_tasks[i]-1) + cur += sizeof(float)*ne11*n_tasks[i]; // this is overestimated by x2 } if (node->src1->type == GGML_TYPE_F16) { - cur = sizeof(float)*ne11*node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*ne11*node->n_tasks; // this is overestimated by x2 + cur = sizeof(float)*ne11*n_tasks[i]; // TODO: this can become (n_tasks[i]-1) + cur += sizeof(float)*ne11*n_tasks[i]; // this is overestimated by x2 } work_size = MAX(work_size, cur); } break; case GGML_OP_FLASH_FF: { - node->n_tasks = n_threads; + n_tasks[i] = n_threads; size_t cur = 0; if (node->src1->type == GGML_TYPE_F32) { - cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + cur = sizeof(float)*node->src1->ne[1]*n_tasks[i]; // TODO: this can become (n_tasks[i]-1) + cur += sizeof(float)*node->src1->ne[1]*n_tasks[i]; // this is overestimated by x2 } if (node->src1->type == GGML_TYPE_F16) { - cur = sizeof(float)*node->src1->ne[1]*node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*node->src1->ne[1]*node->n_tasks; // this is overestimated by x2 + cur = sizeof(float)*node->src1->ne[1]*n_tasks[i]; // TODO: this can become (n_tasks[i]-1) + cur += sizeof(float)*node->src1->ne[1]*n_tasks[i]; // this is overestimated by x2 } work_size = MAX(work_size, cur); } break; case GGML_OP_FLASH_ATTN_BACK: { - node->n_tasks = n_threads; + n_tasks[i] = n_threads; size_t cur = 0; @@ -16349,13 +16350,13 @@ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgrap const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL); const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back if (node->src1->type == GGML_TYPE_F32) { - cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2 + cur = sizeof(float)*mxDn*n_tasks[i]; // TODO: this can become (n_tasks[i]-1) + cur += sizeof(float)*mxDn*n_tasks[i]; // this is overestimated by x2 } if (node->src1->type == GGML_TYPE_F16) { - cur = sizeof(float)*mxDn*node->n_tasks; // TODO: this can become (n_tasks-1) - cur += sizeof(float)*mxDn*node->n_tasks; // this is overestimated by x2 + cur = sizeof(float)*mxDn*n_tasks[i]; // TODO: this can become (n_tasks[i]-1) + cur += sizeof(float)*mxDn*n_tasks[i]; // this is overestimated by x2 } work_size = MAX(work_size, cur); @@ -16368,27 +16369,27 @@ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgrap case GGML_OP_MAP_CUSTOM2: case GGML_OP_MAP_CUSTOM3: { - node->n_tasks = 1; + n_tasks[i] = 1; } break; case GGML_OP_CROSS_ENTROPY_LOSS: { - node->n_tasks = n_threads; + n_tasks[i] = n_threads; - size_t cur = ggml_type_size(node->type)*(node->n_tasks + node->src0->ne[0]*node->n_tasks); + size_t cur = ggml_type_size(node->type)*(n_tasks[i] + node->src0->ne[0]*n_tasks[i]); work_size = MAX(work_size, cur); } break; case GGML_OP_CROSS_ENTROPY_LOSS_BACK: { - node->n_tasks = n_threads; + n_tasks[i] = n_threads; - size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*node->n_tasks; + size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*n_tasks[i]; work_size = MAX(work_size, cur); } break; case GGML_OP_NONE: { - node->n_tasks = 1; + n_tasks[i] = 1; } break; case GGML_OP_COUNT: { @@ -16402,35 +16403,31 @@ void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgrap work_size += CACHE_LINE_SIZE*(n_threads - 1); } - ctx->work_size = work_size; - ctx->work_data = NULL; - ctx->planned = true; + ctx.n_threads = n_threads; + ctx.work_size = work_size; + ctx.work_data = NULL; + + return ctx; } -void ggml_graph_compute_v2(struct ggml_cgraph_context * ctx, struct ggml_cgraph * cgraph) { - if (ctx == NULL) { - ctx = alloca(sizeof(struct ggml_cgraph_context)); +void ggml_graph_compute(struct ggml_graph_compute_plan * ctx, struct ggml_cgraph * cgraph) { + { GGML_ASSERT(ctx); - ctx->work_size = 0; - ctx->work_data = NULL; - ctx->planned = false; - } else { - // The work_size and work_data MAY have default values even if has been planned. + GGML_ASSERT(ctx->n_threads > 0); + if (ctx->work_size > 0) { GGML_ASSERT(ctx->work_data); } - } - if (!ctx->planned) { - ggml_graph_compute_plan(ctx, cgraph); - if (ctx->work_size > 0) { - ctx->work_data = malloc(ctx->work_size * sizeof(GGML_TYPE_I8)); - GGML_ASSERT(ctx->work_data); - GGML_PRINT_DEBUG("%s: allocating work buffer for graph (%zu bytes)\n", __func__, work_size); + for (int i = 0; i < cgraph->n_nodes; ++i) { + if (cgraph->nodes[i]->op != GGML_OP_NONE) { + GGML_ASSERT(ctx->n_tasks[i] > 0); + } } + } - const int n_threads = cgraph->n_threads; + const int n_threads = ctx->n_threads; struct ggml_compute_state_shared state_shared = { /*.cgraph =*/ cgraph, @@ -16494,12 +16491,6 @@ void ggml_graph_compute_v2(struct ggml_cgraph_context * ctx, struct ggml_cgraph } } -// Deprecated, keep it only for backward compatibility. -void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { - UNUSED(ctx); - ggml_graph_compute_v2(NULL, cgraph); -} - void ggml_graph_reset(struct ggml_cgraph * cgraph) { for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * grad = cgraph->grads[i]; @@ -16548,14 +16539,13 @@ static void ggml_graph_export_node(const struct ggml_tensor * tensor, const char const int64_t * ne = tensor->ne; const size_t * nb = tensor->nb; - fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %8d %16p %32s\n", + fprintf(fout, "%-6s %-6s %-12s %8d %" PRId64 " %" PRId64 " %" PRId64 " %" PRId64 " %16zu %16zu %16zu %16zu %16p %32s\n", arg, ggml_type_name(tensor->type), ggml_op_name (tensor->op), tensor->n_dims, ne[0], ne[1], ne[2], ne[3], nb[0], nb[1], nb[2], nb[3], - tensor->n_tasks, tensor->data, tensor->name); } @@ -17283,7 +17273,6 @@ static void ggml_opt_get_grad(int np, struct ggml_tensor * const ps[], float * g // static enum ggml_opt_result ggml_opt_adam( - struct ggml_context * ctx, struct ggml_opt_context * opt, struct ggml_opt_params params, struct ggml_tensor * f, @@ -17291,9 +17280,6 @@ static enum ggml_opt_result ggml_opt_adam( struct ggml_cgraph * gb) { GGML_ASSERT(ggml_is_scalar(f)); - gf->n_threads = params.n_threads; - gb->n_threads = params.n_threads; - // these will store the parameters we want to optimize struct ggml_tensor * ps[GGML_MAX_PARAMS]; @@ -17340,7 +17326,18 @@ static enum ggml_opt_result ggml_opt_adam( // compute the function value ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(ctx, gb); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, gb); + if (plan.work_data) { + free(plan.work_data); + } + } opt->adam.fx_prev = ggml_get_f32_1d(f, 0); opt->adam.fx_best = opt->adam.fx_prev; @@ -17420,7 +17417,18 @@ static enum ggml_opt_result ggml_opt_adam( ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(ctx, gb); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, gb); + if (plan.work_data) { + free(plan.work_data); + } + } const float fx = ggml_get_f32_1d(f, 0); @@ -17491,7 +17499,6 @@ struct ggml_lbfgs_iteration_data { }; static enum ggml_opt_result linesearch_backtracking( - struct ggml_context * ctx, const struct ggml_opt_params * params, int nx, float * x, @@ -17542,7 +17549,18 @@ static enum ggml_opt_result linesearch_backtracking( ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(ctx, gb); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params->n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, gb); + if (plan.work_data) { + free(plan.work_data); + } + } ggml_opt_get_grad(np, ps, g); @@ -17610,9 +17628,6 @@ static enum ggml_opt_result ggml_opt_lbfgs( } } - gf->n_threads = params.n_threads; - gb->n_threads = params.n_threads; - const int m = params.lbfgs.m; // these will store the parameters we want to optimize @@ -17664,7 +17679,17 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_graph_reset (gf); ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(ctx, gb); + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(gb, params.n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, gb); + if (plan.work_data) { + free(plan.work_data); + } + } ggml_opt_get_grad(np, ps, g); @@ -17723,7 +17748,7 @@ static enum ggml_opt_result ggml_opt_lbfgs( ggml_vec_cpy_f32(nx, xp, x); ggml_vec_cpy_f32(nx, gp, g); - ls = linesearch_backtracking(ctx, ¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps); + ls = linesearch_backtracking(¶ms, nx, x, &fx, g, d, step, xp, f, gf, gb, np, ps); if (ls < 0) { // linesearch failed - go back to the previous point and return @@ -18025,7 +18050,7 @@ enum ggml_opt_result ggml_opt_resume_g( switch (opt->params.type) { case GGML_OPT_ADAM: { - result = ggml_opt_adam(ctx, opt, opt->params, f, gf, gb); + result = ggml_opt_adam(opt, opt->params, f, gf, gb); } break; case GGML_OPT_LBFGS: { diff --git a/ggml.h b/ggml.h index f949fe35f..f92f428fa 100644 --- a/ggml.h +++ b/ggml.h @@ -65,7 +65,16 @@ // ggml_set_f32(a, 3.0f); // ggml_set_f32(b, 4.0f); // -// ggml_graph_compute(ctx0, &gf); +// const int n_threads = 1; +// struct ggml_graph_compute_plan ctx = ggml_graph_compute_make_plan(&gf, n_threads); +// if (ctx.work_size > 0) { +// ctx.work_data = malloc(ctx.work_size); +// GGML_ASSERT(ctx.work_data); +// } +// ggml_graph_compute(&ctx, &gf); +// if (ctx.work_data) { +// free(ctx.work_data); +// } // // printf("f = %f\n", ggml_get_f32_1d(f, 0)); // @@ -418,9 +427,6 @@ extern "C" { struct ggml_tensor * src1; struct ggml_tensor * opt[GGML_MAX_OPT]; - // thread scheduling - int n_tasks; - // performance int perf_runs; int64_t perf_cycles; @@ -432,27 +438,30 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - char padding[4]; + char padding[8]; }; static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); - // graph compute context - struct ggml_cgraph_context { - // After call to `ggml_graph_compute_plan()`, `planned` is set as true, - // `work_size` will be updated as non-zero when buffer is required. When - // need buffer, caller MUST allocate memory for `work_data`. - // See https://github.com/ggerganov/ggml/issues/287 + // The default graph compute plan that needs to be prepared for ggml_graph_compute(). + // Since https://github.com/ggerganov/ggml/issues/287 + struct ggml_graph_compute_plan { + // Size of work buffer, calculated by `ggml_graph_compute_make_plan()`. size_t work_size; + // Worker buffer. + // Expect allocate/free by caller before/after calling to `ggml_graph_compute()`. void * work_data; - bool planned; // true means ready to compute graph nodes. + + int n_threads; + + // The `n_tasks` of nodes, 1:1 mapping to cgraph nodes. + int n_tasks[GGML_MAX_NODES]; }; // computation graph struct ggml_cgraph { int n_nodes; int n_leafs; - int n_threads; struct ggml_tensor * nodes[GGML_MAX_NODES]; struct ggml_tensor * grads[GGML_MAX_NODES]; @@ -1305,19 +1314,10 @@ extern "C" { GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor); GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep); - // Since https://github.com/ggerganov/ggml/issues/287 - GGML_API void ggml_graph_compute_plan(struct ggml_cgraph_context * ctx, struct ggml_cgraph * cgraph); - // Since https://github.com/ggerganov/ggml/issues/287 - // When `ctx` is NULL, `ggml_graph_compute_v2()` calculates work_size and allocates memory for `work_data`. - // Another use case: allocate buffer explicitly: - // - call `ggml_graph_compute_plan()`; - // - allocate memory for `ctx->work_data`; - // - finally call `ggml_graph_compute_v2()`. - // NOTE: don't manually set `ctx->planned`. - GGML_API void ggml_graph_compute_v2(struct ggml_cgraph_context * ctx, struct ggml_cgraph * cgraph); - // Deprecated, `ctx` is not required. Use `ggml_graph_compute_v2` instead. - // See https://github.com/ggerganov/ggml/issues/287 - GGML_API void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph); + // ggml_graph_compute_make_plan() needs to be called before ggml_graph_compute(). + // Returns a plan object. When plan.work_size > 0, caller must allocate memory for plan.work_data. + GGML_API struct ggml_graph_compute_plan ggml_graph_compute_make_plan(struct ggml_cgraph * cgraph, const int n_threads/*=GGML_DEFAULT_N_THREADS*/); + GGML_API void ggml_graph_compute(struct ggml_graph_compute_plan * plan, struct ggml_cgraph * cgraph); GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name); diff --git a/llama.cpp b/llama.cpp index 02afdeb14..d1ae57298 100644 --- a/llama.cpp +++ b/llama.cpp @@ -1309,7 +1309,7 @@ static bool llama_eval_internal( // for big prompts, if BLAS is enabled, it is better to use only one thread // otherwise, the threads are spin-lock waiting for the BLAS calls and are degrading the performance ggml_cgraph gf = {}; - gf.n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; + const int actual_n_threads = N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() ? 1 : n_threads; struct ggml_tensor * cur; struct ggml_tensor * inpL; @@ -1612,10 +1612,30 @@ static bool llama_eval_internal( ggml_metal_get_tensor(lctx.ctx_metal, kv_self.v); } - ggml_graph_compute(ctx0, &gf); + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf, actual_n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf); + if (plan.work_data) { + free(plan.work_data); + } + } } #else - ggml_graph_compute(ctx0, &gf); + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf, actual_n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf); + if (plan.work_data) { + free(plan.work_data); + } + } #endif if (cgraph_fname) { @@ -2966,8 +2986,18 @@ int llama_apply_lora_from_file_internal(const struct llama_model & model, const } struct ggml_cgraph gf = ggml_build_forward(r); - gf.n_threads = n_threads; - ggml_graph_compute(lora_ctx, &gf); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf, n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf); + if (plan.work_data) { + free(plan.work_data); + } + } // we won't need these tensors again, reset the context to save memory ggml_free(lora_ctx); @@ -3120,7 +3150,6 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); ggml_cgraph gf{}; - gf.n_threads = 1; ggml_tensor * kout3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); kout3d->data = out; @@ -3140,7 +3169,18 @@ size_t llama_copy_state_data(struct llama_context * ctx, uint8_t * dst) { ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, k3d, kout3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, v3d, vout3d)); - ggml_graph_compute(cpy_ctx, &gf); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf, /*n_threads*/ 1); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf); + if (plan.work_data) { + free(plan.work_data); + } + } ggml_free(cpy_ctx); } @@ -3226,7 +3266,6 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { ggml_context * cpy_ctx = ggml_init({ 4096, NULL, /* no_alloc */ true }); ggml_cgraph gf{}; - gf.n_threads = 1; ggml_tensor * kin3d = ggml_new_tensor_3d(cpy_ctx, kv_self.k->type, n_embd, kv_ntok, n_layer); kin3d->data = (void *) inp; @@ -3246,7 +3285,18 @@ size_t llama_set_state_data(struct llama_context * ctx, uint8_t * src) { ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, kin3d, k3d)); ggml_build_forward_expand(&gf, ggml_cpy(cpy_ctx, vin3d, v3d)); - ggml_graph_compute(cpy_ctx, &gf); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf, /*n_threads*/ 1); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf); + if (plan.work_data) { + free(plan.work_data); + } + } ggml_free(cpy_ctx); } diff --git a/tests/test-grad0.c b/tests/test-grad0.c index a3e25214b..11bb2307f 100644 --- a/tests/test-grad0.c +++ b/tests/test-grad0.c @@ -215,15 +215,36 @@ bool check_gradient( } struct ggml_cgraph gf = ggml_build_forward (f); - gf.n_threads = n_threads; struct ggml_cgraph gb = ggml_build_backward(ctx0, &gf, false); - gb.n_threads = n_threads; ggml_graph_compute(ctx0, &gf); + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf, n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf); + if (plan.work_data) { + free(plan.work_data); + } + } + ggml_graph_reset (&gf); ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(ctx0, &gb); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gb, n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gb); + if (plan.work_data) { + free(plan.work_data); + } + } // ggml_graph_dump_dot(&gf, NULL, "test-grad0-forward.dot"); // ggml_graph_dump_dot(&gb, &gf, "test-grad0-backward.dot"); @@ -236,12 +257,34 @@ bool check_gradient( const float xm = x0 - eps; const float xp = x0 + eps; set_element(x[i], k, xp); - ggml_graph_compute(ctx0, &gf); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf, n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf); + if (plan.work_data) { + free(plan.work_data); + } + } const float f0 = ggml_get_f32_1d(f, 0); set_element(x[i], k, xm); - ggml_graph_compute(ctx0, &gf); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gf, n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gf); + if (plan.work_data) { + free(plan.work_data); + } + } const float f1 = ggml_get_f32_1d(f, 0); @@ -252,7 +295,18 @@ bool check_gradient( // compute gradient using backward graph ggml_graph_reset (&gf); ggml_set_f32 (f->grad, 1.0f); - ggml_graph_compute(ctx0, &gb); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&gb, n_threads); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &gb); + if (plan.work_data) { + free(plan.work_data); + } + } const float g1 = get_element(x[i]->grad, k); diff --git a/tests/test-opt.c b/tests/test-opt.c index d001615ee..cb0d58199 100644 --- a/tests/test-opt.c +++ b/tests/test-opt.c @@ -140,7 +140,19 @@ int main(int argc, const char ** argv) { struct ggml_cgraph ge = ggml_build_forward(e); ggml_graph_reset (&ge); - ggml_graph_compute(ctx, &ge); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&ge, /*n_threads*/ 1); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &ge); + if (plan.work_data) { + free(plan.work_data); + } + } + const float fe = ggml_get_f32_1d(e, 0); printf("%s: e = %.4f\n", __func__, fe); @@ -149,7 +161,19 @@ int main(int argc, const char ** argv) { ggml_opt(ctx, opt_params, e); ggml_graph_reset (&ge); - ggml_graph_compute(ctx, &ge); + + { + struct ggml_graph_compute_plan plan = ggml_graph_compute_make_plan(&ge, /*n_threads*/ 1); + if (plan.work_size > 0) { + plan.work_data = malloc(plan.work_size); + GGML_ASSERT(plan.work_data); + } + ggml_graph_compute(&plan, &ge); + if (plan.work_data) { + free(plan.work_data); + } + } + const float fe_opt = ggml_get_f32_1d(e, 0); printf("%s: original e = %.4f\n", __func__, fe); printf("%s: optimized e = %.4f\n", __func__, fe_opt);