Compare commits

...
Sign in to create a new pull request.

3 commits

Author SHA1 Message Date
Georgi Gerganov
3c8a2a83fe
shmem experiments 2024-11-26 15:17:38 +02:00
Georgi Gerganov
dafedd33d2
4x4 -> 4x 2024-11-26 14:54:02 +02:00
Georgi Gerganov
bf3494345e
metal : some mul_mv experiments 2024-11-26 14:48:50 +02:00
3 changed files with 319 additions and 0 deletions

View file

@ -192,6 +192,29 @@ typedef struct {
int16_t r3;
} ggml_metal_kargs_mul_mv;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne10;
int32_t ne11;
int32_t ne12;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ne0;
int32_t ne1;
int16_t r2;
int16_t r3;
int16_t nsg;
int16_t nxpsg;
} ggml_metal_kargs_mul_mv_ext;
typedef struct {
int32_t nei0;
int32_t nei1;

View file

@ -1,6 +1,8 @@
#import "ggml-metal.h"
#import "ggml-impl.h"
#define GGML_COMMON_DECL_C
#import "ggml-common.h"
#import "ggml-backend-impl.h"
#import "ggml-metal-impl.h"
@ -175,6 +177,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32,
GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32,
@ -699,6 +702,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32, mul_mv_ext_q8_0_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, has_simdgroup_reduction);
@ -1951,6 +1955,61 @@ static void ggml_metal_encode_node(
}
#endif
if (src0t == GGML_TYPE_Q8_0 && (ne00%16 == 0) && (ne11 >= 4 && ne11 < 32)) {
//if (false) {
id<MTLComputePipelineState> pipeline = nil;
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MV_EXT_Q8_0_F32].pipeline;
const int nsg = 4;
const int r0pt = 4;
const int r1pt = 1;
//const int nxpsg = ne11 > 1 ? 8 : 32;
const int nxpsg = 32;
const int nypsg = 32/nxpsg;
const int nr0ptg = nypsg*r0pt*nsg;
//GGML_ASSERT(ne00%4096 == 0);
//GGML_ASSERT(ne01%nr0ptg == 0);
//printf("ne01 = %lld, nr0ptg = %d, ne00 = %lld\n", ne01, nr0ptg, ne00);
ggml_metal_kargs_mul_mv_ext args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne10 =*/ ne10,
/*.ne11 =*/ ne11,
/*.ne12 =*/ ne12,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
/*.nsg =*/ nsg,
/*.nxpsg =*/ nxpsg,
};
[encoder setComputePipelineState:pipeline];
[encoder setBytes:&args length:sizeof(args) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
//printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0ptg - 1)/nr0ptg, (ne11 + r1pt - 1)/r1pt, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
[encoder setThreadgroupMemoryLength:2*8192 atIndex:0];
//printf("ne01 = %lld nr0ptg = %d\n", ne01, nr0ptg);
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nr0ptg - 1)/nr0ptg, (ne11 + r1pt - 1)/r1pt, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
} else
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
if ([device supportsFamily:MTLGPUFamilyApple7] &&

View file

@ -170,6 +170,47 @@ void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg
reg = (type4x4) reg_f;
}
template <typename type4>
void dequantize_q4_0x(device const block_q4_0 *xb, short il, thread type4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d;
for (int i = 0; i < 4; i++) {
reg[i] = qs[0];
}
}
template <typename type4>
void dequantize_q8_0x(device const block_q8_0 *xb, short il, thread type4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
const half d = xb->d;
for (int i = 0; i < 4; i++) {
reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
}
}
template <typename type4>
void dequantize_q8_0s(threadgroup const block_q8_0 * xb, short il, thread type4 & reg) {
threadgroup const int8_t * qs = ((threadgroup const int8_t *) xb->qs);
const float d = xb->d;
for (int i = 0; i < 4; i++) {
reg[i] = (qs[4*(il%4) + i + 16*(il/4)]*d);
}
}
//template <typename type4>
//type4 dequantize_q8_0x(device const int8_t * qs, float d, short il) {
// thread type4 reg;
// for (int i = 0; i < 4; i++) {
// reg[i] = (qs[4*(il%4) + i + 16*(il/4)] * d);
// //reg[i] = qs[i/2];
// }
//
// return reg;
//}
template <typename type4x4>
void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg) {
const float d = xb->d;
@ -1752,6 +1793,202 @@ kernel void kernel_mul_mv_q8_0_f32(
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
}
template<short nsg, short nxpsg>
void kernel_mul_mv_ext_q8_0_f32_impl(
constant ggml_metal_kargs_mul_mv_ext & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
const short chpt = 8;
const short r0pt = 4;
//const short nxpsg = (32);
const short nypsg = (32/nxpsg)*r0pt;
const short tx = tiisg%nxpsg;
const short ty = tiisg/nxpsg;
const int i01 = tgpig.x*(nypsg*nsg) + nypsg*sgitg + ty*r0pt;
const int i11 = tgpig.y;
const int i1m = tgpig.z;
const int i12 = i1m%args.ne12;
const int i13 = i1m/args.ne12;
const uint64_t offset0 = i01*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
const uint64_t offset1 = i11*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
device const block_q8_0 * xq[r0pt];
device const block_q8_0 * xq0[r0pt];
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
//xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (chpt*tx)/8 : (device const block_q8_0 *) src0;
//xq[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) + (tx)/8 : (device const block_q8_0 *) src0;
xq0[ir0] = (i01 + ir0 < args.ne01) ? (device const block_q8_0 *) (src0 + offset0 + ir0*args.nb01) : (device const block_q8_0 *) src0;
}
//device const float4 * y4 = (device const float4 *) (src1 + offset1) + chpt*tx;
device const float4 * y4 = (device const float4 *) (src1 + offset1) + tx;
float sumf[r0pt] = { [0 ... r0pt - 1] = 0.0f };
threadgroup block_q8_0 * shmem_q = (threadgroup block_q8_0 *) shmem + (((4*chpt)*nxpsg)/32)*r0pt*sgitg;
for (int iib = 0; (4*chpt)*(iib*nxpsg + tx) < args.ne00; ++iib) {
//shmem_q[(4*chpt)*(tiisg/16 ) + tiisg%16] = xq0[tiisg/16 ][16*iib + tiisg%16];
//shmem_q[(4*chpt)*(tiisg/16 + 2) + tiisg%16] = xq0[tiisg/16 + 2][16*iib + tiisg%16];
//shmem_q[(4*chpt)*(tiisg/16 + 4) + tiisg%16] = xq0[tiisg/16 + 4][16*iib + tiisg%16];
//shmem_q[(4*chpt)*(tiisg/16 + 6) + tiisg%16] = xq0[tiisg/16 + 6][16*iib + tiisg%16];
//shmem_q[(4*chpt)*2 + tiisg] = xq0[2][32*iib + tiisg];
//shmem_q[(4*chpt)*3 + tiisg] = xq0[3][32*iib + tiisg];
shmem_q[((4*chpt))*(tiisg/32 ) + tiisg%32] = xq0[tiisg/32 ][32*iib + tiisg%32];
shmem_q[((4*chpt))*(tiisg/32 + 1) + tiisg%32] = xq0[tiisg/32 + 1][32*iib + tiisg%32];
shmem_q[((4*chpt))*(tiisg/32 + 2) + tiisg%32] = xq0[tiisg/32 + 2][32*iib + tiisg%32];
shmem_q[((4*chpt))*(tiisg/32 + 3) + tiisg%32] = xq0[tiisg/32 + 3][32*iib + tiisg%32];
//if (chpt == 2) {
// shmem_q[(4*chpt)*(tiisg/8 ) + tiisg%8] = xq0[tiisg/8 ][8*iib + tiisg%8];
//}
simdgroup_barrier(mem_flags::mem_threadgroup);
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
//const float d = xq[ir0]->d;
//device const int8_t * qs = ((device const int8_t *) xq[ir0]->qs);
// float d[chpt];
// device const int8_t * qs[chpt];
//#pragma unroll(chpt)
// for (short ch = 0; ch < chpt; ++ch) {
// device const block_q8_0 * xc = xq[ir0] + (ch*nxpsg)/8;
// d[ch] = xc->d;
// qs[ch] = xc->qs;
// }
#pragma unroll(chpt)
for (short ch = 0; ch < chpt; ++ch) {
float4 lx;
//float4 lx = dequantize_q8_0x<float4>(qs, d, (chpt*tx + ch)%8);
//dequantize_q8_0x(xq[ir0] + ch/8, (chpt*tx + ch)%8, lx);
//float4 lx = dequantize_q8_0x<float4>(qs, d, (tx)%8);
//float4 lx = dequantize_q8_0x<float4>(qs[ch], d[ch], (tx)%8);
//dequantize_q8_0x(xq[ir0] + (ch*nxpsg)/8, (tx)%8, lx);
//dequantize_q8_0x(xq0[ir0] + 8*iib + (ch*nxpsg)/8 + tx/8, (tx)%8, lx);
dequantize_q8_0s(shmem_q + (((4*chpt)*nxpsg)/32)*ir0 + (ch*nxpsg)/8 + tx/8, (tx)%8, lx);
//dequantize_q8_0s(shmem_q + 8*ir0 + (ch*nxpsg)/8 + tx/8, (tx)%8, lx);
//sumf[ir0] += dot(lx, y4[ch]);
sumf[ir0] += dot(lx, y4[ch*nxpsg]);
}
}
y4 += ((4*chpt)*nxpsg)/4;
//for (short ir0 = 0; ir0 < r0pt; ++ir0) {
// xq[ir0] += ((4*chpt)*nxpsg)/32;
//}
}
for (short ir0 = 0; ir0 < r0pt; ++ir0) {
if (nxpsg >= 32) {
sumf[ir0] += simd_shuffle_down(sumf[ir0], 16);
}
if (nxpsg >= 16) {
sumf[ir0] += simd_shuffle_down(sumf[ir0], 8);
}
if (nxpsg >= 8) {
sumf[ir0] += simd_shuffle_down(sumf[ir0], 4);
}
if (nxpsg >= 4) {
sumf[ir0] += simd_shuffle_down(sumf[ir0], 2);
}
if (nxpsg >= 2) {
sumf[ir0] += simd_shuffle_down(sumf[ir0], 1);
}
//sumf[ir0] = simd_sum(sumf[ir0]);
}
device float * dst_f32 = (device float *) dst + (uint64_t)i1m*args.ne0*args.ne1 + (uint64_t)i11*args.ne0;
if (tx == 0) {
for (short ir0 = 0; ir0 < r0pt && i01 + ir0 < args.ne01; ++ir0) {
dst_f32[i01 + ir0] = sumf[ir0];
}
}
}
[[host_name("kernel_mul_mv_ext_q8_0_f32")]]
kernel void kernel_mul_mv_ext_q8_0_f32(
constant ggml_metal_kargs_mul_mv_ext & args,
device const char * src0,
device const char * src1,
device char * dst,
threadgroup char * shmem [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) {
switch (args.nsg) {
case 1:
switch (args.nxpsg) {
case 4: kernel_mul_mv_ext_q8_0_f32_impl<1, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 8: kernel_mul_mv_ext_q8_0_f32_impl<1, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 16: kernel_mul_mv_ext_q8_0_f32_impl<1, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 32: kernel_mul_mv_ext_q8_0_f32_impl<1, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
} break;
case 2:
switch (args.nxpsg) {
case 4: kernel_mul_mv_ext_q8_0_f32_impl<2, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 8: kernel_mul_mv_ext_q8_0_f32_impl<2, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 16: kernel_mul_mv_ext_q8_0_f32_impl<2, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 32: kernel_mul_mv_ext_q8_0_f32_impl<2, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
} break;
case 4:
switch (args.nxpsg) {
case 4: kernel_mul_mv_ext_q8_0_f32_impl<4, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 8: kernel_mul_mv_ext_q8_0_f32_impl<4, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 16: kernel_mul_mv_ext_q8_0_f32_impl<4, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 32: kernel_mul_mv_ext_q8_0_f32_impl<4, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
} break;
case 6:
switch (args.nxpsg) {
case 4: kernel_mul_mv_ext_q8_0_f32_impl<6, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 8: kernel_mul_mv_ext_q8_0_f32_impl<6, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 16: kernel_mul_mv_ext_q8_0_f32_impl<6, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 32: kernel_mul_mv_ext_q8_0_f32_impl<6, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
} break;
case 8:
switch (args.nxpsg) {
case 4: kernel_mul_mv_ext_q8_0_f32_impl<8, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 8: kernel_mul_mv_ext_q8_0_f32_impl<8, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 16: kernel_mul_mv_ext_q8_0_f32_impl<8, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 32: kernel_mul_mv_ext_q8_0_f32_impl<8, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
} break;
case 12:
switch (args.nxpsg) {
case 4: kernel_mul_mv_ext_q8_0_f32_impl<12, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 8: kernel_mul_mv_ext_q8_0_f32_impl<12, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 16: kernel_mul_mv_ext_q8_0_f32_impl<12, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 32: kernel_mul_mv_ext_q8_0_f32_impl<12, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
} break;
case 16:
switch (args.nxpsg) {
case 4: kernel_mul_mv_ext_q8_0_f32_impl<16, 4> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 8: kernel_mul_mv_ext_q8_0_f32_impl<16, 8> (args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 16: kernel_mul_mv_ext_q8_0_f32_impl<16, 16>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
case 32: kernel_mul_mv_ext_q8_0_f32_impl<16, 32>(args, src0, src1, dst, shmem, tgpig, ntg, tiisg, sgitg); break;
} break;
}
}
#define N_MV_T_T 4
template<typename T0, typename T04, typename T1, typename T14, typename args_t>