From a47ca7ae7ad3dd2effd9dfa9d27481783c47e12e Mon Sep 17 00:00:00 2001 From: 0cc4m Date: Sun, 13 Aug 2023 11:01:27 +0200 Subject: [PATCH] Add runtime shader compilation, start transferring shaders to this approach --- Makefile | 6 +- ggml-vulkan-shaders.hpp | 225 +++++++++++++++++++++++ ggml-vulkan.cpp | 167 ++++++++++++----- vk_shaders/dequant_mul_mat_vec_q4_0.glsl | 73 -------- vk_shaders/dequant_q4_0.glsl | 53 ------ vk_shaders/f16_to_f32.glsl | 25 --- vk_shaders/f32_to_f16.glsl | 25 --- vk_shaders/mul_f32.glsl | 31 ---- 8 files changed, 343 insertions(+), 262 deletions(-) create mode 100644 ggml-vulkan-shaders.hpp delete mode 100644 vk_shaders/dequant_mul_mat_vec_q4_0.glsl delete mode 100644 vk_shaders/dequant_q4_0.glsl delete mode 100644 vk_shaders/f16_to_f32.glsl delete mode 100644 vk_shaders/f32_to_f16.glsl delete mode 100644 vk_shaders/mul_f32.glsl diff --git a/Makefile b/Makefile index 669078bd5..d669826a2 100644 --- a/Makefile +++ b/Makefile @@ -228,7 +228,7 @@ endif # LLAMA_METAL ifdef LLAMA_VULKAN CFLAGS += -DGGML_USE_VULKAN CXXFLAGS += -DGGML_USE_VULKAN - LDFLAGS += -lvulkan -lopenblas + LDFLAGS += -lvulkan -lopenblas -lglslang -lSPIRV -lSPIRV-Tools-opt -lSPIRV-Tools -lshaderc_combined OBJS += ggml-vulkan.o ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h $(CXX) $(CXXFLAGS) -c $< -o $@ @@ -240,13 +240,9 @@ ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16_f32_aligned.glsl -o vk_shaders/matmul_f16_f32_aligned.spv & \ glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_split_k_reduce.glsl -o vk_shaders/matmul_split_k_reduce.spv & \ glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f16_to_f32.glsl -o vk_shaders/f16_to_f32.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f32_to_f16.glsl -o vk_shaders/f32_to_f16.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_q4_0.glsl -o vk_shaders/dequant_q4_0.spv & \ glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_f16.glsl -o vk_shaders/dequant_mul_mat_vec_f16.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_q4_0.glsl -o vk_shaders/dequant_mul_mat_vec_q4_0.spv & \ glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_f16_f32.glsl -o vk_shaders/dequant_mul_mat_vec_f16_f32.spv & \ glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_q4_0_f32.glsl -o vk_shaders/dequant_mul_mat_vec_q4_0_f32.spv & \ - glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/mul_f32.glsl -o vk_shaders/mul_f32.spv & \ glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/add_f32.glsl -o vk_shaders/add_f32.spv & \ glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/add_f16_f32_f16.glsl -o vk_shaders/add_f16_f32_f16.spv & \ glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/scale_f32.glsl -o vk_shaders/scale_f32.spv & \ diff --git a/ggml-vulkan-shaders.hpp b/ggml-vulkan-shaders.hpp new file mode 100644 index 000000000..a2643aa26 --- /dev/null +++ b/ggml-vulkan-shaders.hpp @@ -0,0 +1,225 @@ +#include + +// DEQUANT SHADER +const std::string dequant_head = R"( +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +)"; + +const std::string dequant_glsl_fp16_ext = R"( +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +)"; + +const std::string dequant_output_fp16 = R"( +#define OUT_TYPE float16_t +)"; + +const std::string dequant_output_fp32 = R"( +#define OUT_TYPE float +)"; + +const std::string dequant_q4_0_defines = R"( +#define QUANT_K 32 +#define QUANT_R 2 + +struct block_q4_0 +{ + float16_t d; + uint8_t qs[16]; +}; + +#define A_TYPE block_q4_0 +)"; + +const std::string dequant_body = R"( +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A { A_TYPE x[]; }; +layout (binding = 1) writeonly buffer D { OUT_TYPE y[]; }; + +layout (push_constant) uniform parameter +{ + int M; + int K; + int stride_a; + int stride_b; +} p; + +void main() { + const int i = int(gl_GlobalInvocationID.x); + + // Transposed + const int row = i % (p.K / QUANT_K); + const int col = i / (p.K / QUANT_K); + + if (row * QUANT_K >= p.K || col >= p.M) { + return; + } + + const int stride_a = p.stride_a / QUANT_K; + + const A_TYPE blk = x[col * stride_a + row]; + const OUT_TYPE d = blk.d; + + [[unroll]] for (int j = 0; j < QUANT_K/2; ++j) { + const OUT_TYPE x0 = OUT_TYPE((blk.qs[j] & 0x0F) - 8); + const OUT_TYPE x1 = OUT_TYPE((blk.qs[j] >> 4) - 8); + + y[col * p.stride_b + row*QUANT_K + j + 0 ] = x0*d; + y[col * p.stride_b + row*QUANT_K + j + QUANT_K/2] = x1*d; + } +} +)"; + +// Mul Mat Vec +const std::string mul_mat_vec_head = R"( +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_8bit_storage : require +)"; + +const std::string mul_mat_vec_fp16 = R"( +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +)"; + +const std::string mul_mat_vec_q4_0_defines = R"( +#define QUANT_K 32 +#define QUANT_R 2 +#define BLOCK_SIZE 32 + +struct block_q4_0 +{ + float16_t d; + uint8_t qs[16]; +}; +#define A_TYPE block_q4_0 +#define B_TYPE float16_t +#define OUT_TYPE float + +#define DEQUANT_FUNC const float16_t d = x[ib].d; \ +const uint8_t vui = x[ib].qs[iqs]; \ +const int8_t vi0 = int8_t(vui & 0xF); \ +const int8_t vi1 = int8_t(vui >> 4); \ +float16_t v0 = float16_t(vi0 - 8)*d; \ +float16_t v1 = float16_t(vi1 - 8)*d; +)"; + +const std::string mul_mat_vec_body = R"( +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A { A_TYPE x[]; }; +layout (binding = 1) readonly buffer B { B_TYPE y[]; }; +layout (binding = 2) writeonly buffer D { OUT_TYPE dst[]; }; + +layout (push_constant) uniform parameter +{ + int ncols; +} p; + +shared float16_t tmp[BLOCK_SIZE]; + +void main() { + const int block_size = int(gl_WorkGroupSize.x); + const int row = int(gl_WorkGroupID.x); + const int tid = int(gl_LocalInvocationID.x); + + const int y_offset = QUANT_K/2; + + tmp[tid] = 0.0hf; + + [[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) { + const int col = i*block_size + 2*tid; + const int ib = (row*p.ncols + col)/QUANT_K; // block index + const int iqs = (col%QUANT_K)/QUANT_R; // quant index + const int iybs = col - col%QUANT_K; // y block start index + + DEQUANT_FUNC + + // matrix multiplication + tmp[tid] += v0 * y[iybs + iqs + 0]; + tmp[tid] += v1 * y[iybs + iqs + y_offset]; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s=block_size/2; s>0; s>>=1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + if (tid == 0) { + dst[row] = float(tmp[0]); + } +} +)"; + +// F16 to F32 +const std::string f32_to_f16_src = R"( +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A { float data_a[]; }; +layout (binding = 1) writeonly buffer D { float16_t data_b[]; }; + +layout (push_constant) uniform parameter +{ + int M; + int K; + int stride_a; + int stride_b; +} p; + +void main() { + const int row = int(gl_GlobalInvocationID.x % p.K); + const int col = int(gl_GlobalInvocationID.x / p.K); + + if (row < p.K && col < p.M) { + data_b[col * p.stride_b + row] = float16_t(data_a[col * p.stride_a + row]); + } +} +)"; + +// MUL F32 +const std::string mul_f32_src = R"( +#version 450 + +layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; + +layout (binding = 0) buffer X { float data_x[]; }; +layout (binding = 1) buffer Y { float data_y[]; }; +layout (binding = 2) buffer D { float data_d[]; }; + +layout (push_constant) uniform parameter +{ + int M; + int N; + int stride_x; + int stride_y; + int stride_d; + int x_offset; + int y_offset; + int d_offset; + float scale; +} p; + +void main() { + const int x = int(gl_GlobalInvocationID.x); + const int y = int(gl_GlobalInvocationID.y); + + if (x >= p.M || y >= p.N) { + return; + } + + data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] * data_y[p.y_offset + x]; +} +)"; diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 95e1edb4f..b0ea9c175 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -27,9 +27,14 @@ #include #include #include +#include + +#include #include "ggml.h" +#include "ggml-vulkan-shaders.hpp" + #define VK_API_VERSION VK_API_VERSION_1_2 #define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) @@ -160,37 +165,39 @@ vk_pipeline vk_pipeline_f32_to_f16, vk_pipeline_dequant_q4_0; static std::vector> vk_pinned_memory; -static vk_pipeline ggml_vk_create_pipeline(const std::string& path, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector&& specialization_constants, uint32_t align) { +static std::vector ggml_vk_compile_shader(const std::string& name, const std::string& src) { #ifdef VK_DEBUG - std::cerr << "ggml_vk_create_pipeline(" << path << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")" << std::endl; + std::cerr << "ggml_vk_compile_shader(" << name << ", " << src << ")" << std::endl; +#endif + shaderc::Compiler compiler; + shaderc::CompileOptions options; + + shaderc::SpvCompilationResult module = compiler.CompileGlslToSpv(src, shaderc_compute_shader, name.c_str(), options); + + if (module.GetCompilationStatus() != shaderc_compilation_status_success) { + std::cerr << "ggml_vulkan: Shader compile error in " << name << ": " << module.GetErrorMessage(); + return std::vector(); + } + + return {module.cbegin(), module.cend()}; +} + +static vk_pipeline ggml_vk_create_pipeline(const std::string& name, size_t spv_size, const uint32_t* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector&& specialization_constants, uint32_t align) { +#ifdef VK_DEBUG + std::cerr << "ggml_vk_create_pipeline(" << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")" << std::endl; #endif GGML_ASSERT(parameter_count > 0); - GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); + GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT vk_pipeline pipeline; - pipeline.name = path; + pipeline.name = name; pipeline.parameter_count = parameter_count; pipeline.push_constant_size = push_constant_size; pipeline.wg_denoms = wg_denoms; pipeline.align = align; - std::vector matmul_shader_contents; - if (std::ifstream shader_file{ path, std::ios::binary | std::ios::ate }) { - const size_t file_size = shader_file.tellg(); - shader_file.seekg(0); - matmul_shader_contents.resize(file_size, '\0'); - shader_file.read(matmul_shader_contents.data(), file_size); - } else { - std::cerr << "ggml_vulkan: Invalid shader path " << path << std::endl; - abort(); - } - - vk::ShaderModuleCreateInfo shader_module_create_info( - {}, - matmul_shader_contents.size(), - reinterpret_cast(matmul_shader_contents.data()) - ); + vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, spv_data); vk::ShaderModule shader_module = vk_device.device.createShaderModule(shader_module_create_info); std::vector dsl_binding; @@ -279,6 +286,34 @@ static vk_pipeline ggml_vk_create_pipeline(const std::string& path, const std::s return pipeline; } +static vk_pipeline ggml_vk_create_pipeline_from_string(const std::string& name, const std::string& src, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector&& specialization_constants, uint32_t align) { +#ifdef VK_DEBUG + std::cerr << "ggml_vk_create_pipeline_from_string(" << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")" << std::endl; +#endif + + const std::vector spv = ggml_vk_compile_shader(name, src); + return ggml_vk_create_pipeline(name, spv.size() * sizeof(uint32_t), spv.data(), entrypoint, parameter_count, push_constant_size, wg_denoms, std::move(specialization_constants), align); +} + +static vk_pipeline ggml_vk_create_pipeline_from_file(const std::string& path, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, std::vector&& specialization_constants, uint32_t align) { +#ifdef VK_DEBUG + std::cerr << "ggml_vk_create_pipeline_from_file(" << path << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")" << std::endl; +#endif + + std::vector matmul_shader_contents; + if (std::ifstream shader_file{ path, std::ios::binary | std::ios::ate }) { + const size_t file_size = shader_file.tellg(); + shader_file.seekg(0); + matmul_shader_contents.resize(file_size, '\0'); + shader_file.read(matmul_shader_contents.data(), file_size); + } else { + std::cerr << "ggml_vulkan: Invalid shader path " << path << std::endl; + abort(); + } + + return ggml_vk_create_pipeline(path, matmul_shader_contents.size(), reinterpret_cast(matmul_shader_contents.data()), entrypoint, parameter_count, push_constant_size, wg_denoms, std::move(specialization_constants), align); +} + static void ggml_vk_pipeline_allocate_descriptor_sets(vk_pipeline& pipeline, uint32_t n) { #ifdef VK_DEBUG std::cerr << "ggml_vk_pipeline_allocate_descriptor_sets(" << pipeline.name << ", " << n << ")" << std::endl; @@ -599,6 +634,42 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf) { } } +static void ggml_vk_generate_shaders() { +#ifdef VK_DEBUG + std::cerr << "ggml_vk_generate_shaders()" << std::endl; +#endif + std::cerr << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl; + + + // Build dequant q4_0 + std::stringstream stream; + stream << dequant_head; + if (vk_device.fp16) { + stream << dequant_glsl_fp16_ext; + } + stream << dequant_q4_0_defines; + if (vk_device.fp16) { + stream << dequant_output_fp16; + } else { + stream << dequant_output_fp32; + } + stream << dequant_body; + + vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline_from_string("dequant_q4_0", stream.str(), "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1); + + // mul mat vec + stream.str(""); + stream.clear(); + + stream << mul_mat_vec_head << mul_mat_vec_fp16 << mul_mat_vec_q4_0_defines << mul_mat_vec_body; + + vk_pipeline_dequant_mul_mat_vec_q4_0 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0", stream.str(), "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + + // Static shaders + vk_pipeline_f32_to_f16 = ggml_vk_create_pipeline_from_string("f32_to_f16", f32_to_f16_src, "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1); + vk_pipeline_mul_f32 = ggml_vk_create_pipeline_from_string("mul_f32", mul_f32_src, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); +} + void ggml_vk_test_transfer(size_t ne); void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k, size_t num_it, int split_k, int shader_size); void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k, size_t num_it, int split_k, int shader_size); @@ -741,44 +812,40 @@ void ggml_vk_init(void) { auto warptile_s = { 32, 32, 32, 8, 32, 32, 2, 2, 2 }; // Shaders - vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); - vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); - vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); - vk_pipeline_matmul_f32_aligned_l = ggml_vk_create_pipeline("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); - vk_pipeline_matmul_f32_aligned_m = ggml_vk_create_pipeline("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); - vk_pipeline_matmul_f32_aligned_s = ggml_vk_create_pipeline("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); + vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); + vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); + vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); + vk_pipeline_matmul_f32_aligned_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); + vk_pipeline_matmul_f32_aligned_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); + vk_pipeline_matmul_f32_aligned_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); if (vk_device.fp16) { - vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); - vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); - vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); - vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); - vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); - vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); + vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); + vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); + vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); + vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); + vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); + vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); - vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); - vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); - vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); - vk_pipeline_matmul_f16_f32_aligned_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); - vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); - vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); + vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); + vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); + vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); + vk_pipeline_matmul_f16_f32_aligned_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); + vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); + vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); - vk_pipeline_dequant_mul_mat_vec_f16 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_f16.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_dequant_mul_mat_vec_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_q4_0.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + vk_pipeline_dequant_mul_mat_vec_f16 = ggml_vk_create_pipeline_from_file("vk_shaders/dequant_mul_mat_vec_f16.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); } - vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline("vk_shaders/matmul_split_k_reduce.spv", "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1); + vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_split_k_reduce.spv", "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1); - vk_pipeline_f32_to_f16 = ggml_vk_create_pipeline("vk_shaders/f32_to_f16.spv", "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1); - vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_q4_0.spv", "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1); + vk_pipeline_dequant_mul_mat_vec_f16_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/dequant_mul_mat_vec_f16_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + vk_pipeline_dequant_mul_mat_vec_q4_0_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/dequant_mul_mat_vec_q4_0_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_dequant_mul_mat_vec_f16_f32 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_f16_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); - vk_pipeline_dequant_mul_mat_vec_q4_0_f32 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_q4_0_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); + vk_pipeline_add_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/add_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); + vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline_from_file("vk_shaders/add_f16_f32_f16.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); - vk_pipeline_mul_f32 = ggml_vk_create_pipeline("vk_shaders/mul_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); + vk_pipeline_scale_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/scale_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); - vk_pipeline_add_f32 = ggml_vk_create_pipeline("vk_shaders/add_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); - vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline("vk_shaders/add_f16_f32_f16.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); - - vk_pipeline_scale_f32 = ggml_vk_create_pipeline("vk_shaders/scale_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); + ggml_vk_generate_shaders(); // Queues uint32_t queue_index_offset = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; diff --git a/vk_shaders/dequant_mul_mat_vec_q4_0.glsl b/vk_shaders/dequant_mul_mat_vec_q4_0.glsl deleted file mode 100644 index e61e793bf..000000000 --- a/vk_shaders/dequant_mul_mat_vec_q4_0.glsl +++ /dev/null @@ -1,73 +0,0 @@ -#version 450 - -#extension GL_EXT_control_flow_attributes : enable -#extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require - -#define QUANT_K 32 -#define QUANT_R 2 -#define BLOCK_SIZE 32 - -layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; - -struct block_q4_0 -{ - float16_t d; - uint8_t qs[16]; -}; - -layout (binding = 0) readonly buffer A { block_q4_0 x[]; }; -layout (binding = 1) readonly buffer B { float16_t y[]; }; -layout (binding = 2) writeonly buffer D { float dst[]; }; - -layout (push_constant) uniform parameter -{ - int ncols; -} p; - -shared float16_t tmp[BLOCK_SIZE]; - -void main() { - const int block_size = int(gl_WorkGroupSize.x); - const int row = int(gl_WorkGroupID.x); - const int tid = int(gl_LocalInvocationID.x); - - const int y_offset = QUANT_K/2; - - tmp[tid] = 0.0hf; - - [[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) { - const int col = i*block_size + 2*tid; - const int ib = (row*p.ncols + col)/QUANT_K; // block index - const int iqs = (col%QUANT_K)/QUANT_R; // quant index - const int iybs = col - col%QUANT_K; // y block start index - - // dequantize - const float16_t d = x[ib].d; - - const uint8_t vui = x[ib].qs[iqs]; - - const int8_t vi0 = int8_t(vui & 0xF); - const int8_t vi1 = int8_t(vui >> 4); - - float16_t v0 = float16_t(vi0 - 8)*d; - float16_t v1 = float16_t(vi1 - 8)*d; - - // matrix multiplication - tmp[tid] += v0 * y[iybs + iqs + 0]; - tmp[tid] += v1 * y[iybs + iqs + y_offset]; - } - - // sum up partial sums and write back result - barrier(); - [[unroll]] for (int s=block_size/2; s>0; s>>=1) { - if (tid < s) { - tmp[tid] += tmp[tid + s]; - } - barrier(); - } - if (tid == 0) { - dst[row] = float(tmp[0]); - } -} diff --git a/vk_shaders/dequant_q4_0.glsl b/vk_shaders/dequant_q4_0.glsl deleted file mode 100644 index 73a17b257..000000000 --- a/vk_shaders/dequant_q4_0.glsl +++ /dev/null @@ -1,53 +0,0 @@ -#version 450 - -#extension GL_EXT_control_flow_attributes : require -#extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require -#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require - -#define QUANT_K 32 -#define QUANT_R 2 - -layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; - -struct block_q4_0 -{ - float16_t d; - uint8_t qs[16]; -}; - -layout (binding = 0) readonly buffer A { block_q4_0 x[]; }; -layout (binding = 1) writeonly buffer D { float16_t y[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int K; - int stride_a; - int stride_b; -} p; - -void main() { - const int i = int(gl_GlobalInvocationID.x); - - // Transposed - const int row = i % (p.K / QUANT_K); - const int col = i / (p.K / QUANT_K); - - if (row * QUANT_K >= p.K || col >= p.M) { - return; - } - - const int stride_a = p.stride_a / QUANT_K; - - const block_q4_0 blk = x[col * stride_a + row]; - const float16_t d = blk.d; - - [[unroll]] for (int j = 0; j < QUANT_K/2; ++j) { - const float16_t x0 = float16_t((blk.qs[j] & 0x0F) - 8); - const float16_t x1 = float16_t((blk.qs[j] >> 4) - 8); - - y[col * p.stride_b + row*QUANT_K + j + 0 ] = x0*d; - y[col * p.stride_b + row*QUANT_K + j + QUANT_K/2] = x1*d; - } -} diff --git a/vk_shaders/f16_to_f32.glsl b/vk_shaders/f16_to_f32.glsl deleted file mode 100644 index 1c0658a2c..000000000 --- a/vk_shaders/f16_to_f32.glsl +++ /dev/null @@ -1,25 +0,0 @@ -#version 450 - -#extension GL_EXT_shader_16bit_storage : require - -layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A { float16_t data_a[]; }; -layout (binding = 1) writeonly buffer D { float data_b[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int K; - int stride_a; - int stride_b; -} p; - -void main() { - const int row = int(gl_GlobalInvocationID.x % p.K); - const int col = int(gl_GlobalInvocationID.x / p.K); - - if (row < p.M && col < p.K) { - data_b[col * p.stride_b + row] = float(data_a[col * p.stride_a + row]); - } -} diff --git a/vk_shaders/f32_to_f16.glsl b/vk_shaders/f32_to_f16.glsl deleted file mode 100644 index 8c666cb86..000000000 --- a/vk_shaders/f32_to_f16.glsl +++ /dev/null @@ -1,25 +0,0 @@ -#version 450 - -#extension GL_EXT_shader_16bit_storage : require - -layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; - -layout (binding = 0) readonly buffer A { float data_a[]; }; -layout (binding = 1) writeonly buffer D { float16_t data_b[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int K; - int stride_a; - int stride_b; -} p; - -void main() { - const int row = int(gl_GlobalInvocationID.x % p.K); - const int col = int(gl_GlobalInvocationID.x / p.K); - - if (row < p.K && col < p.M) { - data_b[col * p.stride_b + row] = float16_t(data_a[col * p.stride_a + row]); - } -} diff --git a/vk_shaders/mul_f32.glsl b/vk_shaders/mul_f32.glsl deleted file mode 100644 index 420c552a9..000000000 --- a/vk_shaders/mul_f32.glsl +++ /dev/null @@ -1,31 +0,0 @@ -#version 450 - -layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in; - -layout (binding = 0) buffer X { float data_x[]; }; -layout (binding = 1) buffer Y { float data_y[]; }; -layout (binding = 2) buffer D { float data_d[]; }; - -layout (push_constant) uniform parameter -{ - int M; - int N; - int stride_x; - int stride_y; - int stride_d; - int x_offset; - int y_offset; - int d_offset; - float scale; -} p; - -void main() { - const int x = int(gl_GlobalInvocationID.x); - const int y = int(gl_GlobalInvocationID.y); - - if (x >= p.M || y >= p.N) { - return; - } - - data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] * data_y[p.y_offset + x]; -}