Arm AArch64: minor code refactoring for resolving a build issue with cmake
This commit is contained in:
parent
8ee6779147
commit
a657246d62
3 changed files with 547 additions and 800 deletions
1277
ggml-aarch64.cpp
1277
ggml-aarch64.cpp
File diff suppressed because it is too large
Load diff
|
@ -24,17 +24,10 @@ block_q8_0x4 make_block_q8_0x4(const block_q8_0 * const in[4], unsigned int bloc
|
|||
block_q8_0x8 make_block_q8_0x8(const block_q8_0 * const in[8], unsigned int block_len);
|
||||
|
||||
// GEMV
|
||||
void ggml_gemv_q4_0_q8_0_aarch64_sve256 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemv_q4_0_q8_0_aarch64_neon (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemv_q4_0_q8_0_aarch64_neon_noi8mm(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemv_q8_0_q8_0_aarch64_sve256 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemv_q8_0_q8_0_aarch64_neon (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemv_q4_0_q8_0_aarch64 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
|
||||
// GEMM
|
||||
void ggml_gemm_q4_0_q8_0_aarch64_sve256 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemm_q4_0_q8_0_aarch64_neon (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemm_q4_0_q8_0_aarch64_neon_noi8mm(int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemm_q8_0_q8_0_aarch64 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
void ggml_gemm_q4_0_q8_0_aarch64 (int n, float * GGML_RESTRICT s, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, int ith, int nth);
|
||||
|
||||
#ifdef __cplusplus
|
||||
}
|
||||
|
|
|
@ -38,7 +38,7 @@
|
|||
#include <unistd.h>
|
||||
#endif
|
||||
|
||||
#ifdef __ARM_FEATURE_MATMUL_INT8
|
||||
#if defined(__ARM_NEON) || defined(__ARM_FEATURE_SVE) || defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
#undef GGML_USE_LLAMAFILE
|
||||
#endif
|
||||
|
||||
|
@ -915,16 +915,8 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.vec_dot = NULL,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
#if defined(__ARM_FEATURE_SVE)
|
||||
.gemv = ggml_gemv_q4_0_q8_0_aarch64_sve256,
|
||||
.gemm = ggml_gemm_q4_0_q8_0_aarch64_sve256,
|
||||
#elif defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
.gemv = ggml_gemv_q4_0_q8_0_aarch64_neon,
|
||||
.gemm = ggml_gemm_q4_0_q8_0_aarch64_neon,
|
||||
#elif defined(__ARM_NEON)
|
||||
.gemv = ggml_gemv_q4_0_q8_0_aarch64_neon_noi8mm,
|
||||
.gemm = ggml_gemm_q4_0_q8_0_aarch64_neon_noi8mm,
|
||||
#endif
|
||||
.gemv = ggml_gemv_q4_0_q8_0_aarch64,
|
||||
.gemm = ggml_gemm_q4_0_q8_0_aarch64,
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -12242,15 +12234,15 @@ UseGgmlGemm1:;
|
|||
}
|
||||
}
|
||||
}
|
||||
if ((type == GGML_TYPE_Q4_0_AARCH64) && (ne11 >= 4) && (ne12 == 1) && (ne13 == 1)) {
|
||||
if (from_float_to_mat && gemm && (ne11 >= 4) && (ne12 == 1) && (ne13 == 1)) {
|
||||
for (int64_t i11 = 0; i11 < ne11 / 4; ++i11) {
|
||||
from_float_to_mat((float *)((char *) src1->data + i11 * 4 * nb11), (void *) wdata, ne10, 4, ggml_cpu_has_matmul_int8() ? 8 : 4);
|
||||
wdata += row_size * 4;
|
||||
}
|
||||
for (int64_t i11 = (ne11 / 4) * 4; i11 < ne11; ++i11) {
|
||||
from_float_to_vec_dot((float *)((char *) src1->data + i11 * nb11), (void *) wdata, ne10);
|
||||
wdata += row_size;
|
||||
}
|
||||
wdata += row_size;
|
||||
}
|
||||
}
|
||||
else {
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
|
@ -12340,24 +12332,29 @@ UseGgmlGemm2:;
|
|||
//if (ith == 0)
|
||||
// printf("MUL_MAT = [%d, %d, %d, %d] x [%d, %d, %d, %d] = %d x %d = %d. Fp Ops/Ch %d\n", ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, nchunk0, nchunk1, nchunk0 * nchunk1, ne00 * nr0 * nr1 / nchunk0 / nchunk1);
|
||||
|
||||
if ((ggml_n_dims(src0) == 2) && (ne11 == 1) && (type == GGML_TYPE_Q4_0_AARCH64)) {
|
||||
gemv(ne00, (float *)((char *) dst->data), (const char *) src0->data, (const char *) wdata, 1, ne01, ith, nth);
|
||||
if ((ggml_n_dims(src0) == 2) && gemm && gemv) {
|
||||
if (ne11 == 1) gemv(ne00, (float *)((char *) dst->data), (const char *) src0->data, (const char *) wdata, 1, ne01, ith, nth);
|
||||
else {
|
||||
for (int row_iter = 0; row_iter < ne11 / 16; row_iter++) {
|
||||
gemm(ne00, (float *)((char *) dst->data + (row_iter * 16 * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 16) * row_size : (row_iter * 16 * nb11)), 16, ne01, ith, nth);
|
||||
}
|
||||
int rows_processed = (ne11 / 16) * 16;
|
||||
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 8; row_iter++) {
|
||||
gemm(ne00, (float *)((char *) dst->data + ((rows_processed + row_iter * 8) * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 8) * row_size : ((rows_processed + row_iter * 8) * nb11)), 8, ne01, ith, nth);
|
||||
}
|
||||
rows_processed = rows_processed + ((ne11 - rows_processed) / 8) * 8;
|
||||
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) {
|
||||
gemm(ne00, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), 4, ne01, ith, nth);
|
||||
}
|
||||
rows_processed = rows_processed + ((ne11 - rows_processed) / 4) * 4;
|
||||
for (int row_iter = rows_processed; row_iter < ne11; row_iter++) {
|
||||
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * row_size) : (row_iter * nb11)), 1, ne01, ith, nth);
|
||||
}
|
||||
}
|
||||
}
|
||||
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 2) && (type == GGML_TYPE_Q4_0_AARCH64)) {
|
||||
for (int row_iter = 0; row_iter < ne11 / 16; row_iter++) {
|
||||
gemm(ne00, (float *)((char *) dst->data + (row_iter * 16 * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 16) * row_size : (row_iter * 16 * nb11)), 16, ne01, ith, nth);
|
||||
}
|
||||
int rows_processed = (ne11 / 16) * 16;
|
||||
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 8; row_iter++) {
|
||||
gemm(ne00, (float *)((char *) dst->data + ((rows_processed + row_iter * 8) * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 8) * row_size : ((rows_processed + row_iter * 8) * nb11)), 8, ne01, ith, nth);
|
||||
}
|
||||
rows_processed = rows_processed + ((ne11 - rows_processed) / 8) * 8;
|
||||
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) {
|
||||
gemm(ne00, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), 4, ne01, ith, nth);
|
||||
}
|
||||
rows_processed = rows_processed + ((ne11 - rows_processed) / 4) * 4;
|
||||
for (int row_iter = rows_processed; row_iter < ne11; row_iter++) {
|
||||
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), 1, ne01, ith, nth);
|
||||
else if ((ggml_n_dims(src0) == 2) && gemv) {
|
||||
for (int row_iter = 0; row_iter < ne11; row_iter++) {
|
||||
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * row_size) : (row_iter * nb11)), 1, ne01, ith, nth);
|
||||
}
|
||||
}
|
||||
else {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue