revert disabling of threading for rms_norm and norm

This commit is contained in:
xaedes 2023-05-07 21:55:25 +02:00
parent 5d9fed7e7f
commit 47ad186628
No known key found for this signature in database
GPG key ID: 30030EDD817EA2B1

9
ggml.c
View file

@ -9666,7 +9666,7 @@ static void ggml_compute_forward_norm_f32(
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_float sum = 0.0;
@ -9743,7 +9743,7 @@ static void ggml_compute_forward_rms_norm_f32(
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x*x
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_float sum = 0.0;
@ -9823,7 +9823,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
// TODO: optimize
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x*x
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
// src1 is same shape as src0 => same indices
const auto i11 = i01;
const auto i12 = i02;
@ -14537,8 +14537,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
case GGML_OP_RMS_NORM:
case GGML_OP_RMS_NORM_BACK:
{
// i think this cannot be threaded, because we need mean over all items, not just the slices each thread sees.
node->n_tasks = 1;
node->n_tasks = n_threads;
} break;
case GGML_OP_MUL_MAT:
{