Arm AArch64: simplify logic for calling gemm and gemv functions in ggml_compute_forward_mul_mat
This commit is contained in:
parent
7a706067b5
commit
ffbfabb517
1 changed files with 11 additions and 32 deletions
|
@ -12383,41 +12383,20 @@ UseGgmlGemm2:;
|
|||
|
||||
if ((ggml_n_dims(src0) == 2) && gemm && gemv) {
|
||||
if (src0_start >= src0_end) return;
|
||||
if (ne11 == 1)
|
||||
gemv(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata, 1, src0_end - src0_start);
|
||||
else {
|
||||
for (int iter = 0; iter < ne11 / 16; iter++) {
|
||||
gemm(ne00, (float *)((char *) dst->data + (iter * 16 * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter * 16), 16,
|
||||
src0_end - src0_start);
|
||||
}
|
||||
int rows_processed = (ne11 / 16) * 16;
|
||||
for (int iter = 0; iter < (ne11 - rows_processed) / 8; iter++) {
|
||||
gemm(ne00, (float *)((char *) dst->data + ((rows_processed + iter * 8) * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata + (src1_col_stride * (rows_processed + iter * 8)), 8, src0_end - src0_start);
|
||||
}
|
||||
rows_processed = rows_processed + ((ne11 - rows_processed) / 8) * 8;
|
||||
for (int iter = 0; iter < (ne11 - rows_processed) / 4; iter++) {
|
||||
gemm(ne00, (float *)((char *) dst->data + ((rows_processed + iter * 4) * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata + (src1_col_stride * (rows_processed + iter * 4)), 4, src0_end - src0_start);
|
||||
}
|
||||
rows_processed = rows_processed + ((ne11 - rows_processed) / 4) * 4;
|
||||
for (int iter = rows_processed; iter < ne11; iter++) {
|
||||
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||
src0_end - src0_start);
|
||||
}
|
||||
}
|
||||
} else if ((ggml_n_dims(src0) == 2) && gemv) {
|
||||
if (src0_start >= src0_end) return;
|
||||
for (int iter = 0; iter < ne11; iter++) {
|
||||
// If there are more than three rows in src1, use gemm; otherwise, use gemv.
|
||||
if (ne11 > 3)
|
||||
gemm(ne00, (float *)((char *) dst->data) + src0_start, ne01, (const char *) src0->data + src0_start * nb01,
|
||||
(const char *) src1_wdata, ne11 - ne11 % 4, src0_end - src0_start);
|
||||
for (int iter = ne11 - ne11 % 4; iter < ne11; iter++)
|
||||
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||
src0_end - src0_start);
|
||||
} else if ((ggml_n_dims(src0) == 2) && gemv) {
|
||||
if (src0_start >= src0_end) return;
|
||||
for (int iter = 0; iter < ne11; iter++)
|
||||
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)) + src0_start, ne01,
|
||||
(const char *) src0->data + src0_start * nb01, (const char *) src1_wdata + (src1_col_stride * iter), 1,
|
||||
src0_end - src0_start);
|
||||
}
|
||||
} else {
|
||||
// The first chunk comes from our thread_id, the rest will get auto-assigned.
|
||||
int current_chunk = ith;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue