Extracting the logic to it's own function.

This commit is contained in:
Kunnis 2024-05-08 23:58:59 -05:00
parent a968553c6f
commit 7b932e4908

136
ggml.c
View file

@ -11769,6 +11769,87 @@ static bool ggml_compute_forward_mul_mat_use_blas(struct ggml_tensor * dst) {
}
#endif
static void ggml_compute_forward_mul_mat_one_chunk(
const struct ggml_compute_params* params,
struct ggml_tensor* dst,
const int64_t num_rows_per_vec_dot,
const int64_t ir0_start,
const int64_t ir0_end,
const int64_t ir1_start,
const int64_t ir1_end,
const bool src1_cont,
const int64_t r2,
const int64_t r3,
enum ggml_type vec_dot_type,
const void* wdata,
const size_t row_size,
const ggml_vec_dot_t vec_dot
) {
const struct ggml_tensor* src0 = dst->src[0];
const struct ggml_tensor* src1 = dst->src[1];
GGML_TENSOR_BINARY_OP_LOCALS
const enum ggml_type type = src0->type;
assert(ne12 % ne02 == 0);
assert(ne13 % ne03 == 0);
// block-tiling attempt
const int64_t blck_0 = 16;
const int64_t blck_1 = 16;
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
// attempt to reduce false-sharing (does not seem to make a difference)
// 16 * 2, accounting for mmla kernels
float tmp[32];
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
const int64_t i13 = (ir1 / (ne12 * ne1));
const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
// broadcast src0 into src1
const int64_t i03 = i13 / r3;
const int64_t i02 = i12 / r2;
const int64_t i1 = i11;
const int64_t i2 = i12;
const int64_t i3 = i13;
const char* src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char* src1_col = (const char*)wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
: (i11 * nb11 + i12 * nb12 + i13 * nb13));
float* dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
}
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
}
}
}
}
}
static void ggml_compute_forward_mul_mat(
const struct ggml_compute_params * params,
struct ggml_tensor * dst,
@ -12005,60 +12086,7 @@ UseGgmlGemm2:;
num_rows_per_vec_dot = 1;
}
assert(ne12% ne02 == 0);
assert(ne13% ne03 == 0);
// block-tiling attempt
const int64_t blck_0 = 16;
const int64_t blck_1 = 16;
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
// attempt to reduce false-sharing (does not seem to make a difference)
// 16 * 2, accounting for mmla kernels
float tmp[32];
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
const int64_t i13 = (ir1/(ne12*ne1));
const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
// broadcast src0 into src1
const int64_t i03 = i13/r3;
const int64_t i02 = i12/r2;
const int64_t i1 = i11;
const int64_t i2 = i12;
const int64_t i3 = i13;
const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
: (i11*nb11 + i12*nb12 + i13*nb13));
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
//}
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot >1 ? 16 : 0), src0_row + ir0*nb01, (num_rows_per_vec_dot >1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot >1 ? src1_col_stride : 0), num_rows_per_vec_dot);
}
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
}
}
}
}
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end, src1_cont, r2, r3, vec_dot_type, wdata, row_size, vec_dot);
}
// ggml_compute_forward_mul_mat_id