From 592ebb044d0a63a92a08d20c670ace376f7f879a Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Mon, 14 Aug 2023 09:39:58 +0200 Subject: [PATCH] Transfer remaining shaders to header and compile on runtime --- CMakeLists.txt | 19 +- Makefile | 15 - ggml-vulkan-shaders.hpp | 355 +++++++++++++++++-- ggml-vulkan.cpp | 153 +++++--- vk_shaders/add_f16_f32_f16.glsl | 33 -- vk_shaders/add_f32.glsl | 31 -- vk_shaders/dequant_mul_mat_vec_f16.glsl | 59 --- vk_shaders/dequant_mul_mat_vec_f16_f32.glsl | 59 --- vk_shaders/dequant_mul_mat_vec_q4_0_f32.glsl | 73 ---- vk_shaders/matmul_f16.glsl | 145 -------- vk_shaders/matmul_f16_aligned.glsl | 149 -------- vk_shaders/matmul_f16_f32.glsl | 145 -------- vk_shaders/matmul_f16_f32_aligned.glsl | 149 -------- vk_shaders/matmul_f32.glsl | 144 -------- vk_shaders/matmul_f32_aligned.glsl | 140 -------- vk_shaders/matmul_f32_q4_0.glsl | 169 --------- vk_shaders/matmul_split_k_reduce.glsl | 31 -- vk_shaders/scale_f32.glsl | 30 -- 18 files changed, 429 insertions(+), 1470 deletions(-) delete mode 100644 vk_shaders/add_f16_f32_f16.glsl delete mode 100644 vk_shaders/add_f32.glsl delete mode 100644 vk_shaders/dequant_mul_mat_vec_f16.glsl delete mode 100644 vk_shaders/dequant_mul_mat_vec_f16_f32.glsl delete mode 100644 vk_shaders/dequant_mul_mat_vec_q4_0_f32.glsl delete mode 100644 vk_shaders/matmul_f16.glsl delete mode 100644 vk_shaders/matmul_f16_aligned.glsl delete mode 100644 vk_shaders/matmul_f16_f32.glsl delete mode 100644 vk_shaders/matmul_f16_f32_aligned.glsl delete mode 100644 vk_shaders/matmul_f32.glsl delete mode 100644 vk_shaders/matmul_f32_aligned.glsl delete mode 100644 vk_shaders/matmul_f32_q4_0.glsl delete mode 100644 vk_shaders/matmul_split_k_reduce.glsl delete mode 100644 vk_shaders/scale_f32.glsl diff --git a/CMakeLists.txt b/CMakeLists.txt index 7da296160..cd360a3f1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -347,27 +347,12 @@ if (LLAMA_CLBLAST) endif() if (LLAMA_VULKAN) - find_package(Vulkan COMPONENTS glslc) + find_package(Vulkan COMPONENTS glslc SPIRV-Tools) if (Vulkan_FOUND) message(STATUS "Vulkan found") add_library(ggml-vulkan STATIC ggml-vulkan.cpp ggml-vulkan.h) - target_link_libraries(ggml-vulkan PUBLIC Vulkan::Vulkan) - - set(GGML_VULKAN_SHADERS matmul_f32 matmul_f16 f16_to_f32 dequant_q4_0) - - foreach(s IN LISTS GGML_VULKAN_SHADERS) - add_custom_command( - OUTPUT "vk_shaders/${s}.spv" - COMMAND "${Vulkan_GLSLC_EXECUTABLE}" - -fshader-stage=compute - --target-env=vulkan1.2 - "${CMAKE_CURRENT_SOURCE_DIR}/vk_shaders/${s}.glsl" - -o "${CMAKE_CURRENT_BINARY_DIR}/vk_shaders/${s}.spv" - DEPENDS "vk_shaders/${s}.glsl" - ) - target_sources(ggml-vulkan PRIVATE "vk_shaders/${s}.spv") - endforeach() + target_link_libraries(ggml-vulkan PUBLIC Vulkan::Vulkan SPIRV SPIRV-Tools-opt SPIRV-Tools shaderc_combined) add_compile_definitions(GGML_USE_VULKAN) diff --git a/Makefile b/Makefile index d669826a2..f75bfb617 100644 --- a/Makefile +++ b/Makefile @@ -232,21 +232,6 @@ ifdef LLAMA_VULKAN OBJS += ggml-vulkan.o ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h $(CXX) $(CXXFLAGS) -c $< -o $@ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f32.glsl -o vk_shaders/matmul_f32.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f32_aligned.glsl -o vk_shaders/matmul_f32_aligned.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16.glsl -o vk_shaders/matmul_f16.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16_aligned.glsl -o vk_shaders/matmul_f16_aligned.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16_f32.glsl -o vk_shaders/matmul_f16_f32.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16_f32_aligned.glsl -o vk_shaders/matmul_f16_f32_aligned.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_split_k_reduce.glsl -o vk_shaders/matmul_split_k_reduce.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f16_to_f32.glsl -o vk_shaders/f16_to_f32.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_f16.glsl -o vk_shaders/dequant_mul_mat_vec_f16.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_f16_f32.glsl -o vk_shaders/dequant_mul_mat_vec_f16_f32.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_q4_0_f32.glsl -o vk_shaders/dequant_mul_mat_vec_q4_0_f32.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/add_f32.glsl -o vk_shaders/add_f32.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/add_f16_f32_f16.glsl -o vk_shaders/add_f16_f32_f16.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/scale_f32.glsl -o vk_shaders/scale_f32.spv & \ - wait endif ifneq ($(filter aarch64%,$(UNAME_M)),) diff --git a/ggml-vulkan-shaders.hpp b/ggml-vulkan-shaders.hpp index a2643aa26..2599bdf2d 100644 --- a/ggml-vulkan-shaders.hpp +++ b/ggml-vulkan-shaders.hpp @@ -1,5 +1,238 @@ #include +// Generic +const std::string shader_f32 = R"( +#define FLOAT_TYPE float +)"; +const std::string shader_f16 = R"( +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#define FLOAT_TYPE float16_t +)"; +const std::string shader_int8_ext = R"( +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +)"; +const std::string shader_output_f16 = R"( +#define OUT_TYPE float16_t +)"; +const std::string shader_output_f32 = R"( +#define OUT_TYPE float +)"; + +// MULMAT + +const std::string mulmat_head = R"( +#version 450 + +#define WARP 32 + +#extension GL_EXT_control_flow_attributes : enable +)"; + +const std::string mulmat_body = R"( +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#ifdef ALIGNED_INPUT +#define LOAD_VEC 8 +layout (binding = 0) readonly buffer A { A_TYPE data_a[]; }; +layout (binding = 1) readonly buffer B { B_TYPE data_b[]; }; + +#else + +#define LOAD_VEC 1 +layout (binding = 0) readonly buffer A { A_TYPE data_a[]; }; +layout (binding = 1) readonly buffer B { B_TYPE data_b[]; }; +#endif + +layout (binding = 2) writeonly buffer D { D_TYPE data_d[]; }; + +layout (push_constant) uniform parameter +{ + int M; + int N; + int K; + int stride_a; + int stride_b; + int stride_d; + int k_split; +} p; + +layout (constant_id = 1) const int BM = 64; +layout (constant_id = 2) const int BN = 64; +layout (constant_id = 3) const int BK = 16; +layout (constant_id = 4) const int WM = 32; +layout (constant_id = 5) const int WN = 32; +layout (constant_id = 6) const int WMITER = 2; +layout (constant_id = 7) const int TM = 4; +layout (constant_id = 8) const int TN = 2; + +shared FLOAT_TYPE buf_a[BM * (BK+1)]; +shared FLOAT_TYPE buf_b[BN * (BK+1)]; + +void main() { + const int blocks_x = (p.M + BM - 1) / BM; + const int ir = int(gl_WorkGroupID.x) % blocks_x; + const int ik = int(gl_WorkGroupID.x) / blocks_x; + const int ic = int(gl_WorkGroupID.y); + + const int warp_i = int(gl_LocalInvocationID.x / WARP); + const int warp_r = warp_i % (BM / WM); + const int warp_c = warp_i / (BM / WM); + + const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const int WSUBM = WM / WMITER; + const int WSUBN = WN / WNITER; + + const int tiw = int(gl_LocalInvocationID.x % WARP); + const int tiwr = tiw % (WSUBM / TM); + const int tiwc = tiw / (WSUBM / TM); + + const int loadr = int(gl_LocalInvocationID.x % (BK / LOAD_VEC)); + const int loadc = int(gl_LocalInvocationID.x / (BK / LOAD_VEC)); + + const int loadstride = int(gl_WorkGroupSize.x * LOAD_VEC) / BK; + + const int start_k = ik * p.k_split; + const int end_k = (ik + 1) * p.k_split; + + int pos_a = ir * BM * p.stride_a / LOAD_VEC + start_k / LOAD_VEC; + int pos_b = ic * BN * p.stride_b / LOAD_VEC + start_k / LOAD_VEC; + + D_TYPE sums[WMITER * TM * WNITER * TN]; + FLOAT_TYPE cache_a[WMITER * TM]; + FLOAT_TYPE cache_b[WNITER * TN]; + + [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = 0.0f; + } + + [[unroll]] for (int block = start_k; block < end_k; block += BK) { + [[unroll]] for (int l = 0; l < BM; l += loadstride) { +#ifdef ALIGNED_INPUT + A_TYPE tmp = data_a[pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr]; + buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(tmp[0].x); + buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(tmp[0].y); + buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(tmp[0].z); + buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(tmp[0].w); + buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(tmp[1].x); + buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(tmp[1].y); + buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(tmp[1].z); + buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(tmp[1].w); +#else + if (ir * BM + loadc + l < p.M && block + loadr < p.K) { + buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_a[pos_a + (loadc + l) * p.stride_a + loadr]); + } else { + buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f); + } +#endif + } + [[unroll]] for (int l = 0; l < BN; l += loadstride) { +#ifdef ALIGNED_INPUT + B_TYPE tmp = data_b[pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr]; + buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(tmp[0].x); + buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(tmp[0].y); + buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(tmp[0].z); + buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(tmp[0].w); + buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(tmp[1].x); + buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(tmp[1].y); + buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(tmp[1].z); + buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(tmp[1].w); +#else + if (ic * BN + loadc + l < p.N && block + loadr < p.K) { + buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_b[pos_b + (loadc + l) * p.stride_b + loadr]); + } else { + buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(0.0f); + } +#endif + } + + barrier(); + + pos_a += BK / LOAD_VEC; + pos_b += BK / LOAD_VEC; + + for (int i = 0; i < min(BK, p.K - block); i++) { + // Load from shared into cache + [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (int j = 0; j < TM; j++) { + cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; + } + } + [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (int j = 0; j < TN; j++) { + cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; + } + } + + [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (int cc = 0; cc < TN; cc++) { + [[unroll]] for (int cr = 0; cr < TM; cr++) { + sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += D_TYPE(cache_a[wsir * TM + cr]) * D_TYPE(cache_b[wsic * TN + cc]); + } + } + } + } + } + + barrier(); + } + + const int dr = ir * BM + warp_r * WM; + const int dc = ic * BN + warp_c * WN; + + const int k_split_offset = ik * p.M * p.N; + + [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { + + const int dr_warp = dr + wsir * WSUBM + tiwr * TM; + const int dc_warp = dc + wsic * WSUBN + tiwc * TN; + [[unroll]] for (int cc = 0; cc < TN; cc++) { + [[unroll]] for (int cr = 0; cr < TM; cr++) { + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { + data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; + } + } + } + } + } +} +)"; + +const std::string mulmat_split_k_reduce_src = R"( +#version 450 + +layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; + +layout (binding = 0) buffer A { float data[]; }; + +layout (push_constant) uniform parameter +{ + int M; + int N; + int k_num; +} p; + +void main() { + const int glr = int(gl_GlobalInvocationID.x); + const int glc = int(gl_GlobalInvocationID.y); + + if (glr >= p.M || glc >= p.N) { + return; + } + + const int idx = glc * p.M + glr; + + float result = 0.0f; + + for (int i = 0; i < p.k_num; i++) { + result += data[i * p.M * p.N + idx]; + } + + data[idx] = result; +} +)"; + // DEQUANT SHADER const std::string dequant_head = R"( #version 450 @@ -9,18 +242,6 @@ const std::string dequant_head = R"( #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require )"; -const std::string dequant_glsl_fp16_ext = R"( -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -)"; - -const std::string dequant_output_fp16 = R"( -#define OUT_TYPE float16_t -)"; - -const std::string dequant_output_fp32 = R"( -#define OUT_TYPE float -)"; - const std::string dequant_q4_0_defines = R"( #define QUANT_K 32 #define QUANT_R 2 @@ -83,9 +304,23 @@ const std::string mul_mat_vec_head = R"( #extension GL_EXT_shader_8bit_storage : require )"; -const std::string mul_mat_vec_fp16 = R"( -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +const std::string mul_mat_vec_b_type_f32 = R"( +#define B_TYPE float +)"; + +const std::string mul_mat_vec_b_type_f16 = R"( +#define B_TYPE float16_t +)"; + +const std::string mul_mat_vec_f16_defines = R"( +#define QUANT_K 32 +#define QUANT_R 2 +#define BLOCK_SIZE 32 + +#define A_TYPE float16_t + +#define DEQUANT_FUNC float16_t v0 = x[ib + 0]; \ +float16_t v1 = x[ib + 1]; )"; const std::string mul_mat_vec_q4_0_defines = R"( @@ -99,8 +334,6 @@ struct block_q4_0 uint8_t qs[16]; }; #define A_TYPE block_q4_0 -#define B_TYPE float16_t -#define OUT_TYPE float #define DEQUANT_FUNC const float16_t d = x[ib].d; \ const uint8_t vui = x[ib].qs[iqs]; \ @@ -122,7 +355,7 @@ layout (push_constant) uniform parameter int ncols; } p; -shared float16_t tmp[BLOCK_SIZE]; +shared FLOAT_TYPE tmp[BLOCK_SIZE]; void main() { const int block_size = int(gl_WorkGroupSize.x); @@ -142,20 +375,20 @@ void main() { DEQUANT_FUNC // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; + tmp[tid] += FLOAT_TYPE(v0 * y[iybs + iqs + 0]); + tmp[tid] += FLOAT_TYPE(v1 * y[iybs + iqs + y_offset]); } // sum up partial sums and write back result barrier(); - [[unroll]] for (int s=block_size/2; s>0; s>>=1) { + [[unroll]] for (int s = block_size/2; s > 0; s >>= 1) { if (tid < s) { tmp[tid] += tmp[tid + s]; } barrier(); } if (tid == 0) { - dst[row] = float(tmp[0]); + dst[row] = OUT_TYPE(tmp[0]); } } )"; @@ -195,9 +428,9 @@ const std::string mul_f32_src = R"( layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; -layout (binding = 0) buffer X { float data_x[]; }; -layout (binding = 1) buffer Y { float data_y[]; }; -layout (binding = 2) buffer D { float data_d[]; }; +layout (binding = 0) buffer X { X_TYPE data_x[]; }; +layout (binding = 1) buffer Y { Y_TYPE data_y[]; }; +layout (binding = 2) buffer D { D_TYPE data_d[]; }; layout (push_constant) uniform parameter { @@ -220,6 +453,76 @@ void main() { return; } - data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] * data_y[p.y_offset + x]; + data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(data_x[p.x_offset + y * p.stride_x + x]) * D_TYPE(data_y[p.y_offset + x]); +} +)"; + +// ADD +const std::string add_head = R"( +#version 450 +)"; + +const std::string add_body = R"( +layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; +layout (binding = 0) buffer X { X_TYPE data_x[]; }; +layout (binding = 1) buffer Y { Y_TYPE data_y[]; }; +layout (binding = 2) buffer D { D_TYPE data_d[]; }; + +layout (push_constant) uniform parameter +{ + int M; + int N; + int stride_x; + int stride_y; + int stride_d; + int x_offset; + int y_offset; + int d_offset; + float scale; +} p; + +void main() { + const int x = int(gl_GlobalInvocationID.x); + const int y = int(gl_GlobalInvocationID.y); + + if (x >= p.M || y >= p.N) { + return; + } + + data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(data_x[p.x_offset + y * p.stride_x + x]) + D_TYPE(data_y[p.y_offset + x]); +} +)"; + +// SCALE +const std::string scale_src = R"( +#version 450 + +layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; + +layout (binding = 0) buffer X { X_TYPE data_x[]; }; +layout (binding = 1) buffer D { D_TYPE data_d[]; }; + +layout (push_constant) uniform parameter +{ + int M; + int N; + int stride_x; + int stride_y; + int stride_d; + int x_offset; + int y_offset; + int d_offset; + float scale; +} p; + +void main() { + const int x = int(gl_GlobalInvocationID.x); + const int y = int(gl_GlobalInvocationID.y); + + if (x >= p.M || y >= p.N) { + return; + } + + data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(data_x[p.x_offset + y * p.stride_x + x]) * D_TYPE(p.scale); } )"; diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index b0ea9c175..96f39ce2b 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -23,6 +23,7 @@ #include #include #include +#include #include #include #include @@ -165,18 +166,34 @@ vk_pipeline vk_pipeline_f32_to_f16, vk_pipeline_dequant_q4_0; static std::vector> vk_pinned_memory; -static std::vector ggml_vk_compile_shader(const std::string& name, const std::string& src) { +static std::vector ggml_vk_compile_shader(const std::string& name, const std::string& src, std::vector&& defines) { #ifdef VK_DEBUG std::cerr << "ggml_vk_compile_shader(" << name << ", " << src << ")" << std::endl; #endif + GGML_ASSERT(defines.size() % 2 == 0); + shaderc::Compiler compiler; shaderc::CompileOptions options; + for (size_t i = 0; i < defines.size(); i += 2) { + options.AddMacroDefinition(defines[i], defines[i + 1]); + } + shaderc::SpvCompilationResult module = compiler.CompileGlslToSpv(src, shaderc_compute_shader, name.c_str(), options); if (module.GetCompilationStatus() != shaderc_compilation_status_success) { + shaderc::PreprocessedSourceCompilationResult prep_res = compiler.PreprocessGlsl(src, shaderc_compute_shader, name.c_str(), options); + + std::string prep_src = std::string{ prep_res.begin(), prep_res.end() }; + + std::stringstream ss(prep_src); + std::string line; + int counter = 1; + while(std::getline(ss, line, '\n')){ + std::cout << std::setw(3) << counter++ << std::setw(1) << ": " << line << std::endl; + } std::cerr << "ggml_vulkan: Shader compile error in " << name << ": " << module.GetErrorMessage(); - return std::vector(); + GGML_ASSERT(false); } return {module.cbegin(), module.cend()}; @@ -286,12 +303,12 @@ static vk_pipeline ggml_vk_create_pipeline(const std::string& name, size_t spv_s return pipeline; } -static vk_pipeline ggml_vk_create_pipeline_from_string(const std::string& name, const std::string& src, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector&& specialization_constants, uint32_t align) { +static vk_pipeline ggml_vk_create_pipeline_from_string(const std::string& name, const std::string& src, std::vector&& defines, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector&& specialization_constants, uint32_t align) { #ifdef VK_DEBUG std::cerr << "ggml_vk_create_pipeline_from_string(" << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")" << std::endl; #endif - const std::vector spv = ggml_vk_compile_shader(name, src); + const std::vector spv = ggml_vk_compile_shader(name, src, std::move(defines)); return ggml_vk_create_pipeline(name, spv.size() * sizeof(uint32_t), spv.data(), entrypoint, parameter_count, push_constant_size, wg_denoms, std::move(specialization_constants), align); } @@ -640,34 +657,98 @@ static void ggml_vk_generate_shaders() { #endif std::cerr << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; + // mulmat + auto warptile_l = { 128, 128, 128, 16, 64, 64, 2, 4, 4 }; + auto warptile_m = { 128, 64, 64, 16, 32, 32, 2, 4, 2 }; + auto warptile_s = { 32, 32, 32, 8, 32, 32, 2, 2, 2 }; + + std::stringstream stream; + stream << mulmat_head << shader_f32 << mulmat_body; + vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline_from_string("matmul_f32_l", stream.str(), { "A_TYPE", "float", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); + vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline_from_string("matmul_f32_m", stream.str(), { "A_TYPE", "float", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); + vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline_from_string("matmul_f32_s", stream.str(), { "A_TYPE", "float", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); + vk_pipeline_matmul_f32_aligned_l = ggml_vk_create_pipeline_from_string("matmul_f32_aligned_l", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); + vk_pipeline_matmul_f32_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f32_aligned_m", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); + vk_pipeline_matmul_f32_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f32_aligned_s", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); + + stream.str(""); + stream.clear(); + stream << mulmat_head << shader_f16 << mulmat_body; + vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline_from_string("matmul_f16_l", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float16_t", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); + vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline_from_string("matmul_f16_m", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float16_t", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); + vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline_from_string("matmul_f16_s", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float16_t", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); + + vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline_from_string("matmul_f16_aligned_l", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "f16mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); + vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f16_aligned_m", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "f16mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); + vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f16_aligned_s", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "f16mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); + + vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline_from_string("matmul_f16_f32_l", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); + vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline_from_string("matmul_f16_f32_m", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); + vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline_from_string("matmul_f16_f32_s", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); + vk_pipeline_matmul_f16_f32_aligned_l = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_l", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); + vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_m", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); + vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_s", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); // Build dequant q4_0 - std::stringstream stream; + stream.str(""); + stream.clear(); + stream << dequant_head; if (vk_device.fp16) { - stream << dequant_glsl_fp16_ext; - } - stream << dequant_q4_0_defines; - if (vk_device.fp16) { - stream << dequant_output_fp16; + stream << shader_f16 << shader_output_f16; } else { - stream << dequant_output_fp32; + stream << shader_output_f32; } - stream << dequant_body; + stream << dequant_q4_0_defines << dequant_body; - vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline_from_string("dequant_q4_0", stream.str(), "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1); + vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline_from_string("dequant_q4_0", stream.str(), {}, "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1); // mul mat vec stream.str(""); stream.clear(); - stream << mul_mat_vec_head << mul_mat_vec_fp16 << mul_mat_vec_q4_0_defines << mul_mat_vec_body; + stream << mul_mat_vec_head << shader_f16 << shader_int8_ext << shader_output_f32 << mul_mat_vec_b_type_f16 << mul_mat_vec_q4_0_defines << mul_mat_vec_body; - vk_pipeline_dequant_mul_mat_vec_q4_0 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0", stream.str(), "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + vk_pipeline_dequant_mul_mat_vec_q4_0 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0", stream.str(), {}, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + + stream.str(""); + stream.clear(); + + stream << mul_mat_vec_head << shader_f16 << shader_int8_ext << shader_output_f32 << mul_mat_vec_b_type_f32 << mul_mat_vec_q4_0_defines << mul_mat_vec_body; + + vk_pipeline_dequant_mul_mat_vec_q4_0_f32 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0_f32", stream.str(), {}, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + + stream.str(""); + stream.clear(); + + stream << mul_mat_vec_head << shader_f16 << shader_output_f32 << mul_mat_vec_b_type_f16 << mul_mat_vec_f16_defines << mul_mat_vec_body; + + vk_pipeline_dequant_mul_mat_vec_f16 = ggml_vk_create_pipeline_from_string("mul_mat_vec_f16", stream.str(), {}, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + + stream.str(""); + stream.clear(); + + stream << mul_mat_vec_head << shader_f16 << shader_output_f32 << mul_mat_vec_b_type_f32 << mul_mat_vec_f16_defines << mul_mat_vec_body; + vk_pipeline_dequant_mul_mat_vec_f16_f32 = ggml_vk_create_pipeline_from_string("mul_mat_vec_f16_f32", stream.str(), {}, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + + // add + stream.str(""); + stream.clear(); + + stream << add_head << add_body; + vk_pipeline_add_f32 = ggml_vk_create_pipeline_from_string("add_f32", stream.str(), { "X_TYPE", "float", "Y_TYPE", "float", "D_TYPE", "float" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); + stream.str(""); + stream.clear(); + + stream << add_head << shader_f16 << add_body; + vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline_from_string("add_f16_f32_f16", stream.str(), { "X_TYPE", "float16_t", "Y_TYPE", "float", "D_TYPE", "float16_t" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); // Static shaders - vk_pipeline_f32_to_f16 = ggml_vk_create_pipeline_from_string("f32_to_f16", f32_to_f16_src, "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1); - vk_pipeline_mul_f32 = ggml_vk_create_pipeline_from_string("mul_f32", mul_f32_src, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); + vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline_from_string("split_k_reduce", mulmat_split_k_reduce_src, {}, "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1); + vk_pipeline_f32_to_f16 = ggml_vk_create_pipeline_from_string("f32_to_f16", f32_to_f16_src, {}, "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1); + vk_pipeline_mul_f32 = ggml_vk_create_pipeline_from_string("mul_f32", mul_f32_src, { "X_TYPE", "float", "Y_TYPE", "float", "D_TYPE", "float" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); + + vk_pipeline_scale_f32 = ggml_vk_create_pipeline_from_string("scale_f32", scale_src, { "X_TYPE", "float", "D_TYPE", "float" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); } void ggml_vk_test_transfer(size_t ne); @@ -806,45 +887,7 @@ void ggml_vk_init(void) { vk_device.descriptor_set_mode = VK_DEVICE_DESCRIPTOR_POOL_MODE_UNKNOWN; - // Prepare matmul values - auto warptile_l = { 128, 128, 128, 16, 64, 64, 2, 4, 4 }; - auto warptile_m = { 128, 64, 64, 16, 32, 32, 2, 4, 2 }; - auto warptile_s = { 32, 32, 32, 8, 32, 32, 2, 2, 2 }; - // Shaders - vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); - vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); - vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); - vk_pipeline_matmul_f32_aligned_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); - vk_pipeline_matmul_f32_aligned_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); - vk_pipeline_matmul_f32_aligned_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); - if (vk_device.fp16) { - vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); - vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); - vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); - vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); - vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); - vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); - - vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); - vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); - vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); - vk_pipeline_matmul_f16_f32_aligned_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); - vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); - vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); - - vk_pipeline_dequant_mul_mat_vec_f16 = ggml_vk_create_pipeline_from_file("vk_shaders/dequant_mul_mat_vec_f16.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - } - vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_split_k_reduce.spv", "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1); - - vk_pipeline_dequant_mul_mat_vec_f16_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/dequant_mul_mat_vec_f16_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_dequant_mul_mat_vec_q4_0_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/dequant_mul_mat_vec_q4_0_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - - vk_pipeline_add_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/add_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); - vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline_from_file("vk_shaders/add_f16_f32_f16.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); - - vk_pipeline_scale_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/scale_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); - ggml_vk_generate_shaders(); // Queues diff --git a/vk_shaders/add_f16_f32_f16.glsl b/vk_shaders/add_f16_f32_f16.glsl deleted file mode 100644 index 88b9e9488..000000000 --- a/vk_shaders/add_f16_f32_f16.glsl +++ /dev/null @@ -1,33 +0,0 @@ -#version 450 - -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require - -layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; - -layout (binding = 0) buffer X { float16_t data_x[]; }; -layout (binding = 1) buffer Y { float data_y[]; }; -layout (binding = 2) buffer D { float16_t data_d[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int stride_x; - int stride_y; - int stride_d; - int x_offset; - int y_offset; - int d_offset; - float scale; -} p; - -void main() { - const int x = int(gl_GlobalInvocationID.x); - const int y = int(gl_GlobalInvocationID.y); - - if (x >= p.M || y >= p.N) { - return; - } - - data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] + float16_t(data_y[p.y_offset + x]); -} diff --git a/vk_shaders/add_f32.glsl b/vk_shaders/add_f32.glsl deleted file mode 100644 index d7006a450..000000000 --- a/vk_shaders/add_f32.glsl +++ /dev/null @@ -1,31 +0,0 @@ -#version 450 - -layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; - -layout (binding = 0) buffer X { float data_x[]; }; -layout (binding = 1) buffer Y { float data_y[]; }; -layout (binding = 2) buffer D { float data_d[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int stride_x; - int stride_y; - int stride_d; - int x_offset; - int y_offset; - int d_offset; - float scale; -} p; - -void main() { - const int x = int(gl_GlobalInvocationID.x); - const int y = int(gl_GlobalInvocationID.y); - - if (x >= p.M || y >= p.N) { - return; - } - - data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] + data_y[p.y_offset + x]; -} diff --git a/vk_shaders/dequant_mul_mat_vec_f16.glsl b/vk_shaders/dequant_mul_mat_vec_f16.glsl deleted file mode 100644 index 5c45819be..000000000 --- a/vk_shaders/dequant_mul_mat_vec_f16.glsl +++ /dev/null @@ -1,59 +0,0 @@ -#version 450 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require - -#define QUANT_K 32 -#define QUANT_R 2 -#define BLOCK_SIZE 32 - -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A { float16_t x[]; }; -layout (binding = 1) readonly buffer B { float16_t y[]; }; -layout (binding = 2) writeonly buffer D { float dst[]; }; - -layout (push_constant) uniform parameter -{ - int ncols; -} p; - -shared float16_t tmp[BLOCK_SIZE]; - -void main() { - const int block_size = int(gl_WorkGroupSize.x); - const int row = int(gl_WorkGroupID.x); - const int tid = int(gl_LocalInvocationID.x); - - const int y_offset = QUANT_K/2; - - tmp[tid] = 0.0hf; - - [[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*p.ncols + col)/QUANT_K; // block index - const int iqs = (col%QUANT_K)/QUANT_R; // quant index - const int iybs = col - col%QUANT_K; // y block start index - - // dequantize - float16_t v0 = x[ib + 0]; - float16_t v1 = x[ib + 1]; - - // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - dst[row] = float(tmp[0]); - } -} diff --git a/vk_shaders/dequant_mul_mat_vec_f16_f32.glsl b/vk_shaders/dequant_mul_mat_vec_f16_f32.glsl deleted file mode 100644 index f365d8023..000000000 --- a/vk_shaders/dequant_mul_mat_vec_f16_f32.glsl +++ /dev/null @@ -1,59 +0,0 @@ -#version 450 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require - -#define QUANT_K 32 -#define QUANT_R 2 -#define BLOCK_SIZE 32 - -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A { float16_t x[]; }; -layout (binding = 1) readonly buffer B { float y[]; }; -layout (binding = 2) writeonly buffer D { float dst[]; }; - -layout (push_constant) uniform parameter -{ - int ncols; -} p; - -shared float tmp[BLOCK_SIZE]; - -void main() { - const int block_size = int(gl_WorkGroupSize.x); - const int row = int(gl_WorkGroupID.x); - const int tid = int(gl_LocalInvocationID.x); - - const int y_offset = QUANT_K/2; - - tmp[tid] = 0.0hf; - - [[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*p.ncols + col)/QUANT_K; // block index - const int iqs = (col%QUANT_K)/QUANT_R; // quant index - const int iybs = col - col%QUANT_K; // y block start index - - // dequantize - float16_t v0 = x[ib + 0]; - float16_t v1 = x[ib + 1]; - - // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - dst[row] = tmp[0]; - } -} diff --git a/vk_shaders/dequant_mul_mat_vec_q4_0_f32.glsl b/vk_shaders/dequant_mul_mat_vec_q4_0_f32.glsl deleted file mode 100644 index 8aa6ac57a..000000000 --- a/vk_shaders/dequant_mul_mat_vec_q4_0_f32.glsl +++ /dev/null @@ -1,73 +0,0 @@ -#version 450 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require - -#define QUANT_K 32 -#define QUANT_R 2 -#define BLOCK_SIZE 32 - -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; - -struct block_q4_0 -{ - float16_t d; - uint8_t qs[16]; -}; - -layout (binding = 0) readonly buffer A { block_q4_0 x[]; }; -layout (binding = 1) readonly buffer B { float y[]; }; -layout (binding = 2) writeonly buffer D { float dst[]; }; - -layout (push_constant) uniform parameter -{ - int ncols; -} p; - -shared float tmp[BLOCK_SIZE]; - -void main() { - const int block_size = int(gl_WorkGroupSize.x); - const int row = int(gl_WorkGroupID.x); - const int tid = int(gl_LocalInvocationID.x); - - const int y_offset = QUANT_K/2; - - tmp[tid] = 0.0hf; - - [[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*p.ncols + col)/QUANT_K; // block index - const int iqs = (col%QUANT_K)/QUANT_R; // quant index - const int iybs = col - col%QUANT_K; // y block start index - - // dequantize - const float16_t d = x[ib].d; - - const uint8_t vui = x[ib].qs[iqs]; - - const int8_t vi0 = int8_t(vui & 0xF); - const int8_t vi1 = int8_t(vui >> 4); - - float16_t v0 = float16_t(vi0 - 8)*d; - float16_t v1 = float16_t(vi1 - 8)*d; - - // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - dst[row] = tmp[0]; - } -} diff --git a/vk_shaders/matmul_f16.glsl b/vk_shaders/matmul_f16.glsl deleted file mode 100644 index 02c88cc0f..000000000 --- a/vk_shaders/matmul_f16.glsl +++ /dev/null @@ -1,145 +0,0 @@ -#version 450 - -#define WARP 32 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A { float16_t data_a[]; }; -layout (binding = 1) readonly buffer B { float16_t data_b[]; }; -layout (binding = 2) writeonly buffer D { float data_d[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int K; - int stride_a; - int stride_b; - int stride_d; - int k_split; -} p; - -layout (constant_id = 1) const int BM = 64; -layout (constant_id = 2) const int BN = 64; -layout (constant_id = 3) const int BK = 16; -layout (constant_id = 4) const int WM = 32; -layout (constant_id = 5) const int WN = 32; -layout (constant_id = 6) const int WMITER = 2; -layout (constant_id = 7) const int TM = 4; -layout (constant_id = 8) const int TN = 2; - -shared float16_t buf_a[BM * (BK+1)]; -shared float16_t buf_b[BN * (BK+1)]; - -void main() { - const int blocks_x = (p.M + BM - 1) / BM; - const int ir = int(gl_WorkGroupID.x) % blocks_x; - const int ik = int(gl_WorkGroupID.x) / blocks_x; - const int ic = int(gl_WorkGroupID.y); - - const int warp_i = int(gl_LocalInvocationID.x / WARP); - const int warp_r = warp_i % (BM / WM); - const int warp_c = warp_i / (BM / WM); - - const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER); - const int WSUBM = WM / WMITER; - const int WSUBN = WN / WNITER; - - const int tiw = int(gl_LocalInvocationID.x % WARP); - const int tiwr = tiw % (WSUBM / TM); - const int tiwc = tiw / (WSUBM / TM); - - const int loadr = int(gl_LocalInvocationID.x % BK); - const int loadc = int(gl_LocalInvocationID.x / BK); - - const int loadstride = int(gl_WorkGroupSize.x); - - const int start_k = ik * p.k_split; - const int end_k = (ik + 1) * p.k_split; - - int pos_a = ir * BM * p.stride_a + start_k; - int pos_b = ic * BN * p.stride_b + start_k; - - float sums[WMITER * TM * WNITER * TN]; - float16_t cache_a[WMITER * TM]; - float16_t cache_b[WNITER * TN]; - - [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = 0.0f; - } - - [[unroll]] for (int block = start_k; block < end_k; block += BK) { - [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { - const int lr = l % BK; - const int lc = l / BK; - if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) { - buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr]; - } else { - buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf; - } - } - [[unroll]] for (int l = 0; l < BN * BK; l += loadstride) { - const int lr = l % BK; - const int lc = l / BK; - if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) { - buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr]; - } else { - buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf; - } - } - - barrier(); - - pos_a += BK; - pos_b += BK; - - for (int i = 0; i < min(BK, p.K - block); i++) { - // Load from shared into cache - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int j = 0; j < TM; j++) { - cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; - } - } - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; - } - } - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]); - } - } - } - } - } - - barrier(); - } - - const int dr = ir * BM + warp_r * WM; - const int dc = ic * BN + warp_c * WN; - - const int k_split_offset = ik * p.M * p.N; - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - - const int dr_warp = dr + wsir * WSUBM + tiwr * TM; - const int dc_warp = dc + wsic * WSUBN + tiwc * TN; - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; - } - } - } - } - } -} diff --git a/vk_shaders/matmul_f16_aligned.glsl b/vk_shaders/matmul_f16_aligned.glsl deleted file mode 100644 index 2b3d71772..000000000 --- a/vk_shaders/matmul_f16_aligned.glsl +++ /dev/null @@ -1,149 +0,0 @@ -#version 450 - -#define WARP 32 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A { f16mat2x4 data_a[]; }; -layout (binding = 1) readonly buffer B { f16mat2x4 data_b[]; }; -layout (binding = 2) writeonly buffer D { float data_d[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int K; - int stride_a; - int stride_b; - int stride_d; - int k_split; -} p; - -layout (constant_id = 1) const int BM = 64; -layout (constant_id = 2) const int BN = 64; -layout (constant_id = 3) const int BK = 16; -layout (constant_id = 4) const int WM = 32; -layout (constant_id = 5) const int WN = 32; -layout (constant_id = 6) const int WMITER = 2; -layout (constant_id = 7) const int TM = 4; -layout (constant_id = 8) const int TN = 2; - -shared float16_t buf_a[BM * (BK+1)]; -shared float16_t buf_b[BN * (BK+1)]; - -void main() { - const int blocks_x = (p.M + BM - 1) / BM; - const int ir = int(gl_WorkGroupID.x) % blocks_x; - const int ik = int(gl_WorkGroupID.x) / blocks_x; - const int ic = int(gl_WorkGroupID.y); - - const int warp_i = int(gl_LocalInvocationID.x / WARP); - const int warp_r = warp_i % (BM / WM); - const int warp_c = warp_i / (BM / WM); - - const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER); - const int WSUBM = WM / WMITER; - const int WSUBN = WN / WNITER; - - const int tiw = int(gl_LocalInvocationID.x % WARP); - const int tiwr = tiw % (WSUBM / TM); - const int tiwc = tiw / (WSUBM / TM); - - const int loadr = int(gl_LocalInvocationID.x % (BK / 8)); - const int loadc = int(gl_LocalInvocationID.x / (BK / 8)); - - const int loadstride = int(gl_WorkGroupSize.x * 8) / BK; - - const int start_k = ik * p.k_split; - const int end_k = (ik + 1) * p.k_split; - - int pos_a = ir * BM * p.stride_a / 8 + start_k / 8; - int pos_b = ic * BN * p.stride_b / 8 + start_k / 8; - - float sums[WMITER * TM * WNITER * TN]; - float16_t cache_a[WMITER * TM]; - float16_t cache_b[WNITER * TN]; - - [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = 0.0f; - } - - [[unroll]] for (int block = start_k; block < end_k; block += BK) { - [[unroll]] for (int l = 0; l < BM; l += loadstride) { - f16mat2x4 tmp = data_a[pos_a + (loadc + l) * p.stride_a / 8 + loadr]; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 0] = tmp[0].x; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 1] = tmp[0].y; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 2] = tmp[0].z; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 3] = tmp[0].w; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 4] = tmp[1].x; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 5] = tmp[1].y; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 6] = tmp[1].z; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 7] = tmp[1].w; - } - [[unroll]] for (int l = 0; l < BN; l += loadstride) { - f16mat2x4 tmp = data_b[pos_b + (loadc + l) * p.stride_b / 8 + loadr]; - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 0] = tmp[0].x; - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 1] = tmp[0].y; - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 2] = tmp[0].z; - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 3] = tmp[0].w; - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 4] = tmp[1].x; - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 5] = tmp[1].y; - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 6] = tmp[1].z; - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 7] = tmp[1].w; - } - - barrier(); - - pos_a += BK / 8; - pos_b += BK / 8; - - for (int i = 0; i < min(BK, p.K - block); i++) { - // Load from shared into cache - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int j = 0; j < TM; j++) { - cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; - } - } - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; - } - } - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]); - } - } - } - } - } - - barrier(); - } - - const int dr = ir * BM + warp_r * WM; - const int dc = ic * BN + warp_c * WN; - - const int k_split_offset = ik * p.M * p.N; - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - - const int dr_warp = dr + wsir * WSUBM + tiwr * TM; - const int dc_warp = dc + wsic * WSUBN + tiwc * TN; - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; - } - } - } - } - } -} diff --git a/vk_shaders/matmul_f16_f32.glsl b/vk_shaders/matmul_f16_f32.glsl deleted file mode 100644 index d88567932..000000000 --- a/vk_shaders/matmul_f16_f32.glsl +++ /dev/null @@ -1,145 +0,0 @@ -#version 450 - -#define WARP 32 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A { float16_t data_a[]; }; -layout (binding = 1) readonly buffer B { float data_b[]; }; -layout (binding = 2) writeonly buffer D { float data_d[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int K; - int stride_a; - int stride_b; - int stride_d; - int k_split; -} p; - -layout (constant_id = 1) const int BM = 64; -layout (constant_id = 2) const int BN = 64; -layout (constant_id = 3) const int BK = 16; -layout (constant_id = 4) const int WM = 32; -layout (constant_id = 5) const int WN = 32; -layout (constant_id = 6) const int WMITER = 2; -layout (constant_id = 7) const int TM = 4; -layout (constant_id = 8) const int TN = 2; - -shared float16_t buf_a[BM * (BK+1)]; -shared float16_t buf_b[BN * (BK+1)]; - -void main() { - const int blocks_x = (p.M + BM - 1) / BM; - const int ir = int(gl_WorkGroupID.x) % blocks_x; - const int ik = int(gl_WorkGroupID.x) / blocks_x; - const int ic = int(gl_WorkGroupID.y); - - const int warp_i = int(gl_LocalInvocationID.x / WARP); - const int warp_r = warp_i % (BM / WM); - const int warp_c = warp_i / (BM / WM); - - const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER); - const int WSUBM = WM / WMITER; - const int WSUBN = WN / WNITER; - - const int tiw = int(gl_LocalInvocationID.x % WARP); - const int tiwr = tiw % (WSUBM / TM); - const int tiwc = tiw / (WSUBM / TM); - - const int loadr = int(gl_LocalInvocationID.x % BK); - const int loadc = int(gl_LocalInvocationID.x / BK); - - const int loadstride = int(gl_WorkGroupSize.x); - - const int start_k = ik * p.k_split; - const int end_k = (ik + 1) * p.k_split; - - int pos_a = ir * BM * p.stride_a + start_k; - int pos_b = ic * BN * p.stride_b + start_k; - - float sums[WMITER * TM * WNITER * TN]; - float16_t cache_a[WMITER * TM]; - float16_t cache_b[WNITER * TN]; - - [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = 0.0f; - } - - [[unroll]] for (int block = start_k; block < end_k; block += BK) { - [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { - const int lr = l % BK; - const int lc = l / BK; - if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) { - buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr]; - } else { - buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf; - } - } - [[unroll]] for (int l = 0; l < BN * BK; l += loadstride) { - const int lr = l % BK; - const int lc = l / BK; - if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) { - buf_b[(loadc + lc) * (BK+1) + loadr + lr] = float16_t(data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr]); - } else { - buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0hf; - } - } - - barrier(); - - pos_a += BK; - pos_b += BK; - - for (int i = 0; i < min(BK, p.K - block); i++) { - // Load from shared into cache - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int j = 0; j < TM; j++) { - cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; - } - } - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; - } - } - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]); - } - } - } - } - } - - barrier(); - } - - const int dr = ir * BM + warp_r * WM; - const int dc = ic * BN + warp_c * WN; - - const int k_split_offset = ik * p.M * p.N; - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - - const int dr_warp = dr + wsir * WSUBM + tiwr * TM; - const int dc_warp = dc + wsic * WSUBN + tiwc * TN; - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; - } - } - } - } - } -} diff --git a/vk_shaders/matmul_f16_f32_aligned.glsl b/vk_shaders/matmul_f16_f32_aligned.glsl deleted file mode 100644 index b34eaea1a..000000000 --- a/vk_shaders/matmul_f16_f32_aligned.glsl +++ /dev/null @@ -1,149 +0,0 @@ -#version 450 - -#define WARP 32 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A { f16mat2x4 data_a[]; }; -layout (binding = 1) readonly buffer B { mat2x4 data_b[]; }; -layout (binding = 2) writeonly buffer D { float data_d[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int K; - int stride_a; - int stride_b; - int stride_d; - int k_split; -} p; - -layout (constant_id = 1) const int BM = 64; -layout (constant_id = 2) const int BN = 64; -layout (constant_id = 3) const int BK = 16; -layout (constant_id = 4) const int WM = 32; -layout (constant_id = 5) const int WN = 32; -layout (constant_id = 6) const int WMITER = 2; -layout (constant_id = 7) const int TM = 4; -layout (constant_id = 8) const int TN = 2; - -shared float16_t buf_a[BM * (BK+1)]; -shared float16_t buf_b[BN * (BK+1)]; - -void main() { - const int blocks_x = (p.M + BM - 1) / BM; - const int ir = int(gl_WorkGroupID.x) % blocks_x; - const int ik = int(gl_WorkGroupID.x) / blocks_x; - const int ic = int(gl_WorkGroupID.y); - - const int warp_i = int(gl_LocalInvocationID.x / WARP); - const int warp_r = warp_i % (BM / WM); - const int warp_c = warp_i / (BM / WM); - - const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER); - const int WSUBM = WM / WMITER; - const int WSUBN = WN / WNITER; - - const int tiw = int(gl_LocalInvocationID.x % WARP); - const int tiwr = tiw % (WSUBM / TM); - const int tiwc = tiw / (WSUBM / TM); - - const int loadr = int(gl_LocalInvocationID.x % (BK / 8)); - const int loadc = int(gl_LocalInvocationID.x / (BK / 8)); - - const int loadstride = int(gl_WorkGroupSize.x * 8) / BK; - - const int start_k = ik * p.k_split; - const int end_k = (ik + 1) * p.k_split; - - int pos_a = ir * BM * p.stride_a / 8 + start_k / 8; - int pos_b = ic * BN * p.stride_b / 8 + start_k / 8; - - float sums[WMITER * TM * WNITER * TN]; - float16_t cache_a[WMITER * TM]; - float16_t cache_b[WNITER * TN]; - - [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = 0.0f; - } - - [[unroll]] for (int block = start_k; block < end_k; block += BK) { - [[unroll]] for (int l = 0; l < BM; l += loadstride) { - f16mat2x4 tmp = data_a[pos_a + (loadc + l) * p.stride_a / 8 + loadr]; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 0] = tmp[0].x; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 1] = tmp[0].y; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 2] = tmp[0].z; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 3] = tmp[0].w; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 4] = tmp[1].x; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 5] = tmp[1].y; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 6] = tmp[1].z; - buf_a[(loadc + l) * (BK+1) + loadr * 8 + 7] = tmp[1].w; - } - [[unroll]] for (int l = 0; l < BN; l += loadstride) { - mat2x4 tmp = data_b[pos_b + (loadc + l) * p.stride_b / 8 + loadr]; - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 0] = float16_t(tmp[0].x); - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 1] = float16_t(tmp[0].y); - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 2] = float16_t(tmp[0].z); - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 3] = float16_t(tmp[0].w); - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 4] = float16_t(tmp[1].x); - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 5] = float16_t(tmp[1].y); - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 6] = float16_t(tmp[1].z); - buf_b[(loadc + l) * (BK+1) + loadr * 8 + 7] = float16_t(tmp[1].w); - } - - barrier(); - - pos_a += BK / 8; - pos_b += BK / 8; - - for (int i = 0; i < min(BK, p.K - block); i++) { - // Load from shared into cache - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int j = 0; j < TM; j++) { - cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; - } - } - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; - } - } - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += float(cache_a[wsir * TM + cr]) * float(cache_b[wsic * TN + cc]); - } - } - } - } - } - - barrier(); - } - - const int dr = ir * BM + warp_r * WM; - const int dc = ic * BN + warp_c * WN; - - const int k_split_offset = ik * p.M * p.N; - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - - const int dr_warp = dr + wsir * WSUBM + tiwr * TM; - const int dc_warp = dc + wsic * WSUBN + tiwc * TN; - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; - } - } - } - } - } -} diff --git a/vk_shaders/matmul_f32.glsl b/vk_shaders/matmul_f32.glsl deleted file mode 100644 index 5cc268f35..000000000 --- a/vk_shaders/matmul_f32.glsl +++ /dev/null @@ -1,144 +0,0 @@ -#version 450 - -#define WARP 32 - -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A { float data_a[]; }; -layout (binding = 1) readonly buffer B { float data_b[]; }; -layout (binding = 2) writeonly buffer D { float data_d[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int K; - int stride_a; - int stride_b; - int stride_d; - int k_split; -} p; - -layout (constant_id = 1) const int BM = 64; -layout (constant_id = 2) const int BN = 64; -layout (constant_id = 3) const int BK = 16; -layout (constant_id = 4) const int WM = 32; -layout (constant_id = 5) const int WN = 32; -layout (constant_id = 6) const int WMITER = 2; -layout (constant_id = 7) const int TM = 4; -layout (constant_id = 8) const int TN = 2; - -shared float buf_a[BM * (BK+1)]; -shared float buf_b[BN * (BK+1)]; - -void main() { - const int blocks_x = (p.M + BM - 1) / BM; - const int ir = int(gl_WorkGroupID.x) % blocks_x; - const int ik = int(gl_WorkGroupID.x) / blocks_x; - const int ic = int(gl_WorkGroupID.y); - - const int warp_i = int(gl_LocalInvocationID.x / WARP); - const int warp_r = warp_i % (BM / WM); - const int warp_c = warp_i / (BM / WM); - - const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER); - const int WSUBM = WM / WMITER; - const int WSUBN = WN / WNITER; - - const int tiw = int(gl_LocalInvocationID.x % WARP); - const int tiwr = tiw % (WSUBM / TM); - const int tiwc = tiw / (WSUBM / TM); - - const int loadr = int(gl_LocalInvocationID.x % BK); - const int loadc = int(gl_LocalInvocationID.x / BK); - - const int loadstride = int(gl_WorkGroupSize.x); - - const int start_k = ik * p.k_split; - const int end_k = (ik + 1) * p.k_split; - - int pos_a = ir * BM * p.stride_a + start_k; - int pos_b = ic * BN * p.stride_b + start_k; - - float sums[WMITER * TM * WNITER * TN]; - float cache_a[WMITER * TM]; - float cache_b[WNITER * TN]; - - [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = 0.0f; - } - - [[unroll]] for (int block = start_k; block < end_k; block += BK) { - [[unroll]] for (int l = 0; l < BM * BK; l += loadstride) { - const int lr = l % BK; - const int lc = l / BK; - if (ir * BM + loadc + lc < p.M && block + loadr + lr < p.K) { - buf_a[(loadc + lc) * (BK+1) + loadr + lr] = data_a[pos_a + (loadc + lc) * p.stride_a + loadr + lr]; - } else { - buf_a[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f; - } - } - [[unroll]] for (int l = 0; l < BN * BK; l += loadstride) { - const int lr = l % BK; - const int lc = l / BK; - if (ic * BN + loadc + lc < p.N && block + loadr + lr < p.K) { - buf_b[(loadc + lc) * (BK+1) + loadr + lr] = data_b[pos_b + (loadc + lc) * p.stride_b + loadr + lr]; - } else { - buf_b[(loadc + lc) * (BK+1) + loadr + lr] = 0.0f; - } - } - - barrier(); - - pos_a += BK; - pos_b += BK; - - for (int i = 0; i < min(BK, p.K - block); i++) { - // Load from shared into cache - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int j = 0; j < TM; j++) { - cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; - } - } - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; - } - } - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += cache_a[wsir * TM + cr] * cache_b[wsic * TN + cc]; - } - } - } - } - } - - barrier(); - } - - const int dr = ir * BM + warp_r * WM; - const int dc = ic * BN + warp_c * WN; - - const int k_split_offset = ik * p.M * p.N; - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - - const int dr_warp = dr + wsir * WSUBM + tiwr * TM; - const int dc_warp = dc + wsic * WSUBN + tiwc * TN; - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; - } - } - } - } - } -} diff --git a/vk_shaders/matmul_f32_aligned.glsl b/vk_shaders/matmul_f32_aligned.glsl deleted file mode 100644 index e76c60d43..000000000 --- a/vk_shaders/matmul_f32_aligned.glsl +++ /dev/null @@ -1,140 +0,0 @@ -#version 450 - -#define WARP 32 - -#extension GL_EXT_control_flow_attributes : enable - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A { vec4 data_a[]; }; -layout (binding = 1) readonly buffer B { vec4 data_b[]; }; -layout (binding = 2) writeonly buffer D { float data_d[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int K; - int stride_a; - int stride_b; - int stride_d; - int k_split; -} p; - -layout (constant_id = 1) const int BM = 64; -layout (constant_id = 2) const int BN = 64; -layout (constant_id = 3) const int BK = 16; -layout (constant_id = 4) const int WM = 32; -layout (constant_id = 5) const int WN = 32; -layout (constant_id = 6) const int WMITER = 2; -layout (constant_id = 7) const int TM = 4; -layout (constant_id = 8) const int TN = 2; - -shared float buf_a[BM * (BK+1)]; -shared float buf_b[BN * (BK+1)]; - -void main() { - const int blocks_x = (p.M + BM - 1) / BM; - const int ir = int(gl_WorkGroupID.x) % blocks_x; - const int ik = int(gl_WorkGroupID.x) / blocks_x; - const int ic = int(gl_WorkGroupID.y); - - const int warp_i = int(gl_LocalInvocationID.x / WARP); - const int warp_r = warp_i % (BM / WM); - const int warp_c = warp_i / (BM / WM); - - const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER); - const int WSUBM = WM / WMITER; - const int WSUBN = WN / WNITER; - - const int tiw = int(gl_LocalInvocationID.x % WARP); - const int tiwr = tiw % (WSUBM / TM); - const int tiwc = tiw / (WSUBM / TM); - - const int loadr = int(gl_LocalInvocationID.x % (BK / 4)); - const int loadc = int(gl_LocalInvocationID.x / (BK / 4)); - - const int loadstride = int(gl_WorkGroupSize.x * 4) / BK; - - const int start_k = ik * p.k_split; - const int end_k = (ik + 1) * p.k_split; - - int pos_a = ir * BM * p.stride_a / 4 + start_k / 4; - int pos_b = ic * BN * p.stride_b / 4 + start_k / 4; - - float sums[WMITER * TM * WNITER * TN]; - float cache_a[WMITER * TM]; - float cache_b[WNITER * TN]; - - [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = 0.0f; - } - - [[unroll]] for (int block = start_k; block < end_k; block += BK) { - [[unroll]] for (int l = 0; l < BM; l += loadstride) { - vec4 tmp = data_a[pos_a + (loadc + l) * p.stride_a / 4 + loadr]; - buf_a[(loadc + l) * (BK+1) + loadr * 4 + 0] = tmp.x; - buf_a[(loadc + l) * (BK+1) + loadr * 4 + 1] = tmp.y; - buf_a[(loadc + l) * (BK+1) + loadr * 4 + 2] = tmp.z; - buf_a[(loadc + l) * (BK+1) + loadr * 4 + 3] = tmp.w; - } - [[unroll]] for (int l = 0; l < BN; l += loadstride) { - vec4 tmp = data_b[pos_b + (loadc + l) * p.stride_b / 4 + loadr]; - buf_b[(loadc + l) * (BK+1) + loadr * 4 + 0] = tmp.x; - buf_b[(loadc + l) * (BK+1) + loadr * 4 + 1] = tmp.y; - buf_b[(loadc + l) * (BK+1) + loadr * 4 + 2] = tmp.z; - buf_b[(loadc + l) * (BK+1) + loadr * 4 + 3] = tmp.w; - } - - barrier(); - - pos_a += BK / 4; - pos_b += BK / 4; - - for (int i = 0; i < min(BK, p.K - block); i++) { - // Load from shared into cache - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int j = 0; j < TM; j++) { - cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; - } - } - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; - } - } - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += cache_a[wsir * TM + cr] * cache_b[wsic * TN + cc]; - } - } - } - } - } - - barrier(); - } - - const int dr = ir * BM + warp_r * WM; - const int dc = ic * BN + warp_c * WN; - - const int k_split_offset = ik * p.M * p.N; - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - - const int dr_warp = dr + wsir * WSUBM + tiwr * TM; - const int dc_warp = dc + wsic * WSUBN + tiwc * TN; - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; - } - } - } - } - } -} diff --git a/vk_shaders/matmul_f32_q4_0.glsl b/vk_shaders/matmul_f32_q4_0.glsl deleted file mode 100644 index 094b099fc..000000000 --- a/vk_shaders/matmul_f32_q4_0.glsl +++ /dev/null @@ -1,169 +0,0 @@ -#version 450 - -#define WARP 32 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require - -#define QUANT_K 32 -#define QUANT_R 2 - -struct block_q4_0 -{ - float16_t d; - uint8_t qs[16]; -}; - -layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A { block_q4_0 data_a[]; }; -layout (binding = 1) readonly buffer B { vec4 data_b[]; }; -layout (binding = 2) writeonly buffer D { float data_d[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int K; - int stride_a; - int stride_b; - int stride_d; - int k_split; -} p; - -layout (constant_id = 1) const int BM = 64; -layout (constant_id = 2) const int BN = 64; -layout (constant_id = 3) const int BK = 16; -layout (constant_id = 4) const int WM = 32; -layout (constant_id = 5) const int WN = 32; -layout (constant_id = 6) const int WMITER = 2; -layout (constant_id = 7) const int TM = 4; -layout (constant_id = 8) const int TN = 2; - -shared float buf_a[BM * (BK+1)]; -shared float buf_b[BN * (BK+1)]; - -void main() { - const int blocks_x = (p.M + BM - 1) / BM; - const int ir = int(gl_WorkGroupID.x) % blocks_x; - const int ik = int(gl_WorkGroupID.x) / blocks_x; - const int ic = int(gl_WorkGroupID.y); - - const int stride_a = p.stride_a / QUANT_K; - - const int warp_i = int(gl_LocalInvocationID.x / WARP); - const int warp_r = warp_i % (BM / WM); - const int warp_c = warp_i / (BM / WM); - - const int WNITER = (WM * WN) / (WARP * TM * TN * WMITER); - const int WSUBM = WM / WMITER; - const int WSUBN = WN / WNITER; - - const int tiw = int(gl_LocalInvocationID.x % WARP); - const int tiwr = tiw % (WSUBM / TM); - const int tiwc = tiw / (WSUBM / TM); - - const int loadr = int(gl_LocalInvocationID.x % (BK / 4)); - const int loadc = int(gl_LocalInvocationID.x / (BK / 4)); - - const int loadstride = int(gl_WorkGroupSize.x * 4) / BK; - - const int start_k = ik * p.k_split; - const int end_k = (ik + 1) * p.k_split; - - int pos_b = ic * BN * p.stride_b / 4 + start_k / 4; - - float sums[WMITER * TM * WNITER * TN]; - float cache_a[WMITER * TM]; - float cache_b[WNITER * TN]; - - [[unroll]] for (int i = 0; i < WMITER*TM*WNITER*TN; i++) { - sums[i] = 0.0f; - } - - [[unroll]] for (int block = start_k; block < end_k; block += BK) { - [[unroll]] for (int l = 0; l < BM; l += loadstride) { - const int row = (block + loadr * 4) / QUANT_K; - const int qi = (block + loadr * 4) % QUANT_K; - const block_q4_0 blk = data_a[(ir * BM + loadc + l) * stride_a + row]; - const float d = float(blk.d); - - int x0, x1, x2, x3; - if (qi < 16) { - x0 = (blk.qs[qi + 0] & 0x0F) - 8; - x1 = (blk.qs[qi + 1] & 0x0F) - 8; - x2 = (blk.qs[qi + 2] & 0x0F) - 8; - x3 = (blk.qs[qi + 3] & 0x0F) - 8; - } else { - x0 = (blk.qs[qi + 0] >> 4) - 8; - x1 = (blk.qs[qi + 1] >> 4) - 8; - x2 = (blk.qs[qi + 2] >> 4) - 8; - x3 = (blk.qs[qi + 3] >> 4) - 8; - } - - buf_a[(loadc + l) * (BK+1) + loadr * 4 + 0] = x0*d; - buf_a[(loadc + l) * (BK+1) + loadr * 4 + 1] = x1*d; - buf_a[(loadc + l) * (BK+1) + loadr * 4 + 2] = x2*d; - buf_a[(loadc + l) * (BK+1) + loadr * 4 + 3] = x3*d; - } - [[unroll]] for (int l = 0; l < BN; l += loadstride) { - vec4 tmp = data_b[pos_b + (loadc + l) * p.stride_b / 4 + loadr]; - buf_b[(loadc + l) * (BK+1) + loadr * 4 + 0] = tmp.x; - buf_b[(loadc + l) * (BK+1) + loadr * 4 + 1] = tmp.y; - buf_b[(loadc + l) * (BK+1) + loadr * 4 + 2] = tmp.z; - buf_b[(loadc + l) * (BK+1) + loadr * 4 + 3] = tmp.w; - } - - barrier(); - - pos_b += BK / 4; - - for (int i = 0; i < min(BK, p.K - block); i++) { - // Load from shared into cache - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int j = 0; j < TM; j++) { - cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * (BK+1) + i]; - } - } - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int j = 0; j < TN; j++) { - cache_b[wsic * TN + j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * (BK+1) + i]; - } - } - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr] += cache_a[wsir * TM + cr] * cache_b[wsic * TN + cc]; - } - } - } - } - } - - barrier(); - } - - const int dr = ir * BM + warp_r * WM; - const int dc = ic * BN + warp_c * WN; - - const int k_split_offset = ik * p.M * p.N; - - [[unroll]] for (int wsic = 0; wsic < WNITER; wsic++) { - [[unroll]] for (int wsir = 0; wsir < WMITER; wsir++) { - - const int dr_warp = dr + wsir * WSUBM + tiwr * TM; - const int dc_warp = dc + wsic * WSUBN + tiwc * TN; - [[unroll]] for (int cc = 0; cc < TN; cc++) { - [[unroll]] for (int cr = 0; cr < TM; cr++) { - if (dr_warp + cr < p.M && dc_warp + cc < p.N) { - data_d[k_split_offset + (dc_warp + cc) * p.stride_d + dr_warp + cr] = sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]; - } - } - } - } - } -} diff --git a/vk_shaders/matmul_split_k_reduce.glsl b/vk_shaders/matmul_split_k_reduce.glsl deleted file mode 100644 index 006f8cbf9..000000000 --- a/vk_shaders/matmul_split_k_reduce.glsl +++ /dev/null @@ -1,31 +0,0 @@ -#version 450 - -layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; - -layout (binding = 0) buffer A { float data[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int k_num; -} p; - -void main() { - const int glr = int(gl_GlobalInvocationID.x); - const int glc = int(gl_GlobalInvocationID.y); - - if (glr >= p.M || glc >= p.N) { - return; - } - - const int idx = glc * p.M + glr; - - float result = 0.0f; - - for (int i = 0; i < p.k_num; i++) { - result += data[i * p.M * p.N + idx]; - } - - data[idx] = result; -} diff --git a/vk_shaders/scale_f32.glsl b/vk_shaders/scale_f32.glsl deleted file mode 100644 index f1fc4d10e..000000000 --- a/vk_shaders/scale_f32.glsl +++ /dev/null @@ -1,30 +0,0 @@ -#version 450 - -layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; - -layout (binding = 0) buffer X { float data_x[]; }; -layout (binding = 1) buffer D { float data_d[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int stride_x; - int stride_y; - int stride_d; - int x_offset; - int y_offset; - int d_offset; - float scale; -} p; - -void main() { - const int x = int(gl_GlobalInvocationID.x); - const int y = int(gl_GlobalInvocationID.y); - - if (x >= p.M || y >= p.N) { - return; - } - - data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] * p.scale; -}