Extracting the logic to it's own function.
This commit is contained in:
parent
a968553c6f
commit
7b932e4908
1 changed files with 82 additions and 54 deletions
136
ggml.c
136
ggml.c
|
@ -11769,6 +11769,87 @@ static bool ggml_compute_forward_mul_mat_use_blas(struct ggml_tensor * dst) {
|
|||
}
|
||||
#endif
|
||||
|
||||
static void ggml_compute_forward_mul_mat_one_chunk(
|
||||
const struct ggml_compute_params* params,
|
||||
struct ggml_tensor* dst,
|
||||
const int64_t num_rows_per_vec_dot,
|
||||
const int64_t ir0_start,
|
||||
const int64_t ir0_end,
|
||||
const int64_t ir1_start,
|
||||
const int64_t ir1_end,
|
||||
|
||||
const bool src1_cont,
|
||||
const int64_t r2,
|
||||
const int64_t r3,
|
||||
enum ggml_type vec_dot_type,
|
||||
const void* wdata,
|
||||
const size_t row_size,
|
||||
const ggml_vec_dot_t vec_dot
|
||||
) {
|
||||
|
||||
const struct ggml_tensor* src0 = dst->src[0];
|
||||
const struct ggml_tensor* src1 = dst->src[1];
|
||||
|
||||
GGML_TENSOR_BINARY_OP_LOCALS
|
||||
|
||||
const enum ggml_type type = src0->type;
|
||||
|
||||
assert(ne12 % ne02 == 0);
|
||||
assert(ne13 % ne03 == 0);
|
||||
|
||||
// block-tiling attempt
|
||||
const int64_t blck_0 = 16;
|
||||
const int64_t blck_1 = 16;
|
||||
|
||||
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
|
||||
|
||||
// attempt to reduce false-sharing (does not seem to make a difference)
|
||||
// 16 * 2, accounting for mmla kernels
|
||||
float tmp[32];
|
||||
|
||||
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
||||
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
|
||||
const int64_t i13 = (ir1 / (ne12 * ne1));
|
||||
const int64_t i12 = (ir1 - i13 * ne12 * ne1) / ne1;
|
||||
const int64_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1);
|
||||
|
||||
// broadcast src0 into src1
|
||||
const int64_t i03 = i13 / r3;
|
||||
const int64_t i02 = i12 / r2;
|
||||
|
||||
const int64_t i1 = i11;
|
||||
const int64_t i2 = i12;
|
||||
const int64_t i3 = i13;
|
||||
|
||||
const char* src0_row = (const char*)src0->data + (0 + i02 * nb02 + i03 * nb03);
|
||||
|
||||
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
||||
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
||||
// the original src1 data pointer, so we should index using the indices directly
|
||||
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
||||
const char* src1_col = (const char*)wdata +
|
||||
(src1_cont || src1->type != vec_dot_type
|
||||
? (i11 + i12 * ne11 + i13 * ne12 * ne11) * row_size
|
||||
: (i11 * nb11 + i12 * nb12 + i13 * nb13));
|
||||
float* dst_col = (float*)((char*)dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3));
|
||||
|
||||
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
|
||||
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
||||
//}
|
||||
|
||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot > 1 ? 16 : 0), src0_row + ir0 * nb01, (num_rows_per_vec_dot > 1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot > 1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
||||
}
|
||||
|
||||
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
||||
memcpy(&dst_col[iir0 + cn * nb1 / nb0], tmp + (cn * 16), (MIN(iir0 + blck_0, ir0_end) - iir0) * sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void ggml_compute_forward_mul_mat(
|
||||
const struct ggml_compute_params * params,
|
||||
struct ggml_tensor * dst,
|
||||
|
@ -12005,60 +12086,7 @@ UseGgmlGemm2:;
|
|||
num_rows_per_vec_dot = 1;
|
||||
}
|
||||
|
||||
assert(ne12% ne02 == 0);
|
||||
assert(ne13% ne03 == 0);
|
||||
|
||||
// block-tiling attempt
|
||||
const int64_t blck_0 = 16;
|
||||
const int64_t blck_1 = 16;
|
||||
|
||||
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
|
||||
|
||||
// attempt to reduce false-sharing (does not seem to make a difference)
|
||||
// 16 * 2, accounting for mmla kernels
|
||||
float tmp[32];
|
||||
|
||||
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
||||
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
|
||||
const int64_t i13 = (ir1/(ne12*ne1));
|
||||
const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
|
||||
const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
|
||||
|
||||
// broadcast src0 into src1
|
||||
const int64_t i03 = i13/r3;
|
||||
const int64_t i02 = i12/r2;
|
||||
|
||||
const int64_t i1 = i11;
|
||||
const int64_t i2 = i12;
|
||||
const int64_t i3 = i13;
|
||||
|
||||
const char * src0_row = (const char *) src0->data + (0 + i02*nb02 + i03*nb03);
|
||||
|
||||
// desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
|
||||
// if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
|
||||
// the original src1 data pointer, so we should index using the indices directly
|
||||
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
||||
const char * src1_col = (const char *) wdata +
|
||||
(src1_cont || src1->type != vec_dot_type
|
||||
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
|
||||
: (i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
|
||||
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
||||
//}
|
||||
|
||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot >1 ? 16 : 0), src0_row + ir0*nb01, (num_rows_per_vec_dot >1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot >1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
||||
}
|
||||
|
||||
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
||||
memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
ggml_compute_forward_mul_mat_one_chunk(params, dst, num_rows_per_vec_dot, ir0_start, ir0_end, ir1_start, ir1_end, src1_cont, r2, r3, vec_dot_type, wdata, row_size, vec_dot);
|
||||
}
|
||||
|
||||
// ggml_compute_forward_mul_mat_id
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue