diff --git a/ggml.c b/ggml.c index f8eddd816..27232af28 100644 --- a/ggml.c +++ b/ggml.c @@ -10717,8 +10717,6 @@ static void ggml_compute_forward_mul_mat( float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); - assert(ne00 % 32 == 0); - for (int64_t ic = 0; ic < ne11; ++ic) { vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size)); } @@ -16078,328 +16076,327 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { n_threads = GGML_DEFAULT_N_THREADS; } + size_t work_size = 0; + struct ggml_cplan cplan; memset(&cplan, 0, sizeof(struct ggml_cplan)); - int * n_tasks = cplan.n_tasks; + // thread scheduling for the different operations + work buffer size estimation + for (int i = 0; i < cgraph->n_nodes; i++) { + int n_tasks = 1; - size_t work_size = 0; + struct ggml_tensor * node = cgraph->nodes[i]; - // initialize tasks + work buffer - { - // thread scheduling for the different operations - for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * node = cgraph->nodes[i]; + switch (node->op) { + case GGML_OP_CPY: + case GGML_OP_DUP: + { + n_tasks = n_threads; - switch (node->op) { - case GGML_OP_CPY: - case GGML_OP_DUP: - { - n_tasks[i] = n_threads; + size_t cur = 0; + if (ggml_is_quantized(node->type)) { + cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks; + } - size_t cur = 0; - if (ggml_is_quantized(node->type)) { - cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->ne[0] * n_tasks[i]; - } + work_size = MAX(work_size, cur); + } break; + case GGML_OP_ADD: + case GGML_OP_ADD1: + { + n_tasks = n_threads; - work_size = MAX(work_size, cur); - } break; - case GGML_OP_ADD: - case GGML_OP_ADD1: - { - n_tasks[i] = n_threads; + size_t cur = 0; - size_t cur = 0; + if (ggml_is_quantized(node->src0->type)) { + cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_tasks; + } - if (ggml_is_quantized(node->src0->type)) { - cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src0->ne[0] * n_tasks[i]; - } + work_size = MAX(work_size, cur); + } break; + case GGML_OP_ACC: + { + n_tasks = n_threads; - work_size = MAX(work_size, cur); - } break; - case GGML_OP_ACC: - { - n_tasks[i] = n_threads; + size_t cur = 0; - size_t cur = 0; + if (ggml_is_quantized(node->src0->type)) { + cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_tasks; + } - if (ggml_is_quantized(node->src0->type)) { - cur = GGML_TYPE_SIZE[GGML_TYPE_F32] * node->src1->ne[0] * n_tasks[i]; - } + work_size = MAX(work_size, cur); + } break; + case GGML_OP_SUB: + case GGML_OP_DIV: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_LOG: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_ABS: + case GGML_OP_SGN: + case GGML_OP_NEG: + case GGML_OP_STEP: + case GGML_OP_TANH: + case GGML_OP_ELU: + case GGML_OP_RELU: + { + n_tasks = 1; + } break; + case GGML_OP_MUL: + case GGML_OP_GELU: + case GGML_OP_GELU_QUICK: + case GGML_OP_SILU: + case GGML_OP_SILU_BACK: + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + { + n_tasks = n_threads; + } break; + case GGML_OP_MUL_MAT: + case GGML_OP_OUT_PROD: + { + n_tasks = n_threads; - work_size = MAX(work_size, cur); - } break; - case GGML_OP_SUB: - case GGML_OP_DIV: - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_LOG: - case GGML_OP_SUM: - case GGML_OP_SUM_ROWS: - case GGML_OP_MEAN: - case GGML_OP_ARGMAX: - case GGML_OP_REPEAT: - case GGML_OP_REPEAT_BACK: - case GGML_OP_ABS: - case GGML_OP_SGN: - case GGML_OP_NEG: - case GGML_OP_STEP: - case GGML_OP_TANH: - case GGML_OP_ELU: - case GGML_OP_RELU: - { - n_tasks[i] = 1; - } break; - case GGML_OP_MUL: - case GGML_OP_GELU: - case GGML_OP_GELU_QUICK: - case GGML_OP_SILU: - case GGML_OP_SILU_BACK: - case GGML_OP_NORM: - case GGML_OP_RMS_NORM: - case GGML_OP_RMS_NORM_BACK: - { - n_tasks[i] = n_threads; - } break; - case GGML_OP_MUL_MAT: - case GGML_OP_OUT_PROD: - { - n_tasks[i] = n_threads; + // TODO: use different scheduling for different matrix sizes + //const int nr0 = ggml_nrows(node->src0); + //const int nr1 = ggml_nrows(node->src1); - // TODO: use different scheduling for different matrix sizes - //const int nr0 = ggml_nrows(node->src0); - //const int nr1 = ggml_nrows(node->src1); + //n_tasks = MIN(n_threads, MAX(1, nr0/128)); + //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks%d\n", nr0, nr1, nr0*nr1, n_tasks); - //n_tasks[i] = MIN(n_threads, MAX(1, nr0/128)); - //printf("nr0 = %8d, nr1 = %8d, nr0*nr1 = %8d, n_tasks = %d\n", nr0, nr1, nr0*nr1, n_tasks[i]); - - size_t cur = 0; - const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type; + size_t cur = 0; + const enum ggml_type vec_dot_type = type_traits[node->src0->type].vec_dot_type; #if defined(GGML_USE_CUBLAS) - if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) { - n_tasks[i] = 1; // TODO: this actually is doing nothing - // the threads are still spinning - } - else + if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + } + else #elif defined(GGML_USE_CLBLAST) if (ggml_cl_can_mul_mat(node->src0, node->src1, node)) { - n_tasks[i] = 1; // TODO: this actually is doing nothing - // the threads are still spinning + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning cur = ggml_cl_mul_mat_get_wsize(node->src0, node->src1, node); } else #endif #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) - if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { - n_tasks[i] = 1; // TODO: this actually is doing nothing - // the threads are still spinning - if (node->src0->type != GGML_TYPE_F32) { - // here we need memory just for single 2D matrix from src0 - cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); - } - } else + if (ggml_compute_forward_mul_mat_use_blas(node->src0, node->src1, node)) { + n_tasks = 1; // TODO: this actually is doing nothing + // the threads are still spinning + if (node->src0->type != GGML_TYPE_F32) { + // here we need memory just for single 2D matrix from src0 + cur = GGML_TYPE_SIZE[GGML_TYPE_F32]*(node->src0->ne[0]*node->src0->ne[1]); + } + } else #endif - if (node->src1->type != vec_dot_type) { - cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type]; - } else { - cur = 0; - } + if (node->src1->type != vec_dot_type) { + cur = GGML_TYPE_SIZE[vec_dot_type]*ggml_nelements(node->src1)/GGML_BLCK_SIZE[vec_dot_type]; + } else { + cur = 0; + } - work_size = MAX(work_size, cur); - } break; - case GGML_OP_SCALE: - { - n_tasks[i] = 1; - } break; - case GGML_OP_SET: - case GGML_OP_CONT: - case GGML_OP_RESHAPE: - case GGML_OP_VIEW: - case GGML_OP_PERMUTE: - case GGML_OP_TRANSPOSE: - case GGML_OP_GET_ROWS: - case GGML_OP_GET_ROWS_BACK: - case GGML_OP_DIAG: - case GGML_OP_DIAG_MASK_ZERO: - { - n_tasks[i] = 1; - } break; - case GGML_OP_DIAG_MASK_INF: - case GGML_OP_SOFT_MAX: - case GGML_OP_SOFT_MAX_BACK: - case GGML_OP_ROPE: - case GGML_OP_ROPE_BACK: - { - n_tasks[i] = n_threads; - } break; - case GGML_OP_ALIBI: - { - n_tasks[i] = 1; //TODO - } break; - case GGML_OP_CLAMP: - { - n_tasks[i] = 1; //TODO - } break; - case GGML_OP_CONV_1D: - { - n_tasks[i] = n_threads; + work_size = MAX(work_size, cur); + } break; + case GGML_OP_SCALE: + { + n_tasks = 1; + } break; + case GGML_OP_SET: + case GGML_OP_CONT: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_GET_ROWS: + case GGML_OP_GET_ROWS_BACK: + case GGML_OP_DIAG: + case GGML_OP_DIAG_MASK_ZERO: + { + n_tasks = 1; + } break; + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + { + n_tasks = n_threads; + } break; + case GGML_OP_ALIBI: + { + n_tasks = 1; //TODO + } break; + case GGML_OP_CLAMP: + { + n_tasks = 1; //TODO + } break; + case GGML_OP_CONV_1D: + { + n_tasks = n_threads; - GGML_ASSERT(node->src0->ne[3] == 1); - GGML_ASSERT(node->src1->ne[2] == 1); - GGML_ASSERT(node->src1->ne[3] == 1); + GGML_ASSERT(node->src0->ne[3] == 1); + GGML_ASSERT(node->src1->ne[2] == 1); + GGML_ASSERT(node->src1->ne[3] == 1); - size_t cur = 0; - const int nk = node->src0->ne[0]; + size_t cur = 0; + const int nk = node->src0->ne[0]; - if (node->src0->type == GGML_TYPE_F16 && + if (node->src0->type == GGML_TYPE_F16 && node->src1->type == GGML_TYPE_F32) { - cur = sizeof(ggml_fp16_t)*( - nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] + - ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] - ); - } else if (node->src0->type == GGML_TYPE_F32 && - node->src1->type == GGML_TYPE_F32) { - cur = sizeof(float)*( - nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] + - ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] - ); - } else { - GGML_ASSERT(false); - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_CONV_2D: - { - n_tasks[i] = n_threads; - - GGML_ASSERT(node->src1->ne[3] == 1); - - const int64_t ne00 = node->src0->ne[0]; // W - const int64_t ne01 = node->src0->ne[1]; // H - const int64_t ne02 = node->src0->ne[2]; // C - const int64_t ne03 = node->src0->ne[3]; // N - - const int64_t ne10 = node->src1->ne[0]; // W - const int64_t ne11 = node->src1->ne[1]; // H - const int64_t ne12 = node->src1->ne[2]; // C - - const int64_t nk = ne00*ne01; - - UNUSED(ne02); - UNUSED(ne03); - UNUSED(nk); - - size_t cur = 0; - - if (node->src0->type == GGML_TYPE_F16 && + cur = sizeof(ggml_fp16_t)*( + nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] + + ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] + ); + } else if (node->src0->type == GGML_TYPE_F32 && node->src1->type == GGML_TYPE_F32) { - cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12); - } else if (node->src0->type == GGML_TYPE_F32 && - node->src1->type == GGML_TYPE_F32) { - cur = sizeof(float)* (ne10*ne11*ne12); - } else { - GGML_ASSERT(false); - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_FLASH_ATTN: - { - n_tasks[i] = n_threads; - - size_t cur = 0; - - const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL); - - if (node->src1->type == GGML_TYPE_F32) { - cur = sizeof(float)*ne11*n_tasks[i]; // TODO: this can become (n_tasks[i]-1) - cur += sizeof(float)*ne11*n_tasks[i]; // this is overestimated by x2 - } - - if (node->src1->type == GGML_TYPE_F16) { - cur = sizeof(float)*ne11*n_tasks[i]; // TODO: this can become (n_tasks[i]-1) - cur += sizeof(float)*ne11*n_tasks[i]; // this is overestimated by x2 - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_FLASH_FF: - { - n_tasks[i] = n_threads; - - size_t cur = 0; - - if (node->src1->type == GGML_TYPE_F32) { - cur = sizeof(float)*node->src1->ne[1]*n_tasks[i]; // TODO: this can become (n_tasks[i]-1) - cur += sizeof(float)*node->src1->ne[1]*n_tasks[i]; // this is overestimated by x2 - } - - if (node->src1->type == GGML_TYPE_F16) { - cur = sizeof(float)*node->src1->ne[1]*n_tasks[i]; // TODO: this can become (n_tasks[i]-1) - cur += sizeof(float)*node->src1->ne[1]*n_tasks[i]; // this is overestimated by x2 - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_FLASH_ATTN_BACK: - { - n_tasks[i] = n_threads; - - size_t cur = 0; - - const int64_t D = node->src0->ne[0]; - const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL); - const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back - if (node->src1->type == GGML_TYPE_F32) { - cur = sizeof(float)*mxDn*n_tasks[i]; // TODO: this can become (n_tasks[i]-1) - cur += sizeof(float)*mxDn*n_tasks[i]; // this is overestimated by x2 - } - - if (node->src1->type == GGML_TYPE_F16) { - cur = sizeof(float)*mxDn*n_tasks[i]; // TODO: this can become (n_tasks[i]-1) - cur += sizeof(float)*mxDn*n_tasks[i]; // this is overestimated by x2 - } - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_WIN_PART: - case GGML_OP_WIN_UNPART: - case GGML_OP_MAP_UNARY: - case GGML_OP_MAP_BINARY: - case GGML_OP_MAP_CUSTOM1: - case GGML_OP_MAP_CUSTOM2: - case GGML_OP_MAP_CUSTOM3: - { - n_tasks[i] = 1; - } break; - case GGML_OP_CROSS_ENTROPY_LOSS: - { - n_tasks[i] = n_threads; - - size_t cur = ggml_type_size(node->type)*(n_tasks[i] + node->src0->ne[0]*n_tasks[i]); - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - n_tasks[i] = n_threads; - - size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*n_tasks[i]; - - work_size = MAX(work_size, cur); - } break; - case GGML_OP_NONE: - { - n_tasks[i] = 1; - } break; - case GGML_OP_COUNT: - { + cur = sizeof(float)*( + nk*ggml_up32(node->src0->ne[1])*node->src0->ne[2] + + ( 2*(nk/2) + node->src1->ne[0])*node->src1->ne[1] + ); + } else { GGML_ASSERT(false); - } break; - } + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_CONV_2D: + { + n_tasks = n_threads; + + GGML_ASSERT(node->src1->ne[3] == 1); + + const int64_t ne00 = node->src0->ne[0]; // W + const int64_t ne01 = node->src0->ne[1]; // H + const int64_t ne02 = node->src0->ne[2]; // C + const int64_t ne03 = node->src0->ne[3]; // N + + const int64_t ne10 = node->src1->ne[0]; // W + const int64_t ne11 = node->src1->ne[1]; // H + const int64_t ne12 = node->src1->ne[2]; // C + + const int64_t nk = ne00*ne01; + + UNUSED(ne02); + UNUSED(ne03); + UNUSED(nk); + + size_t cur = 0; + + if (node->src0->type == GGML_TYPE_F16 && + node->src1->type == GGML_TYPE_F32) { + cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12); + } else if (node->src0->type == GGML_TYPE_F32 && + node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)* (ne10*ne11*ne12); + } else { + GGML_ASSERT(false); + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_FLASH_ATTN: + { + n_tasks = n_threads; + + size_t cur = 0; + + const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL); + + if (node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == GGML_TYPE_F16) { + cur = sizeof(float)*ne11*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_FLASH_FF: + { + n_tasks = n_threads; + + size_t cur = 0; + + if (node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == GGML_TYPE_F16) { + cur = sizeof(float)*node->src1->ne[1]*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*node->src1->ne[1]*n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_FLASH_ATTN_BACK: + { + n_tasks = n_threads; + + size_t cur = 0; + + const int64_t D = node->src0->ne[0]; + const int64_t ne11 = ggml_up(node->src1->ne[1], GGML_SOFT_MAX_UNROLL); + const int64_t mxDn = MAX(D, ne11) * 2; // *2 because of S and SM in ggml_compute_forward_flash_attn_back + if (node->src1->type == GGML_TYPE_F32) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } + + if (node->src1->type == GGML_TYPE_F16) { + cur = sizeof(float)*mxDn*n_tasks; // TODO: this can become (n_tasks-1) + cur += sizeof(float)*mxDn*n_tasks; // this is overestimated by x2 + } + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_WIN_PART: + case GGML_OP_WIN_UNPART: + case GGML_OP_MAP_UNARY: + case GGML_OP_MAP_BINARY: + case GGML_OP_MAP_CUSTOM1: + case GGML_OP_MAP_CUSTOM2: + case GGML_OP_MAP_CUSTOM3: + { + n_tasks = 1; + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: + { + n_tasks = n_threads; + + size_t cur = ggml_type_size(node->type)*(n_tasks + node->src0->ne[0]*n_tasks); + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_CROSS_ENTROPY_LOSS_BACK: + { + n_tasks = n_threads; + + size_t cur = ggml_type_size(node->type)*node->src0->ne[0]*n_tasks; + + work_size = MAX(work_size, cur); + } break; + case GGML_OP_NONE: + { + n_tasks = 1; + } break; + case GGML_OP_COUNT: + { + GGML_ASSERT(false); + } break; } + + cplan.n_tasks[i] = n_tasks; } if (work_size > 0) {