From aa17d321b36e2d8ab912c0b05e05088793221567 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9my=20O?= Date: Sun, 26 Jan 2025 14:50:47 +0100 Subject: [PATCH] vulkan: avoid using workgroup size before it is referenced --- .../vulkan-shaders/copy_from_quant.comp | 2 +- .../vulkan-shaders/copy_to_quant.comp | 2 +- .../vulkan-shaders/dequant_iq2_s.comp | 2 +- .../vulkan-shaders/dequant_iq2_xs.comp | 2 +- .../vulkan-shaders/dequant_iq2_xxs.comp | 2 +- .../vulkan-shaders/dequant_iq3_s.comp | 2 +- .../vulkan-shaders/dequant_iq3_xxs.comp | 2 +- .../vulkan-shaders/dequant_iq4_nl.comp | 2 +- .../vulkan-shaders/flash_attn_cm2.comp | 2 +- .../vulkan-shaders/get_rows_quant.comp | 2 +- .../vulkan-shaders/mul_mat_vec.comp | 2 +- .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 2 +- .../vulkan-shaders/mul_mm_cm2.comp | 2 +- .../src/ggml-vulkan/vulkan-shaders/types.comp | 26 +++++++++---------- 14 files changed, 26 insertions(+), 26 deletions(-) diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp index 4ab50dc1e..aeae5400d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -13,7 +13,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; void main() { #if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) - init_iq_shmem(); + init_iq_shmem(gl_WorkGroupSize); if (gl_LocalInvocationIndex.x != 0) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp index 79d723ea2..d4b068e61 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -218,7 +218,7 @@ void quantize(uint dst_idx, uint src_idx) void main() { #if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) - init_iq_shmem(); + init_iq_shmem(gl_WorkGroupSize); if (gl_LocalInvocationIndex.x != 0) { return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp index a0b29106d..48f6b65bc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp @@ -11,7 +11,7 @@ void main() { // Each thread handles 1 subblock (32 values with 2 scales) const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; - init_iq_shmem(); + init_iq_shmem(gl_WorkGroupSize); if (ib >= p.nel / 256) { return; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp index f6af019d1..a08331c40 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp @@ -11,7 +11,7 @@ void main() { // Each thread handles 1 subblock (32 values with 2 scales) const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; - init_iq_shmem(); + init_iq_shmem(gl_WorkGroupSize); if (ib >= p.nel / 256) { return; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp index 334f9af8d..e370690bc 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp @@ -12,7 +12,7 @@ void main() { // Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; - init_iq_shmem(); + init_iq_shmem(gl_WorkGroupSize); if (ib >= p.nel / 256) { return; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp index 418aaa887..c3f4bca5d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp @@ -12,7 +12,7 @@ void main() { // Each block contains 4 scale bytes (8 scales) for 256 output values. const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; - init_iq_shmem(); + init_iq_shmem(gl_WorkGroupSize); if (ib >= p.nel / 256) { return; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp index 54bc13107..a92b82961 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp @@ -12,7 +12,7 @@ void main() { // 8 threads handle 1 superblock const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; - init_iq_shmem(); + init_iq_shmem(gl_WorkGroupSize); if (ib >= p.nel / 256) { return; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp index c5d05925f..46d9ad15e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp @@ -10,7 +10,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; void main() { const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; - init_iq_shmem(); + init_iq_shmem(gl_WorkGroupSize); const uint tid = gl_LocalInvocationID.x % 64; const uint il = tid/32; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 586704e30..043a53023 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -105,7 +105,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele void main() { #if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) - init_iq_shmem(); + init_iq_shmem(gl_WorkGroupSize); #endif const uint32_t N = p.N; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp index 247c85342..09dc43d8d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -13,7 +13,7 @@ void main() { const uint i12 = (gl_GlobalInvocationID.z)%p.ne12; #if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) - init_iq_shmem(); + init_iq_shmem(gl_WorkGroupSize); #endif if (i00 >= p.ne00) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp index 596f15611..48156e7ba 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -134,7 +134,7 @@ void main() { const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); #if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) - init_iq_shmem(); + init_iq_shmem(gl_WorkGroupSize); #endif // do NUM_ROWS at a time, unless there aren't enough remaining rows diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index d083a464c..d0559aac8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -96,7 +96,7 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; void main() { #if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) - init_iq_shmem(); + init_iq_shmem(gl_WorkGroupSize); #endif #ifdef MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index c545189b2..27c5d68b3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -107,7 +107,7 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem void main() { #if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL) - init_iq_shmem(); + init_iq_shmem(gl_WorkGroupSize); #endif #ifdef MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp index 6b874aad4..9e56a3530 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.comp @@ -380,10 +380,10 @@ const uvec2[256] iq2xxs_grid_const = { shared uvec2 iq2xxs_grid[256]; -void init_iq_shmem() +void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq2xxs_grid.length(); i += gl_WorkGroupSize.x) { + for (uint i = gl_LocalInvocationIndex.x; i < iq2xxs_grid.length(); i += wgsize.x) { iq2xxs_grid[i] = iq2xxs_grid_const[i]; } barrier(); @@ -547,10 +547,10 @@ const uvec2 iq2xs_grid_const[512] = { shared uvec2 iq2xs_grid[512]; -void init_iq_shmem() +void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq2xs_grid.length(); i += gl_WorkGroupSize.x) { + for (uint i = gl_LocalInvocationIndex.x; i < iq2xs_grid.length(); i += wgsize.x) { iq2xs_grid[i] = iq2xs_grid_const[i]; } barrier(); @@ -836,10 +836,10 @@ const uvec2 iq2s_grid_const[1024] = { shared uvec2 iq2s_grid[1024]; -void init_iq_shmem() +void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq2s_grid.length(); i += gl_WorkGroupSize.x) { + for (uint i = gl_LocalInvocationIndex.x; i < iq2s_grid.length(); i += wgsize.x) { iq2s_grid[i] = iq2s_grid_const[i]; } barrier(); @@ -904,10 +904,10 @@ const uint32_t iq3xxs_grid_const[256] = { shared uint32_t iq3xxs_grid[256]; -void init_iq_shmem() +void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq3xxs_grid.length(); i += gl_WorkGroupSize.x) { + for (uint i = gl_LocalInvocationIndex.x; i < iq3xxs_grid.length(); i += wgsize.x) { iq3xxs_grid[i] = iq3xxs_grid_const[i]; } barrier(); @@ -1011,10 +1011,10 @@ const uint32_t iq3s_grid_const[512] = { shared uint32_t iq3s_grid[512]; -void init_iq_shmem() +void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - for (uint i = gl_LocalInvocationIndex.x; i < iq3s_grid.length(); i += gl_WorkGroupSize.x) { + for (uint i = gl_LocalInvocationIndex.x; i < iq3s_grid.length(); i += wgsize.x) { iq3s_grid[i] = iq3s_grid_const[i]; } barrier(); @@ -1050,11 +1050,11 @@ const int8_t kvalues_iq4nl_const[16] = { shared FLOAT_TYPE kvalues_iq4nl[16]; -void init_iq_shmem() +void init_iq_shmem(uvec3 wgsize) { // copy the table into shared memory and sync - if (gl_LocalInvocationIndex.x < 16) { - kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]); + for (uint i = gl_LocalInvocationIndex.x; i < kvalues_iq4nl.length(); i += wgsize.x) { + kvalues_iq4nl[i] = FLOAT_TYPE(kvalues_iq4nl_const[i]); } barrier(); }