Arm AArch64: simplify logic for calling gemm and gemv functions in ggml_compute_forward_mul_mat

This commit is contained in:
Dibakar Gope 2024-06-23 20:22:28 +00:00
parent 7a706067b5
commit ffbfabb517

View file

@ -12383,41 +12383,20 @@ UseGgmlGemm2:;
if ((ggml_n_dims(src0) == 2) && gemm && gemv) {
if (src0_start >= src0_end) return;
if (ne11 == 1)
gemv(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata, 1, src0_end - src0_start);
else {
for (int iter = 0; iter < ne11 / 16; iter++) {
gemm(ne00, (float *)((char *) dst->data + (iter * 16 * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter * 16), 16,
src0_end - src0_start);
}
int rows_processed = (ne11 / 16) * 16;
for (int iter = 0; iter < (ne11 - rows_processed) / 8; iter++) {
gemm(ne00, (float *)((char *) dst->data + ((rows_processed + iter * 8) * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata + (src1_col_stride * (rows_processed + iter * 8)), 8, src0_end - src0_start);
}
rows_processed = rows_processed + ((ne11 - rows_processed) / 8) * 8;
for (int iter = 0; iter < (ne11 - rows_processed) / 4; iter++) {
gemm(ne00, (float *)((char *) dst->data + ((rows_processed + iter * 4) * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata + (src1_col_stride * (rows_processed + iter * 4)), 4, src0_end - src0_start);
}
rows_processed = rows_processed + ((ne11 - rows_processed) / 4) * 4;
for (int iter = rows_processed; iter < ne11; iter++) {
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start);
}
}
} else if ((ggml_n_dims(src0) == 2) && gemv) {
if (src0_start >= src0_end) return;
for (int iter = 0; iter < ne11; iter++) {
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
if (ne11 > 3)
gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++)
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start);
} else if ((ggml_n_dims(src0) == 2) && gemv) {
if (src0_start >= src0_end) return;
for (int iter = 0; iter < ne11; iter++)
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
src0_end - src0_start);
}
} else {
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;