metal : small-batch mat-mul kernels
ggml-ci
This commit is contained in:
parent
991f8aabee
commit
f45c40e31c
4 changed files with 402 additions and 24 deletions
|
@ -192,6 +192,30 @@ typedef struct {
|
||||||
int16_t r3;
|
int16_t r3;
|
||||||
} ggml_metal_kargs_mul_mv;
|
} ggml_metal_kargs_mul_mv;
|
||||||
|
|
||||||
|
typedef struct {
|
||||||
|
int32_t ne00;
|
||||||
|
int32_t ne01;
|
||||||
|
int32_t ne02;
|
||||||
|
uint64_t nb00;
|
||||||
|
uint64_t nb01;
|
||||||
|
uint64_t nb02;
|
||||||
|
uint64_t nb03;
|
||||||
|
int32_t ne10;
|
||||||
|
int32_t ne11;
|
||||||
|
int32_t ne12;
|
||||||
|
uint64_t nb10;
|
||||||
|
uint64_t nb11;
|
||||||
|
uint64_t nb12;
|
||||||
|
uint64_t nb13;
|
||||||
|
int32_t ne0;
|
||||||
|
int32_t ne1;
|
||||||
|
int16_t r2;
|
||||||
|
int16_t r3;
|
||||||
|
int16_t nsg;
|
||||||
|
int16_t nxpsg;
|
||||||
|
int16_t r1ptg;
|
||||||
|
} ggml_metal_kargs_mul_mv_ext;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
int32_t nei0;
|
int32_t nei0;
|
||||||
int32_t nei1;
|
int32_t nei1;
|
||||||
|
|
|
@ -175,6 +175,30 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
|
||||||
|
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
|
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
|
||||||
|
@ -699,6 +723,30 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
|
||||||
|
@ -1930,28 +1978,128 @@ static void ggml_metal_encode_node(
|
||||||
// to the matrix-vector kernel
|
// to the matrix-vector kernel
|
||||||
int ne11_mm_min = 4;
|
int ne11_mm_min = 4;
|
||||||
|
|
||||||
#if 0
|
if ((src0t == GGML_TYPE_F16 || // TODO: helper function
|
||||||
// the numbers below are measured on M2 Ultra for 7B and 13B models
|
src0t == GGML_TYPE_Q4_0 ||
|
||||||
// these numbers do not translate to other devices or model sizes
|
src0t == GGML_TYPE_Q4_1 ||
|
||||||
// TODO: need to find a better approach
|
src0t == GGML_TYPE_Q5_0 ||
|
||||||
if ([device.name isEqualToString:@"Apple M2 Ultra"]) {
|
src0t == GGML_TYPE_Q5_1 ||
|
||||||
switch (src0t) {
|
src0t == GGML_TYPE_Q8_0
|
||||||
case GGML_TYPE_F16: ne11_mm_min = 2; break;
|
) &&
|
||||||
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break;
|
src1t == GGML_TYPE_F32 &&
|
||||||
case GGML_TYPE_Q2_K: ne11_mm_min = 15; break;
|
(ne00%256 == 0) && // TODO: this can be relaxed to 128 for nxpsg == 8
|
||||||
case GGML_TYPE_Q3_K: ne11_mm_min = 7; break;
|
(ne11 >= 2 && ne11 <= 8)) {
|
||||||
case GGML_TYPE_Q4_0:
|
|
||||||
case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
|
|
||||||
case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
|
|
||||||
case GGML_TYPE_Q5_0: // not tested yet
|
|
||||||
case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
|
|
||||||
case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
|
|
||||||
case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
|
|
||||||
default: ne11_mm_min = 1; break;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
#endif
|
|
||||||
|
|
||||||
|
// TODO: determine the optimal parameters based on grid utilization
|
||||||
|
const int nsg = 2; // TODO: or 4?
|
||||||
|
const int nxpsg = ne11 < 3 ? 16 : 8;
|
||||||
|
const int nypsg = 32/nxpsg;
|
||||||
|
const int r0ptg = nypsg*nsg;
|
||||||
|
int r1ptg = 4;
|
||||||
|
|
||||||
|
switch (ne11) {
|
||||||
|
case 2:
|
||||||
|
r1ptg = 2; break;
|
||||||
|
case 3:
|
||||||
|
case 6:
|
||||||
|
r1ptg = 3; break;
|
||||||
|
case 4:
|
||||||
|
case 7:
|
||||||
|
case 8:
|
||||||
|
r1ptg = 4; break;
|
||||||
|
case 5:
|
||||||
|
r1ptg = 5; break;
|
||||||
|
};
|
||||||
|
|
||||||
|
assert(nxpsg >= 8);
|
||||||
|
assert(nxpsg%8 == 0);
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
switch (src0->type) {
|
||||||
|
case GGML_TYPE_F16:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q4_0:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q4_1:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q5_0:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q5_1:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
|
case GGML_TYPE_Q8_0:
|
||||||
|
switch (r1ptg) {
|
||||||
|
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break;
|
||||||
|
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break;
|
||||||
|
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break;
|
||||||
|
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
} break;
|
||||||
|
default: GGML_ABORT("not implemented");
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_metal_kargs_mul_mv_ext args = {
|
||||||
|
/*.ne00 =*/ ne00,
|
||||||
|
/*.ne01 =*/ ne01,
|
||||||
|
/*.ne02 =*/ ne02,
|
||||||
|
/*.nb00 =*/ nb00,
|
||||||
|
/*.nb01 =*/ nb01,
|
||||||
|
/*.nb02 =*/ nb02,
|
||||||
|
/*.nb03 =*/ nb03,
|
||||||
|
/*.ne10 =*/ ne10,
|
||||||
|
/*.ne11 =*/ ne11,
|
||||||
|
/*.ne12 =*/ ne12,
|
||||||
|
/*.nb10 =*/ nb10,
|
||||||
|
/*.nb11 =*/ nb11,
|
||||||
|
/*.nb12 =*/ nb12,
|
||||||
|
/*.nb13 =*/ nb13,
|
||||||
|
/*.ne0 =*/ ne0,
|
||||||
|
/*.ne1 =*/ ne1,
|
||||||
|
/*.r2 =*/ r2,
|
||||||
|
/*.r3 =*/ r3,
|
||||||
|
/*.nsg =*/ nsg,
|
||||||
|
/*.nxpsg =*/ nxpsg,
|
||||||
|
/*.r1ptg =*/ r1ptg,
|
||||||
|
};
|
||||||
|
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||||
|
|
||||||
|
//printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
|
} else
|
||||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||||
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
if ([device supportsFamily:MTLGPUFamilyApple7] &&
|
||||||
|
|
|
@ -47,6 +47,11 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
||||||
reg = (type4x4)(*src);
|
reg = (type4x4)(*src);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
|
||||||
|
reg = (type4)(*(src + il));
|
||||||
|
}
|
||||||
|
|
||||||
#if defined(GGML_METAL_USE_BF16)
|
#if defined(GGML_METAL_USE_BF16)
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
|
||||||
|
@ -73,6 +78,21 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
|
||||||
reg = (type4x4) reg_f;
|
reg = (type4x4) reg_f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
|
||||||
|
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||||
|
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
|
||||||
|
const float d2 = d1 / 256.f;
|
||||||
|
const float md = -8.h * xb->d;
|
||||||
|
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
|
||||||
|
const ushort mask1 = mask0 << 8;
|
||||||
|
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
|
||||||
|
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
|
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
|
||||||
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
||||||
|
@ -92,6 +112,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
|
||||||
reg = (type4x4) reg_f;
|
reg = (type4x4) reg_f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
|
||||||
|
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
||||||
|
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
|
||||||
|
const float d2 = d1 / 256.f;
|
||||||
|
const float m = xb->m;
|
||||||
|
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
|
||||||
|
const ushort mask1 = mask0 << 8;
|
||||||
|
|
||||||
|
for (int i = 0; i < 2; i++) {
|
||||||
|
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
|
||||||
|
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
|
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
|
||||||
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
device const uint16_t * qs = ((device const uint16_t *)xb + 3);
|
||||||
|
@ -124,6 +159,14 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
|
||||||
reg = (type4x4) reg_f;
|
reg = (type4x4) reg_f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
|
||||||
|
// TODO: implement
|
||||||
|
float4x4 tmp;
|
||||||
|
dequantize_q5_0(xb, il/4, tmp);
|
||||||
|
reg = tmp[il%4];
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
|
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
|
||||||
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
device const uint16_t * qs = ((device const uint16_t *)xb + 4);
|
||||||
|
@ -156,10 +199,18 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
|
||||||
reg = (type4x4) reg_f;
|
reg = (type4x4) reg_f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
|
||||||
|
// TODO: implement
|
||||||
|
float4x4 tmp;
|
||||||
|
dequantize_q5_1(xb, il/4, tmp);
|
||||||
|
reg = tmp[il%4];
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
|
||||||
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||||
const half d = xb->d;
|
const float d = xb->d;
|
||||||
|
|
||||||
float4x4 reg_f;
|
float4x4 reg_f;
|
||||||
|
|
||||||
|
@ -170,6 +221,16 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
|
||||||
reg = (type4x4) reg_f;
|
reg = (type4x4) reg_f;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename type4>
|
||||||
|
void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
|
||||||
|
device const int8_t * qs = ((device const int8_t *)xb->qs);
|
||||||
|
const float d = xb->d;
|
||||||
|
|
||||||
|
for (int i = 0; i < 4; i++) {
|
||||||
|
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
template <typename type4x4>
|
template <typename type4x4>
|
||||||
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
|
||||||
const float d = xb->d;
|
const float d = xb->d;
|
||||||
|
@ -1752,6 +1813,142 @@ kernel void kernel_mul_mv_q8_0_f32(
|
||||||
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template<short nxpsg, short r1ptg, typename q_t, short bp32, void (*deq_t4)(device const q_t *, short, thread float4 &) >
|
||||||
|
void kernel_mul_mv_ext_q_f32_impl(
|
||||||
|
constant ggml_metal_kargs_mul_mv_ext & args,
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
const short chpt = 4;
|
||||||
|
|
||||||
|
//const short nxpsg = (32);
|
||||||
|
const short nypsg = (32/nxpsg);
|
||||||
|
|
||||||
|
const short tx = tiisg%nxpsg;
|
||||||
|
const short ty = tiisg/nxpsg;
|
||||||
|
|
||||||
|
const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
|
||||||
|
const int i11 = tgpig.y*r1ptg;
|
||||||
|
const int i1m = tgpig.z;
|
||||||
|
|
||||||
|
const int i12 = i1m%args.ne12;
|
||||||
|
const int i13 = i1m/args.ne12;
|
||||||
|
|
||||||
|
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||||
|
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||||
|
|
||||||
|
device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + (tx/8)*bp32 : (device const q_t *) src0;
|
||||||
|
|
||||||
|
device const float4 * y4[r1ptg];
|
||||||
|
|
||||||
|
for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||||
|
y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
|
||||||
|
}
|
||||||
|
|
||||||
|
float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
|
||||||
|
|
||||||
|
for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
|
||||||
|
#pragma unroll(chpt)
|
||||||
|
for (short ch = 0; ch < chpt; ++ch) {
|
||||||
|
float4 lx;
|
||||||
|
|
||||||
|
deq_t4(xq + (4*ch*nxpsg)/(32/bp32), tx%8, lx);
|
||||||
|
|
||||||
|
#pragma unroll(r1ptg)
|
||||||
|
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||||
|
sumf[ir1] += dot(lx, y4[ir1][ch*nxpsg]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
xq += ((4*chpt)*nxpsg)/(32/bp32);
|
||||||
|
|
||||||
|
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||||
|
y4[ir1] += ((4*chpt)*nxpsg)/4;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
|
||||||
|
if (nxpsg >= 32) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 16) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 8) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 4) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
|
||||||
|
}
|
||||||
|
if (nxpsg >= 2) {
|
||||||
|
sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
//sumf[ir1] = simd_sum(sumf[ir1]);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (tx == 0) {
|
||||||
|
for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
|
||||||
|
device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
|
||||||
|
|
||||||
|
if (i01 < args.ne01) {
|
||||||
|
dst_f32[i01] = sumf[ir1];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<short r1ptg, typename q_t, short bp32, void (*deq_t4)(device const q_t *, short, thread float4 &)>
|
||||||
|
kernel void kernel_mul_mv_ext_q_f32_disp(
|
||||||
|
constant ggml_metal_kargs_mul_mv_ext & args,
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device char * dst,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
switch (args.nxpsg) {
|
||||||
|
case 8: kernel_mul_mv_ext_q_f32_impl<8, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||||
|
case 16: kernel_mul_mv_ext_q_f32_impl<16, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||||
|
case 32: kernel_mul_mv_ext_q_f32_impl<32, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
typedef decltype(kernel_mul_mv_ext_q_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q_f32_t;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, half4, 8, dequantize_f16_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, half4, 8, dequantize_f16_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, half4, 8, dequantize_f16_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, half4, 8, dequantize_f16_t4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q4_0, 1, dequantize_q4_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q4_0, 1, dequantize_q4_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q4_0, 1, dequantize_q4_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q4_0, 1, dequantize_q4_0_t4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q4_1, 1, dequantize_q4_1_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q4_1, 1, dequantize_q4_1_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q4_1, 1, dequantize_q4_1_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q4_1, 1, dequantize_q4_1_t4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q5_0, 1, dequantize_q5_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q5_0, 1, dequantize_q5_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q5_0, 1, dequantize_q5_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q5_0, 1, dequantize_q5_0_t4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q5_1, 1, dequantize_q5_1_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q5_1, 1, dequantize_q5_1_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q5_1, 1, dequantize_q5_1_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q5_1, 1, dequantize_q5_1_t4>;
|
||||||
|
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q8_0, 1, dequantize_q8_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q8_0, 1, dequantize_q8_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q8_0, 1, dequantize_q8_0_t4>;
|
||||||
|
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q8_0, 1, dequantize_q8_0_t4>;
|
||||||
|
|
||||||
#define N_MV_T_T 4
|
#define N_MV_T_T 4
|
||||||
|
|
||||||
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
|
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
|
||||||
|
|
|
@ -3572,6 +3572,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
|
||||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
|
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
|
||||||
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
|
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
|
||||||
|
|
||||||
|
for (int i = 1; i < 64; ++i) {
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
|
||||||
|
}
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
for (ggml_type type_a : base_types) {
|
for (ggml_type type_a : base_types) {
|
||||||
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue