Compare commits

...
Sign in to create a new pull request.

2 commits

Author SHA1 Message Date
Georgi Gerganov
a1cdd29cd2 ggml : rms_norm in chunks 2023-05-20 10:15:54 +03:00
Georgi Gerganov
5a317898e8 ggml : process mul mat rows in chunks 2023-05-20 10:15:53 +03:00

138
ggml.c
View file

@ -3590,6 +3590,9 @@ struct ggml_compute_params {
// work buffer for all threads // work buffer for all threads
size_t wsize; size_t wsize;
void * wdata; void * wdata;
// atomic counter used to distribute chunks of work
atomic_int * aic;
}; };
// //
@ -9030,18 +9033,20 @@ static void ggml_compute_forward_rms_norm_f32(
GGML_ASSERT(ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
atomic_store(params->aic, 0);
return; return;
} }
GGML_ASSERT(src0->nb[0] == sizeof(float)); GGML_ASSERT(src0->nb[0] == sizeof(float));
const int ith = params->ith; const int ith = params->ith; UNUSED(ith);
const int nth = params->nth; const int nth = params->nth;
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1]; const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2]; const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3]; const int64_t ne03 = src0->ne[3]; UNUSED(ne03);
const size_t nb01 = src0->nb[1]; const size_t nb01 = src0->nb[1];
const size_t nb02 = src0->nb[2]; const size_t nb02 = src0->nb[2];
@ -9053,30 +9058,45 @@ static void ggml_compute_forward_rms_norm_f32(
const float eps = 1e-6f; // TODO: make this a parameter const float eps = 1e-6f; // TODO: make this a parameter
// TODO: optimize const int nr = ggml_nrows(src0);
for (int64_t i03 = 0; i03 < ne03; i03++) { const int dr = (nr + 8*nth - 1)/(8*nth);
for (int64_t i02 = 0; i02 < ne02; i02++) {
for (int64_t i01 = ith; i01 < ne01; i01 += nth) {
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_float sum = 0.0; while (true) {
for (int64_t i00 = 0; i00 < ne00; i00++) { const int ir0 = atomic_fetch_add(params->aic, dr);
sum += (ggml_float)(x[i00] * x[i00]);
}
float mean = sum/ne00; for (int ir = ir0; ir < ir0 + dr; ++ir) {
if (ir >= nr) {
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); break;
memcpy(y, x, ne00 * sizeof(float));
// for (int i00 = 0; i00 < ne00; i00++) {
// y[i00] = x[i00];
// }
const float scale = 1.0f/sqrtf(mean + eps);
ggml_vec_scale_f32(ne00, y, scale);
} }
// src0 indices
const int i03 = ir/(ne02*ne01);
const int i02 = (ir - i03*ne02*ne01)/ne01;
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03);
ggml_float sum = 0.0;
for (int64_t i00 = 0; i00 < ne00; i00++) {
sum += (ggml_float)(x[i00] * x[i00]);
}
float mean = sum/ne00;
float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3);
memcpy(y, x, ne00 * sizeof(float));
// for (int i00 = 0; i00 < ne00; i00++) {
// y[i00] = x[i00];
// }
const float scale = 1.0f/sqrtf(mean + eps);
ggml_vec_scale_f32(ne00, y, scale);
}
if (ir0 + dr >= nr) {
break;
} }
} }
} }
@ -9751,7 +9771,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
const int nb2 = dst->nb[2]; const int nb2 = dst->nb[2];
const int nb3 = dst->nb[3]; const int nb3 = dst->nb[3];
const int ith = params->ith; const int ith = params->ith; UNUSED(ith);
const int nth = params->nth; const int nth = params->nth;
GGML_ASSERT(ne02 == ne12); GGML_ASSERT(ne02 == ne12);
@ -9867,6 +9887,8 @@ static void ggml_compute_forward_mul_mat_q_f32(
} }
} }
atomic_store(params->aic, 0);
return; return;
} }
@ -9874,43 +9896,48 @@ static void ggml_compute_forward_mul_mat_q_f32(
return; return;
} }
// parallelize by src0 rows using ggml_vec_dot_q
// total rows in src0
const int nr = ne01*ne02*ne03;
// rows per thread
const int dr = (nr + nth - 1)/nth;
// row range for this thread
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
void * wdata = params->wdata; void * wdata = params->wdata;
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
for (int ir = ir0; ir < ir1; ++ir) { // parallelize by src0 rows using ggml_vec_dot_q
// src0 indices
const int i03 = ir/(ne02*ne01);
const int i02 = (ir - i03*ne02*ne01)/ne01;
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int i13 = i03; const int nr = ggml_nrows(src0);
const int i12 = i02; const int dr = (nr + 8*nth - 1)/(8*nth);
const int i0 = i01; while (true) {
const int i2 = i02; const int ir0 = atomic_fetch_add(params->aic, dr);
const int i3 = i03;
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); for (int ir = ir0; ir < ir0 + dr; ++ir) {
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size)); if (ir >= nr) {
break;
}
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); // src0 indices
const int i03 = ir/(ne02*ne01);
const int i02 = (ir - i03*ne02*ne01)/ne01;
const int i01 = (ir - i03*ne02*ne01 - i02*ne01);
assert(ne00 % 32 == 0); const int i13 = i03;
const int i12 = i02;
for (int64_t ic = 0; ic < ne11; ++ic) { const int i0 = i01;
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size)); const int i2 = i02;
const int i3 = i03;
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size));
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
assert(ne00 % 32 == 0);
for (int64_t ic = 0; ic < ne11; ++ic) {
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
}
}
if (ir0 + dr >= nr) {
break;
} }
} }
@ -13749,6 +13776,7 @@ struct ggml_compute_state_shared {
// synchronization primitives // synchronization primitives
atomic_int n_ready; atomic_int n_ready;
atomic_int aic;
atomic_bool has_work; atomic_bool has_work;
atomic_bool stop; // stop all threads atomic_bool stop; // stop all threads
}; };
@ -13817,6 +13845,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
/*.spin =*/ GGML_LOCK_INITIALIZER, /*.spin =*/ GGML_LOCK_INITIALIZER,
/*.n_threads =*/ n_threads, /*.n_threads =*/ n_threads,
/*.n_ready =*/ 0, /*.n_ready =*/ 0,
/*.aic =*/ 0,
/*.has_work =*/ false, /*.has_work =*/ false,
/*.stop =*/ false, /*.stop =*/ false,
}; };
@ -13837,6 +13866,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.nth = n_threads, .nth = n_threads,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL, .wdata = cgraph->work ? cgraph->work->data : NULL,
.aic = &state_shared.aic,
}, },
.node = NULL, .node = NULL,
.shared = &state_shared, .shared = &state_shared,
@ -14126,6 +14156,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
/*.nth =*/ node->n_tasks, /*.nth =*/ node->n_tasks,
/*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0, /*.wsize =*/ cgraph->work ? ggml_nbytes(cgraph->work) : 0,
/*.wdata =*/ cgraph->work ? cgraph->work->data : NULL, /*.wdata =*/ cgraph->work ? cgraph->work->data : NULL,
/*.aic =*/ &state_shared.aic,
}; };
ggml_compute_forward(&params, node); ggml_compute_forward(&params, node);
@ -14149,6 +14180,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.nth = node->n_tasks, .nth = node->n_tasks,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL, .wdata = cgraph->work ? cgraph->work->data : NULL,
.aic = &state_shared.aic,
}; };
workers[j].node = node; workers[j].node = node;
} }
@ -14164,6 +14196,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
} }
params.type = GGML_TASK_COMPUTE; params.type = GGML_TASK_COMPUTE;
params.aic = &state_shared.aic;
ggml_compute_forward(&params, node); ggml_compute_forward(&params, node);
// wait for thread pool // wait for thread pool
@ -14204,6 +14237,7 @@ void ggml_graph_compute(struct ggml_context * ctx, struct ggml_cgraph * cgraph)
.nth = node->n_tasks, .nth = node->n_tasks,
.wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0, .wsize = cgraph->work ? ggml_nbytes(cgraph->work) : 0,
.wdata = cgraph->work ? cgraph->work->data : NULL, .wdata = cgraph->work ? cgraph->work->data : NULL,
.aic = &state_shared.aic,
}; };
workers[j].node = node; workers[j].node = node;
} }