metal : reduce the kernel launches for ggml_mul_mat_id
This commit is contained in:
parent
7e2006b0c0
commit
8c5b66eeaa
2 changed files with 49 additions and 27 deletions
|
@ -3474,19 +3474,22 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
kernel void kernel_mul_mm_id(
|
||||
device const int32_t * ids,
|
||||
device const uchar * ids,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
device uchar * dst,
|
||||
constant int64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & nb01,
|
||||
constant int64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant int64_t & nb10,
|
||||
constant int64_t & nb11,
|
||||
constant int64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & nb1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
|
@ -3504,10 +3507,16 @@ kernel void kernel_mul_mm_id(
|
|||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const uchar * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
|
||||
kernel_mul_mm_impl<block_q, nl, dequantize_func>(
|
||||
src0[ids[idx]],
|
||||
src1,
|
||||
dst,
|
||||
src0[id],
|
||||
src1 + bid*nb11,
|
||||
(device float *) (dst + bid*nb1),
|
||||
ne00,
|
||||
ne02,
|
||||
nb01,
|
||||
|
@ -3589,19 +3598,22 @@ template [[host_name("kernel_mul_mm_q5_K_f32")]] kernel mat_mm_t kernel_mul_mm<b
|
|||
template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
|
||||
typedef void (mat_mm_id_t)(
|
||||
device const int32_t * ids,
|
||||
device const uchar * ids,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
device uchar * dst,
|
||||
constant int64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & nb01,
|
||||
constant int64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant int64_t & nb10,
|
||||
constant int64_t & nb11,
|
||||
constant int64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant int64_t & nb1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue