Arm AArch64: minor code changes for rebase
This commit is contained in:
parent
7ac03e5fe8
commit
e2c1c47fa8
1 changed files with 13 additions and 19 deletions
|
@ -12370,29 +12370,31 @@ UseGgmlGemm2:;
|
||||||
//if (ith == 0)
|
//if (ith == 0)
|
||||||
// printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
|
// printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
|
||||||
|
|
||||||
|
const void * src1_wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||||
|
const size_t src1_col_stride = ggml_is_contiguous(src1) || src1->type != vec_dot_type ? ggml_row_size(vec_dot_type, ne10) : nb11;
|
||||||
if ((ggml_n_dims(src0) == 2) && gemm && gemv) {
|
if ((ggml_n_dims(src0) == 2) && gemm && gemv) {
|
||||||
if (ne11 == 1) gemv(ne00, (float *)((char *) dst->data), (const char *) src0->data, (const char *) wdata, 1, ne01, ith, nth);
|
if (ne11 == 1) gemv(ne00, (float *)((char *) dst->data), (const char *) src0->data, (const char *) src1_wdata, 1, ne01, ith, nth);
|
||||||
else {
|
else {
|
||||||
for (int row_iter = 0; row_iter < ne11 / 16; row_iter++) {
|
for (int iter = 0; iter < ne11 / 16; iter++) {
|
||||||
gemm(ne00, (float *)((char *) dst->data + (row_iter * 16 * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 16) * row_size : (row_iter * 16 * nb11)), 16, ne01, ith, nth);
|
gemm(ne00, (float *)((char *) dst->data + (iter * 16 * nb1)), (const char *) src0->data, (const char *) src1_wdata + (src1_col_stride * iter * 16), 16, ne01, ith, nth);
|
||||||
}
|
}
|
||||||
int rows_processed = (ne11 / 16) * 16;
|
int rows_processed = (ne11 / 16) * 16;
|
||||||
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 8; row_iter++) {
|
for (int iter = 0; iter < (ne11 - rows_processed) / 8; iter++) {
|
||||||
gemm(ne00, (float *)((char *) dst->data + ((rows_processed + row_iter * 8) * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 8) * row_size : ((rows_processed + row_iter * 8) * nb11)), 8, ne01, ith, nth);
|
gemm(ne00, (float *)((char *) dst->data + ((rows_processed + iter * 8) * nb1)), (const char *) src0->data, (const char *) src1_wdata + (src1_col_stride * (rows_processed + iter * 8)), 8, ne01, ith, nth);
|
||||||
}
|
}
|
||||||
rows_processed = rows_processed + ((ne11 - rows_processed) / 8) * 8;
|
rows_processed = rows_processed + ((ne11 - rows_processed) / 8) * 8;
|
||||||
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) {
|
for (int iter = 0; iter < (ne11 - rows_processed) / 4; iter++) {
|
||||||
gemm(ne00, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), 4, ne01, ith, nth);
|
gemm(ne00, (float *)((char *) dst->data + ((rows_processed + iter * 4) * nb1)), (const char *) src0->data, (const char *) src1_wdata + (src1_col_stride * (rows_processed + iter * 4)), 4, ne01, ith, nth);
|
||||||
}
|
}
|
||||||
rows_processed = rows_processed + ((ne11 - rows_processed) / 4) * 4;
|
rows_processed = rows_processed + ((ne11 - rows_processed) / 4) * 4;
|
||||||
for (int row_iter = rows_processed; row_iter < ne11; row_iter++) {
|
for (int iter = rows_processed; iter < ne11; iter++) {
|
||||||
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * row_size) : (row_iter * nb11)), 1, ne01, ith, nth);
|
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)), (const char *) src0->data, (const char *) src1_wdata + (src1_col_stride * iter), 1, ne01, ith, nth);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else if ((ggml_n_dims(src0) == 2) && gemv) {
|
else if ((ggml_n_dims(src0) == 2) && gemv) {
|
||||||
for (int row_iter = 0; row_iter < ne11; row_iter++) {
|
for (int iter = 0; iter < ne11; iter++) {
|
||||||
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * row_size) : (row_iter * nb11)), 1, ne01, ith, nth);
|
gemv(ne00, (float *)((char *) dst->data + (iter * nb1)), (const char *) src0->data, (const char *) src1_wdata + (src1_col_stride * iter), 1, ne01, ith, nth);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
else {
|
else {
|
||||||
|
@ -22030,12 +22032,4 @@ int ggml_cpu_has_matmul_int8(void) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
int ggml_cpu_has_sve(void) {
|
|
||||||
#if defined(__ARM_FEATURE_SVE)
|
|
||||||
return 1;
|
|
||||||
#else
|
|
||||||
return 0;
|
|
||||||
#endif
|
|
||||||
}
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue