ggml : group mul_mat_id rows by matrix (cpu only)

This commit is contained in:
slaren 2023-12-15 00:15:24 +01:00
parent 6744dbe924
commit 66ce753abd

167
ggml.c
View file

@ -9802,27 +9802,180 @@ static void ggml_compute_forward_mul_mat(
static void ggml_compute_forward_mul_mat_id(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * ids,
const struct ggml_tensor * src1,
struct ggml_tensor * dst) {
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
// during GGML_TASK_INIT the entire src1 is converted to vec_dot_type
ggml_compute_forward_mul_mat(params, dst->src[2], src1, dst, 0, dst->ne[1]);
const struct ggml_tensor * src0 = dst->src[2]; // only for GGML_TENSOR_BINARY_OP_LOCALS
GGML_TENSOR_BINARY_OP_LOCALS
const int ith = params->ith;
const int nth = params->nth;
const enum ggml_type type = src0->type;
const bool src1_cont = ggml_is_contiguous(src1);
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1
GGML_ASSERT(nb00 == ggml_type_size(type));
GGML_ASSERT(nb10 == ggml_type_size(src1->type));
// dst cannot be transposed or permuted
GGML_ASSERT(nb0 == sizeof(float));
GGML_ASSERT(nb0 <= nb1);
GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3);
// broadcast factors
const int64_t r2 = ne12/ne02;
const int64_t r3 = ne13/ne03;
if (params->type == GGML_TASK_INIT) {
if (src1->type != vec_dot_type) {
char * wdata = params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
assert(params->wsize >= ne11*ne12*ne13*row_size);
assert(src1->type == GGML_TYPE_F32);
for (int64_t i13 = 0; i13 < ne13; ++i13) {
for (int64_t i12 = 0; i12 < ne12; ++i12) {
for (int64_t i11 = 0; i11 < ne11; ++i11) {
from_float_to_vec_dot((float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11), (void *) wdata, ne10);
wdata += row_size;
}
}
}
}
return;
}
if (params->type == GGML_TASK_FINALIZE) {
return;
}
const struct ggml_tensor * ids = src0;
const int id = ggml_get_op_params_i32(dst, 0);
const int n_as = ggml_get_op_params_i32(dst, 1);
// group rows by src0 matrix
// TODO: allocate in wdata
#define MMID_MAX_BATCH 512
int matrix_row_counts[GGML_MAX_SRC-2] = {0}; // number of rows for each matrix
int matrix_rows[GGML_MAX_SRC-2][MMID_MAX_BATCH]; // row indices for each matrix
for (int64_t i01 = 0; i01 < ids->ne[1]; i01++) {
const int32_t row_id = *(const int32_t *) ((const char *) ids->data + i01*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(row_id >= 0 && row_id < n_as);
const struct ggml_tensor * src0_row = dst->src[row_id + 2];
ggml_compute_forward_mul_mat(params, src0_row, src1, dst, i01, 1);
matrix_rows[row_id][matrix_row_counts[row_id]] = i01;
matrix_row_counts[row_id] += 1;
}
// compute each matrix multiplication in sequence
for (int cur_a = 0; cur_a < n_as; ++cur_a) {
const int64_t cne1 = matrix_row_counts[cur_a];
if (cne1 == 0) {
continue;
}
const struct ggml_tensor * src0_cur = dst->src[cur_a + 2];
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
const int64_t nr0 = ne01; // src0 rows
const int64_t nr1 = cne1*ne12*ne13; // src1 rows
//printf("nr0 = %lld, nr1 = %lld\n", nr0, nr1);
// distribute the thread work across the inner or outer loop based on which one is larger
const int64_t nth0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows
const int64_t nth1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows
const int64_t ith0 = ith % nth0;
const int64_t ith1 = ith / nth0;
const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
const int64_t ir010 = dr0*ith0;
const int64_t ir011 = MIN(ir010 + dr0, nr0);
const int64_t ir110 = dr1*ith1;
const int64_t ir111 = MIN(ir110 + dr1, nr1);
//printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
// threads with no work simply yield (not sure if it helps)
if (ir010 >= ir011 || ir110 >= ir111) {
sched_yield();
continue;
}
assert(ne12 % ne02 == 0);
assert(ne13 % ne03 == 0);
// block-tiling attempt
const int64_t blck_0 = 16;
const int64_t blck_1 = 16;
// attempt to reduce false-sharing (does not seem to make a difference)
float tmp[16];
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
const int64_t i13 = (ir1/(ne12*cne1)); // TODO: remove, src1 is always a matrix
const int64_t i12 = (ir1 - i13*ne12*cne1)/cne1;
const int64_t _i11 = (ir1 - i13*ne12*cne1 - i12*cne1);
const int64_t i11 = matrix_rows[cur_a][_i11];
// broadcast src0 into src1
const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2;
const int64_t i1 = i11;
const int64_t i2 = i12;
const int64_t i3 = i13;
const char * src0_row = (const char *) src0_cur->data + (0 + i02*nb02 + i03*nb03);
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
: (i11*nb11 + i12*nb12 + i13*nb13));
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
}
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
}
}
}
}
}