ggml : alternative thread distribution for mul_mat

ggml-ci
This commit is contained in:
Georgi Gerganov 2023-07-24 16:35:34 +03:00
parent 0822d27613
commit a2eb57e796

51
ggml.c
View file

@ -10513,17 +10513,29 @@ static void ggml_compute_forward_mul_mat(
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
// src1 rows const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = ne11*ne12*ne13; const int64_t nr1 = ne11*ne12*ne13; // src1 rows
/*int64_t ir010 = 0;*/ //printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
/*int64_t ir011 = ne01;*/
// parallelize by src0 rows // distribute the thread work across the inner or outer loop based on which one is larger
const int64_t dr = (ne01 + nth - 1)/nth;
const int64_t ir010 = dr*ith; const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
const int64_t ir011 = MIN(ir010 + dr, ne01); const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
const int64_t ith0 = ith % nth0;
const int64_t ith1 = ith / nth0;
const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
const int64_t ir010 = dr0*ith0;
const int64_t ir011 = MIN(ir010 + dr0, nr0);
const int64_t ir110 = dr1*ith1;
const int64_t ir111 = MIN(ir110 + dr1, nr1);
//printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
assert(ne12 % ne02 == 0); assert(ne12 % ne02 == 0);
assert(ne13 % ne03 == 0); assert(ne13 % ne03 == 0);
@ -10532,11 +10544,12 @@ static void ggml_compute_forward_mul_mat(
const int64_t r2 = ne12/ne02; const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03; const int64_t r3 = ne13/ne03;
for (int64_t ir1 = 0; ir1 < nr1; ++ir1) { for (int64_t ir1 = ir110; ir1 < ir111; ++ir1) {
const int64_t i13 = (ir1/(ne12*ne11)); const int64_t i13 = (ir1/(ne12*ne11));
const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11; const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11); const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
// broadcast src0 into src1
const int64_t i03 = i13/r3; const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2; const int64_t i02 = i12/r2;
@ -10544,7 +10557,7 @@ static void ggml_compute_forward_mul_mat(
const int64_t i2 = i12; const int64_t i2 = i12;
const int64_t i3 = i13; const int64_t i3 = i13;
const char * src0_row = (const char *) src0->data + ( 0 + i02*nb02 + i03*nb03 ); const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
@ -10557,28 +10570,14 @@ static void ggml_compute_forward_mul_mat(
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3)); float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
for (int64_t ir = ir010; ir < ir011; ++ir) { for (int64_t ir0 = ir010; ir0 < ir011; ++ir0) {
vec_dot(ne00, &dst_col[ir], src0_row + ir*nb01, src1_col); vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
} }
} }
//int64_t t1 = ggml_time_us();
//static int64_t acc = 0;
//acc += t1 - t0;
//if (t1 - t0 > 10) {
// printf("\n");
// printf("ne00 = %5d, ne01 = %5d, ne02 = %5d, ne03 = %5d\n", ne00, ne01, ne02, ne03);
// printf("nb00 = %5d, nb01 = %5d, nb02 = %5d, nb03 = %5d\n", nb00, nb01, nb02, nb03);
// printf("ne10 = %5d, ne11 = %5d, ne12 = %5d, ne13 = %5d\n", ne10, ne11, ne12, ne13);
// printf("XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX task %d/%d: %d us, acc = %d\n", ith, nth, (int) (t1 - t0), (int) acc);
//}
} }
// ggml_compute_forward_out_prod // ggml_compute_forward_out_prod
static void ggml_compute_forward_out_prod_f32( static void ggml_compute_forward_out_prod_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,