revert disabling of threading for rms_norm and norm
This commit is contained in:
parent
5d9fed7e7f
commit
47ad186628
1 changed files with 4 additions and 5 deletions
9
ggml.c
9
ggml.c
|
@ -9666,7 +9666,7 @@ static void ggml_compute_forward_norm_f32(
|
|||
// TODO: optimize
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
|
@ -9743,7 +9743,7 @@ static void ggml_compute_forward_rms_norm_f32(
|
|||
// TODO: optimize
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x*x
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
|
||||
|
||||
ggml_float sum = 0.0;
|
||||
|
@ -9823,7 +9823,7 @@ static void ggml_compute_forward_rms_norm_back_f32(
|
|||
// TODO: optimize
|
||||
for (int64_t i03 = 0; i03 < ne03; i03++) {
|
||||
for (int64_t i02 = 0; i02 < ne02; i02++) {
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) { // i think this must not be threaded, because we need mean over all x*x
|
||||
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
|
||||
// src1 is same shape as src0 => same indices
|
||||
const auto i11 = i01;
|
||||
const auto i12 = i02;
|
||||
|
@ -14537,8 +14537,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
|
|||
case GGML_OP_RMS_NORM:
|
||||
case GGML_OP_RMS_NORM_BACK:
|
||||
{
|
||||
// i think this cannot be threaded, because we need mean over all items, not just the slices each thread sees.
|
||||
node->n_tasks = 1;
|
||||
node->n_tasks = n_threads;
|
||||
} break;
|
||||
case GGML_OP_MUL_MAT:
|
||||
{
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue