diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 3c772b32f..05bd55f27 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -439,9 +439,9 @@ void ggml_vk_init(void) { vmaCreateAllocator(&allocator_info, &vk_allocator); // Shaders - vk_pipeline_matmul_f32 = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 6, {128, 128, 1}); + vk_pipeline_matmul_f32 = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 6, {64, 64, 1}); if (vk_fp16_support) { - vk_pipeline_matmul_f16 = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 6, {128, 128, 1}); + vk_pipeline_matmul_f16 = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 6, {64, 64, 1}); } vk_pipeline_f16_to_f32 = ggml_vk_create_pipeline("vk_shaders/f16_to_f32.spv", "main", 2, 1, {32, 1, 1}); diff --git a/vk_shaders/matmul_f16.glsl b/vk_shaders/matmul_f16.glsl index 0abab4827..8fa3cda6c 100644 --- a/vk_shaders/matmul_f16.glsl +++ b/vk_shaders/matmul_f16.glsl @@ -1,15 +1,20 @@ #version 450 -#define BM 128 -#define BN 128 -#define BK 8 -#define TM 8 -#define TN 8 +#define BM 64 +#define BN 64 +#define BK 16 +#define WM 32 +#define WN 32 +#define WMITER 2 +#define TM 4 +#define TN 2 + +#define WARP 32 #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A { float16_t data_a[]; }; layout (binding = 1) readonly buffer B { float16_t data_b[]; }; @@ -32,10 +37,17 @@ void main() { const int ir = int(gl_WorkGroupID.x); const int ic = int(gl_WorkGroupID.y); - const int rstride = BM / TM; + const int warp_i = int(gl_LocalInvocationID.x / WARP); + const int warp_r = warp_i % (BM / WM); + const int warp_c = warp_i / (BM / WM); - const int lr = int(gl_LocalInvocationID.x % rstride); - const int lc = int(gl_LocalInvocationID.x / rstride); + const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const int WSUBM = WM / WMITER; + const int WSUBN = WN / WNITER; + + const int tiw = int(gl_LocalInvocationID.x % WARP); + const int tiwr = tiw % (WSUBM / TM); + const int tiwc = tiw / (WSUBM / TM); const int loadr = int(gl_LocalInvocationID.x % BK); const int loadc = int(gl_LocalInvocationID.x / BK); @@ -45,11 +57,11 @@ void main() { int pos_a = ir * BM * p.stride_a; int pos_b = ic * BN * p.stride_b; - float sums[TM * TN]; - float16_t cache_a[TM]; - float16_t cache_b[TN]; + float sums[WMITER * TM * WNITER * TN]; + float16_t cache_a[WMITER * TM]; + float16_t cache_b[WNITER * TN]; - [[unroll]] for (int i = 0; i < TM*TN; i++) { + [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = 0.0hf; } @@ -80,16 +92,24 @@ void main() { [[unroll]] for (int i = 0; i < BK; i++) { // Load from shared into cache - [[unroll]] for (int j = 0; j < BM; j++) { - cache_a[j] = buf_a[(lr + j*rstride) * (BK+1) + i]; + [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (int j = 0; j < TM; j++) { + cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; + } } - [[unroll]] for (int j = 0; j < TN; j++) { - cache_b[j] = buf_b[(lc * TN + j) * (BK+1) + i]; + [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (int j = 0; j < TN; j++) { + cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; + } } - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - sums[cc * TM + cr] += float(cache_a[cr]) * float(cache_b[cc]); + [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (int cc = 0; cc < TN; cc++) { + [[unroll]] for (int cr = 0; cr < TM; cr++) { + sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]); + } + } } } } @@ -97,13 +117,20 @@ void main() { barrier(); } - const int dr = ir * BM + lr; - const int dc = ic * BN + lc * TN; + const int dr = ir * BM + warp_r * WM; + const int dc = ic * BN + warp_c * WN; - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - if (dr + cr*rstride < p.M && dc + cc < p.N) { - data_d[(dc + cc) * p.stride_d + dr + cr*rstride] = sums[cc * TM + cr]; + [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { + + const int dr_warp = dr + wsir * WSUBM + tiwr * TM; + const int dc_warp = dc + wsic * WSUBN + tiwc * TN; + [[unroll]] for (int cc = 0; cc < TN; cc++) { + [[unroll]] for (int cr = 0; cr < TM; cr++) { + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { + data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; + } + } } } } diff --git a/vk_shaders/matmul_f16_warptile.glsl b/vk_shaders/matmul_f16_warptile.glsl deleted file mode 100644 index 8bf6c2b5a..000000000 --- a/vk_shaders/matmul_f16_warptile.glsl +++ /dev/null @@ -1,137 +0,0 @@ -#version 450 - -#define BM 128 -#define BN 128 -#define BK 16 -#define WM 64 -#define WN 64 -#define WMITER 4 -#define TM 4 -#define TN 8 - -#define WARP 32 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require - -layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A { float16_t data_a[]; }; -layout (binding = 1) readonly buffer B { float16_t data_b[]; }; -layout (binding = 2) writeonly buffer D { float data_d[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int K; - int stride_a; - int stride_b; - int stride_d; -} p; - -shared float16_t buf_a[BM * (BK+1)]; -shared float16_t buf_b[BN * (BK+1)]; - -void main() { - const int ir = int(gl_WorkGroupID.x); - const int ic = int(gl_WorkGroupID.y); - - const int warp_i = int(gl_LocalInvocationID.x / WARP); - const int warp_r = warp_i % (BM / WM); - const int warp_c = warp_i / (BM / WM); - - const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER); - const int WSUBM = WM / WMITER; - const int WSUBN = WN / WNITER; - - const int tiw = int(gl_LocalInvocationID.x % WARP); - const int tiwr = tiw % (WSUBM / TM); - const int tiwc = tiw / (WSUBM / TM); - - const int loadr = int(gl_LocalInvocationID.x % BK); - const int loadc = int(gl_LocalInvocationID.x / BK); - - const int loadstride = int(gl_WorkGroupSize.x); - - int pos_a = ir * BM * p.stride_a; - int pos_b = ic * BN * p.stride_b; - - float sums[WMITER * TM * WNITER * TN]; - float16_t cache_a[WMITER * TM]; - float16_t cache_b[WNITER * TN]; - - [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = 0.0hf; - } - - [[unroll]] for (int block = 0; block < p.K; block += BK) { - [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { - const int lr = l % BK; - const int lc = l / BK; - if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) { - buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr]; - } else { - buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf; - } - } - [[unroll]] for (int l = 0; l < BN * BK; l += loadstride) { - const int lr = l % BK; - const int lc = l / BK; - if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) { - buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr]; - } else { - buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf; - } - } - - barrier(); - - pos_a += BK; - pos_b += BK; - - [[unroll]] for (int i = 0; i < BK; i++) { - // Load from shared into cache - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int j = 0; j < TM; j++) { - cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; - } - } - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; - } - } - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]); - } - } - } - } - } - - barrier(); - } - - const int dr = ir * BM + warp_r * WM; - const int dc = ic * BN + warp_c * WN; - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - - const int dr_warp = dr + wsir * WSUBM + tiwr * TM; - const int dc_warp = dc + wsic * WSUBN + tiwc * TN; - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; - } - } - } - } - } -} diff --git a/vk_shaders/matmul_f32.glsl b/vk_shaders/matmul_f32.glsl index dfc572a6f..a353345af 100644 --- a/vk_shaders/matmul_f32.glsl +++ b/vk_shaders/matmul_f32.glsl @@ -1,14 +1,19 @@ #version 450 -#define BM 128 -#define BN 128 -#define BK 8 -#define TM 8 -#define TN 8 +#define BM 64 +#define BN 64 +#define BK 16 +#define WM 32 +#define WN 32 +#define WMITER 2 +#define TM 4 +#define TN 2 + +#define WARP 32 #extension GL_EXT_control_flow_attributes : enable -layout(local_size_x = (BM * BN) / (TM * TN), local_size_y = 1, local_size_z = 1) in; +layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; layout (binding = 0) readonly buffer A { float data_a[]; }; layout (binding = 1) readonly buffer B { float data_b[]; }; @@ -31,10 +36,17 @@ void main() { const int ir = int(gl_WorkGroupID.x); const int ic = int(gl_WorkGroupID.y); - const int rstride = BM / TM; + const int warp_i = int(gl_LocalInvocationID.x / WARP); + const int warp_r = warp_i % (BM / WM); + const int warp_c = warp_i / (BM / WM); - const int lr = int(gl_LocalInvocationID.x % rstride); - const int lc = int(gl_LocalInvocationID.x / rstride); + const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const int WSUBM = WM / WMITER; + const int WSUBN = WN / WNITER; + + const int tiw = int(gl_LocalInvocationID.x % WARP); + const int tiwr = tiw % (WSUBM / TM); + const int tiwc = tiw / (WSUBM / TM); const int loadr = int(gl_LocalInvocationID.x % BK); const int loadc = int(gl_LocalInvocationID.x / BK); @@ -44,11 +56,11 @@ void main() { int pos_a = ir * BM * p.stride_a; int pos_b = ic * BN * p.stride_b; - float sums[TM * TN]; - float cache_a[TM]; - float cache_b[TN]; + float sums[WMITER * TM * WNITER * TN]; + float cache_a[WMITER * TM]; + float cache_b[WNITER * TN]; - [[unroll]] for (int i = 0; i < TM*TN; i++) { + [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { sums[i] = 0.0f; } @@ -79,16 +91,24 @@ void main() { [[unroll]] for (int i = 0; i < BK; i++) { // Load from shared into cache - [[unroll]] for (int j = 0; j < BM; j++) { - cache_a[j] = buf_a[(lr + j*rstride) * (BK+1) + i]; + [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (int j = 0; j < TM; j++) { + cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; + } } - [[unroll]] for (int j = 0; j < TN; j++) { - cache_b[j] = buf_b[(lc * TN + j) * (BK+1) + i]; + [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (int j = 0; j < TN; j++) { + cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; + } } - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - sums[cc * TM + cr] += cache_a[cr] * cache_b[cc]; + [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (int cc = 0; cc < TN; cc++) { + [[unroll]] for (int cr = 0; cr < TM; cr++) { + sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += cache_a[wsir * TM + cr] * cache_b[wsic * TN + cc]; + } + } } } } @@ -96,13 +116,20 @@ void main() { barrier(); } - const int dr = ir * BM + lr; - const int dc = ic * BN + lc * TN; + const int dr = ir * BM + warp_r * WM; + const int dc = ic * BN + warp_c * WN; - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - if (dr + cr*rstride < p.M && dc + cc < p.N) { - data_d[(dc + cc) * p.stride_d + dr + cr*rstride] = sums[cc * TM + cr]; + [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { + + const int dr_warp = dr + wsir * WSUBM + tiwr * TM; + const int dc_warp = dc + wsic * WSUBN + tiwc * TN; + [[unroll]] for (int cc = 0; cc < TN; cc++) { + [[unroll]] for (int cr = 0; cr < TM; cr++) { + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { + data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; + } + } } } } diff --git a/vk_shaders/matmul_f32_warptile.glsl b/vk_shaders/matmul_f32_warptile.glsl deleted file mode 100644 index 19b02c0c4..000000000 --- a/vk_shaders/matmul_f32_warptile.glsl +++ /dev/null @@ -1,136 +0,0 @@ -#version 450 - -#define BM 128 -#define BN 128 -#define BK 16 -#define WM 64 -#define WN 64 -#define WMITER 4 -#define TM 4 -#define TN 8 - -#define WARP 32 - -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A { float data_a[]; }; -layout (binding = 1) readonly buffer B { float data_b[]; }; -layout (binding = 2) writeonly buffer D { float data_d[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int K; - int stride_a; - int stride_b; - int stride_d; -} p; - -shared float buf_a[BM * (BK+1)]; -shared float buf_b[BN * (BK+1)]; - -void main() { - const int ir = int(gl_WorkGroupID.x); - const int ic = int(gl_WorkGroupID.y); - - const int warp_i = int(gl_LocalInvocationID.x / WARP); - const int warp_r = warp_i % (BM / WM); - const int warp_c = warp_i / (BM / WM); - - const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER); - const int WSUBM = WM / WMITER; - const int WSUBN = WN / WNITER; - - const int tiw = int(gl_LocalInvocationID.x % WARP); - const int tiwr = tiw % (WSUBM / TM); - const int tiwc = tiw / (WSUBM / TM); - - const int loadr = int(gl_LocalInvocationID.x % BK); - const int loadc = int(gl_LocalInvocationID.x / BK); - - const int loadstride = int(gl_WorkGroupSize.x); - - int pos_a = ir * BM * p.stride_a; - int pos_b = ic * BN * p.stride_b; - - float sums[WMITER * TM * WNITER * TN]; - float cache_a[WMITER * TM]; - float cache_b[WNITER * TN]; - - [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = 0.0f; - } - - [[unroll]] for (int block = 0; block < p.K; block += BK) { - [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { - const int lr = l % BK; - const int lc = l / BK; - if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) { - buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr]; - } else { - buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f; - } - } - [[unroll]] for (int l = 0; l < BN * BK; l += loadstride) { - const int lr = l % BK; - const int lc = l / BK; - if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) { - buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr]; - } else { - buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f; - } - } - - barrier(); - - pos_a += BK; - pos_b += BK; - - [[unroll]] for (int i = 0; i < BK; i++) { - // Load from shared into cache - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int j = 0; j < TM; j++) { - cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; - } - } - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; - } - } - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += cache_a[wsir * TM + cr] * cache_b[wsic * TN + cc]; - } - } - } - } - } - - barrier(); - } - - const int dr = ir * BM + warp_r * WM; - const int dc = ic * BN + warp_c * WN; - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - - const int dr_warp = dr + wsir * WSUBM + tiwr * TM; - const int dc_warp = dc + wsic * WSUBN + tiwc * TN; - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[(dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; - } - } - } - } - } -}