Renaming and moving a bunch of variables around.
This commit is contained in:
parent
e098171aa7
commit
a968553c6f
1 changed files with 24 additions and 24 deletions
44
ggml.c
44
ggml.c
|
@ -11983,20 +11983,28 @@ UseGgmlGemm2:;
|
|||
const int64_t dr0 = (nr0 + nth0 - 1)/nth0;
|
||||
const int64_t dr1 = (nr1 + nth1 - 1)/nth1;
|
||||
|
||||
const int64_t ir010 = dr0*ith0;
|
||||
const int64_t ir011 = MIN(ir010 + dr0, nr0);
|
||||
const int64_t ir0_start = dr0*ith0;
|
||||
const int64_t ir0_end = MIN(ir0_start + dr0, nr0);
|
||||
|
||||
const int64_t ir110 = dr1*ith1;
|
||||
const int64_t ir111 = MIN(ir110 + dr1, nr1);
|
||||
const int64_t ir1_start = dr1*ith1;
|
||||
const int64_t ir1_end = MIN(ir1_start + dr1, nr1);
|
||||
|
||||
//printf("ir010 = %6lld, ir011 = %6lld, ir110 = %6lld, ir111 = %6lld\n", ir010, ir011, ir110, ir111);
|
||||
//printf("ir0_start = %6lld, ir0_end = %6lld, ir1_start = %6lld, ir1_end = %6lld\n", ir0_start, ir0_end, ir1_start, ir1_end);
|
||||
|
||||
// threads with no work simply yield (not sure if it helps)
|
||||
if (ir010 >= ir011 || ir110 >= ir111) {
|
||||
if (ir0_start >= ir0_end || ir1_start >= ir1_end) {
|
||||
sched_yield();
|
||||
return;
|
||||
}
|
||||
|
||||
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
||||
int64_t num_rows_per_vec_dot = vec_dot_num_rows;
|
||||
// TODO: currently the mmla kernels support only even numbered rows/cols.
|
||||
// this check can be removed once they are extended to support odd numbered rows/cols too
|
||||
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
|
||||
num_rows_per_vec_dot = 1;
|
||||
}
|
||||
|
||||
assert(ne12% ne02 == 0);
|
||||
assert(ne13% ne03 == 0);
|
||||
|
||||
|
@ -12004,23 +12012,15 @@ UseGgmlGemm2:;
|
|||
const int64_t blck_0 = 16;
|
||||
const int64_t blck_1 = 16;
|
||||
|
||||
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
||||
int64_t nrc = vec_dot_num_rows;
|
||||
// TODO: currently the mmla kernels support only even numbered rows/cols.
|
||||
// this check can be removed once they are extended to support odd numbered rows/cols too
|
||||
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
|
||||
nrc = 1;
|
||||
}
|
||||
|
||||
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
|
||||
|
||||
// attempt to reduce false-sharing (does not seem to make a difference)
|
||||
// 16 * 2, accounting for mmla kernels
|
||||
float tmp[32];
|
||||
|
||||
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
||||
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) {
|
||||
for (int64_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) {
|
||||
for (int64_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) {
|
||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir1_end; ir1 += num_rows_per_vec_dot) {
|
||||
const int64_t i13 = (ir1/(ne12*ne1));
|
||||
const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
|
||||
const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
|
||||
|
@ -12045,16 +12045,16 @@ UseGgmlGemm2:;
|
|||
: (i11*nb11 + i12*nb12 + i13*nb13));
|
||||
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
||||
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ++ir0) {
|
||||
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
||||
//}
|
||||
|
||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc);
|
||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir0_end; ir0 += num_rows_per_vec_dot) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], (num_rows_per_vec_dot >1 ? 16 : 0), src0_row + ir0*nb01, (num_rows_per_vec_dot >1 ? nb01 : 0), src1_col, (num_rows_per_vec_dot >1 ? src1_col_stride : 0), num_rows_per_vec_dot);
|
||||
}
|
||||
|
||||
for (int cn = 0; cn < nrc; ++cn) {
|
||||
memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
|
||||
for (int cn = 0; cn < num_rows_per_vec_dot; ++cn) {
|
||||
memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir0_end) - iir0)*sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue