metal : small-batch mat-mul kernels

ggml-ci
This commit is contained in:
Georgi Gerganov 2024-11-11 13:16:12 +02:00
parent 991f8aabee
commit f45c40e31c
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
4 changed files with 402 additions and 24 deletions

View file

@ -192,6 +192,30 @@ typedef struct {
int16_t r3; int16_t r3;
} ggml_metal_kargs_mul_mv; } ggml_metal_kargs_mul_mv;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
int16_t nsg;
int16_t nxpsg;
int16_t r1ptg;
} ggml_metal_kargs_mul_mv_ext;
typedef struct { typedef struct {
int32_t nei0; int32_t nei0;
int32_t nei1; int32_t nei1;

View file

@ -175,6 +175,30 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
@ -699,6 +723,30 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2, mul_mv_ext_f16_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3, mul_mv_ext_f16_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4, mul_mv_ext_f16_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5, mul_mv_ext_f16_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2, mul_mv_ext_q4_0_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3, mul_mv_ext_q4_0_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4, mul_mv_ext_q4_0_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5, mul_mv_ext_q4_0_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2, mul_mv_ext_q4_1_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3, mul_mv_ext_q4_1_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4, mul_mv_ext_q4_1_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5, mul_mv_ext_q4_1_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2, mul_mv_ext_q5_0_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3, mul_mv_ext_q5_0_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4, mul_mv_ext_q5_0_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5, mul_mv_ext_q5_0_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2, mul_mv_ext_q5_1_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3, mul_mv_ext_q5_1_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4, mul_mv_ext_q5_1_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5, mul_mv_ext_q5_1_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2, mul_mv_ext_q8_0_f32_r1_2, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3, mul_mv_ext_q8_0_f32_r1_3, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4, mul_mv_ext_q8_0_f32_r1_4, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5, mul_mv_ext_q8_0_f32_r1_5, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
@ -1930,28 +1978,128 @@ static void ggml_metal_encode_node(
// to the matrix-vector kernel // to the matrix-vector kernel
int ne11_mm_min = 4; int ne11_mm_min = 4;
#if 0 if ((src0t == GGML_TYPE_F16 || // TODO: helper function
// the numbers below are measured on M2 Ultra for 7B and 13B models src0t == GGML_TYPE_Q4_0 ||
// these numbers do not translate to other devices or model sizes src0t == GGML_TYPE_Q4_1 ||
// TODO: need to find a better approach src0t == GGML_TYPE_Q5_0 ||
if ([device.name isEqualToString:@"Apple M2 Ultra"]) { src0t == GGML_TYPE_Q5_1 ||
switch (src0t) { src0t == GGML_TYPE_Q8_0
case GGML_TYPE_F16: ne11_mm_min = 2; break; ) &&
case GGML_TYPE_Q8_0: ne11_mm_min = 7; break; src1t == GGML_TYPE_F32 &&
case GGML_TYPE_Q2_K: ne11_mm_min = 15; break; (ne00%256 == 0) && // TODO: this can be relaxed to 128 for nxpsg == 8
case GGML_TYPE_Q3_K: ne11_mm_min = 7; break; (ne11 >= 2 && ne11 <= 8)) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1: ne11_mm_min = 15; break;
case GGML_TYPE_Q4_K: ne11_mm_min = 11; break;
case GGML_TYPE_Q5_0: // not tested yet
case GGML_TYPE_Q5_1: ne11_mm_min = 13; break; // not tested yet
case GGML_TYPE_Q5_K: ne11_mm_min = 7; break;
case GGML_TYPE_Q6_K: ne11_mm_min = 7; break;
default: ne11_mm_min = 1; break;
}
}
#endif
// TODO: determine the optimal parameters based on grid utilization
const int nsg = 2; // TODO: or 4?
const int nxpsg = ne11 < 3 ? 16 : 8;
const int nypsg = 32/nxpsg;
const int r0ptg = nypsg*nsg;
int r1ptg = 4;
switch (ne11) {
case 2:
r1ptg = 2; break;
case 3:
case 6:
r1ptg = 3; break;
case 4:
case 7:
case 8:
r1ptg = 4; break;
case 5:
r1ptg = 5; break;
};
assert(nxpsg >= 8);
assert(nxpsg%8 == 0);
id<MTLComputePipelineState> pipeline = nil;
switch (src0->type) {
case GGML_TYPE_F16:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_2].pipeline; break;
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_3].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_4].pipeline; break;
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_F16_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
case GGML_TYPE_Q4_0:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_2].pipeline; break;
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_3].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_4].pipeline; break;
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_0_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
case GGML_TYPE_Q4_1:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_2].pipeline; break;
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_3].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_4].pipeline; break;
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q4_1_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
case GGML_TYPE_Q5_0:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_2].pipeline; break;
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_3].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_4].pipeline; break;
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_0_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
case GGML_TYPE_Q5_1:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_2].pipeline; break;
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_3].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_4].pipeline; break;
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q5_1_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
case GGML_TYPE_Q8_0:
switch (r1ptg) {
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_2].pipeline; break;
case 3: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_3].pipeline; break;
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_4].pipeline; break;
case 5: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32_R1_5].pipeline; break;
default: GGML_ABORT("not implemented");
} break;
default: GGML_ABORT("not implemented");
}
ggml_metal_kargs_mul_mv_ext args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
/*.nsg =*/ nsg,
/*.nxpsg =*/ nxpsg,
/*.r1ptg =*/ r1ptg,
};
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
//printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + r0ptg - 1)/r0ptg, (ne11 + r1ptg - 1)/r1ptg, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if ([device supportsFamily:MTLGPUFamilyApple7] && if ([device supportsFamily:MTLGPUFamilyApple7] &&

View file

@ -47,6 +47,11 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
reg = (type4x4)(*src); reg = (type4x4)(*src);
} }
template <typename type4>
void dequantize_f16_t4(device const half4 * src, short il, thread type4 & reg) {
reg = (type4)(*(src + il));
}
#if defined(GGML_METAL_USE_BF16) #if defined(GGML_METAL_USE_BF16)
template <typename type4x4> template <typename type4x4>
void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) { void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & reg) {
@ -73,6 +78,21 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
reg = (type4x4) reg_f; reg = (type4x4) reg_f;
} }
template <typename type4>
void dequantize_q4_0_t4(device const block_q4_0 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
for (int i = 0; i < 2; i++) {
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + md;
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + md;
}
}
template <typename type4x4> template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) { void dequantize_q4_1(device const block_q4_1 * xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2); device const uint16_t * qs = ((device const uint16_t *)xb + 2);
@ -92,6 +112,21 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
reg = (type4x4) reg_f; reg = (type4x4) reg_f;
} }
template <typename type4>
void dequantize_q4_1_t4(device const block_q4_1 * xb, short il, thread type4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const float d1 = (il/4) ? (xb->d / 16.h) : xb->d;
const float d2 = d1 / 256.f;
const float m = xb->m;
const ushort mask0 = (il/4) ? 0x00F0 : 0x000F;
const ushort mask1 = mask0 << 8;
for (int i = 0; i < 2; i++) {
reg[2*i + 0] = d1 * (qs[2*(il%4) + i] & mask0) + m;
reg[2*i + 1] = d2 * (qs[2*(il%4) + i] & mask1) + m;
}
}
template <typename type4x4> template <typename type4x4>
void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 3); device const uint16_t * qs = ((device const uint16_t *)xb + 3);
@ -124,6 +159,14 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
reg = (type4x4) reg_f; reg = (type4x4) reg_f;
} }
template <typename type4>
void dequantize_q5_0_t4(device const block_q5_0 * xb, short il, thread type4 & reg) {
// TODO: implement
float4x4 tmp;
dequantize_q5_0(xb, il/4, tmp);
reg = tmp[il%4];
}
template <typename type4x4> template <typename type4x4>
void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 4); device const uint16_t * qs = ((device const uint16_t *)xb + 4);
@ -156,10 +199,18 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
reg = (type4x4) reg_f; reg = (type4x4) reg_f;
} }
template <typename type4>
void dequantize_q5_1_t4(device const block_q5_1 * xb, short il, thread type4 & reg) {
// TODO: implement
float4x4 tmp;
dequantize_q5_1(xb, il/4, tmp);
reg = tmp[il%4];
}
template <typename type4x4> template <typename type4x4>
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs); device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d; const float d = xb->d;
float4x4 reg_f; float4x4 reg_f;
@ -170,6 +221,16 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
reg = (type4x4) reg_f; reg = (type4x4) reg_f;
} }
template <typename type4>
void dequantize_q8_0_t4(device const block_q8_0 *xb, short il, thread type4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const float d = xb->d;
for (int i = 0; i < 4; i++) {
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
}
}
template <typename type4x4> template <typename type4x4>
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) { void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
const float d = xb->d; const float d = xb->d;
@ -1752,6 +1813,142 @@ kernel void kernel_mul_mv_q8_0_f32(
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
} }
template<short nxpsg, short r1ptg, typename q_t, short bp32, void (*deq_t4)(device const q_t *, short, thread float4 &) >
void kernel_mul_mv_ext_q_f32_impl(
constant ggml_metal_kargs_mul_mv_ext & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short chpt = 4;
//const short nxpsg = (32);
const short nypsg = (32/nxpsg);
const short tx = tiisg%nxpsg;
const short ty = tiisg/nxpsg;
const int i01 = tgpig.x*(nypsg*args.nsg) + nypsg*sgitg + ty;
const int i11 = tgpig.y*r1ptg;
const int i1m = tgpig.z;
const int i12 = i1m%args.ne12;
const int i13 = i1m/args.ne12;
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const q_t * xq = (i01 < args.ne01) ? (device const q_t *) (src0 + offset0) + (tx/8)*bp32 : (device const q_t *) src0;
device const float4 * y4[r1ptg];
for (int ir1 = 0; ir1 < r1ptg; ++ir1) {
y4[ir1] = (i11 + ir1 < args.ne11) ? (device const float4 *) (src1 + offset1 + ir1*args.nb11) + tx : (device const float4 *) src1;
}
float sumf[r1ptg] = { [ 0 ... r1ptg - 1 ] = 0.0f };
for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
#pragma unroll(chpt)
for (short ch = 0; ch < chpt; ++ch) {
float4 lx;
deq_t4(xq + (4*ch*nxpsg)/(32/bp32), tx%8, lx);
#pragma unroll(r1ptg)
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
sumf[ir1] += dot(lx, y4[ir1][ch*nxpsg]);
}
}
xq += ((4*chpt)*nxpsg)/(32/bp32);
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
y4[ir1] += ((4*chpt)*nxpsg)/4;
}
}
for (short ir1 = 0; ir1 < r1ptg; ++ir1) {
if (nxpsg >= 32) {
sumf[ir1] += simd_shuffle_down(sumf[ir1], 16);
}
if (nxpsg >= 16) {
sumf[ir1] += simd_shuffle_down(sumf[ir1], 8);
}
if (nxpsg >= 8) {
sumf[ir1] += simd_shuffle_down(sumf[ir1], 4);
}
if (nxpsg >= 4) {
sumf[ir1] += simd_shuffle_down(sumf[ir1], 2);
}
if (nxpsg >= 2) {
sumf[ir1] += simd_shuffle_down(sumf[ir1], 1);
}
//sumf[ir1] = simd_sum(sumf[ir1]);
}
if (tx == 0) {
for (short ir1 = 0; ir1 < r1ptg && i11 + ir1 < args.ne11; ++ir1) {
device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)(i11 + ir1)*args.ne0;
if (i01 < args.ne01) {
dst_f32[i01] = sumf[ir1];
}
}
}
}
template<short r1ptg, typename q_t, short bp32, void (*deq_t4)(device const q_t *, short, thread float4 &)>
kernel void kernel_mul_mv_ext_q_f32_disp(
constant ggml_metal_kargs_mul_mv_ext & args,
device const char * src0,
device const char * src1,
device char * dst,
uint3 tgpig[[threadgroup_position_in_grid]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
switch (args.nxpsg) {
case 8: kernel_mul_mv_ext_q_f32_impl<8, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
case 16: kernel_mul_mv_ext_q_f32_impl<16, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
case 32: kernel_mul_mv_ext_q_f32_impl<32, r1ptg, q_t, bp32, deq_t4>(args, src0, src1, dst, tgpig, tiisg, sgitg); break;
}
}
typedef decltype(kernel_mul_mv_ext_q_f32_disp<2, block_q8_0, 32, dequantize_q8_0_t4>) mul_mv_ext_q_f32_t;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, half4, 8, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, half4, 8, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, half4, 8, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_f16_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, half4, 8, dequantize_f16_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q4_0, 1, dequantize_q4_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q4_0, 1, dequantize_q4_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q4_0, 1, dequantize_q4_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q4_0, 1, dequantize_q4_0_t4>;
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q4_1, 1, dequantize_q4_1_t4>;
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q4_1, 1, dequantize_q4_1_t4>;
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q4_1, 1, dequantize_q4_1_t4>;
template [[host_name("kernel_mul_mv_ext_q4_1_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q4_1, 1, dequantize_q4_1_t4>;
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q5_0, 1, dequantize_q5_0_t4>;
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q5_0, 1, dequantize_q5_0_t4>;
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q5_0, 1, dequantize_q5_0_t4>;
template [[host_name("kernel_mul_mv_ext_q5_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q5_0, 1, dequantize_q5_0_t4>;
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q5_1, 1, dequantize_q5_1_t4>;
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q5_1, 1, dequantize_q5_1_t4>;
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q5_1, 1, dequantize_q5_1_t4>;
template [[host_name("kernel_mul_mv_ext_q5_1_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q5_1, 1, dequantize_q5_1_t4>;
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_2")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<2, block_q8_0, 1, dequantize_q8_0_t4>;
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_3")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<3, block_q8_0, 1, dequantize_q8_0_t4>;
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_4")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<4, block_q8_0, 1, dequantize_q8_0_t4>;
template [[host_name("kernel_mul_mv_ext_q8_0_f32_r1_5")]] kernel mul_mv_ext_q_f32_t kernel_mul_mv_ext_q_f32_disp<5, block_q8_0, 1, dequantize_q8_0_t4>;
#define N_MV_T_T 4 #define N_MV_T_T 4
template<typename T0, typename T04, typename T1, typename T14, typename args_t> template<typename T0, typename T04, typename T1, typename T14, typename args_t>

View file

@ -3572,6 +3572,15 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 32, 4));
test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4)); test_cases.emplace_back(new test_rwkv_wkv6(GGML_TYPE_F32, 32, 64, 128, 4));
for (int i = 1; i < 64; ++i) {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q5_1, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q8_0, GGML_TYPE_F32, 16, i, 256, { 1, 1}, {1, 1}));
}
#if 1 #if 1
for (ggml_type type_a : base_types) { for (ggml_type type_a : base_types) {
for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (ggml_type type_b : {GGML_TYPE_F32, GGML_TYPE_F16}) {