Arm AArch64: minor code refactoring for rebase

This commit is contained in:
Dibakar Gope 2024-05-01 06:53:48 +00:00 committed by Dibakar Gope
parent 441ab64989
commit 8ee6779147
4 changed files with 31 additions and 35 deletions

View file

@ -92,7 +92,7 @@ size_t quantize_q4_0_aarch64(const float * GGML_RESTRICT src, void * GGML_RESTRI
}
}
void quantize_row_q8_0_aarch64(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int nrows_interleaved, int blocklen_per_row) {
void quantize_q8_0_aarch64(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k, int nrows_interleaved, int blocklen_per_row) {
assert(QK8_0 == 32);
assert(k % QK8_0 == 0);
const int nb = k / QK8_0;

View file

@ -13,7 +13,7 @@ extern "C" {
#endif
// Quantization
void quantize_row_q8_0_aarch64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k, int nrows_interleaved, int blocklen_per_row);
void quantize_q8_0_aarch64(const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k, int nrows_interleaved, int blocklen_per_row);
// Quantization utilizing an importance matrix (a.k.a. "Activation aWare Quantization")
size_t quantize_q4_0_aarch64(const float * GGML_RESTRICT src, void * GGML_RESTRICT dst, int64_t nrows, int64_t n_per_row, const float * imatrix);

View file

@ -14760,6 +14760,16 @@ static bool validate_fp16(ggml_fp16_t f, size_t i) {
} \
}
#define VALIDATE_ROW_DATA_DVEC_F16_IMPL(type, data, nb, nr) \
const type * q = (const type *) (data); \
for (size_t i = 0; i < (nb); ++i) { \
for (size_t j = 0; j < (nr); ++j) { \
if (!validate_fp16(q[i].d[j], i)) { \
return false; \
} \
} \
}
bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbytes) {
if (type < 0 || type >= GGML_TYPE_COUNT) {
fprintf(stderr, "%s: invalid type %d\n", __func__, type);
@ -14977,6 +14987,19 @@ bool ggml_validate_row_data(enum ggml_type type, const void * data, size_t nbyte
{
VALIDATE_ROW_DATA_D_F16_IMPL(block_iq4_nl, data, nb);
} break;
case GGML_TYPE_Q4_0_AARCH64:
{
#if defined(__ARM_FEATURE_SVE)
if (svcntw() == 8) {
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x8, data, nbytes / sizeof(block_q4_0x8), 8);
}
else if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) {
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4);
}
#elif defined(__ARM_NEON)
VALIDATE_ROW_DATA_DVEC_F16_IMPL(block_q4_0x4, data, nbytes / sizeof(block_q4_0x4), 4);
#endif
} break;
case GGML_TYPE_I8:
case GGML_TYPE_I16:
case GGML_TYPE_I32:

View file

@ -705,7 +705,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
#else
.nrows = 1,
#endif
.from_float_to_mat = quantize_row_q8_0_aarch64,
.from_float_to_mat = quantize_q8_0_aarch64,
},
[GGML_TYPE_Q8_1] = {
.type_name = "q8_1",
@ -909,16 +909,12 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.blck_size = QK4_0,
.type_size = sizeof(block_q4_0),
.is_quantized = true,
.to_float = (ggml_to_float_t) dequantize_row_q4_0,
.from_float = quantize_row_q4_0,
.from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
.vec_dot = ggml_vec_dot_q4_0_q8_0,
.to_float = NULL,
.from_float = NULL,
.from_float_reference = NULL,
.vec_dot = NULL,
.vec_dot_type = GGML_TYPE_Q8_0,
#if defined (__ARM_FEATURE_MATMUL_INT8)
.nrows = 2,
#else
.nrows = 1,
#endif
#if defined(__ARM_FEATURE_SVE)
.gemv = ggml_gemv_q4_0_q8_0_aarch64_sve256,
.gemm = ggml_gemm_q4_0_q8_0_aarch64_sve256,
@ -12347,8 +12343,7 @@ UseGgmlGemm2:;
if ((ggml_n_dims(src0) == 2) && (ne11 == 1) && (type == GGML_TYPE_Q4_0_AARCH64)) {
gemv(ne00, (float *)((char *) dst->data), (const char *) src0->data, (const char *) wdata, 1, ne01, ith, nth);
}
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 16) && (type == GGML_TYPE_Q4_0_AARCH64)) {
// use nrows-sized 16, 8, and 4 GEMM kernels
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 2) && (type == GGML_TYPE_Q4_0_AARCH64)) {
for (int row_iter = 0; row_iter < ne11 / 16; row_iter++) {
gemm(ne00, (float *)((char *) dst->data + (row_iter * 16 * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 16) * row_size : (row_iter * 16 * nb11)), 16, ne01, ith, nth);
}
@ -12365,28 +12360,6 @@ UseGgmlGemm2:;
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), 1, ne01, ith, nth);
}
}
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 8) && (type == GGML_TYPE_Q4_0_AARCH64)) {
// use nrows-sized 8, and 4 GEMM kernels
for (int row_iter = 0; row_iter < ne11 / 8; row_iter++) {
gemm(ne00, (float *)((char *) dst->data + (row_iter * 8 * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 8) * row_size : (row_iter * 8 * nb11)), 8, ne01, ith, nth);
}
int rows_processed = (ne11 / 8) * 8;
for (int row_iter = 0; row_iter < (ne11 - rows_processed) / 4; row_iter++) {
gemm(ne00, (float *)((char *) dst->data + ((rows_processed + row_iter * 4) * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (rows_processed + row_iter * 4) * row_size : ((rows_processed + row_iter * 4) * nb11)), 4, ne01, ith, nth);
}
for (int row_iter = ((ne11 / 8) * 8) + ((ne11 - rows_processed) / 4 * 4); row_iter < ne11; row_iter++) {
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), 1, ne01, ith, nth);
}
}
else if ((ggml_n_dims(src0) == 2) && (ne11 >= 4) && (type == GGML_TYPE_Q4_0_AARCH64)) {
// use nrows-sized 4 GEMM kernel
for (int row_iter = 0; row_iter < ne11 / 4; row_iter++) {
gemm(ne00, (float *)((char *) dst->data + (row_iter * 4 * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter * 4) * row_size : (row_iter * 4 * nb11)), 4, ne01, ith, nth);
}
for (int row_iter = (ne11 / 4) * 4; row_iter < ne11; row_iter++) {
gemv(ne00, (float *)((char *) dst->data + (row_iter * nb1)), (const char *) src0->data, (const char *) wdata + (src1_cont || src1->type != vec_dot_type ? (row_iter)*row_size : (row_iter * nb11)), 1, ne01, ith, nth);
}
}
else {
// The first chunk comes from our thread_id, the rest will get auto-assigned.
int current_chunk = ith;