metal: yet another MUL mat-vec template

This commit is contained in:
lshzh-ww 2023-08-30 23:46:42 -04:00
parent aa4b7d29a2
commit 35cd10c173
2 changed files with 69 additions and 5 deletions

View file

@ -864,7 +864,10 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
[encoder setThreadgroupMemoryLength:8 * buffer_size_aligned atIndex:0];
// only for k-quants we use threadgroup memory
if (ggml_blck_size(src0t) >= 64){
[encoder setThreadgroupMemoryLength:8 * buffer_size_aligned atIndex:0];
}
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(64, 1, 1)];
} else {
switch (src0->type) {

View file

@ -1320,6 +1320,67 @@ kernel void kernel_mat_mv(device const void * src0,
}
}
template<typename block_q_type, int nr, int nsg, int nl, int n_shift, template<typename, typename, typename> class quant_dri>
kernel void kernel_mat_mv_no_tg_mem(device const void * src0,
device const float * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant int64_t & ne10,
constant int64_t & ne12,
constant int64_t & ne0,
constant int64_t & ne1,
constant uint & gqa,
threadgroup uint * shared_memory[[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/(nl * 16);
const int r0 = tgpig.x;
const int r1 = tgpig.y;
const int im = tgpig.z;
const int ix = tiisg / nl;
const int il = tiisg % nl;
const int first_row = (r0 * nsg) * nr + sgitg;
const uint offset0 = first_row * nb + im/gqa*(nb*ne0) + ix;
const uint offset1 = r1*ne10 + im*ne00*ne1 + ix * (nl * 16) + (il/(n_shift/8))*16*(n_shift/8) + (il%(n_shift/8)) * 8;
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
device const float * y = (device const float *) src1 + offset1;
float4x4 yl; // src1 vector cache
float sumf[nr] = {0.f};
quant_dri<device const uint16_t *, device const block_q_type *, half4x4> dequan_worker;
dequan_worker.init(il);
// each thread in a SIMD group deals with 16 dequantized weights.
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH / nl) {
yl[0] = *((device const float4 *)y);
yl[1] = *((device const float4 *)y + 1);
yl[2] = *((device const float4 *)y + n_shift/4);
yl[3] = *((device const float4 *)y + n_shift/4 + 1);
dequan_worker.inner_product_pre(il, yl);
#pragma unroll(nr)
for (int row = 0; row < nr; row++) {
float sum_temp = 0.f;
dequan_worker.inner_product(x + 2 * nb * row, il, yl, sum_temp);
sumf[row] += sum_temp;
}
x += N_SIMDWIDTH / nl;
y += N_SIMDWIDTH * 16;
}
for (int row = 0; row < nr; ++row) {
const float tot = simd_sum(sumf[row]);
if (tiisg == 0 && first_row + nsg * row < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + nsg * row] = tot;
}
}
}
#define BLOCK_SIZE_M 64 // 8 simdgroup matrices from matrix A
#define BLOCK_SIZE_N 32 // 4 simdgroup matrices from matrix A
#define BLOCK_SIZE_K 32
@ -1487,10 +1548,10 @@ typedef void (mat_mv_t)(device const void *, device const float *, device float
#define N_DST 4
#define N_SIMDGROUP 2
template [[host_name("kernel_mul_mv_f16_f32" )]] kernel mat_mv_t kernel_mat_mv<half4x4, N_DST, N_SIMDGROUP, 1, 8, f16_driver>;
template [[host_name("kernel_mul_mv_q4_0_f32")]] kernel mat_mv_t kernel_mat_mv<block_q4_0, N_DST, N_SIMDGROUP, 2, 16, q4_0_driver>;
template [[host_name("kernel_mul_mv_q4_1_f32")]] kernel mat_mv_t kernel_mat_mv<block_q4_1, N_DST, N_SIMDGROUP, 2, 16, q4_1_driver>;
template [[host_name("kernel_mul_mv_q8_0_f32")]] kernel mat_mv_t kernel_mat_mv<block_q8_0, N_DST, N_SIMDGROUP, 2, 8, q8_0_driver>;
template [[host_name("kernel_mul_mv_f16_f32" )]] kernel mat_mv_t kernel_mat_mv_no_tg_mem<half4x4, N_DST, N_SIMDGROUP, 1, 8, f16_driver>;
template [[host_name("kernel_mul_mv_q4_0_f32")]] kernel mat_mv_t kernel_mat_mv_no_tg_mem<block_q4_0, N_DST, N_SIMDGROUP, 2, 16, q4_0_driver>;
template [[host_name("kernel_mul_mv_q4_1_f32")]] kernel mat_mv_t kernel_mat_mv_no_tg_mem<block_q4_1, N_DST, N_SIMDGROUP, 2, 16, q4_1_driver>;
template [[host_name("kernel_mul_mv_q8_0_f32")]] kernel mat_mv_t kernel_mat_mv_no_tg_mem<block_q8_0, N_DST, N_SIMDGROUP, 2, 8, q8_0_driver>;
template [[host_name("kernel_mul_mv_q2_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q2_K, N_DST, N_SIMDGROUP, QK_NL, 8, q2_K_driver>;
template [[host_name("kernel_mul_mv_q3_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q3_K, N_DST, N_SIMDGROUP, QK_NL, 8, q3_K_driver>;
template [[host_name("kernel_mul_mv_q4_K_f32")]] kernel mat_mv_t kernel_mat_mv<block_q4_K, N_DST, N_SIMDGROUP, QK_NL, 32, q4_K_driver>;