ggml : add mmla kernels for quantized GEMM (#4966)
* ggml: aarch64: implement smmla kernel for q8_0_q8_0 quantized gemm armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q8_0_q8_0 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel. * ggml: aarch64: implement smmla kernel for q4_0_q8_0 quantized gemm armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q4_0_q8_0 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel. * ggml: aarch64: implement smmla kernel for q4_1_q8_1 quantized gemm armv8.2-a and above supports MMLA instructions that have higher throughput than DOT. this commit adds mmla kernel for q4_1_q8_1 gemm. The feature is enabled if the platform supports "__ARM_FEATURE_MATMUL_INT8" On AWS Graviton3 processors this kernel resulted up to 1.5x improvement for prompt evaluation throughput compared to the default sdot kernel. * ggml: update unit tests for the new vec_dot interface * llama.cpp: add MATMUL_INT8 capability to system_info
This commit is contained in:
parent
e4640d8fdf
commit
a07d0fee1f
10 changed files with 441 additions and 88 deletions
164
ggml.c
164
ggml.c
|
@ -428,8 +428,8 @@ int64_t ggml_cycles_per_ms(void) {
|
|||
|
||||
static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
||||
|
||||
static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y);
|
||||
static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y);
|
||||
static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc);
|
||||
static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc);
|
||||
|
||||
static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
||||
[GGML_TYPE_I8] = {
|
||||
|
@ -457,6 +457,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.is_quantized = false,
|
||||
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f32,
|
||||
.vec_dot_type = GGML_TYPE_F32,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_F16] = {
|
||||
.type_name = "f16",
|
||||
|
@ -468,6 +469,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = (ggml_from_float_t) ggml_fp32_to_fp16_row,
|
||||
.vec_dot = (ggml_vec_dot_t) ggml_vec_dot_f16,
|
||||
.vec_dot_type = GGML_TYPE_F16,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q4_0] = {
|
||||
.type_name = "q4_0",
|
||||
|
@ -479,6 +481,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = (ggml_from_float_t) quantize_row_q4_0_reference,
|
||||
.vec_dot = ggml_vec_dot_q4_0_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
||||
.nrows = 2,
|
||||
#else
|
||||
.nrows = 1,
|
||||
#endif
|
||||
},
|
||||
[GGML_TYPE_Q4_1] = {
|
||||
.type_name = "q4_1",
|
||||
|
@ -490,6 +497,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = (ggml_from_float_t) quantize_row_q4_1_reference,
|
||||
.vec_dot = ggml_vec_dot_q4_1_q8_1,
|
||||
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
||||
.nrows = 2,
|
||||
#else
|
||||
.nrows = 1,
|
||||
#endif
|
||||
},
|
||||
[4] = { // GGML_TYPE_Q4_2
|
||||
.type_name = "DEPRECATED",
|
||||
|
@ -501,6 +513,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = NULL,
|
||||
.vec_dot = NULL,
|
||||
.vec_dot_type = GGML_TYPE_COUNT,
|
||||
.nrows = 1,
|
||||
},
|
||||
[5] = { // GGML_TYPE_Q4_3
|
||||
.type_name = "DEPRECATED",
|
||||
|
@ -512,6 +525,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = NULL,
|
||||
.vec_dot = NULL,
|
||||
.vec_dot_type = GGML_TYPE_COUNT,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q5_0] = {
|
||||
.type_name = "q5_0",
|
||||
|
@ -523,6 +537,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = (ggml_from_float_t) quantize_row_q5_0_reference,
|
||||
.vec_dot = ggml_vec_dot_q5_0_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q5_1] = {
|
||||
.type_name = "q5_1",
|
||||
|
@ -534,6 +549,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = (ggml_from_float_t) quantize_row_q5_1_reference,
|
||||
.vec_dot = ggml_vec_dot_q5_1_q8_1,
|
||||
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q8_0] = {
|
||||
.type_name = "q8_0",
|
||||
|
@ -545,6 +561,11 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = (ggml_from_float_t) quantize_row_q8_0_reference,
|
||||
.vec_dot = ggml_vec_dot_q8_0_q8_0,
|
||||
.vec_dot_type = GGML_TYPE_Q8_0,
|
||||
#if defined (__ARM_FEATURE_MATMUL_INT8)
|
||||
.nrows = 2,
|
||||
#else
|
||||
.nrows = 1,
|
||||
#endif
|
||||
},
|
||||
[GGML_TYPE_Q8_1] = {
|
||||
.type_name = "q8_1",
|
||||
|
@ -554,6 +575,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float = quantize_row_q8_1,
|
||||
.from_float_reference = (ggml_from_float_t) quantize_row_q8_1_reference,
|
||||
.vec_dot_type = GGML_TYPE_Q8_1,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q2_K] = {
|
||||
.type_name = "q2_K",
|
||||
|
@ -565,6 +587,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = (ggml_from_float_t) quantize_row_q2_K_reference,
|
||||
.vec_dot = ggml_vec_dot_q2_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q3_K] = {
|
||||
.type_name = "q3_K",
|
||||
|
@ -576,6 +599,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = (ggml_from_float_t) quantize_row_q3_K_reference,
|
||||
.vec_dot = ggml_vec_dot_q3_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q4_K] = {
|
||||
.type_name = "q4_K",
|
||||
|
@ -587,6 +611,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = (ggml_from_float_t) quantize_row_q4_K_reference,
|
||||
.vec_dot = ggml_vec_dot_q4_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q5_K] = {
|
||||
.type_name = "q5_K",
|
||||
|
@ -598,6 +623,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = (ggml_from_float_t) quantize_row_q5_K_reference,
|
||||
.vec_dot = ggml_vec_dot_q5_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q6_K] = {
|
||||
.type_name = "q6_K",
|
||||
|
@ -609,6 +635,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = (ggml_from_float_t) quantize_row_q6_K_reference,
|
||||
.vec_dot = ggml_vec_dot_q6_K_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_IQ2_XXS] = {
|
||||
.type_name = "iq2_xxs",
|
||||
|
@ -620,6 +647,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = NULL,
|
||||
.vec_dot = ggml_vec_dot_iq2_xxs_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_IQ2_XS] = {
|
||||
.type_name = "iq2_xs",
|
||||
|
@ -631,6 +659,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = NULL,
|
||||
.vec_dot = ggml_vec_dot_iq2_xs_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_IQ3_XXS] = {
|
||||
.type_name = "iq3_xxs",
|
||||
|
@ -642,6 +671,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.from_float_reference = (ggml_from_float_t)quantize_row_iq3_xxs_reference,
|
||||
.vec_dot = ggml_vec_dot_iq3_xxs_q8_K,
|
||||
.vec_dot_type = GGML_TYPE_Q8_K,
|
||||
.nrows = 1,
|
||||
},
|
||||
[GGML_TYPE_Q8_K] = {
|
||||
.type_name = "q8_K",
|
||||
|
@ -1212,7 +1242,13 @@ inline static void ggml_vec_neg_f32 (const int n, float * y, const float * x)
|
|||
inline static void ggml_vec_mul_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]*y[i]; }
|
||||
inline static void ggml_vec_div_f32 (const int n, float * z, const float * x, const float * y) { for (int i = 0; i < n; ++i) z[i] = x[i]/y[i]; }
|
||||
|
||||
static void ggml_vec_dot_f32(const int n, float * restrict s, const float * restrict x, const float * restrict y) {
|
||||
static void ggml_vec_dot_f32(int n, float * restrict s, size_t bs, const float * restrict x, size_t bx, const float * restrict y, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
#ifdef GGML_SIMD
|
||||
float sumf = 0.0f;
|
||||
const int np = (n & ~(GGML_F32_STEP - 1));
|
||||
|
@ -1249,7 +1285,13 @@ static void ggml_vec_dot_f32(const int n, float * restrict s, const float * rest
|
|||
*s = sumf;
|
||||
}
|
||||
|
||||
static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t * restrict x, ggml_fp16_t * restrict y) {
|
||||
static void ggml_vec_dot_f16(int n, float * restrict s, size_t bs, ggml_fp16_t * restrict x, size_t bx, ggml_fp16_t * restrict y, size_t by, int nrc) {
|
||||
assert(nrc == 1);
|
||||
UNUSED(nrc);
|
||||
UNUSED(bx);
|
||||
UNUSED(by);
|
||||
UNUSED(bs);
|
||||
|
||||
ggml_float sumf = 0.0;
|
||||
|
||||
#if defined(GGML_SIMD)
|
||||
|
@ -1455,7 +1497,7 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
|||
#endif
|
||||
}
|
||||
|
||||
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); }
|
||||
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); }
|
||||
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
|
||||
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
|
||||
inline static void ggml_vec_log_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = logf(x[i]); }
|
||||
|
@ -9992,6 +10034,7 @@ static void ggml_compute_forward_mul_mat(
|
|||
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
|
||||
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
|
||||
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
|
||||
int64_t const vec_dot_num_rows = type_traits[type].nrows;
|
||||
|
||||
GGML_ASSERT(ne0 == ne01);
|
||||
GGML_ASSERT(ne1 == ne11);
|
||||
|
@ -10159,12 +10202,23 @@ static void ggml_compute_forward_mul_mat(
|
|||
const int64_t blck_0 = 16;
|
||||
const int64_t blck_1 = 16;
|
||||
|
||||
// dot kernels can handle 1 row and col at a time, but mmla kernels can process 2 rows and cols
|
||||
int64_t nrc = vec_dot_num_rows;
|
||||
// TODO: currently the mmla kernels support only even numbered rows/cols.
|
||||
// this check can be removed once they are extended to support odd numbered rows/cols too
|
||||
if ((nr0 % 2 != 0) || (ne11 % 2 != 0)) {
|
||||
nrc = 1;
|
||||
}
|
||||
|
||||
const size_t src1_col_stride = src1_cont || src1->type != vec_dot_type ? row_size : nb11;
|
||||
|
||||
// attempt to reduce false-sharing (does not seem to make a difference)
|
||||
float tmp[16];
|
||||
// 16 * 2, accounting for mmla kernels
|
||||
float tmp[32];
|
||||
|
||||
for (int64_t iir1 = ir110; iir1 < ir111; iir1 += blck_1) {
|
||||
for (int64_t iir0 = ir010; iir0 < ir011; iir0 += blck_0) {
|
||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ++ir1) {
|
||||
for (int64_t ir1 = iir1; ir1 < iir1 + blck_1 && ir1 < ir111; ir1 += nrc) {
|
||||
const int64_t i13 = (ir1/(ne12*ne1));
|
||||
const int64_t i12 = (ir1 - i13*ne12*ne1)/ne1;
|
||||
const int64_t i11 = (ir1 - i13*ne12*ne1 - i12*ne1);
|
||||
|
@ -10187,17 +10241,19 @@ static void ggml_compute_forward_mul_mat(
|
|||
(src1_cont || src1->type != vec_dot_type
|
||||
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
|
||||
: (i11*nb11 + i12*nb12 + i13*nb13));
|
||||
|
||||
float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
|
||||
|
||||
//for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
||||
// vec_dot(ne00, &dst_col[ir0], src0_row + ir0*nb01, src1_col);
|
||||
//}
|
||||
|
||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
|
||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ir0 += nrc) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], (nrc>1 ? 16 : 0), src0_row + ir0*nb01, (nrc>1 ? nb01 : 0), src1_col, (nrc>1 ? src1_col_stride : 0), nrc);
|
||||
}
|
||||
|
||||
for (int cn = 0; cn < nrc; ++cn) {
|
||||
memcpy(&dst_col[iir0 + cn*nb1/nb0], tmp + (cn*16), (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
|
||||
}
|
||||
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -10386,7 +10442,7 @@ static void ggml_compute_forward_mul_mat_id(
|
|||
//}
|
||||
|
||||
for (int64_t ir0 = iir0; ir0 < iir0 + blck_0 && ir0 < ir011; ++ir0) {
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], src0_row + ir0*nb01, src1_col);
|
||||
vec_dot(ne00, &tmp[ir0 - iir0], 0, src0_row + ir0*nb01, 0, src1_col, 0, 1);
|
||||
}
|
||||
memcpy(&dst_col[iir0], tmp, (MIN(iir0 + blck_0, ir011) - iir0)*sizeof(float));
|
||||
}
|
||||
|
@ -11568,7 +11624,7 @@ static void ggml_compute_forward_soft_max_back_f32(
|
|||
|
||||
// linear runtime, no additional memory
|
||||
float dot_y_dy = 0;
|
||||
ggml_vec_dot_f32 (nc, &dot_y_dy, y, dy);
|
||||
ggml_vec_dot_f32 (nc, &dot_y_dy, 0, y, 0, dy, 0, 1);
|
||||
ggml_vec_cpy_f32 (nc, dx, dy);
|
||||
ggml_vec_acc1_f32(nc, dx, -dot_y_dy);
|
||||
ggml_vec_mul_f32 (nc, dx, dx, y);
|
||||
|
@ -12369,9 +12425,9 @@ static void ggml_compute_forward_conv_transpose_1d_f16_f32(
|
|||
const int i1n = i10*ne11;
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
float v = 0;
|
||||
ggml_vec_dot_f16(ne02, &v,
|
||||
(ggml_fp16_t *) wdata_src + i1n,
|
||||
(ggml_fp16_t *) wdata_kernel + i00*ne02);
|
||||
ggml_vec_dot_f16(ne02, &v, 0,
|
||||
(ggml_fp16_t *) wdata_src + i1n, 0,
|
||||
(ggml_fp16_t *) wdata_kernel + i00*ne02, 0, 1);
|
||||
dst_data[i10*s0 + i00] += v;
|
||||
}
|
||||
}
|
||||
|
@ -12466,9 +12522,9 @@ static void ggml_compute_forward_conv_transpose_1d_f32(
|
|||
const int i1n = i10*ne11;
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
float v = 0;
|
||||
ggml_vec_dot_f32(ne02, &v,
|
||||
wdata_src + i1n,
|
||||
wdata_kernel + i00*ne02);
|
||||
ggml_vec_dot_f32(ne02, &v, 0,
|
||||
wdata_src + i1n, 0,
|
||||
wdata_kernel + i00*ne02, 0, 1);
|
||||
dst_data[i10*s0 + i00] += v;
|
||||
}
|
||||
}
|
||||
|
@ -12783,9 +12839,9 @@ static void ggml_compute_forward_conv_transpose_2d(
|
|||
for (int i01 = 0; i01 < ne01; i01++) {
|
||||
for (int i00 = 0; i00 < ne00; i00++) {
|
||||
float v = 0;
|
||||
ggml_vec_dot_f16(ne03, &v,
|
||||
wdata_src + i1n,
|
||||
wdata_kernel + i01*ne00*ne03 + i00*ne03);
|
||||
ggml_vec_dot_f16(ne03, &v, 0,
|
||||
wdata_src + i1n, 0,
|
||||
wdata_kernel + i01*ne00*ne03 + i00*ne03, 0, 1);
|
||||
dst_data[(i11*stride + i01)*ne0 + i10*stride + i00] += v;
|
||||
}
|
||||
}
|
||||
|
@ -13214,9 +13270,9 @@ static void ggml_compute_forward_flash_attn_f32(
|
|||
const int i1 = ik1;
|
||||
|
||||
ggml_vec_dot_f32(neq0,
|
||||
S + i1,
|
||||
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
||||
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
||||
S + i1, 0,
|
||||
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
||||
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
||||
}
|
||||
|
||||
// scale
|
||||
|
@ -13299,9 +13355,9 @@ static void ggml_compute_forward_flash_attn_f32(
|
|||
const int iv3 = iq3;
|
||||
|
||||
ggml_vec_dot_f32(masked_begin,
|
||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
||||
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
||||
S);
|
||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
||||
(float *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
|
||||
S, 0, 1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -13404,9 +13460,9 @@ static void ggml_compute_forward_flash_attn_f16(
|
|||
const int i1 = ik1;
|
||||
|
||||
ggml_vec_dot_f16(neq0,
|
||||
S + i1,
|
||||
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
||||
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
||||
S + i1, 0,
|
||||
(ggml_fp16_t *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
||||
(ggml_fp16_t *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
||||
}
|
||||
} else {
|
||||
for (int64_t ic = 0; ic < nek1; ic += GGML_VEC_DOT_UNROLL) {
|
||||
|
@ -13508,9 +13564,9 @@ static void ggml_compute_forward_flash_attn_f16(
|
|||
const int iv3 = iq3;
|
||||
|
||||
ggml_vec_dot_f16(nev0,
|
||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
||||
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)),
|
||||
S16);
|
||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
||||
(ggml_fp16_t *) ((char *) v->data + ( ic*nbv1 + iv2*nbv2 + iv3*nbv3)), 0,
|
||||
S16, 0, 1);
|
||||
}
|
||||
} else {
|
||||
for (int64_t ic = 0; ic < nev1; ic += GGML_VEC_DOT_UNROLL) {
|
||||
|
@ -13652,9 +13708,9 @@ static void ggml_compute_forward_flash_ff_f16(
|
|||
const int i1 = ib01;
|
||||
|
||||
ggml_vec_dot_f16(nea0,
|
||||
S + i1,
|
||||
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)),
|
||||
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)));
|
||||
S + i1, 0,
|
||||
(ggml_fp16_t *) ((char *) b0->data + (ib01*nbb01 + ib02*nbb02 + ib03*nbb03)), 0,
|
||||
(ggml_fp16_t *) ((char *) a->data + ( ia1*nba1 + ia2*nba2 + ia3*nba3)), 0, 1);
|
||||
}
|
||||
|
||||
ggml_vec_add_f32(neb01, S, S, (float *) b1->data);
|
||||
|
@ -13677,9 +13733,9 @@ static void ggml_compute_forward_flash_ff_f16(
|
|||
for (int64_t ic = 0; ic < nec01; ++ic) {
|
||||
|
||||
ggml_vec_dot_f16(neb01,
|
||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)),
|
||||
(ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)),
|
||||
S16);
|
||||
(float *) ((char *) dst->data + (ic*nb0 + i1*nb1 + i2*nb2 + i3*nb3)), 0,
|
||||
(ggml_fp16_t *) ((char *) c0->data + ( ic*nbc01 + i2*nbc02 + i3*nbc03)), 0,
|
||||
S16, 0, 1);
|
||||
}
|
||||
|
||||
ggml_vec_add_f32(nec01,
|
||||
|
@ -13866,9 +13922,9 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
const int i1 = ik1;
|
||||
|
||||
ggml_vec_dot_f32(neq0,
|
||||
S + i1,
|
||||
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
||||
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)));
|
||||
S + i1, 0,
|
||||
(float *) ((char *) k->data + (ik1*nbk1 + ik2*nbk2 + ik3*nbk3)), 0,
|
||||
(float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)), 0, 1);
|
||||
}
|
||||
|
||||
// scale
|
||||
|
@ -14013,7 +14069,7 @@ static void ggml_compute_forward_flash_attn_back_f32(
|
|||
|
||||
// S = SM * (S - dot(SM, S))
|
||||
float dot_SM_gradSM = 0;
|
||||
ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, SM, S);
|
||||
ggml_vec_dot_f32 (masked_begin, &dot_SM_gradSM, 0, SM, 0, S, 0, 1);
|
||||
ggml_vec_acc1_f32(M, S, -dot_SM_gradSM);
|
||||
ggml_vec_mul_f32 (masked_begin, S, S, SM);
|
||||
|
||||
|
@ -18382,7 +18438,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|||
}
|
||||
|
||||
// compute the initial gradient in the search direction
|
||||
ggml_vec_dot_f32(nx, &dginit, g, d);
|
||||
ggml_vec_dot_f32(nx, &dginit, 0, g, 0, d, 0, 1);
|
||||
|
||||
// make sure that d points to a descent direction
|
||||
if (0 < dginit) {
|
||||
|
@ -18432,7 +18488,7 @@ static enum ggml_opt_result linesearch_backtracking(
|
|||
return count;
|
||||
}
|
||||
|
||||
ggml_vec_dot_f32(nx, &dg, g, d);
|
||||
ggml_vec_dot_f32(nx, &dg, 0, g, 0, d, 0, 1);
|
||||
|
||||
// check the Wolfe condition
|
||||
if (dg < params->lbfgs.wolfe * dginit) {
|
||||
|
@ -18693,8 +18749,8 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|||
// ys = y^t \cdot s -> 1 / \rho.
|
||||
// yy = y^t \cdot y.
|
||||
//
|
||||
ggml_vec_dot_f32(nx, &ys, &lm_y[end[0]*nx], &lm_s[end[0]*nx]);
|
||||
ggml_vec_dot_f32(nx, &yy, &lm_y[end[0]*nx], &lm_y[end[0]*nx]);
|
||||
ggml_vec_dot_f32(nx, &ys, 0, &lm_y[end[0]*nx], 0, &lm_s[end[0]*nx], 0, 1);
|
||||
ggml_vec_dot_f32(nx, &yy, 0, &lm_y[end[0]*nx], 0, &lm_y[end[0]*nx], 0, 1);
|
||||
|
||||
lm_ys[end[0]] = ys;
|
||||
|
||||
|
@ -18713,7 +18769,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|||
for (int i = 0; i < bound; ++i) {
|
||||
j[0] = (j[0] + m - 1) % m;
|
||||
// \alpha_{j} = \rho_{j} s^{t}_{j} \cdot q_{k+1}
|
||||
ggml_vec_dot_f32(nx, &lm_alpha[j[0]], &lm_s[j[0]*nx], d);
|
||||
ggml_vec_dot_f32(nx, &lm_alpha[j[0]], 0, &lm_s[j[0]*nx], 0, d, 0, 1);
|
||||
lm_alpha[j[0]] /= lm_ys[j[0]];
|
||||
// q_{i} = q_{i+1} - \alpha_{i} y_{i}
|
||||
ggml_vec_mad_f32(nx, d, &lm_y[j[0]*nx], -lm_alpha[j[0]]);
|
||||
|
@ -18723,7 +18779,7 @@ static enum ggml_opt_result ggml_opt_lbfgs(
|
|||
|
||||
for (int i = 0; i < bound; ++i) {
|
||||
// \beta_{j} = \rho_{j} y^t_{j} \cdot \gamma_{i}
|
||||
ggml_vec_dot_f32(nx, &beta, &lm_y[j[0]*nx], d);
|
||||
ggml_vec_dot_f32(nx, &beta, 0, &lm_y[j[0]*nx], 0, d, 0, 1);
|
||||
beta /= lm_ys[j[0]];
|
||||
// \gamma_{i+1} = \gamma_{i} + (\alpha_{j} - \beta_{j}) s_{j}
|
||||
ggml_vec_mad_f32(nx, d, &lm_s[j[0]*nx], lm_alpha[j[0]] - beta);
|
||||
|
@ -20611,4 +20667,12 @@ int ggml_cpu_has_vsx(void) {
|
|||
#endif
|
||||
}
|
||||
|
||||
int ggml_cpu_has_matmul_int8(void) {
|
||||
#if defined(__ARM_FEATURE_MATMUL_INT8)
|
||||
return 1;
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue