diff --git a/ggml.c b/ggml.c index f03c26fb1..ab59771b6 100644 --- a/ggml.c +++ b/ggml.c @@ -9666,7 +9666,7 @@ static void ggml_compute_forward_norm_f32( // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); ggml_float sum = 0.0; @@ -9743,7 +9743,7 @@ static void ggml_compute_forward_rms_norm_f32( // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x*x + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); ggml_float sum = 0.0; @@ -9823,7 +9823,7 @@ static void ggml_compute_forward_rms_norm_back_f32( // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { - for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x*x + for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // src1 is same shape as src0 => same indices const auto i11 = i01; const auto i12 = i02; @@ -14537,8 +14537,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph) case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM_BACK: { - // i think this cannot be threaded, because we need mean over all items, not just the slices each thread sees. - node->n_tasks = 1; + node->n_tasks = n_threads; } break; case GGML_OP_MUL_MAT: {