From 1198ae7749c82925e2b42e17c3f6e857c43eda7a Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Nov 2024 15:28:55 +0200 Subject: [PATCH 01/21] metal : add kernel arg structs (wip) --- ggml/src/ggml-common.h | 30 ++++++++ ggml/src/ggml-metal.m | 68 ++++++++++-------- ggml/src/ggml-metal.metal | 143 +++++++++++++------------------------- 3 files changed, 116 insertions(+), 125 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 050161393..f3d12df51 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -418,6 +418,36 @@ typedef struct { } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); +#if defined(GGML_COMMON_DECL_METAL_KARGS) +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t n_past; + int32_t n_dims; + int32_t n_ctx_orig; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; +} ggml_metal_kargs_rope; +#endif + #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 04ec5117f..96d65b1d7 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -3,6 +3,10 @@ #import "ggml-impl.h" #import "ggml-backend-impl.h" +#define GGML_COMMON_DECL_C +#define GGML_COMMON_DECL_METAL_KARGS +#include "ggml-common.h" + #import #import @@ -2702,40 +2706,44 @@ static void ggml_metal_encode_node( }; } + ggml_metal_kargs_rope args = { + .ne00 = ne00, + .ne01 = ne01, + .ne02 = ne02, + .ne03 = ne03, + .nb00 = nb00, + .nb01 = nb01, + .nb02 = nb02, + .nb03 = nb03, + .ne0 = ne0, + .ne1 = ne1, + .ne2 = ne2, + .ne3 = ne3, + .nb0 = nb0, + .nb1 = nb1, + .nb2 = nb2, + .nb3 = nb3, + .n_past = n_past, + .n_dims = n_dims, + .n_ctx_orig = n_ctx_orig, + .freq_base = freq_base, + .freq_scale = freq_scale, + .ext_factor = ext_factor, + .attn_factor = attn_factor, + .beta_fast = beta_fast, + .beta_slow = beta_slow, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBytes:&args length:sizeof(args) atIndex:4]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index e8b71a9f8..85702fce7 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1,4 +1,5 @@ #define GGML_COMMON_DECL_METAL +#define GGML_COMMON_DECL_METAL_KARGS #define GGML_COMMON_IMPL_METAL #include "ggml-common.h" @@ -2229,7 +2230,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static void rope_yarn( - float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale, thread float * cos_theta, thread float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; @@ -2261,65 +2262,41 @@ static void rope_yarn_corr_dims( template kernel void kernel_rope_norm( - device const void * src0, - device const int32_t * src1, - device const float * src2, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & n_ctx_orig, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + constant ggml_metal_kargs_rope & args, uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], + uint3 tptg [[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); - device const int32_t * pos = src1; + device const int32_t * pos = (device const int32_t *) src1; const float theta_base = (float) pos[i2]; - const float inv_ndims = -1.f/n_dims; + const float inv_ndims = -1.f/args.n_dims; float cos_theta; float sin_theta; - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - if (i0 < n_dims) { - const int64_t ic = i0/2; + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; - const float theta = theta_base * pow(freq_base, inv_ndims*i0); + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; - rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); const float x0 = src[0]; const float x1 = src[1]; @@ -2327,8 +2304,8 @@ kernel void kernel_rope_norm( dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; } else { - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); dst_data[0] = src[0]; dst_data[1] = src[1]; @@ -2338,74 +2315,50 @@ kernel void kernel_rope_norm( template kernel void kernel_rope_neox( - device const void * src0, - device const int32_t * src1, - device const float * src2, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & n_ctx_orig, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + constant ggml_metal_kargs_rope & args, uint tiitg[[thread_index_in_threadgroup]], uint3 tptg[[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); - device const int32_t * pos = src1; + device const int32_t * pos = (device const int32_t *) src1; const float theta_base = (float) pos[i2]; - const float inv_ndims = -1.f/n_dims; + const float inv_ndims = -1.f/args.n_dims; float cos_theta; float sin_theta; - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - if (i0 < n_dims) { - const int64_t ic = i0/2; + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; - const float theta = theta_base * pow(freq_base, inv_ndims*i0); + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; - rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); const float x0 = src[0]; - const float x1 = src[n_dims/2]; + const float x1 = src[args.n_dims/2]; - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; } else { - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); dst_data[0] = src[0]; dst_data[1] = src[1]; From 9e07bcc06ebcbbd19f426aeec83ee3e38fb09fbe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Nov 2024 16:09:31 +0200 Subject: [PATCH 02/21] metal : fattn args ggml-ci --- ggml/src/ggml-common.h | 24 ++++++ ggml/src/ggml-metal.m | 58 +++++++------- ggml/src/ggml-metal.metal | 156 ++++++++++++++------------------------ 3 files changed, 113 insertions(+), 125 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index f3d12df51..6529d71eb 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -446,6 +446,30 @@ typedef struct { float beta_fast; float beta_slow; } ggml_metal_kargs_rope; + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb_12_1; + uint64_t nb_12_2; + uint64_t nb_12_3; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext; #endif #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 96d65b1d7..379358b10 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -3228,37 +3228,41 @@ static void ggml_metal_encode_node( } } + ggml_metal_kargs_flash_attn_ext args = { + .ne01 = ne01, + .ne02 = ne02, + .ne03 = ne03, + .nb01 = nb01, + .nb02 = nb02, + .nb03 = nb03, + .ne11 = ne11, + .ne_12_2 = ne12, + .ne_12_3 = ne13, + .nb_12_1 = nb11, + .nb_12_2 = nb12, + .nb_12_3 = nb13, + .nb31 = nb31, + .ne1 = ne1, + .ne2 = ne2, + .scale = scale, + .max_bias = max_bias, + .m0 = m0, + .m1 = m1, + .n_head_log2 = n_head_log2, + .logit_softcap = logit_softcap, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19]; - [encoder setBytes:&scale length:sizeof( float) atIndex:20]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:21]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:22]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:23]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24]; - [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&args length:sizeof(args) atIndex:5]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 85702fce7..0466f2acd 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2267,9 +2267,9 @@ kernel void kernel_rope_norm( device const char * src2, device char * dst, constant ggml_metal_kargs_rope & args, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { const int i3 = tgpig[2]; const int i2 = tgpig[1]; const int i1 = tgpig[0]; @@ -2320,9 +2320,9 @@ kernel void kernel_rope_neox( device const char * src2, device char * dst, constant ggml_metal_kargs_rope & args, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { const int i3 = tgpig[2]; const int i2 = tgpig[1]; const int i1 = tgpig[0]; @@ -2761,32 +2761,12 @@ kernel void kernel_flash_attn_ext( device const char * v, device const char * mask, device float * dst, - constant int32_t & ne01, - constant int32_t & ne02, - constant int32_t & ne03, - constant uint32_t & nb01, - constant uint32_t & nb02, - constant uint32_t & nb03, - constant int32_t & ne11, - constant int32_t & ne_12_2, // assume K and V are same shape - constant int32_t & ne_12_3, - constant uint32_t & nb_12_1, - constant uint32_t & nb_12_2, - constant uint32_t & nb_12_3, - constant uint32_t & nb31, - constant int32_t & ne1, - constant int32_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint16_t & n_head_log2, - constant float & logit_softcap, + constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]], - ushort3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups const int iq3 = tgpig[2]; @@ -2819,10 +2799,10 @@ kernel void kernel_flash_attn_ext( // load heads from Q to shared memory for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); for (short i = tiisg; i < D4; i += NW) { - if (iq1 + j < ne01) { + if (iq1 + j < args.ne01) { sq4[j*D4 + i] = (q4_t) q4[i]; } else { sq4[j*D4 + i] = (q4_t) 0.0f; @@ -2855,11 +2835,11 @@ kernel void kernel_flash_attn_ext( const short ty = tiisg/4; // broadcast kv - //const short rk2 = ne02/ne12; - //const short rk3 = ne03/ne13; + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; - const short ikv2 = iq2/(ne02/ne_12_2); - const short ikv3 = iq3/(ne03/ne_12_3); + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); // load the queries from shared memory into local memory q8x8_t mq[D8]; @@ -2873,20 +2853,20 @@ kernel void kernel_flash_attn_ext( half slope = 1.0f; // ALiBi - if (max_bias > 0.0f) { + if (args.max_bias > 0.0f) { const short h = iq2; - const half base = h < n_head_log2 ? m0 : m1; - const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exph); } // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { const int ic = ic0 + C*sgitg; - if (ic >= ne11) { + if (ic >= args.ne11) { break; } @@ -2897,7 +2877,7 @@ kernel void kernel_flash_attn_ext( // load the mask in shared memory #pragma unroll(Q) for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31); + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); const half m = pm[ic + tiisg]; @@ -2920,18 +2900,18 @@ kernel void kernel_flash_attn_ext( // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); #pragma unroll(D8) for (short i = 0; i < D8; ++i) { k8x8_t mk; - simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } } else { for (short ii = 0; ii < D16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); if (D16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks @@ -2990,10 +2970,10 @@ kernel void kernel_flash_attn_ext( const half m = M[j]; // scale and apply the logitcap / mask - half s = ss[j*TS + tiisg]*scale; + half s = ss[j*TS + tiisg]*args.scale; - if (logit_softcap != 0.0f) { - s = logit_softcap*precise::tanh(s); + if (args.logit_softcap != 0.0f) { + s = args.logit_softcap*precise::tanh(s); } // mqk = mqk + mask*slope @@ -3035,18 +3015,18 @@ kernel void kernel_flash_attn_ext( if (is_same::value) { // we can read directly from global memory - device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); #pragma unroll(D8) for (short i = 0; i < D8; ++i) { v8x8_t mv; - simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 + simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); } } else { for (short ii = 0; ii < D16; ii += 4) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); if (D16%4 == 0) { // no need for bound checks @@ -3175,11 +3155,11 @@ kernel void kernel_flash_attn_ext( // final rescale with 1/S and store to global memory if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < ne01; ++j) { + for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; } } } @@ -3276,33 +3256,13 @@ kernel void kernel_flash_attn_ext_vec( device const char * v, device const char * mask, device float * dst, - constant int32_t & ne01, - constant int32_t & ne02, - constant int32_t & ne03, - constant uint32_t & nb01, - constant uint32_t & nb02, - constant uint32_t & nb03, - constant int32_t & ne11, - constant int32_t & ne_12_2, // assume K and V are same shape - constant int32_t & ne_12_3, - constant uint32_t & nb_12_1, - constant uint32_t & nb_12_2, - constant uint32_t & nb_12_3, - constant uint32_t & nb31, - constant int32_t & ne1, - constant int32_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint16_t & n_head_log2, - constant float & logit_softcap, + constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]], - ushort3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups const int iq3 = tgpig[2]; @@ -3329,10 +3289,10 @@ kernel void kernel_flash_attn_ext_vec( o4x4_t lo[D16/NL]; // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); for (short i = tiisg; i < D4; i += NW) { - if (iq1 < ne01) { + if (iq1 < args.ne01) { sq4[i] = (q4_t) q4[i]; } else { sq4[i] = (q4_t) 0.0f; @@ -3360,11 +3320,11 @@ kernel void kernel_flash_attn_ext_vec( const short ty = tiisg/NL; // broadcast kv - //const short rk2 = ne02/ne12; - //const short rk3 = ne03/ne13; + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; - const short ikv2 = iq2/(ne02/ne_12_2); - const short ikv3 = iq3/(ne03/ne_12_3); + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); // load the queries from shared memory into local memory q4x4_t mq[D16/NL]; @@ -3377,25 +3337,25 @@ kernel void kernel_flash_attn_ext_vec( const bool has_mask = mask != q; // pointer to the mask - device const half * pm = (device const half *) (mask + iq1*nb31); + device const half * pm = (device const half *) (mask + iq1*args.nb31); half slope = 1.0f; // ALiBi - if (max_bias > 0.0f) { + if (args.max_bias > 0.0f) { const short h = iq2; - const half base = h < n_head_log2 ? m0 : m1; - const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exph); } // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { const int ic = ic0 + C*sgitg; - if (ic >= ne11) { + if (ic >= args.ne11) { break; } @@ -3409,7 +3369,7 @@ kernel void kernel_flash_attn_ext_vec( for (short cc = 0; cc < C/4; ++cc) { qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 }; - device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); #pragma unroll(D16/NL) for (short ii = 0; ii < D16; ii += NL) { @@ -3445,10 +3405,10 @@ kernel void kernel_flash_attn_ext_vec( // mqk = mqk*scale + mask*slope if (tx == 0) { - mqk *= scale; + mqk *= args.scale; - if (logit_softcap != 0.0f) { - mqk = logit_softcap*precise::tanh(mqk); + if (args.logit_softcap != 0.0f) { + mqk = args.logit_softcap*precise::tanh(mqk); } mqk += sm[4*cc + ty]*slope; @@ -3487,7 +3447,7 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { for (short cc = 0; cc < C/4; ++cc) { - device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); const s4x4_t ms(ss[4*cc + ty]); @@ -3592,7 +3552,7 @@ kernel void kernel_flash_attn_ext_vec( const float S = ss[0]; for (short i = tiisg; i < D16; i += NW) { - dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1)*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; } } } From 4fd6fc5ab874023fd28fc94bb52fd2c6a467abe0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Nov 2024 16:39:36 +0200 Subject: [PATCH 03/21] metal : cont + avoid potential int overflow [no ci] --- ggml/src/ggml-metal.metal | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 0466f2acd..a5fbed869 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -2756,11 +2756,11 @@ template< short KV = 8, // key/value processed per each simdgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -3159,7 +3159,7 @@ kernel void kernel_flash_attn_ext( const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; } } } @@ -3251,11 +3251,11 @@ template< short Q = 1, // queries per threadgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -3552,7 +3552,7 @@ kernel void kernel_flash_attn_ext_vec( const float S = ss[0]; for (short i = tiisg; i < D16; i += NW) { - dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (iq1)*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; } } } From 7670809ad47fd750af8dd5b5ab42e33c8f35469e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Nov 2024 17:54:40 +0200 Subject: [PATCH 04/21] metal : mul mat struct (wip) --- ggml/src/ggml-common.h | 17 ++++++ ggml/src/ggml-metal.m | 125 ++++++++++++++++++++------------------ ggml/src/ggml-metal.metal | 73 ++++++++++------------ 3 files changed, 113 insertions(+), 102 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 6529d71eb..b03e1a6bf 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -470,6 +470,23 @@ typedef struct { uint16_t n_head_log2; float logit_softcap; } ggml_metal_kargs_flash_attn_ext; + +typedef struct { + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mm; #endif #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 379358b10..46ae1574b 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1959,24 +1959,29 @@ static void ggml_metal_encode_node( default: GGML_ABORT("MUL MAT-MAT not implemented"); } + ggml_metal_kargs_mul_mm args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:15]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:16]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { @@ -2707,31 +2712,31 @@ static void ggml_metal_encode_node( } ggml_metal_kargs_rope args = { - .ne00 = ne00, - .ne01 = ne01, - .ne02 = ne02, - .ne03 = ne03, - .nb00 = nb00, - .nb01 = nb01, - .nb02 = nb02, - .nb03 = nb03, - .ne0 = ne0, - .ne1 = ne1, - .ne2 = ne2, - .ne3 = ne3, - .nb0 = nb0, - .nb1 = nb1, - .nb2 = nb2, - .nb3 = nb3, - .n_past = n_past, - .n_dims = n_dims, - .n_ctx_orig = n_ctx_orig, - .freq_base = freq_base, - .freq_scale = freq_scale, - .ext_factor = ext_factor, - .attn_factor = attn_factor, - .beta_fast = beta_fast, - .beta_slow = beta_slow, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.n_past =*/ n_past, + /*.n_dims =*/ n_dims, + /*.n_ctx_orig =*/ n_ctx_orig, + /*.freq_base =*/ freq_base, + /*.freq_scale =*/ freq_scale, + /*.ext_factor =*/ ext_factor, + /*.attn_factor =*/ attn_factor, + /*.beta_fast =*/ beta_fast, + /*.beta_slow =*/ beta_slow, }; [encoder setComputePipelineState:pipeline]; @@ -3229,27 +3234,27 @@ static void ggml_metal_encode_node( } ggml_metal_kargs_flash_attn_ext args = { - .ne01 = ne01, - .ne02 = ne02, - .ne03 = ne03, - .nb01 = nb01, - .nb02 = nb02, - .nb03 = nb03, - .ne11 = ne11, - .ne_12_2 = ne12, - .ne_12_3 = ne13, - .nb_12_1 = nb11, - .nb_12_2 = nb12, - .nb_12_3 = nb13, - .nb31 = nb31, - .ne1 = ne1, - .ne2 = ne2, - .scale = scale, - .max_bias = max_bias, - .m0 = m0, - .m1 = m1, - .n_head_log2 = n_head_log2, - .logit_softcap = logit_softcap, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.nb_12_1 =*/ nb11, + /*.nb_12_2 =*/ nb12, + /*.nb_12_3 =*/ nb13, + /*.nb31 =*/ nb31, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, }; [encoder setComputePipelineState:pipeline]; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index a5fbed869..35a25eaf6 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3259,7 +3259,6 @@ kernel void kernel_flash_attn_ext_vec( constant ggml_metal_kargs_flash_attn_ext & args, threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { @@ -6210,38 +6209,26 @@ kernel void kernel_get_rows_i32( // each block_q contains 16*nl weights template -kernel void kernel_mul_mm(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { +kernel void kernel_mul_mm( + device const char * src0, + device const char * src1, + device char * dst, + constant ggml_metal_kargs_mul_mm & args, + threadgroup char * shared_memory [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { threadgroup T * sa = (threadgroup T *)(shared_memory); threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - const uint im = tgpig.z; + const int r0 = tgpig.y; + const int r1 = tgpig.x; + const int im = tgpig.z; // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; + short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; // a thread shouldn't load data outside of the matrix short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; @@ -6257,20 +6244,20 @@ kernel void kernel_mul_mm(device const uchar * src0, short il = (tiitg % THREAD_PER_ROW); - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; - uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03; - ushort offset1 = il/nl; + int offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + short offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*nb01 + offset0) + offset1; + device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1; device const float * y = (device const float *)(src1 - + nb13 * i13 - + nb12 * i12 - + nb11 * (r1 * BLOCK_SIZE_N + thread_col) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*(r1 * BLOCK_SIZE_N + thread_col) + + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory T4x4 temp_a; dequantize_func(x, il, temp_a); @@ -6317,11 +6304,13 @@ kernel void kernel_mul_mm(device const uchar * src0, } } - if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; + if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) { + device float * C = (device float *) dst + + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) + \ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; + for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); + simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0); } } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix @@ -6336,7 +6325,7 @@ kernel void kernel_mul_mm(device const uchar * src0, if (sgitg == 0) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0; + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0; device float4 * D4 = (device float4 *) D; threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); From c59a13d93fdfa567158a2dbbd96fa034b73c3915 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sat, 9 Nov 2024 22:56:39 +0200 Subject: [PATCH 05/21] cont : mul mat vec --- ggml/src/ggml-common.h | 43 ++ ggml/src/ggml-metal.m | 82 ++- ggml/src/ggml-metal.metal | 1346 +++++++++---------------------------- 3 files changed, 402 insertions(+), 1069 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index b03e1a6bf..fcf0d997c 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -487,6 +487,49 @@ typedef struct { int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mm; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mv; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; + uint64_t nb1; +} ggml_metal_kargs_mul_mv_id; #endif #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 46ae1574b..922d0d67a 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -2159,28 +2159,32 @@ static void ggml_metal_encode_node( } }; + ggml_metal_kargs_mul_mv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:19]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:20]; + [encoder setBytes:&args length:sizeof(args) atIndex:3]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || @@ -2472,30 +2476,34 @@ static void ggml_metal_encode_node( GGML_ASSERT(ne00 >= nth0*nth1); } + ggml_metal_kargs_mul_mv_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nb1 =*/ nb1, + }; + [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; + [encoder setBytes:&args length:sizeof(args) atIndex:4]; const int64_t _ne1 = 1; const int tgz = dst_rows; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 35a25eaf6..d8a65d411 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1629,26 +1629,12 @@ void mul_vec_q_n_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - const int nb = ne00/QK4_0; + uint3 tgpig, + uint tiisg, + uint sgitg) { + const int nb = args.ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; @@ -1656,11 +1642,11 @@ void mul_vec_q_n_f32_impl( const int first_row = (r0 * nsg + sgitg) * nr; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -1668,7 +1654,7 @@ void mul_vec_q_n_f32_impl( // pointers to src0 rows device const block_q_type * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } @@ -1706,8 +1692,8 @@ void mul_vec_q_n_f32_impl( for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; + if (tiisg == 0 && first_row + row < args.ne01) { + dst[im*args.ne0*args.ne1 + r1*args.ne0 + first_row + row] = tot; } } } @@ -1716,136 +1702,53 @@ kernel void kernel_mul_mv_q4_0_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } - #define NB_Q8_0 8 void kernel_mul_mv_q8_0_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1854,18 +1757,18 @@ void kernel_mul_mv_q8_0_f32_impl( const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; - const int nb = ne00/QK8_0; + const int nb = args.ne00/QK8_0; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * nsg + sgitg) * nr; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -1873,7 +1776,7 @@ void kernel_mul_mv_q8_0_f32_impl( // pointers to src0 rows device const block_q8_0 * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } @@ -1906,8 +1809,8 @@ void kernel_mul_mv_q8_0_f32_impl( for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + if (tiisg == 0 && first_row + row < args.ne01) { + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = tot; } } } @@ -1917,28 +1820,11 @@ kernel void kernel_mul_mv_q8_0_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + kernel_mul_mv_q8_0_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } #define N_MV_T_T 4 @@ -1948,80 +1834,63 @@ void kernel_mul_mv_impl( device const char * src0, device const char * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg) { + ggml_metal_kargs_mul_mv args, + uint3 tgpig, + uint tiisg) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_MV_T_T; const int64_t im = tgpig.z; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); - if (ne00 < 128) { + if (args.ne00 < 128) { for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; - if (r1 >= ne11) { + if (r1 >= args.ne11) { break; } - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { + for (int i = tiisg; i < args.ne00; i += 32) { sumf += (T0) x[i] * (T1) y[i]; } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; } } } else { device const T04 * x4 = (device const T04 *) x; for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; - if (r1 >= ne11) { + if (r1 >= args.ne11) { break; } - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); device const T14 * y4 = (device const T14 *) y; float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { + for (int i = tiisg; i < args.ne00/4; i += 32) { for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); + dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; } } } @@ -2032,48 +1901,14 @@ kernel void kernel_mul_mv( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_impl( src0, src1, dst, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - nb03, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - nb13, - ne0, - ne1, - r2, - r3, + args, tgpig, tiisg); } @@ -2093,24 +1928,7 @@ kernel void kernel_mul_mv_1row( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -2118,37 +1936,37 @@ kernel void kernel_mul_mv_1row( const int64_t r1 = tgpig.y; const int64_t im = tgpig.z; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T * x = (device const T *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); float sumf = 0; - if (ne00 < 128) { - for (int i = tiisg; i < ne00; i += 32) { + if (args.ne00 < 128) { + for (int i = tiisg; i < args.ne00; i += 32) { sumf += (float) x[i] * (float) y[i]; } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; } } else { device const T4 * x4 = (device const T4 *) x; device const float4 * y4 = (device const float4 *) y; - for (int i = tiisg; i < ne00/4; i += 32) { + for (int i = tiisg; i < args.ne00/4; i += 32) { for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); + dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; } } } @@ -2166,51 +1984,34 @@ kernel void kernel_mul_mv_l4( device const char * src0, device const char * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { - const int nrows = ne11; + const int nrows = args.ne11; const int64_t r0 = tgpig.x; const int64_t im = tgpig.z; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T4 * x4 = (device const T4 *) (src0 + offset0); for (int r1 = 0; r1 < nrows; ++r1) { - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const float4 * y4 = (device const float4 *) (src1 + offset1); float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { + for (int i = tiisg; i < args.ne00/4; i += 32) { for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; } } } @@ -4129,38 +3930,24 @@ void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -4212,9 +3999,9 @@ void kernel_mul_mv_q2_K_f32_impl( (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - qs += nb01/2; - sc += nb01; - dh += nb01/2; + qs += args.nb01/2; + sc += args.nb01; + dh += args.nb01/2; } y4 += 4 * QK_K; @@ -4223,7 +4010,7 @@ void kernel_mul_mv_q2_K_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; } } } @@ -4233,56 +4020,25 @@ kernel void kernel_mul_mv_q2_K_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -4290,11 +4046,11 @@ void kernel_mul_mv_q3_K_f32_impl( const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0); device const float * yy = (device const float *) ((device char *) src1 + offset1); @@ -4398,10 +4154,10 @@ void kernel_mul_mv_q3_K_f32_impl( sumf1[row] += d1 * (scales[1] - 32); sumf2[row] += d2 * (scales[3] - 32); - q += nb01/2; - h += nb01/2; - a += nb01/2; - dh += nb01/2; + q += args.nb01/2; + h += args.nb01/2; + a += args.nb01/2; + dh += args.nb01/2; } y1 += 4 * QK_K; @@ -4413,7 +4169,7 @@ void kernel_mul_mv_q3_K_f32_impl( } if (tiisg == 0) { for (int row = 0; row < 2; ++row) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row]; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = sumf1[row]; } } } @@ -4423,54 +4179,23 @@ kernel void kernel_mul_mv_q3_K_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q4_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; @@ -4481,18 +4206,18 @@ void kernel_mul_mv_q4_K_f32_impl( const int iq = it/4; // 0 or 1 const int ir = it%4; // 0...3 - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int first_row = r0 * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -4548,9 +4273,9 @@ void kernel_mul_mv_q4_K_f32_impl( (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += nb01/2; - sc += nb01/2; - dh += nb01/2; + q1 += args.nb01/2; + sc += args.nb01/2; + dh += args.nb01/2; } y4 += 4 * QK_K; @@ -4559,7 +4284,7 @@ void kernel_mul_mv_q4_K_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; } } } @@ -4569,56 +4294,25 @@ kernel void kernel_mul_mv_q4_K_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q5_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -4626,11 +4320,11 @@ void kernel_mul_mv_q5_K_f32_impl( const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0); device const float * yy = (device const float *) ((device char *) src1 + offset1); @@ -4707,10 +4401,10 @@ void kernel_mul_mv_q5_K_f32_impl( sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += nb01; - qh += nb01; - dh += nb01/2; - a += nb01/2; + q1 += args.nb01; + qh += args.nb01; + dh += args.nb01/2; + a += args.nb01/2; } y1 += 4 * QK_K; @@ -4719,7 +4413,7 @@ void kernel_mul_mv_q5_K_f32_impl( for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = tot; } } } @@ -4729,61 +4423,30 @@ kernel void kernel_mul_mv_q5_K_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } void kernel_mul_mv_q6_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; const uint8_t kmask3 = 0x30; const uint8_t kmask4 = 0xC0; - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -4791,11 +4454,11 @@ void kernel_mul_mv_q6_K_f32_impl( const int row = 2 * r0 + sgitg; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0); device const float * yy = (device const float *) ((device char *) src1 + offset1); @@ -4839,7 +4502,7 @@ void kernel_mul_mv_q6_K_f32_impl( const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + row] = tot; } } @@ -4848,29 +4511,12 @@ kernel void kernel_mul_mv_q6_K_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit @@ -4879,38 +4525,24 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -4966,8 +4598,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl( } sumf[row] += d * sum; - dh += nb01/2; - q2 += nb01/2; + dh += args.nb01/2; + q2 += args.nb01/2; } y4 += 32 * 32; @@ -4976,7 +4608,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f; } } } @@ -4986,68 +4618,37 @@ kernel void kernel_mul_mv_iq2_xxs_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq2_xs_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5112,9 +4713,9 @@ void kernel_mul_mv_iq2_xs_f32_impl( } sumf[row] += d1 * sum1 + d2 * sum2; - dh += nb01/2; - q2 += nb01/2; - sc += nb01; + dh += args.nb01/2; + q2 += args.nb01/2; + sc += args.nb01; } y4 += 32 * 32; @@ -5123,7 +4724,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f; } } } @@ -5133,68 +4734,37 @@ kernel void kernel_mul_mv_iq2_xs_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq3_xxs_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5252,9 +4822,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl( } sumf[row] += d * (sum[0] + sum[1]); - dh += nb01/2; - q3 += nb01; - gas += nb01/2; + dh += args.nb01/2; + q3 += args.nb01; + gas += args.nb01/2; } y4 += 32 * 32; @@ -5263,7 +4833,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.5f; } } } @@ -5273,68 +4843,37 @@ kernel void kernel_mul_mv_iq3_xxs_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq3_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5390,11 +4929,11 @@ void kernel_mul_mv_iq3_s_f32_impl( } sumf[row] += d * (sum[0] + sum[1]); - dh += nb01/2; - qs += nb01; - qh += nb01; - sc += nb01; - signs += nb01; + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01; + sc += args.nb01; + signs += args.nb01; } y4 += 32 * 32; @@ -5403,7 +4942,7 @@ void kernel_mul_mv_iq3_s_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; } } } @@ -5413,68 +4952,37 @@ kernel void kernel_mul_mv_iq3_s_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq2_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5531,11 +5039,11 @@ void kernel_mul_mv_iq2_s_f32_impl( } sumf[row] += d1 * sum[0] + d2 * sum[1]; - dh += nb01/2; - qs += nb01; - qh += nb01; - sc += nb01; - signs += nb01; + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01; + sc += args.nb01; + signs += args.nb01; } y4 += 32 * 32; @@ -5544,7 +5052,7 @@ void kernel_mul_mv_iq2_s_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f; } } } @@ -5554,68 +5062,37 @@ kernel void kernel_mul_mv_iq2_s_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } void kernel_mul_mv_iq1_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5661,9 +5138,9 @@ void kernel_mul_mv_iq1_s_f32_impl( } sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); - dh += nb01/2; - qs += nb01; - qh += nb01/2; + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01/2; } y4 += 32 * 32; @@ -5672,7 +5149,7 @@ void kernel_mul_mv_iq1_s_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; } } } @@ -5681,38 +5158,24 @@ void kernel_mul_mv_iq1_m_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5767,9 +5230,9 @@ void kernel_mul_mv_iq1_m_f32_impl( sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); - sc += nb01/2; - qs += nb01; - qh += nb01; + sc += args.nb01/2; + qs += args.nb01; + qh += args.nb01; } y4 += 32 * 32; @@ -5778,7 +5241,7 @@ void kernel_mul_mv_iq1_m_f32_impl( for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; } } } @@ -5787,38 +5250,24 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values_i8, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK4_NL; + const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * 2 + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5844,7 +5293,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { device const block_iq4_nl & xb = x[row*nb + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); @@ -5876,10 +5325,10 @@ void kernel_mul_mv_iq4_nl_f32_impl( yb += 16 * QK4_NL; } - for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; } } } @@ -5888,38 +5337,24 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values_i8, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + uint tiisg, + uint sgitg) { threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * 2 + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0); device const float * y = (device const float *) ((device char *) src1 + offset1); @@ -5981,7 +5416,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( for (int row = 0; row < 2; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; } } } @@ -5991,29 +5426,12 @@ kernel void kernel_mul_mv_iq1_s_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq1_m_f32")]] @@ -6021,29 +5439,12 @@ kernel void kernel_mul_mv_iq1_m_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_nl_f32")]] @@ -6051,30 +5452,13 @@ kernel void kernel_mul_mv_iq4_nl_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] @@ -6082,30 +5466,13 @@ kernel void kernel_mul_mv_iq4_xs_f32( device const void * src0, device const float * src1, device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, + constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } template @@ -6648,46 +6015,15 @@ typedef void (kernel_mul_mv_impl_t)( device const char * src0, device const char * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg); + ggml_metal_kargs_mul_mv args, + uint3 tgpig, + uint tiisg); typedef void (kernel_mul_mv2_impl_t)( device const void * src0, device const float * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -6698,32 +6034,13 @@ void mmv_fn( device const char * src0, device const char * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - int64_t ne13, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint64_t nb1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg); + impl_fn(src0, src1, dst, args, tgpig, tiisg); } template @@ -6731,32 +6048,13 @@ void mmv_fn( device const char * src0, device const char * src1, device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - int64_t ne13, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint64_t nb1, - uint r2, - uint r3, + ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); + impl_fn(src0,(const device float *) src1, dst, args, shared_values, tgpig, tiisg, sgitg); } typedef decltype(mmv_fn>) mul_mv_impl_fn_t; @@ -6767,71 +6065,55 @@ kernel void kernel_mul_mv_id( device const char * src1, device float * dst, device const char * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, + constant ggml_metal_kargs_mul_mv_id & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int iid1 = tgpig.z/nei0; - const int idx = tgpig.z%nei0; + const int iid1 = tgpig.z/args.nei0; + const int idx = tgpig.z%args.nei0; tgpig.z = 0; - const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx]; + const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx]; - const int64_t i11 = idx % ne11; + const int64_t i11 = idx % args.ne11; const int64_t i12 = iid1; const int64_t i1 = idx; const int64_t i2 = i12; - device const char * src0_cur = src0s + i02*nb02; - device const char * src1_cur = src1 + i11*nb11 + i12*nb12; - device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + device const char * src0_cur = src0s + i02*args.nb02; + device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12; + device float * dst_cur = dst + i1*args.ne0 + i2*args.ne1*args.ne0; + + ggml_metal_kargs_mul_mv args0 = { + /*.ne00 =*/ args.ne00, + /*.ne01 =*/ args.ne01, + /*.ne02 =*/ 1, // args.ne02, + /*.nb00 =*/ args.nb00, + /*.nb01 =*/ args.nb01, + /*.nb02 =*/ args.nb02, + /*.nb03 =*/ args.nb02, // args.ne02 == 1 + /*.ne10 =*/ args.ne10, + /*.ne11 =*/ 1, // args.ne11, + /*.ne12 =*/ 1, // args.ne12, + /*.nb10 =*/ args.nb10, + /*.nb11 =*/ args.nb11, + /*.nb12 =*/ args.nb12, + /*.nb13 =*/ args.nb12, // ne12 == 1 + /*.ne0 =*/ args.ne0, + /*.ne1 =*/ 1, // args.ne1, + /*.r2 =*/ 1, + /*.r3 =*/ 1, + }; impl_fn( /* src0 */ src0_cur, /* src1 */ src1_cur, /* dst */ dst_cur, - /* ne00 */ ne00, - /* ne01 */ ne01, - /* ne02 */ 1, // ne02, - /* nb00 */ nb00, - /* nb01 */ nb01, - /* nb02 */ nb02, - /* nb03 */ nb02, // ne02 == 1 - /* ne10 */ ne10, - /* ne11 */ 1, // ne11, - /* ne12 */ 1, // ne12, - /* ne13 */ 1, // ne13, - /* nb10 */ nb10, - /* nb11 */ nb11, - /* nb12 */ nb12, - /* ne13 */ nb12, // ne12 == 1 - /* ne0 */ ne0, - /* ne1 */ 1, // ne1, - /* nb1 */ nb1, - /* r2 */ 1, - /* r3 */ 1, + args0, shared_values, tgpig, tiitg, From b65e4c1e105ca56d0959abb358626fa030f41bb0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 08:10:22 +0200 Subject: [PATCH 06/21] cont : pass by reference --- ggml/src/ggml-metal.metal | 99 ++++++++++++++++++++++----------------- 1 file changed, 57 insertions(+), 42 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index d8a65d411..7526643e8 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1624,12 +1624,12 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre // quantizations where the block size is 32. It also does not // guard against the number of rows not being divisible by // N_DST, so this is another explicit assumption of the implementation. -template +template void mul_vec_q_n_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1706,7 +1706,7 @@ kernel void kernel_mul_mv_q4_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( @@ -1715,9 +1715,9 @@ kernel void kernel_mul_mv_q4_1_f32( device float * dst, constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( @@ -1728,7 +1728,7 @@ kernel void kernel_mul_mv_q5_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( @@ -1739,16 +1739,17 @@ kernel void kernel_mul_mv_q5_1_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } #define NB_Q8_0 8 +template void kernel_mul_mv_q8_0_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1824,17 +1825,17 @@ kernel void kernel_mul_mv_q8_0_f32( uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } #define N_MV_T_T 4 -template +template void kernel_mul_mv_impl( device const char * src0, device const char * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, uint3 tgpig, uint tiisg) { const int64_t r0 = tgpig.x; @@ -1904,7 +1905,7 @@ kernel void kernel_mul_mv( constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_impl( + kernel_mul_mv_impl( src0, src1, dst, @@ -3926,11 +3927,12 @@ kernel void kernel_concat( } } +template void kernel_mul_mv_q2_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4025,14 +4027,15 @@ kernel void kernel_mul_mv_q2_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q3_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4184,14 +4187,15 @@ kernel void kernel_mul_mv_q3_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q4_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4299,14 +4303,15 @@ kernel void kernel_mul_mv_q4_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q5_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4428,14 +4433,15 @@ kernel void kernel_mul_mv_q5_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q6_K_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4516,16 +4522,17 @@ kernel void kernel_mul_mv_q6_K_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit +template void kernel_mul_mv_iq2_xxs_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4624,14 +4631,15 @@ kernel void kernel_mul_mv_iq2_xxs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq2_xs_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4740,14 +4748,15 @@ kernel void kernel_mul_mv_iq2_xs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq3_xxs_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4849,14 +4858,15 @@ kernel void kernel_mul_mv_iq3_xxs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq3_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4958,14 +4968,15 @@ kernel void kernel_mul_mv_iq3_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq2_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -5068,14 +5079,15 @@ kernel void kernel_mul_mv_iq2_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq1_s_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5154,11 +5166,12 @@ void kernel_mul_mv_iq1_s_f32_impl( } } +template void kernel_mul_mv_iq1_m_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5246,11 +5259,12 @@ void kernel_mul_mv_iq1_m_f32_impl( } } +template void kernel_mul_mv_iq4_nl_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5333,11 +5347,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( } } +template void kernel_mul_mv_iq4_xs_f32_impl( device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, + A args, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5431,7 +5446,7 @@ kernel void kernel_mul_mv_iq1_s_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq1_m_f32")]] @@ -5444,7 +5459,7 @@ kernel void kernel_mul_mv_iq1_m_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_nl_f32")]] @@ -5458,7 +5473,7 @@ kernel void kernel_mul_mv_iq4_nl_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] @@ -5472,7 +5487,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); } template @@ -6057,7 +6072,7 @@ void mmv_fn( impl_fn(src0,(const device float *) src1, dst, args, shared_values, tgpig, tiisg, sgitg); } -typedef decltype(mmv_fn>) mul_mv_impl_fn_t; +typedef decltype(mmv_fn>) mul_mv_impl_fn_t; template kernel void kernel_mul_mv_id( From c81640a5fca628b686354a3abb87a1c5017b59a1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 08:47:30 +0200 Subject: [PATCH 07/21] cont : args is first argument --- ggml/src/ggml-metal.m | 52 +++++------ ggml/src/ggml-metal.metal | 182 +++++++++++++++++++------------------- 2 files changed, 117 insertions(+), 117 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 922d0d67a..d3497233e 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1977,10 +1977,10 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; @@ -2181,10 +2181,10 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&args length:sizeof(args) atIndex:3]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || @@ -2499,11 +2499,11 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&args length:sizeof(args) atIndex:4]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; const int64_t _ne1 = 1; const int tgz = dst_rows; @@ -2748,15 +2748,15 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&args length:sizeof(args) atIndex:4]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -3266,16 +3266,16 @@ static void ggml_metal_encode_node( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&args length:sizeof(args) atIndex:5]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:5]; if (!use_vec_kernel) { // half8x8 kernel diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 7526643e8..f2bd66f6c 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1624,12 +1624,12 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre // quantizations where the block size is 32. It also does not // guard against the number of rows not being divisible by // N_DST, so this is another explicit assumption of the implementation. -template +template void mul_vec_q_n_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1699,57 +1699,57 @@ void mul_vec_q_n_f32_impl( } kernel void kernel_mul_mv_q4_0_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } #define NB_Q8_0 8 -template +template void kernel_mul_mv_q8_0_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1818,24 +1818,24 @@ void kernel_mul_mv_q8_0_f32_impl( [[host_name("kernel_mul_mv_q8_0_f32")]] kernel void kernel_mul_mv_q8_0_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } #define N_MV_T_T 4 -template +template void kernel_mul_mv_impl( + args_t args, device const char * src0, device const char * src1, device float * dst, - A args, uint3 tgpig, uint tiisg) { const int64_t r0 = tgpig.x; @@ -1899,17 +1899,17 @@ void kernel_mul_mv_impl( template kernel void kernel_mul_mv( + constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_impl( + args, src0, src1, dst, - args, tgpig, tiisg); } @@ -1926,10 +1926,10 @@ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv< template kernel void kernel_mul_mv_1row( + constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -1982,10 +1982,10 @@ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kerne // Assumes row size (ne00) is a multiple of 4 template kernel void kernel_mul_mv_l4( + constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { @@ -2064,11 +2064,11 @@ static void rope_yarn_corr_dims( template kernel void kernel_rope_norm( + constant ggml_metal_kargs_rope & args, device const char * src0, device const char * src1, device const char * src2, device char * dst, - constant ggml_metal_kargs_rope & args, ushort tiitg[[thread_index_in_threadgroup]], ushort3 tptg [[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { @@ -2117,11 +2117,11 @@ kernel void kernel_rope_norm( template kernel void kernel_rope_neox( + constant ggml_metal_kargs_rope & args, device const char * src0, device const char * src1, device const char * src2, device char * dst, - constant ggml_metal_kargs_rope & args, ushort tiitg[[thread_index_in_threadgroup]], ushort3 tptg [[threads_per_threadgroup]], uint3 tgpig[[threadgroup_position_in_grid]]) { @@ -2558,13 +2558,13 @@ template< short KV = 8, // key/value processed per each simdgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext( + constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, device const char * k, device const char * v, device const char * mask, device char * dst, - constant ggml_metal_kargs_flash_attn_ext & args, - threadgroup half * shared [[threadgroup(0)]], + threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], @@ -3053,13 +3053,13 @@ template< short Q = 1, // queries per threadgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( + constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, device const char * k, device const char * v, device const char * mask, device char * dst, - constant ggml_metal_kargs_flash_attn_ext & args, - threadgroup half * shared [[threadgroup(0)]], + threadgroup half * shared [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], @@ -3927,12 +3927,12 @@ kernel void kernel_concat( } } -template +template void kernel_mul_mv_q2_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4019,23 +4019,23 @@ void kernel_mul_mv_q2_K_f32_impl( [[host_name("kernel_mul_mv_q2_K_f32")]] kernel void kernel_mul_mv_q2_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q3_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4179,23 +4179,23 @@ void kernel_mul_mv_q3_K_f32_impl( [[host_name("kernel_mul_mv_q3_K_f32")]] kernel void kernel_mul_mv_q3_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q4_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4295,23 +4295,23 @@ void kernel_mul_mv_q4_K_f32_impl( [[host_name("kernel_mul_mv_q4_K_f32")]] kernel void kernel_mul_mv_q4_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q5_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4425,23 +4425,23 @@ void kernel_mul_mv_q5_K_f32_impl( [[host_name("kernel_mul_mv_q5_K_f32")]] kernel void kernel_mul_mv_q5_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_q6_K_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4514,25 +4514,25 @@ void kernel_mul_mv_q6_K_f32_impl( [[host_name("kernel_mul_mv_q6_K_f32")]] kernel void kernel_mul_mv_q6_K_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit -template +template void kernel_mul_mv_iq2_xxs_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4622,24 +4622,24 @@ void kernel_mul_mv_iq2_xxs_f32_impl( [[host_name("kernel_mul_mv_iq2_xxs_f32")]] kernel void kernel_mul_mv_iq2_xxs_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_xs_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4739,24 +4739,24 @@ void kernel_mul_mv_iq2_xs_f32_impl( [[host_name("kernel_mul_mv_iq2_xs_f32")]] kernel void kernel_mul_mv_iq2_xs_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_xxs_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4849,24 +4849,24 @@ void kernel_mul_mv_iq3_xxs_f32_impl( [[host_name("kernel_mul_mv_iq3_xxs_f32")]] kernel void kernel_mul_mv_iq3_xxs_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq3_s_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4959,24 +4959,24 @@ void kernel_mul_mv_iq3_s_f32_impl( [[host_name("kernel_mul_mv_iq3_s_f32")]] kernel void kernel_mul_mv_iq3_s_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq2_s_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -5070,24 +5070,24 @@ void kernel_mul_mv_iq2_s_f32_impl( [[host_name("kernel_mul_mv_iq2_s_f32")]] kernel void kernel_mul_mv_iq2_s_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq1_s_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5166,12 +5166,12 @@ void kernel_mul_mv_iq1_s_f32_impl( } } -template +template void kernel_mul_mv_iq1_m_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5259,12 +5259,12 @@ void kernel_mul_mv_iq1_m_f32_impl( } } -template +template void kernel_mul_mv_iq4_nl_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5347,12 +5347,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( } } -template +template void kernel_mul_mv_iq4_xs_f32_impl( + args_t args, device const void * src0, device const float * src1, device float * dst, - A args, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5438,56 +5438,56 @@ void kernel_mul_mv_iq4_xs_f32_impl( [[host_name("kernel_mul_mv_iq1_s_f32")]] kernel void kernel_mul_mv_iq1_s_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq1_m_f32")]] kernel void kernel_mul_mv_iq1_m_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_nl_f32")]] kernel void kernel_mul_mv_iq4_nl_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] kernel void kernel_mul_mv_iq4_xs_f32( + constant ggml_metal_kargs_mul_mv & args, device const void * src0, device const float * src1, device float * dst, - constant ggml_metal_kargs_mul_mv & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } template @@ -5592,10 +5592,10 @@ kernel void kernel_get_rows_i32( // each block_q contains 16*nl weights template kernel void kernel_mul_mm( + constant ggml_metal_kargs_mul_mm & args, device const char * src0, device const char * src1, device char * dst, - constant ggml_metal_kargs_mul_mm & args, threadgroup char * shared_memory [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiitg[[thread_index_in_threadgroup]], @@ -6027,18 +6027,18 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel // typedef void (kernel_mul_mv_impl_t)( + ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, device float * dst, - ggml_metal_kargs_mul_mv args, uint3 tgpig, uint tiisg); typedef void (kernel_mul_mv2_impl_t)( + ggml_metal_kargs_mul_mv args, device const void * src0, device const float * src1, device float * dst, - ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -6046,41 +6046,41 @@ typedef void (kernel_mul_mv2_impl_t)( template void mmv_fn( + ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, device float * dst, - ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0, src1, dst, args, tgpig, tiisg); + impl_fn(args, src0, src1, dst, tgpig, tiisg); } template void mmv_fn( + ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, device float * dst, - ggml_metal_kargs_mul_mv args, threadgroup int8_t * shared_values, uint3 tgpig, uint tiitg, uint tiisg, uint sgitg) { - impl_fn(src0,(const device float *) src1, dst, args, shared_values, tgpig, tiisg, sgitg); + impl_fn(args, src0,(const device float *) src1, dst, shared_values, tgpig, tiisg, sgitg); } typedef decltype(mmv_fn>) mul_mv_impl_fn_t; template kernel void kernel_mul_mv_id( + constant ggml_metal_kargs_mul_mv_id & args, device const char * src0s, device const char * src1, device float * dst, device const char * ids, - constant ggml_metal_kargs_mul_mv_id & args, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], @@ -6125,10 +6125,10 @@ kernel void kernel_mul_mv_id( }; impl_fn( + args0, /* src0 */ src0_cur, /* src1 */ src1_cur, /* dst */ dst_cur, - args0, shared_values, tgpig, tiitg, From a1a201c1a90577914e615b71de2b66015707c2b0 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 09:26:53 +0200 Subject: [PATCH 08/21] cont : use char ptr --- ggml/src/ggml-metal.metal | 395 +++++++++++++++++++++----------------- 1 file changed, 216 insertions(+), 179 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index f2bd66f6c..95c07659e 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1627,9 +1627,9 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre template void mul_vec_q_n_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1648,8 +1648,8 @@ void mul_vec_q_n_f32_impl( //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows device const block_q_type * ax[nr]; @@ -1690,19 +1690,22 @@ void mul_vec_q_n_f32_impl( yb += QK4_0 * 16; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < args.ne01) { - dst[im*args.ne0*args.ne1 + r1*args.ne0 + first_row + row] = tot; + dst_f32[first_row + row] = tot; } } } kernel void kernel_mul_mv_q4_0_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1711,9 +1714,9 @@ kernel void kernel_mul_mv_q4_0_f32( kernel void kernel_mul_mv_q4_1_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1722,9 +1725,9 @@ kernel void kernel_mul_mv_q4_1_f32( kernel void kernel_mul_mv_q5_0_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1733,9 +1736,9 @@ kernel void kernel_mul_mv_q5_0_f32( kernel void kernel_mul_mv_q5_1_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1747,9 +1750,9 @@ kernel void kernel_mul_mv_q5_1_f32( template void kernel_mul_mv_q8_0_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -1771,8 +1774,8 @@ void kernel_mul_mv_q8_0_f32_impl( //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows device const block_q8_0 * ax[nr]; @@ -1808,10 +1811,12 @@ void kernel_mul_mv_q8_0_f32_impl( yb += NB_Q8_0 * nw; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < args.ne01) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = tot; + dst_f32[first_row + row] = tot; } } } @@ -1819,9 +1824,9 @@ void kernel_mul_mv_q8_0_f32_impl( [[host_name("kernel_mul_mv_q8_0_f32")]] kernel void kernel_mul_mv_q8_0_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -1833,9 +1838,9 @@ kernel void kernel_mul_mv_q8_0_f32( template void kernel_mul_mv_impl( args_t args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig, uint tiisg) { const int64_t r0 = tgpig.x; @@ -1849,6 +1854,8 @@ void kernel_mul_mv_impl( device const T0 * x = (device const T0 *) (src0 + offset0); + device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1; + if (args.ne00 < 128) { for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; @@ -1867,7 +1874,7 @@ void kernel_mul_mv_impl( float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[r1*args.ne0 + r0] = all_sum; } } } else { @@ -1891,7 +1898,7 @@ void kernel_mul_mv_impl( float all_sum = simd_sum(sumf); if (tiisg == 0) { for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[r1*args.ne0 + r0] = all_sum; } } } @@ -1900,9 +1907,9 @@ void kernel_mul_mv_impl( template kernel void kernel_mul_mv( constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_impl( @@ -3930,9 +3937,9 @@ kernel void kernel_concat( template void kernel_mul_mv_q2_K_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -3951,8 +3958,8 @@ void kernel_mul_mv_q2_K_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4009,10 +4016,12 @@ void kernel_mul_mv_q2_K_f32_impl( y4 += 4 * QK_K; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } @@ -4020,9 +4029,9 @@ void kernel_mul_mv_q2_K_f32_impl( [[host_name("kernel_mul_mv_q2_K_f32")]] kernel void kernel_mul_mv_q2_K_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -4033,9 +4042,9 @@ kernel void kernel_mul_mv_q2_K_f32( template void kernel_mul_mv_q3_K_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4055,8 +4064,8 @@ void kernel_mul_mv_q3_K_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0); - device const float * yy = (device const float *) ((device char *) src1 + offset1); + device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); float yl[32]; @@ -4170,9 +4179,12 @@ void kernel_mul_mv_q3_K_f32_impl( const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); sumf1[row] = simd_sum(sumf); } + + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + if (tiisg == 0) { for (int row = 0; row < 2; ++row) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = sumf1[row]; + dst_f32[first_row + row] = sumf1[row]; } } } @@ -4180,9 +4192,9 @@ void kernel_mul_mv_q3_K_f32_impl( [[host_name("kernel_mul_mv_q3_K_f32")]] kernel void kernel_mul_mv_q3_K_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -4193,9 +4205,9 @@ kernel void kernel_mul_mv_q3_K_f32( template void kernel_mul_mv_q4_K_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4223,8 +4235,8 @@ void kernel_mul_mv_q4_K_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[16]; float yh[16]; @@ -4285,10 +4297,12 @@ void kernel_mul_mv_q4_K_f32_impl( y4 += 4 * QK_K; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } @@ -4296,9 +4310,9 @@ void kernel_mul_mv_q4_K_f32_impl( [[host_name("kernel_mul_mv_q4_K_f32")]] kernel void kernel_mul_mv_q4_K_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -4309,9 +4323,9 @@ kernel void kernel_mul_mv_q4_K_f32( template void kernel_mul_mv_q5_K_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4331,8 +4345,8 @@ void kernel_mul_mv_q5_K_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0); - device const float * yy = (device const float *) ((device char *) src1 + offset1); + device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); float sumf[2]={0.f}; @@ -4415,10 +4429,12 @@ void kernel_mul_mv_q5_K_f32_impl( y1 += 4 * QK_K; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = tot; + dst_f32[first_row + row] = tot; } } } @@ -4426,9 +4442,9 @@ void kernel_mul_mv_q5_K_f32_impl( [[host_name("kernel_mul_mv_q5_K_f32")]] kernel void kernel_mul_mv_q5_K_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -4439,9 +4455,9 @@ kernel void kernel_mul_mv_q5_K_f32( template void kernel_mul_mv_q6_K_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4466,8 +4482,8 @@ void kernel_mul_mv_q6_K_f32_impl( const uint offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0); - device const float * yy = (device const float *) ((device char *) src1 + offset1); + device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); float sumf = 0; @@ -4506,18 +4522,20 @@ void kernel_mul_mv_q6_K_f32_impl( } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + row] = tot; + dst_f32[row] = tot; } } [[host_name("kernel_mul_mv_q6_K_f32")]] kernel void kernel_mul_mv_q6_K_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -4530,9 +4548,9 @@ kernel void kernel_mul_mv_q6_K_f32( template void kernel_mul_mv_iq2_xxs_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4551,8 +4569,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4612,10 +4630,12 @@ void kernel_mul_mv_iq2_xxs_f32_impl( y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = all_sum * 0.25f; } } } @@ -4623,9 +4643,9 @@ void kernel_mul_mv_iq2_xxs_f32_impl( [[host_name("kernel_mul_mv_iq2_xxs_f32")]] kernel void kernel_mul_mv_iq2_xxs_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -4637,9 +4657,9 @@ kernel void kernel_mul_mv_iq2_xxs_f32( template void kernel_mul_mv_iq2_xs_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4658,8 +4678,8 @@ void kernel_mul_mv_iq2_xs_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4729,10 +4749,12 @@ void kernel_mul_mv_iq2_xs_f32_impl( y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = all_sum * 0.25f; } } } @@ -4740,9 +4762,9 @@ void kernel_mul_mv_iq2_xs_f32_impl( [[host_name("kernel_mul_mv_iq2_xs_f32")]] kernel void kernel_mul_mv_iq2_xs_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -4754,9 +4776,9 @@ kernel void kernel_mul_mv_iq2_xs_f32( template void kernel_mul_mv_iq3_xxs_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4775,8 +4797,8 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4839,10 +4861,12 @@ void kernel_mul_mv_iq3_xxs_f32_impl( y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.5f; + dst_f32[first_row + row] = all_sum * 0.5f; } } } @@ -4850,9 +4874,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl( [[host_name("kernel_mul_mv_iq3_xxs_f32")]] kernel void kernel_mul_mv_iq3_xxs_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -4864,9 +4888,9 @@ kernel void kernel_mul_mv_iq3_xxs_f32( template void kernel_mul_mv_iq3_s_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4885,8 +4909,8 @@ void kernel_mul_mv_iq3_s_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4949,10 +4973,12 @@ void kernel_mul_mv_iq3_s_f32_impl( y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } @@ -4960,9 +4986,9 @@ void kernel_mul_mv_iq3_s_f32_impl( [[host_name("kernel_mul_mv_iq3_s_f32")]] kernel void kernel_mul_mv_iq3_s_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -4974,9 +5000,9 @@ kernel void kernel_mul_mv_iq3_s_f32( template void kernel_mul_mv_iq2_s_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -4995,8 +5021,8 @@ void kernel_mul_mv_iq2_s_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -5060,10 +5086,12 @@ void kernel_mul_mv_iq2_s_f32_impl( y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = all_sum * 0.25f; } } } @@ -5071,9 +5099,9 @@ void kernel_mul_mv_iq2_s_f32_impl( [[host_name("kernel_mul_mv_iq2_s_f32")]] kernel void kernel_mul_mv_iq2_s_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -5085,9 +5113,9 @@ kernel void kernel_mul_mv_iq2_s_f32( template void kernel_mul_mv_iq1_s_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5106,8 +5134,8 @@ void kernel_mul_mv_iq1_s_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -5158,10 +5186,12 @@ void kernel_mul_mv_iq1_s_f32_impl( y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } @@ -5169,9 +5199,9 @@ void kernel_mul_mv_iq1_s_f32_impl( template void kernel_mul_mv_iq1_m_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_value, uint3 tgpig, uint tiisg, @@ -5190,8 +5220,8 @@ void kernel_mul_mv_iq1_m_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -5251,10 +5281,12 @@ void kernel_mul_mv_iq1_m_f32_impl( y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } @@ -5262,9 +5294,9 @@ void kernel_mul_mv_iq1_m_f32_impl( template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5283,8 +5315,8 @@ void kernel_mul_mv_iq4_nl_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); const int ix = tiisg/2; // 0...15 const int it = tiisg%2; // 0 or 1 @@ -5339,10 +5371,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( yb += 16 * QK4_NL; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } @@ -5350,9 +5384,9 @@ void kernel_mul_mv_iq4_nl_f32_impl( template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values_i8, uint3 tgpig, uint tiisg, @@ -5371,8 +5405,8 @@ void kernel_mul_mv_iq4_xs_f32_impl( const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); const int ix = tiisg/16; // 0 or 1 const int it = tiisg%16; // 0...15 @@ -5428,10 +5462,12 @@ void kernel_mul_mv_iq4_xs_f32_impl( yb += 2 * QK_K; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < 2; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } @@ -5439,9 +5475,9 @@ void kernel_mul_mv_iq4_xs_f32_impl( [[host_name("kernel_mul_mv_iq1_s_f32")]] kernel void kernel_mul_mv_iq1_s_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -5452,9 +5488,9 @@ kernel void kernel_mul_mv_iq1_s_f32( [[host_name("kernel_mul_mv_iq1_m_f32")]] kernel void kernel_mul_mv_iq1_m_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { @@ -5465,9 +5501,9 @@ kernel void kernel_mul_mv_iq1_m_f32( [[host_name("kernel_mul_mv_iq4_nl_f32")]] kernel void kernel_mul_mv_iq4_nl_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -5479,9 +5515,9 @@ kernel void kernel_mul_mv_iq4_nl_f32( [[host_name("kernel_mul_mv_iq4_xs_f32")]] kernel void kernel_mul_mv_iq4_xs_f32( constant ggml_metal_kargs_mul_mv & args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], @@ -6028,17 +6064,17 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel typedef void (kernel_mul_mv_impl_t)( ggml_metal_kargs_mul_mv args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig, uint tiisg); typedef void (kernel_mul_mv2_impl_t)( ggml_metal_kargs_mul_mv args, - device const void * src0, - device const float * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiisg, @@ -6047,9 +6083,9 @@ typedef void (kernel_mul_mv2_impl_t)( template void mmv_fn( ggml_metal_kargs_mul_mv args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiitg, @@ -6061,15 +6097,15 @@ void mmv_fn( template void mmv_fn( ggml_metal_kargs_mul_mv args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, threadgroup int8_t * shared_values, uint3 tgpig, uint tiitg, uint tiisg, uint sgitg) { - impl_fn(args, src0,(const device float *) src1, dst, shared_values, tgpig, tiisg, sgitg); + impl_fn(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); } typedef decltype(mmv_fn>) mul_mv_impl_fn_t; @@ -6077,10 +6113,10 @@ typedef decltype(mmv_fn kernel void kernel_mul_mv_id( constant ggml_metal_kargs_mul_mv_id & args, - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, + device const char * src0s, + device const char * src1, + device char * dst, + device const char * ids, threadgroup int8_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiitg[[thread_index_in_threadgroup]], @@ -6101,7 +6137,8 @@ kernel void kernel_mul_mv_id( device const char * src0_cur = src0s + i02*args.nb02; device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12; - device float * dst_cur = dst + i1*args.ne0 + i2*args.ne1*args.ne0; + + device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float); ggml_metal_kargs_mul_mv args0 = { /*.ne00 =*/ args.ne00, From cacc4c225fe0995fcfe7b5956edb82113a823067 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 09:45:06 +0200 Subject: [PATCH 09/21] cont : shmem style --- ggml/src/ggml-metal.metal | 215 +++++++++++++++++++------------------- 1 file changed, 107 insertions(+), 108 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 95c07659e..3d3cd769a 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1630,7 +1630,7 @@ void mul_vec_q_n_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -1753,7 +1753,7 @@ void kernel_mul_mv_q8_0_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -2571,7 +2571,7 @@ kernel void kernel_flash_attn_ext( device const char * v, device const char * mask, device char * dst, - threadgroup half * shared [[threadgroup(0)]], + threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], @@ -2591,17 +2591,17 @@ kernel void kernel_flash_attn_ext( const short TS = nsg*SH; // shared memory size per query in (s_t == float) const short T = D + 2*TS; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t - threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation - threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t - threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix - threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory - threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t - threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory - threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) o8x8_t lo[D8]; @@ -3066,7 +3066,7 @@ kernel void kernel_flash_attn_ext_vec( device const char * v, device const char * mask, device char * dst, - threadgroup half * shared [[threadgroup(0)]], + threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 ntg[[threads_per_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], @@ -3085,13 +3085,13 @@ kernel void kernel_flash_attn_ext_vec( const short T = D + nsg*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t - threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t - threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t - threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask - threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t + threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask + threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) o4x4_t lo[D16/NL]; @@ -3940,7 +3940,7 @@ void kernel_mul_mv_q2_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4045,7 +4045,7 @@ void kernel_mul_mv_q3_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4208,7 +4208,7 @@ void kernel_mul_mv_q4_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4326,7 +4326,7 @@ void kernel_mul_mv_q5_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4458,7 +4458,7 @@ void kernel_mul_mv_q6_K_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4551,7 +4551,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4577,15 +4577,15 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const int nb32 = nb * (QK_K / 32); - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); { int nval = 4; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4615,8 +4615,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl( float sum = 0; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]); + const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; for (int j = 0; j < 8; ++j) { sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } @@ -4646,12 +4646,11 @@ kernel void kernel_mul_mv_iq2_xxs_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -4660,7 +4659,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4686,15 +4685,15 @@ void kernel_mul_mv_iq2_xs_f32_impl( const int nb32 = nb * (QK_K / 32); - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); + threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512); { int nval = 8; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4726,15 +4725,15 @@ void kernel_mul_mv_iq2_xs_f32_impl( float sum1 = 0, sum2 = 0; for (int l = 0; l < 2; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); + const uint8_t signs = ssigns[(q2[l] >> 9)]; for (int j = 0; j < 8; ++j) { sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } for (int l = 2; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); + const uint8_t signs = ssigns[(q2[l] >> 9)]; for (int j = 0; j < 8; ++j) { sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } @@ -4765,12 +4764,12 @@ kernel void kernel_mul_mv_iq2_xs_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -4779,7 +4778,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4805,15 +4804,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const int nb32 = nb * (QK_K / 32); - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); { int nval = 4; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4843,9 +4842,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl( float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]); + const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; for (int j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); @@ -4877,12 +4876,12 @@ kernel void kernel_mul_mv_iq3_xxs_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -4891,7 +4890,7 @@ void kernel_mul_mv_iq3_s_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -4917,11 +4916,11 @@ void kernel_mul_mv_iq3_s_f32_impl( const int nb32 = nb * (QK_K / 32); - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem; { int nval = 8; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -4952,8 +4951,8 @@ void kernel_mul_mv_iq3_s_f32_impl( float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; - const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues; const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); for (int j = 0; j < 4; ++j) { @@ -4989,12 +4988,12 @@ kernel void kernel_mul_mv_iq3_s_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -5003,7 +5002,7 @@ void kernel_mul_mv_iq2_s_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -5029,11 +5028,11 @@ void kernel_mul_mv_iq2_s_f32_impl( const int nb32 = nb * (QK_K / 32); - //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem; //{ // int nval = 32; // int pos = (32*sgitg + tiisg)*nval; - // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i]; + // for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i]; // threadgroup_barrier(mem_flags::mem_threadgroup); //} @@ -5065,8 +5064,8 @@ void kernel_mul_mv_iq2_s_f32_impl( float2 sum = {0}; for (int l = 0; l < 2; ++l) { - //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); - //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); for (int j = 0; j < 8; ++j) { @@ -5102,12 +5101,12 @@ kernel void kernel_mul_mv_iq2_s_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -5116,7 +5115,7 @@ void kernel_mul_mv_iq1_s_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_value, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -5202,7 +5201,7 @@ void kernel_mul_mv_iq1_m_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_value, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { @@ -5297,12 +5296,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values_i8, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; @@ -5321,7 +5320,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( const int ix = tiisg/2; // 0...15 const int it = tiisg%2; // 0 or 1 - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; @@ -5349,16 +5348,16 @@ void kernel_mul_mv_iq4_nl_f32_impl( aux32[0] = q4[0] | (q4[1] << 16); aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; aux32[0] = q4[2] | (q4[3] << 16); aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[2] * qf1; acc2 += yl[3] * qf2; @@ -5387,12 +5386,12 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values_i8, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg) { - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; + threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; @@ -5413,7 +5412,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( const int ib = it/2; const int il = it%2; - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; @@ -5440,15 +5439,15 @@ void kernel_mul_mv_iq4_xs_f32_impl( aux32[0] = q4[0] & 0x0f0f0f0f; aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; aux32[0] = q4[1] & 0x0f0f0f0f; aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[2] * qf1; acc2 += yl[3] * qf2; @@ -5504,12 +5503,12 @@ kernel void kernel_mul_mv_iq4_nl_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] @@ -5518,12 +5517,12 @@ kernel void kernel_mul_mv_iq4_xs_f32( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -6075,7 +6074,7 @@ typedef void (kernel_mul_mv2_impl_t)( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, + threadgroup char * shmem, uint3 tgpig, uint tiisg, uint sgitg); @@ -6086,11 +6085,11 @@ void mmv_fn( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { + threadgroup char * shmem, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { impl_fn(args, src0, src1, dst, tgpig, tiisg); } @@ -6100,12 +6099,12 @@ void mmv_fn( device const char * src0, device const char * src1, device char * dst, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { - impl_fn(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg); + threadgroup char * shmem, + uint3 tgpig, + uint tiitg, + uint tiisg, + uint sgitg) { + impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } typedef decltype(mmv_fn>) mul_mv_impl_fn_t; @@ -6117,11 +6116,11 @@ kernel void kernel_mul_mv_id( device const char * src1, device char * dst, device const char * ids, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int iid1 = tgpig.z/args.nei0; const int idx = tgpig.z%args.nei0; @@ -6166,7 +6165,7 @@ kernel void kernel_mul_mv_id( /* src0 */ src0_cur, /* src1 */ src1_cur, /* dst */ dst_cur, - shared_values, + shmem, tgpig, tiitg, tiisg, From 15a7105967289fe6755625450b1a475f22ca156b Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 09:57:41 +0200 Subject: [PATCH 10/21] cont : thread counters style --- ggml/src/ggml-metal.metal | 281 +++++++++++++++++++------------------- 1 file changed, 141 insertions(+), 140 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 3d3cd769a..f8591239a 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1631,9 +1631,9 @@ void mul_vec_q_n_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK4_0; const int r0 = tgpig.x; @@ -1706,9 +1706,9 @@ kernel void kernel_mul_mv_q4_0_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1717,9 +1717,9 @@ kernel void kernel_mul_mv_q4_1_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1728,9 +1728,9 @@ kernel void kernel_mul_mv_q5_0_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1739,9 +1739,9 @@ kernel void kernel_mul_mv_q5_1_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1754,9 +1754,9 @@ void kernel_mul_mv_q8_0_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nr = N_DST; const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; @@ -1766,7 +1766,7 @@ void kernel_mul_mv_q8_0_f32_impl( const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr; + const int first_row = (r0*nsg + sgitg)*nr; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -1786,12 +1786,12 @@ void kernel_mul_mv_q8_0_f32_impl( } float yl[NB_Q8_0]; - float sumf[nr]={0.f}; + float sumf[nr] = { 0.f }; const int ix = tiisg/4; const int il = tiisg%4; - device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; + device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; // each thread in a SIMD group deals with NB_Q8_0 quants at a time for (int ib = ix; ib < nb; ib += nw/4) { @@ -1800,7 +1800,7 @@ void kernel_mul_mv_q8_0_f32_impl( } for (int row = 0; row < nr; row++) { - device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il; + device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; float sumq = 0.f; for (int iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; @@ -1808,13 +1808,14 @@ void kernel_mul_mv_q8_0_f32_impl( sumf[row] += sumq*ax[row][ib].d; } - yb += NB_Q8_0 * nw; + yb += nw*NB_Q8_0; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); + if (tiisg == 0 && first_row + row < args.ne01) { dst_f32[first_row + row] = tot; } @@ -1827,9 +1828,9 @@ kernel void kernel_mul_mv_q8_0_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -1841,8 +1842,8 @@ void kernel_mul_mv_impl( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig, - uint tiisg) { + uint3 tgpig, + ushort tiisg) { const int64_t r0 = tgpig.x; const int64_t rb = tgpig.y*N_MV_T_T; const int64_t im = tgpig.z; @@ -1910,8 +1911,8 @@ kernel void kernel_mul_mv( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { kernel_mul_mv_impl( args, src0, @@ -1937,8 +1938,8 @@ kernel void kernel_mul_mv_1row( device const char * src0, device const char * src1, device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; @@ -1993,8 +1994,8 @@ kernel void kernel_mul_mv_l4( device const char * src0, device const char * src1, device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { const int nrows = args.ne11; const int64_t r0 = tgpig.x; @@ -3941,9 +3942,9 @@ void kernel_mul_mv_q2_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4032,9 +4033,9 @@ kernel void kernel_mul_mv_q2_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4046,9 +4047,9 @@ void kernel_mul_mv_q3_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; @@ -4195,9 +4196,9 @@ kernel void kernel_mul_mv_q3_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4209,9 +4210,9 @@ void kernel_mul_mv_q4_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; @@ -4313,9 +4314,9 @@ kernel void kernel_mul_mv_q4_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4327,9 +4328,9 @@ void kernel_mul_mv_q5_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; @@ -4445,9 +4446,9 @@ kernel void kernel_mul_mv_q5_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4459,9 +4460,9 @@ void kernel_mul_mv_q6_K_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; @@ -4536,9 +4537,9 @@ kernel void kernel_mul_mv_q6_K_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -4552,9 +4553,9 @@ void kernel_mul_mv_iq2_xxs_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4647,9 +4648,9 @@ kernel void kernel_mul_mv_iq2_xxs_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -4660,9 +4661,9 @@ void kernel_mul_mv_iq2_xs_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4765,9 +4766,9 @@ kernel void kernel_mul_mv_iq2_xs_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -4779,9 +4780,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4877,9 +4878,9 @@ kernel void kernel_mul_mv_iq3_xxs_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -4891,9 +4892,9 @@ void kernel_mul_mv_iq3_s_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -4989,9 +4990,9 @@ kernel void kernel_mul_mv_iq3_s_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -5003,9 +5004,9 @@ void kernel_mul_mv_iq2_s_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -5102,9 +5103,9 @@ kernel void kernel_mul_mv_iq2_s_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -5116,9 +5117,9 @@ void kernel_mul_mv_iq1_s_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -5202,9 +5203,9 @@ void kernel_mul_mv_iq1_m_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nb = args.ne00/QK_K; const int r0 = tgpig.x; @@ -5297,9 +5298,9 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK4_NL; @@ -5387,9 +5388,9 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiisg, + ushort sgitg) { threadgroup float * shmem_f32 = (threadgroup float *) shmem; const int nb = args.ne00/QK_K; @@ -5477,9 +5478,9 @@ kernel void kernel_mul_mv_iq1_s_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -5490,9 +5491,9 @@ kernel void kernel_mul_mv_iq1_m_f32( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } @@ -5504,9 +5505,9 @@ kernel void kernel_mul_mv_iq4_nl_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -5518,9 +5519,9 @@ kernel void kernel_mul_mv_iq4_xs_f32( device const char * src1, device char * dst, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -5631,13 +5632,13 @@ kernel void kernel_mul_mm( device const char * src0, device const char * src1, device char * dst, - threadgroup char * shared_memory [[threadgroup(0)]], + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiitg[[thread_index_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup T * sa = (threadgroup T *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + threadgroup T * sa = (threadgroup T *)(shmem); + threadgroup float * sb = (threadgroup float *)(shmem + 4096); const int r0 = tgpig.y; const int r1 = tgpig.x; @@ -5732,7 +5733,7 @@ kernel void kernel_mul_mm( } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *) shared_memory) \ + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M; for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); @@ -6066,8 +6067,8 @@ typedef void (kernel_mul_mv_impl_t)( device const char * src0, device const char * src1, device char * dst, - uint3 tgpig, - uint tiisg); + uint3 tgpig, + ushort tiisg); typedef void (kernel_mul_mv2_impl_t)( ggml_metal_kargs_mul_mv args, @@ -6075,9 +6076,9 @@ typedef void (kernel_mul_mv2_impl_t)( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiisg, - uint sgitg); + uint3 tgpig, + ushort tiisg, + ushort sgitg); template void mmv_fn( @@ -6086,10 +6087,10 @@ void mmv_fn( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiitg, + ushort tiisg, + ushort sgitg) { impl_fn(args, src0, src1, dst, tgpig, tiisg); } @@ -6100,10 +6101,10 @@ void mmv_fn( device const char * src1, device char * dst, threadgroup char * shmem, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { + uint3 tgpig, + ushort tiitg, + ushort tiisg, + ushort sgitg) { impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } @@ -6117,10 +6118,10 @@ kernel void kernel_mul_mv_id( device char * dst, device const char * ids, threadgroup char * shmem [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const int iid1 = tgpig.z/args.nei0; const int idx = tgpig.z%args.nei0; From c5cf1d74f03330bcf1b9a76f5d60f0125c00cc05 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 10:32:15 +0200 Subject: [PATCH 11/21] cont : mul mm id ggml-ci --- ggml/src/ggml-common.h | 18 +++++ ggml/src/ggml-metal.m | 43 ++++++----- ggml/src/ggml-metal.metal | 151 +++++++++++++++++++------------------- 3 files changed, 118 insertions(+), 94 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index fcf0d997c..a20703269 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -509,6 +509,24 @@ typedef struct { int16_t r3; } ggml_metal_kargs_mul_mv; +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; +} ggml_metal_kargs_mul_mm_id; + typedef struct { int32_t nei0; int32_t nei1; diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index d3497233e..4c83aa7b7 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -2297,27 +2297,30 @@ static void ggml_metal_encode_node( default: GGML_ABORT("MUL_MAT_ID not implemented"); } + ggml_metal_kargs_mul_mm_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index f8591239a..561476551 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -5764,31 +5764,32 @@ kernel void kernel_mul_mm( } // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids +// TODO: this kernel needs to be reimplemented from scratch for better performance template void kernel_mul_mm_id_impl( - device const uchar * src0, - device const uchar * src1, + int32_t ne00, + int32_t ne02, + uint64_t nb01, + uint64_t nb02, + int32_t ne11, + int32_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int32_t ne0, + int32_t ne1, + int64_t ne0ne1, + device const char * src0, + device const char * src1, threadgroup ushort2 * rowids, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - int64_t ne1, - int64_t ne0ne1, - threadgroup uchar * shared_memory, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + device char * dst, + threadgroup char * shmem, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + threadgroup half * sa = (threadgroup half *)(shmem); + threadgroup float * sb = (threadgroup float *)(shmem + 4096); const uint r0 = tgpig.y; const uint r1 = tgpig.x; @@ -5805,9 +5806,9 @@ void kernel_mul_mm_id_impl( simdgroup_half8x8 ma[4]; simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; + simdgroup_float8x8 mc[8]; for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); + mc[i] = make_filled_simdgroup_matrix(0.f); } short il = (tiitg % THREAD_PER_ROW); @@ -5845,11 +5846,14 @@ void kernel_mul_mm_id_impl( threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + #pragma unroll(BLOCK_SIZE_K/8) for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) for (int i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) for (int i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } @@ -5857,29 +5861,42 @@ void kernel_mul_mm_id_impl( lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + #pragma unroll(8) for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); } } } { threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); } threadgroup_barrier(mem_flags::mem_threadgroup); - device float * C = dst + (BLOCK_SIZE_M * r0); if (sgitg == 0) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; - int joff = jid[0] * ne0 + jid[1] * ne0ne1; - for (int i = 0; i < n_rows; i++) { - *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); + int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1; + + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = 0; + for (; i < n_rows/4; i++) { + *(D4 + i) = *(C4 + i); + } + + i *= 4; + for (; i < n_rows; i++) { + *(D + i) = *(C + i); } } } @@ -5888,48 +5905,34 @@ void kernel_mul_mm_id_impl( template kernel void kernel_mul_mm_id( - device const uchar * src0s, - device const uchar * src1, - device float * dst, - device const uchar * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mm_id & args, + device const char * src0s, + device const char * src1, + device char * dst, + device const char * ids, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const int32_t i02 = tgpig.z; + tgpig.z = 0; - device const uchar * src0 = src0s + i02*nb02; + device const char * src0 = src0s + i02*args.nb02; // row indices - threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192); // TODO: parallelize this loop int64_t _ne1 = 0; - for (ushort ii1 = 0; ii1 < nei1; ii1++) { - for (ushort ii0 = 0; ii0 < nei0; ii0++) { - int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + for (ushort ii1 = 0; ii1 < args.nei1; ii1++) { + for (ushort ii0 = 0; ii0 < args.nei0; ii0++) { + int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0]; if (id == i02) { - //if (tiitg == 0) { + if (tiitg == 0) { rowids[_ne1] = ushort2(ii0, ii1); - //} + } _ne1++; } } @@ -5938,23 +5941,23 @@ kernel void kernel_mul_mm_id( threadgroup_barrier(mem_flags::mem_threadgroup); kernel_mul_mm_id_impl( + args.ne00, + args.ne02, + args.nb01, + args.nb02, + args.ne11, + args.ne12, + args.nb10, + args.nb11, + args.nb12, + args.ne0, + _ne1, + (int64_t)args.ne0*args.ne1, src0, src1, rowids, dst, - ne00, - ne02, - nb01, - nb02, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - _ne1, - ne0*ne1, - shared_memory, + shmem, tgpig, tiitg, sgitg); From bb821e48548d1ffbe5d5235d87650d5acd1d78fe Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 11:05:10 +0200 Subject: [PATCH 12/21] cont : int safety + register optimizations ggml-ci --- ggml/src/ggml-metal.metal | 226 +++++++++++++++++++------------------- 1 file changed, 113 insertions(+), 113 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 561476551..d86458b1a 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1645,8 +1645,8 @@ void mul_vec_q_n_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -1654,7 +1654,7 @@ void mul_vec_q_n_f32_impl( // pointers to src0 rows device const block_q_type * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } @@ -1662,10 +1662,10 @@ void mul_vec_q_n_f32_impl( float yl[16]; // src1 vector cache float sumf[nr] = {0.f}; - const int ix = (tiisg/2); - const int il = (tiisg%2)*8; + const short ix = (tiisg/2); + const short il = (tiisg%2)*8; - device const float * yb = y + ix * QK4_0 + il; + device const float * yb = y + ix*QK4_0 + il; // each thread in a SIMD group deals with half a block. for (int ib = ix; ib < nb; ib += nw/2) { @@ -1771,8 +1771,8 @@ void kernel_mul_mv_q8_0_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - //const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -1780,7 +1780,7 @@ void kernel_mul_mv_q8_0_f32_impl( // pointers to src0 rows device const block_q8_0 * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } @@ -1788,21 +1788,21 @@ void kernel_mul_mv_q8_0_f32_impl( float yl[NB_Q8_0]; float sumf[nr] = { 0.f }; - const int ix = tiisg/4; - const int il = tiisg%4; + const short ix = tiisg/4; + const short il = tiisg%4; device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; // each thread in a SIMD group deals with NB_Q8_0 quants at a time for (int ib = ix; ib < nb; ib += nw/4) { - for (int i = 0; i < NB_Q8_0; ++i) { + for (short i = 0; i < NB_Q8_0; ++i) { yl[i] = yb[i]; } for (int row = 0; row < nr; row++) { device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; float sumq = 0.f; - for (int iq = 0; iq < NB_Q8_0; ++iq) { + for (short iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } sumf[row] += sumq*ax[row][ib].d; @@ -1811,7 +1811,7 @@ void kernel_mul_mv_q8_0_f32_impl( yb += nw*NB_Q8_0; } - device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); @@ -1844,18 +1844,18 @@ void kernel_mul_mv_impl( device char * dst, uint3 tgpig, ushort tiisg) { - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_MV_T_T; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int rb = tgpig.y*N_MV_T_T; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); - device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; if (args.ne00 < 128) { for (int row = 0; row < N_MV_T_T; ++row) { @@ -1864,7 +1864,7 @@ void kernel_mul_mv_impl( break; } - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); @@ -1875,7 +1875,7 @@ void kernel_mul_mv_impl( float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst_f32[r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } else { @@ -1886,20 +1886,20 @@ void kernel_mul_mv_impl( break; } - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); device const T14 * y4 = (device const T14 *) y; float sumf = 0; for (int i = tiisg; i < args.ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + sumf += dot((T14) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst_f32[r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } @@ -1935,25 +1935,27 @@ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv< template kernel void kernel_mul_mv_1row( constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]]) { - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T * x = (device const T *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + float sumf = 0; if (args.ne00 < 128) { for (int i = tiisg; i < args.ne00; i += 32) { @@ -1961,21 +1963,21 @@ kernel void kernel_mul_mv_1row( } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[r0] = all_sum; } } else { device const T4 * x4 = (device const T4 *) x; device const float4 * y4 = (device const float4 *) y; for (int i = tiisg; i < args.ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + sumf += dot((float4) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[r0] = all_sum; } } } @@ -1991,36 +1993,38 @@ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kerne template kernel void kernel_mul_mv_l4( constant ggml_metal_kargs_mul_mv & args, - device const char * src0, - device const char * src1, - device float * dst, + device const char * src0, + device const char * src1, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]]) { const int nrows = args.ne11; - const int64_t r0 = tgpig.x; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int im = tgpig.z; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T4 * x4 = (device const T4 *) (src0 + offset0); + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + for (int r1 = 0; r1 < nrows; ++r1) { - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const float4 * y4 = (device const float4 *) (src1 + offset1); float sumf = 0; for (int i = tiisg; i < args.ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + sumf += dot((float4) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } @@ -2969,7 +2973,7 @@ kernel void kernel_flash_attn_ext( const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; } } } @@ -3361,7 +3365,7 @@ kernel void kernel_flash_attn_ext_vec( const float S = ss[0]; for (short i = tiisg; i < D16; i += NW) { - dst44[((int64_t)iq3*args.ne2*args.ne1 + iq2 + (int64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; } } } @@ -3956,8 +3960,8 @@ void kernel_mul_mv_q2_K_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4017,7 +4021,7 @@ void kernel_mul_mv_q2_K_f32_impl( y4 += 4 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4053,17 +4057,17 @@ void kernel_mul_mv_q3_K_f32_impl( const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); @@ -4096,9 +4100,10 @@ void kernel_mul_mv_q3_K_f32_impl( const ushort4 hm = mm[2*ip + il/2]; - const int shift = 2*il; - const float v1 = il == 0 ? 4.f : 64.f; - const float v2 = 4.f * v1; + const short shift = 2*il; + + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + il; @@ -4181,7 +4186,7 @@ void kernel_mul_mv_q3_K_f32_impl( sumf1[row] = simd_sum(sumf); } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; if (tiisg == 0) { for (int row = 0; row < 2; ++row) { @@ -4233,8 +4238,8 @@ void kernel_mul_mv_q4_K_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4298,7 +4303,7 @@ void kernel_mul_mv_q4_K_f32_impl( y4 += 4 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4334,8 +4339,8 @@ void kernel_mul_mv_q5_K_f32_impl( const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; + const int r0 = tgpig.x; + const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; @@ -4343,8 +4348,8 @@ void kernel_mul_mv_q5_K_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); @@ -4430,7 +4435,7 @@ void kernel_mul_mv_q5_K_f32_impl( y1 += 4 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); @@ -4471,17 +4476,17 @@ void kernel_mul_mv_q6_K_f32_impl( const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; - const int row = 2 * r0 + sgitg; + const int row = 2*r0 + sgitg; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); device const float * yy = (device const float *) (src1 + offset1); @@ -4501,7 +4506,6 @@ void kernel_mul_mv_q6_K_f32_impl( const int q_offset_h = 32*ip + l0; for (int i = ix; i < nb; i += 2) { - device const uint8_t * q1 = x[i].ql + q_offset_l; device const uint8_t * q2 = q1 + 32; device const uint8_t * qh = x[i].qh + q_offset_h; @@ -4523,7 +4527,7 @@ void kernel_mul_mv_q6_K_f32_impl( } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; const float tot = simd_sum(sumf); if (tiisg == 0) { @@ -4567,8 +4571,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4631,7 +4635,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4675,8 +4679,8 @@ void kernel_mul_mv_iq2_xs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4749,7 +4753,7 @@ void kernel_mul_mv_iq2_xs_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4794,8 +4798,8 @@ void kernel_mul_mv_iq3_xxs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4822,7 +4826,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -4836,7 +4839,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const half * dh = &xr->d; for (int row = 0; row < N_DST; row++) { - const float db = dh[0]; const uint32_t aux32 = gas[0] | (gas[1] << 16); const float d = db * (0.5f + (aux32 >> 28)); @@ -4861,7 +4863,7 @@ void kernel_mul_mv_iq3_xxs_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -4906,8 +4908,8 @@ void kernel_mul_mv_iq3_s_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -4973,7 +4975,7 @@ void kernel_mul_mv_iq3_s_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5018,8 +5020,8 @@ void kernel_mul_mv_iq2_s_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5086,7 +5088,7 @@ void kernel_mul_mv_iq2_s_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5131,8 +5133,8 @@ void kernel_mul_mv_iq1_s_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5186,7 +5188,7 @@ void kernel_mul_mv_iq1_s_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5217,8 +5219,8 @@ void kernel_mul_mv_iq1_m_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5281,7 +5283,7 @@ void kernel_mul_mv_iq1_m_f32_impl( y4 += 32 * 32; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); @@ -5312,8 +5314,8 @@ void kernel_mul_mv_iq4_nl_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5371,7 +5373,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( yb += 16 * QK4_NL; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { all_sum = simd_sum(sumf[row]); @@ -5402,8 +5404,8 @@ void kernel_mul_mv_iq4_xs_f32_impl( const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; - const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); @@ -5427,25 +5429,23 @@ void kernel_mul_mv_iq4_xs_f32_impl( float4 qf1, qf2; for (int ibl = ix; ibl < nb; ibl += 2) { - device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; for (int row = 0; row < 2; ++row) { - device const block_iq4_xs & xb = x[row*nb + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); float4 acc1 = {0.f}, acc2 = {0.f}; - aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[0] = (q4[0] ) & 0x0f0f0f0f; aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; - aux32[0] = q4[1] & 0x0f0f0f0f; + aux32[0] = (q4[1] ) & 0x0f0f0f0f; aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; @@ -5462,7 +5462,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( yb += 2 * QK_K; } - device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; for (int row = 0; row < 2; ++row) { all_sum = simd_sum(sumf[row]); @@ -5665,8 +5665,8 @@ kernel void kernel_mul_mm( const int i12 = im%args.ne12; const int i13 = im/args.ne12; - int offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - short offset1 = il/nl; + uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + short offset1 = il/nl; device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1; device const float * y = (device const float *)(src1 @@ -5791,10 +5791,10 @@ void kernel_mul_mm_id_impl( threadgroup half * sa = (threadgroup half *)(shmem); threadgroup float * sb = (threadgroup float *)(shmem + 4096); - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; + const int r0 = tgpig.y; + const int r1 = tgpig.x; - if (r1 * BLOCK_SIZE_N >= ne1) return; + if (r1*BLOCK_SIZE_N >= ne1) return; // if this block is of 64x32 shape or smaller short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; @@ -5925,7 +5925,7 @@ kernel void kernel_mul_mm_id( threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192); // TODO: parallelize this loop - int64_t _ne1 = 0; + int32_t _ne1 = 0; for (ushort ii1 = 0; ii1 < args.nei1; ii1++) { for (ushort ii0 = 0; ii0 < args.nei0; ii0++) { int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0]; From 9058c51d9dcae3139ab546aace36f41f02b71681 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 13:03:25 +0200 Subject: [PATCH 13/21] metal : GGML_OP_CONCAT ggml-ci --- ggml/src/ggml-common.h | 28 ++++++++++++++++++ ggml/src/ggml-metal.m | 60 +++++++++++++++++++++------------------ ggml/src/ggml-metal.metal | 54 ++++++++++------------------------- 3 files changed, 75 insertions(+), 67 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index a20703269..cd54c18ed 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -419,6 +419,34 @@ typedef struct { static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); #if defined(GGML_COMMON_DECL_METAL_KARGS) +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t dim; +} ggml_metal_kargs_concat; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 4c83aa7b7..2058ea25e 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1193,35 +1193,39 @@ static void ggml_metal_encode_node( const int32_t dim = ((const int32_t *) dst->op_params)[0]; + ggml_metal_kargs_concat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.dim =*/ dim, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; const int nth = MIN(1024, ne0); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index d86458b1a..1c7bef770 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1893,7 +1893,7 @@ void kernel_mul_mv_impl( float sumf = 0; for (int i = tiisg; i < args.ne00/4; i += 32) { - sumf += dot((T14) x4[i], y4[i]); + sumf += dot((float4) x4[i], (float4) y4[i]); } float all_sum = simd_sum(sumf); @@ -3885,55 +3885,31 @@ kernel void kernel_cpy_f32_iq4_nl( } kernel void kernel_concat( + constant ggml_metal_kargs_concat & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int32_t & dim, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; - int64_t o[4] = {0, 0, 0, 0}; - o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); + int o[4] = {0, 0, 0, 0}; + o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03)); device const float * x; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00); } else { - x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10); } - device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); *y = *x; } From 3250c98bf69ac43a33bf958676ef4ac9a7c89ee1 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 13:16:54 +0200 Subject: [PATCH 14/21] metal : GGML_OP_ADD, GGML_OP_SUB, GGML_OP_MUL, GGML_OP_DIV --- ggml/src/ggml-common.h | 28 +++++ ggml/src/ggml-metal.m | 124 ++++++++++---------- ggml/src/ggml-metal.metal | 234 ++++++++++++-------------------------- 3 files changed, 164 insertions(+), 222 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index cd54c18ed..6d7e07ee6 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -447,6 +447,34 @@ typedef struct { int32_t dim; } ggml_metal_kargs_concat; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; +} ggml_metal_kargs_bin; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 2058ea25e..8a96e09ba 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1243,8 +1243,6 @@ static void ggml_metal_encode_node( bool bcast_row = false; - int64_t nb = ne00; // used by the "row" kernels - id pipeline = nil; if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { @@ -1253,7 +1251,6 @@ static void ggml_metal_encode_node( // src1 is a row GGML_ASSERT(ne11 == 1); - nb = ne00 / 4; switch (dst->op) { case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; @@ -1273,36 +1270,39 @@ static void ggml_metal_encode_node( } } + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.offs =*/ offs, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; if (bcast_row) { const int64_t n = ggml_nelements(dst)/4; @@ -1400,35 +1400,39 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ pnb1, + /*.nb02 =*/ pnb2, + /*.nb03 =*/ pnb3, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + /*.offs =*/ offs, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 1c7bef770..79e0f07c4 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -493,200 +493,106 @@ enum ggml_sort_order { // pros: works for non-contiguous tensors, supports broadcast across all dims // cons: not very efficient kernel void kernel_add( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_sub( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_mul( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_div( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); } } @@ -740,38 +646,42 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat // assumption: src1 is a row // broadcast src1 into src0 kernel void kernel_add_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] + src1[tpig % nb]; } kernel void kernel_sub_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] - src1[tpig % nb]; } kernel void kernel_mul_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] * src1[tpig % nb]; } kernel void kernel_div_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] / src1[tpig % nb]; } From f46f710ca633c0f8dea546b5fb3ee99847cc3200 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 13:21:59 +0200 Subject: [PATCH 15/21] metal : GGML_OP_REPEAT --- ggml/src/ggml-common.h | 19 +++++++++++++++++ ggml/src/ggml-metal.m | 40 ++++++++++++++++++---------------- ggml/src/ggml-metal.metal | 45 +++++++++++++-------------------------- 3 files changed, 56 insertions(+), 48 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 6d7e07ee6..a0720e982 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -475,6 +475,25 @@ typedef struct { uint64_t offs; } ggml_metal_kargs_bin; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_repeat; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 8a96e09ba..128376a50 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1326,25 +1326,29 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); } + ggml_metal_kargs_repeat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 79e0f07c4..eeb20f18d 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -598,41 +598,26 @@ kernel void kernel_div( template kernel void kernel_repeat( + constant ggml_metal_kargs_repeat & args, device const char * src0, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; - const int64_t i03 = i3 % ne03; - const int64_t i02 = i2 % ne02; - const int64_t i01 = i1 % ne01; + const int i03 = i3%args.ne03; + const int i02 = i2%args.ne02; + const int i01 = i1%args.ne01; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i00 = i0 % ne00; - *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i00 = i0%args.ne00; + *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00)); } } From 647a7044f5d14837958624ffc0f11d2f858b7910 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 13:55:26 +0200 Subject: [PATCH 16/21] metal : GGML_OP_CPY --- ggml/src/ggml-common.h | 19 +++ ggml/src/ggml-metal.m | 80 +++++---- ggml/src/ggml-metal.metal | 344 +++++++++++++------------------------- 3 files changed, 182 insertions(+), 261 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index a0720e982..2d45311a3 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -494,6 +494,25 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_repeat; +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_cpy; + typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 128376a50..c2542149a 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1377,25 +1377,29 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; + ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); @@ -3425,25 +3429,29 @@ static void ggml_metal_encode_node( default: GGML_ABORT("not implemented"); } + ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index eeb20f18d..14526b214 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -3302,42 +3302,27 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ template kernel void kernel_cpy( - device const void * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); - device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { + device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; } } @@ -3357,42 +3342,27 @@ template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy 0 ? sumqx/sumq2 : d; - } } From e9ecd5d4defe54218b9f929d315f3a00022daec4 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 15:31:43 +0200 Subject: [PATCH 17/21] metal : GGML_OP_RMS_NORM --- ggml/src/ggml-common.h | 7 +++++ ggml/src/ggml-metal.m | 22 ++++++++----- ggml/src/ggml-metal.metal | 65 ++++++++++++++++++--------------------- 3 files changed, 52 insertions(+), 42 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 2d45311a3..86f4753b1 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -642,6 +642,13 @@ typedef struct { int32_t ne1; uint64_t nb1; } ggml_metal_kargs_mul_mv_id; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_rms_norm; #endif #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index c2542149a..f41f48b86 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -2618,20 +2618,28 @@ static void ggml_metal_encode_node( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + int nth = 32; // SIMD width - while (nth < ne00/4 && nth < 1024) { + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { nth *= 2; } - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_rms_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index 14526b214..a9d0bf334 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1288,50 +1288,45 @@ kernel void kernel_norm( } kernel void kernel_rms_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + constant ggml_metal_kargs_rms_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } - float4 sumf = 0; - float all_sum = 0; + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float sumf = 0.0f; // parallel sum - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - sumf += x[i00] * x[i00]; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf += dot(x[i00], x[i00]); } - all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; - all_sum = simd_sum(all_sum); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } + sumf = simd_sum(sumf); - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - if (tiisg == 0) { - buf[sgitg] = all_sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - all_sum = buf[tiisg]; - all_sum = simd_sum(all_sum); + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; } - const float mean = all_sum/ne00; - const float scale = 1.0f/sqrt(mean + eps); + threadgroup_barrier(mem_flags::mem_threadgroup); - device float4 * y = (device float4 *) (dst + tgpig*ne00); - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float mean = sumf/args.ne00; + const float scale = 1.0f/sqrt(mean + args.eps); + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { y[i00] = x[i00] * scale; } } From 964206a780695059475160ef89c0a625cc55d065 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 17:17:18 +0200 Subject: [PATCH 18/21] metal : GGML_OP_NORM --- ggml/src/ggml-common.h | 7 +++ ggml/src/ggml-metal.m | 29 ++++++++---- ggml/src/ggml-metal.metal | 93 +++++++++++++++++++++++---------------- 3 files changed, 82 insertions(+), 47 deletions(-) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index 86f4753b1..d25100693 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -643,6 +643,13 @@ typedef struct { uint64_t nb1; } ggml_metal_kargs_mul_mv_id; +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_norm; + typedef struct { int32_t ne00; int32_t ne00_4; diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index f41f48b86..77bc936ea 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -2681,22 +2681,35 @@ static void ggml_metal_encode_node( } break; case GGML_OP_NORM: { + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ggml_is_contiguous_1(src0)); float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int nth = MIN(256, ne00); - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index a9d0bf334..c7acb49da 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1236,53 +1236,68 @@ kernel void kernel_ssm_scan_f32( } kernel void kernel_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * sum [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); - // MEAN - // parallel sum - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - sum[tpitg] += x[i00]; + constant ggml_metal_kargs_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; } - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float mean = sum[0] / ne00; - // recenter and VARIANCE + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float4 sumf4(0.0f); + + float sumf = 0.0f; + + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf4 += x[i00]; + } + sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3]; + sumf = simd_sum(sumf); + threadgroup_barrier(mem_flags::mem_threadgroup); - device float * y = dst + tgpig*ne00; - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float mean = sumf/args.ne00; + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + + sumf = 0.0f; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { y[i00] = x[i00] - mean; - sum[tpitg] += y[i00] * y[i00]; + sumf += dot(y[i00], y[i00]); } + sumf = simd_sum(sumf); - // reduce threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float variance = sum[0] / ne00; - const float scale = 1.0f/sqrt(variance + eps); - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float variance = sumf/args.ne00; + + const float scale = 1.0f/sqrt(variance + args.eps); + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { y[i00] = y[i00] * scale; } } From 63bab93c48acc2464f327f2d618c592d1a8b7c50 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 17:56:12 +0200 Subject: [PATCH 19/21] metal : add TODOs for rest of ops --- ggml/src/ggml-metal.m | 25 +++++++++++++++++++++---- 1 file changed, 21 insertions(+), 4 deletions(-) diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 77bc936ea..6af5ef2d1 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -1481,10 +1481,10 @@ static void ggml_metal_encode_node( memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float)); [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&min length:sizeof(min) atIndex:2]; - [encoder setBytes:&max length:sizeof(max) atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; const int64_t n = ggml_nelements(dst); @@ -1656,6 +1656,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1731,6 +1732,8 @@ static void ggml_metal_encode_node( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + // TODO: add ggml_metal_kargs struct + // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6) [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; if (id_src1) { @@ -1747,6 +1750,7 @@ static void ggml_metal_encode_node( [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; @@ -1763,6 +1767,7 @@ static void ggml_metal_encode_node( pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; } + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1787,6 +1792,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1857,6 +1863,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -2595,6 +2602,7 @@ static void ggml_metal_encode_node( default: GGML_ABORT("not implemented"); } + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -2664,6 +2672,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2853,6 +2862,7 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); }; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2893,6 +2903,7 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2927,6 +2938,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2963,6 +2975,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; @@ -2984,6 +2997,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -3022,6 +3036,7 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); }; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -3040,6 +3055,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -3517,6 +3533,7 @@ static void ggml_metal_encode_node( const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; From 86ed72d20cfe4119a5a57846273c1f0d23127c6e Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 10 Nov 2024 18:29:09 +0200 Subject: [PATCH 20/21] ggml : add ggml-metal-impl.h ggml-ci --- Makefile | 7 +- ggml/src/CMakeLists.txt | 21 ++-- ggml/src/ggml-common.h | 240 ----------------------------------- ggml/src/ggml-metal-impl.h | 249 +++++++++++++++++++++++++++++++++++++ ggml/src/ggml-metal.m | 5 +- ggml/src/ggml-metal.metal | 2 +- 6 files changed, 269 insertions(+), 255 deletions(-) create mode 100644 ggml/src/ggml-metal-impl.h diff --git a/Makefile b/Makefile index dfa32d516..e92335c7c 100644 --- a/Makefile +++ b/Makefile @@ -901,9 +901,12 @@ ggml/src/ggml-metal.o: \ ifdef GGML_METAL_EMBED_LIBRARY ggml/src/ggml-metal-embed.o: \ ggml/src/ggml-metal.metal \ - ggml/src/ggml-common.h + ggml/src/ggml-common.h \ + ggml/src/ggml-metal-impl.h @echo "Embedding Metal library" - @sed -e '/#include "ggml-common.h"/r ggml/src/ggml-common.h' -e '/#include "ggml-common.h"/d' < ggml/src/ggml-metal.metal > ggml/src/ggml-metal-embed.metal + @sed -e '/#include "ggml-common.h"/r ggml/src/ggml-common.h' -e '/#include "ggml-common.h"/d' < ggml/src/ggml-metal.metal > ggml/src/ggml-metal-embed.metal.tmp + @sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal-embed.metal + $(eval TEMP_ASSEMBLY=$(shell mktemp -d)) @echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s @echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index a05f8c505..670389569 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -62,9 +62,10 @@ if (GGML_METAL) add_compile_definitions(GGML_METAL_USE_BF16) endif() - # copy ggml-common.h and ggml-metal.metal to bin directory - configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) - configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) + # copy metal files to bin directory + configure_file(ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) + configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) + configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) if (GGML_METAL_EMBED_LIBRARY) enable_language(ASM) @@ -72,25 +73,28 @@ if (GGML_METAL) add_compile_definitions(GGML_METAL_EMBED_LIBRARY) set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/ggml-common.h") + set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h") set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") # merge ggml-common.h and ggml-metal.metal into a single file - set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") - set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") + set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") add_custom_command( OUTPUT ${METALLIB_EMBED_ASM} COMMAND echo "Embedding Metal library" - COMMAND sed -e '/\#include \"ggml-common.h\"/r ${METALLIB_COMMON}' -e '/\#include \"ggml-common.h\"/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED} + COMMAND sed -e '/\#include \"ggml-common.h\"/r ${METALLIB_COMMON}' -e '/\#include \"ggml-common.h\"/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED_TMP} + COMMAND sed -e '/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}' -e '/\#include \"ggml-metal-impl.h\"/d' < ${METALLIB_SOURCE_EMBED_TMP} > ${METALLIB_SOURCE_EMBED} COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM} COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM} COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM} COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM} COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM} COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM} - DEPENDS ggml-metal.metal ggml-common.h + DEPENDS ggml-metal.metal ggml-common.h ggml-metal-impl.h COMMENT "Generate assembly for embedded Metal library" ) @@ -128,8 +132,9 @@ if (GGML_METAL) COMMAND xcrun -sdk macosx metallib ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air -o ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/default.metallib COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.air COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h + COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COMMAND rm -f ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal - DEPENDS ggml-metal.metal ggml-common.h + DEPENDS ggml-metal.metal ggml-common.h ggml-metal-impl.h COMMENT "Compiling Metal kernels" ) diff --git a/ggml/src/ggml-common.h b/ggml/src/ggml-common.h index d25100693..050161393 100644 --- a/ggml/src/ggml-common.h +++ b/ggml/src/ggml-common.h @@ -418,246 +418,6 @@ typedef struct { } block_iq4_xs; static_assert(sizeof(block_iq4_xs) == sizeof(ggml_half) + sizeof(uint16_t) + QK_K/64 + QK_K/2, "wrong iq4_xs block size/padding"); -#if defined(GGML_COMMON_DECL_METAL_KARGS) -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne10; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; - int32_t dim; -} ggml_metal_kargs_concat; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne10; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; - uint64_t offs; -} ggml_metal_kargs_bin; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; -} ggml_metal_kargs_repeat; - -typedef struct { - int64_t ne00; - int64_t ne01; - int64_t ne02; - int64_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int64_t ne0; - int64_t ne1; - int64_t ne2; - int64_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; -} ggml_metal_kargs_cpy; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne0; - int32_t ne1; - int32_t ne2; - int32_t ne3; - uint64_t nb0; - uint64_t nb1; - uint64_t nb2; - uint64_t nb3; - int32_t n_past; - int32_t n_dims; - int32_t n_ctx_orig; - float freq_base; - float freq_scale; - float ext_factor; - float attn_factor; - float beta_fast; - float beta_slow; -} ggml_metal_kargs_rope; - -typedef struct { - int32_t ne01; - int32_t ne02; - int32_t ne03; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne11; - int32_t ne_12_2; // assume K and V are same shape - int32_t ne_12_3; - uint64_t nb_12_1; - uint64_t nb_12_2; - uint64_t nb_12_3; - uint64_t nb31; - int32_t ne1; - int32_t ne2; - float scale; - float max_bias; - float m0; - float m1; - uint16_t n_head_log2; - float logit_softcap; -} ggml_metal_kargs_flash_attn_ext; - -typedef struct { - int32_t ne00; - int32_t ne02; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne12; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int16_t r2; - int16_t r3; -} ggml_metal_kargs_mul_mm; - -typedef struct { - int32_t ne00; - int32_t ne01; - int32_t ne02; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - uint64_t nb03; - int32_t ne10; - int32_t ne11; - int32_t ne12; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - uint64_t nb13; - int32_t ne0; - int32_t ne1; - int16_t r2; - int16_t r3; -} ggml_metal_kargs_mul_mv; - -typedef struct { - int32_t nei0; - int32_t nei1; - uint64_t nbi1; - int32_t ne00; - int32_t ne02; - uint64_t nb01; - uint64_t nb02; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - int32_t ne0; - int32_t ne1; -} ggml_metal_kargs_mul_mm_id; - -typedef struct { - int32_t nei0; - int32_t nei1; - uint64_t nbi1; - int32_t ne00; - int32_t ne01; - int32_t ne02; - uint64_t nb00; - uint64_t nb01; - uint64_t nb02; - int32_t ne10; - int32_t ne11; - int32_t ne12; - int32_t ne13; - uint64_t nb10; - uint64_t nb11; - uint64_t nb12; - int32_t ne0; - int32_t ne1; - uint64_t nb1; -} ggml_metal_kargs_mul_mv_id; - -typedef struct { - int32_t ne00; - int32_t ne00_4; - uint64_t nb01; - float eps; -} ggml_metal_kargs_norm; - -typedef struct { - int32_t ne00; - int32_t ne00_4; - uint64_t nb01; - float eps; -} ggml_metal_kargs_rms_norm; -#endif - #endif // GGML_COMMON_DECL #endif // GGML_COMMON_DECL diff --git a/ggml/src/ggml-metal-impl.h b/ggml/src/ggml-metal-impl.h new file mode 100644 index 000000000..53c135496 --- /dev/null +++ b/ggml/src/ggml-metal-impl.h @@ -0,0 +1,249 @@ +#ifndef GGML_METAL_IMPL +#define GGML_METAL_IMPL + +// kernel argument structs +// +// - element counters (e.g. ne00) typically use int32_t to reduce register usage +// however, be careful from int overflows when using those in the kernel implementation +// +// - strides (e.g. nb00) use uint64_t + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t dim; +} ggml_metal_kargs_concat; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; +} ggml_metal_kargs_bin; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_repeat; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_cpy; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t n_past; + int32_t n_dims; + int32_t n_ctx_orig; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; +} ggml_metal_kargs_rope; + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb_12_1; + uint64_t nb_12_2; + uint64_t nb_12_3; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext; + +typedef struct { + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mm; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mv; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; +} ggml_metal_kargs_mul_mm_id; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; + uint64_t nb1; +} ggml_metal_kargs_mul_mv_id; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_norm; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_rms_norm; + +#endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal.m b/ggml/src/ggml-metal.m index 6af5ef2d1..d9060400f 100644 --- a/ggml/src/ggml-metal.m +++ b/ggml/src/ggml-metal.m @@ -2,10 +2,7 @@ #import "ggml-impl.h" #import "ggml-backend-impl.h" - -#define GGML_COMMON_DECL_C -#define GGML_COMMON_DECL_METAL_KARGS -#include "ggml-common.h" +#import "ggml-metal-impl.h" #import diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index c7acb49da..ff6aff28e 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1,7 +1,7 @@ #define GGML_COMMON_DECL_METAL -#define GGML_COMMON_DECL_METAL_KARGS #define GGML_COMMON_IMPL_METAL #include "ggml-common.h" +#include "ggml-metal-impl.h" #include From 8c1b186cb5eac4abd26b73a8cb2a2248808491e7 Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Tue, 12 Nov 2024 15:30:51 +0200 Subject: [PATCH 21/21] metal : minor Q4_0 optimization --- ggml/src/ggml-metal.metal | 35 +++++++++++++++++------------------ 1 file changed, 17 insertions(+), 18 deletions(-) diff --git a/ggml/src/ggml-metal.metal b/ggml/src/ggml-metal.metal index ff6aff28e..a5b631cf9 100644 --- a/ggml/src/ggml-metal.metal +++ b/ggml/src/ggml-metal.metal @@ -1441,18 +1441,18 @@ kernel void kernel_group_norm( inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; - float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f }; + float acc = -8.0f*sumy; device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2); - for (int i = 0; i < 8; i += 2) { - acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F); - acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00); - acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0); - acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000); + for (short i = 0; i < 4; ++i) { + acc += yl[2*i + 0] * (qs[i] & 0x000F); + acc += yl[2*i + 1] * (qs[i] & 0x0F00); + acc += yl[2*i + 8] * (qs[i] & 0x00F0); + acc += yl[2*i + 9] * (qs[i] & 0xF000); } - return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]); + return d * acc; } // function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i]) @@ -1567,29 +1567,28 @@ void mul_vec_q_n_f32_impl( float yl[16]; // src1 vector cache float sumf[nr] = {0.f}; - const short ix = (tiisg/2); - const short il = (tiisg%2)*8; + const short ix = (tiisg%16); + const short il = (tiisg/16)*8; device const float * yb = y + ix*QK4_0 + il; // each thread in a SIMD group deals with half a block. for (int ib = ix; ib < nb; ib += nw/2) { - float sumy[2] = { 0.f, 0.f }; + float sumy = 0.0f; -#pragma unroll +#pragma unroll(4) for (int i = 0; i < 8; i += 2) { - sumy[0] += yb[i + 0] + yb[i + 1]; + sumy += yb[i + 0] + yb[i + 1] + yb[i + 16] + yb[i + 17]; + yl[i + 0] = yb[i + 0]; yl[i + 1] = yb[i + 1]/256.f; - - sumy[1] += yb[i + 16] + yb[i + 17]; yl[i + 8] = yb[i + 16]/16.f; yl[i + 9] = yb[i + 17]/4096.f; } -#pragma unroll - for (int row = 0; row < nr; row++) { - sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy[0] + sumy[1], yl, il); +#pragma unroll(nr) + for (short row = 0; row < nr; row++) { + sumf[row] += block_q_n_dot_y(ax[row] + ib, sumy, yl, il); } yb += QK4_0 * 16; @@ -1597,7 +1596,7 @@ void mul_vec_q_n_f32_impl( device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; - for (int row = 0; row < nr; ++row) { + for (short row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0 && first_row + row < args.ne01) {