Add runtime shader compilation, start transferring shaders to this approach
This commit is contained in:
parent
1132941cb3
commit
a47ca7ae7a
8 changed files with 343 additions and 262 deletions
6
Makefile
6
Makefile
|
@ -228,7 +228,7 @@ endif # LLAMA_METAL
|
|||
ifdef LLAMA_VULKAN
|
||||
CFLAGS += -DGGML_USE_VULKAN
|
||||
CXXFLAGS += -DGGML_USE_VULKAN
|
||||
LDFLAGS += -lvulkan -lopenblas
|
||||
LDFLAGS += -lvulkan -lopenblas -lglslang -lSPIRV -lSPIRV-Tools-opt -lSPIRV-Tools -lshaderc_combined
|
||||
OBJS += ggml-vulkan.o
|
||||
ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h
|
||||
$(CXX) $(CXXFLAGS) -c $< -o $@
|
||||
|
@ -240,13 +240,9 @@ ggml-vulkan.o: ggml-vulkan.cpp ggml-vulkan.h
|
|||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_f16_f32_aligned.glsl -o vk_shaders/matmul_f16_f32_aligned.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/matmul_split_k_reduce.glsl -o vk_shaders/matmul_split_k_reduce.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f16_to_f32.glsl -o vk_shaders/f16_to_f32.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/f32_to_f16.glsl -o vk_shaders/f32_to_f16.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_q4_0.glsl -o vk_shaders/dequant_q4_0.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_f16.glsl -o vk_shaders/dequant_mul_mat_vec_f16.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_q4_0.glsl -o vk_shaders/dequant_mul_mat_vec_q4_0.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_f16_f32.glsl -o vk_shaders/dequant_mul_mat_vec_f16_f32.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/dequant_mul_mat_vec_q4_0_f32.glsl -o vk_shaders/dequant_mul_mat_vec_q4_0_f32.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/mul_f32.glsl -o vk_shaders/mul_f32.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/add_f32.glsl -o vk_shaders/add_f32.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/add_f16_f32_f16.glsl -o vk_shaders/add_f16_f32_f16.spv & \
|
||||
glslc -fshader-stage=compute --target-env=vulkan1.2 vk_shaders/scale_f32.glsl -o vk_shaders/scale_f32.spv & \
|
||||
|
|
225
ggml-vulkan-shaders.hpp
Normal file
225
ggml-vulkan-shaders.hpp
Normal file
|
@ -0,0 +1,225 @@
|
|||
#include <string>
|
||||
|
||||
// DEQUANT SHADER
|
||||
const std::string dequant_head = R"(
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
)";
|
||||
|
||||
const std::string dequant_glsl_fp16_ext = R"(
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
)";
|
||||
|
||||
const std::string dequant_output_fp16 = R"(
|
||||
#define OUT_TYPE float16_t
|
||||
)";
|
||||
|
||||
const std::string dequant_output_fp32 = R"(
|
||||
#define OUT_TYPE float
|
||||
)";
|
||||
|
||||
const std::string dequant_q4_0_defines = R"(
|
||||
#define QUANT_K 32
|
||||
#define QUANT_R 2
|
||||
|
||||
struct block_q4_0
|
||||
{
|
||||
float16_t d;
|
||||
uint8_t qs[16];
|
||||
};
|
||||
|
||||
#define A_TYPE block_q4_0
|
||||
)";
|
||||
|
||||
const std::string dequant_body = R"(
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { A_TYPE x[]; };
|
||||
layout (binding = 1) writeonly buffer D { OUT_TYPE y[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int i = int(gl_GlobalInvocationID.x);
|
||||
|
||||
// Transposed
|
||||
const int row = i % (p.K / QUANT_K);
|
||||
const int col = i / (p.K / QUANT_K);
|
||||
|
||||
if (row * QUANT_K >= p.K || col >= p.M) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int stride_a = p.stride_a / QUANT_K;
|
||||
|
||||
const A_TYPE blk = x[col * stride_a + row];
|
||||
const OUT_TYPE d = blk.d;
|
||||
|
||||
[[unroll]] for (int j = 0; j < QUANT_K/2; ++j) {
|
||||
const OUT_TYPE x0 = OUT_TYPE((blk.qs[j] & 0x0F) - 8);
|
||||
const OUT_TYPE x1 = OUT_TYPE((blk.qs[j] >> 4) - 8);
|
||||
|
||||
y[col * p.stride_b + row*QUANT_K + j + 0 ] = x0*d;
|
||||
y[col * p.stride_b + row*QUANT_K + j + QUANT_K/2] = x1*d;
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
// Mul Mat Vec
|
||||
const std::string mul_mat_vec_head = R"(
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_8bit_storage : require
|
||||
)";
|
||||
|
||||
const std::string mul_mat_vec_fp16 = R"(
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
)";
|
||||
|
||||
const std::string mul_mat_vec_q4_0_defines = R"(
|
||||
#define QUANT_K 32
|
||||
#define QUANT_R 2
|
||||
#define BLOCK_SIZE 32
|
||||
|
||||
struct block_q4_0
|
||||
{
|
||||
float16_t d;
|
||||
uint8_t qs[16];
|
||||
};
|
||||
#define A_TYPE block_q4_0
|
||||
#define B_TYPE float16_t
|
||||
#define OUT_TYPE float
|
||||
|
||||
#define DEQUANT_FUNC const float16_t d = x[ib].d; \
|
||||
const uint8_t vui = x[ib].qs[iqs]; \
|
||||
const int8_t vi0 = int8_t(vui & 0xF); \
|
||||
const int8_t vi1 = int8_t(vui >> 4); \
|
||||
float16_t v0 = float16_t(vi0 - 8)*d; \
|
||||
float16_t v1 = float16_t(vi1 - 8)*d;
|
||||
)";
|
||||
|
||||
const std::string mul_mat_vec_body = R"(
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { A_TYPE x[]; };
|
||||
layout (binding = 1) readonly buffer B { B_TYPE y[]; };
|
||||
layout (binding = 2) writeonly buffer D { OUT_TYPE dst[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int ncols;
|
||||
} p;
|
||||
|
||||
shared float16_t tmp[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const int block_size = int(gl_WorkGroupSize.x);
|
||||
const int row = int(gl_WorkGroupID.x);
|
||||
const int tid = int(gl_LocalInvocationID.x);
|
||||
|
||||
const int y_offset = QUANT_K/2;
|
||||
|
||||
tmp[tid] = 0.0hf;
|
||||
|
||||
[[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) {
|
||||
const int col = i*block_size + 2*tid;
|
||||
const int ib = (row*p.ncols + col)/QUANT_K; // block index
|
||||
const int iqs = (col%QUANT_K)/QUANT_R; // quant index
|
||||
const int iybs = col - col%QUANT_K; // y block start index
|
||||
|
||||
DEQUANT_FUNC
|
||||
|
||||
// matrix multiplication
|
||||
tmp[tid] += v0 * y[iybs + iqs + 0];
|
||||
tmp[tid] += v1 * y[iybs + iqs + y_offset];
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s=block_size/2; s>0; s>>=1) {
|
||||
if (tid < s) {
|
||||
tmp[tid] += tmp[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid == 0) {
|
||||
dst[row] = float(tmp[0]);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
// F16 to F32
|
||||
const std::string f32_to_f16_src = R"(
|
||||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float data_a[]; };
|
||||
layout (binding = 1) writeonly buffer D { float16_t data_b[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int row = int(gl_GlobalInvocationID.x % p.K);
|
||||
const int col = int(gl_GlobalInvocationID.x / p.K);
|
||||
|
||||
if (row < p.K && col < p.M) {
|
||||
data_b[col * p.stride_b + row] = float16_t(data_a[col * p.stride_a + row]);
|
||||
}
|
||||
}
|
||||
)";
|
||||
|
||||
// MUL F32
|
||||
const std::string mul_f32_src = R"(
|
||||
#version 450
|
||||
|
||||
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) buffer X { float data_x[]; };
|
||||
layout (binding = 1) buffer Y { float data_y[]; };
|
||||
layout (binding = 2) buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int stride_x;
|
||||
int stride_y;
|
||||
int stride_d;
|
||||
int x_offset;
|
||||
int y_offset;
|
||||
int d_offset;
|
||||
float scale;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int x = int(gl_GlobalInvocationID.x);
|
||||
const int y = int(gl_GlobalInvocationID.y);
|
||||
|
||||
if (x >= p.M || y >= p.N) {
|
||||
return;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] * data_y[p.y_offset + x];
|
||||
}
|
||||
)";
|
167
ggml-vulkan.cpp
167
ggml-vulkan.cpp
|
@ -27,9 +27,14 @@
|
|||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <mutex>
|
||||
#include <sstream>
|
||||
|
||||
#include <shaderc/shaderc.hpp>
|
||||
|
||||
#include "ggml.h"
|
||||
|
||||
#include "ggml-vulkan-shaders.hpp"
|
||||
|
||||
#define VK_API_VERSION VK_API_VERSION_1_2
|
||||
|
||||
#define CEIL_DIV(M, N) (((M) + (N)-1) / (N))
|
||||
|
@ -160,37 +165,39 @@ vk_pipeline vk_pipeline_f32_to_f16, vk_pipeline_dequant_q4_0;
|
|||
|
||||
static std::vector<std::tuple<void*, size_t, vk_buffer>> vk_pinned_memory;
|
||||
|
||||
static vk_pipeline ggml_vk_create_pipeline(const std::string& path, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<int>&& specialization_constants, uint32_t align) {
|
||||
static std::vector<uint32_t> ggml_vk_compile_shader(const std::string& name, const std::string& src) {
|
||||
#ifdef VK_DEBUG
|
||||
std::cerr << "ggml_vk_create_pipeline(" << path << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")" << std::endl;
|
||||
std::cerr << "ggml_vk_compile_shader(" << name << ", " << src << ")" << std::endl;
|
||||
#endif
|
||||
shaderc::Compiler compiler;
|
||||
shaderc::CompileOptions options;
|
||||
|
||||
shaderc::SpvCompilationResult module = compiler.CompileGlslToSpv(src, shaderc_compute_shader, name.c_str(), options);
|
||||
|
||||
if (module.GetCompilationStatus() != shaderc_compilation_status_success) {
|
||||
std::cerr << "ggml_vulkan: Shader compile error in " << name << ": " << module.GetErrorMessage();
|
||||
return std::vector<uint32_t>();
|
||||
}
|
||||
|
||||
return {module.cbegin(), module.cend()};
|
||||
}
|
||||
|
||||
static vk_pipeline ggml_vk_create_pipeline(const std::string& name, size_t spv_size, const uint32_t* spv_data, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<int>&& specialization_constants, uint32_t align) {
|
||||
#ifdef VK_DEBUG
|
||||
std::cerr << "ggml_vk_create_pipeline(" << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")" << std::endl;
|
||||
#endif
|
||||
GGML_ASSERT(parameter_count > 0);
|
||||
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0);
|
||||
GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT
|
||||
|
||||
vk_pipeline pipeline;
|
||||
|
||||
pipeline.name = path;
|
||||
pipeline.name = name;
|
||||
pipeline.parameter_count = parameter_count;
|
||||
pipeline.push_constant_size = push_constant_size;
|
||||
pipeline.wg_denoms = wg_denoms;
|
||||
pipeline.align = align;
|
||||
|
||||
std::vector<char> matmul_shader_contents;
|
||||
if (std::ifstream shader_file{ path, std::ios::binary | std::ios::ate }) {
|
||||
const size_t file_size = shader_file.tellg();
|
||||
shader_file.seekg(0);
|
||||
matmul_shader_contents.resize(file_size, '\0');
|
||||
shader_file.read(matmul_shader_contents.data(), file_size);
|
||||
} else {
|
||||
std::cerr << "ggml_vulkan: Invalid shader path " << path << std::endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
vk::ShaderModuleCreateInfo shader_module_create_info(
|
||||
{},
|
||||
matmul_shader_contents.size(),
|
||||
reinterpret_cast<const uint32_t*>(matmul_shader_contents.data())
|
||||
);
|
||||
vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, spv_data);
|
||||
vk::ShaderModule shader_module = vk_device.device.createShaderModule(shader_module_create_info);
|
||||
|
||||
std::vector<vk::DescriptorSetLayoutBinding> dsl_binding;
|
||||
|
@ -279,6 +286,34 @@ static vk_pipeline ggml_vk_create_pipeline(const std::string& path, const std::s
|
|||
return pipeline;
|
||||
}
|
||||
|
||||
static vk_pipeline ggml_vk_create_pipeline_from_string(const std::string& name, const std::string& src, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<int>&& specialization_constants, uint32_t align) {
|
||||
#ifdef VK_DEBUG
|
||||
std::cerr << "ggml_vk_create_pipeline_from_string(" << name << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")" << std::endl;
|
||||
#endif
|
||||
|
||||
const std::vector<uint32_t> spv = ggml_vk_compile_shader(name, src);
|
||||
return ggml_vk_create_pipeline(name, spv.size() * sizeof(uint32_t), spv.data(), entrypoint, parameter_count, push_constant_size, wg_denoms, std::move(specialization_constants), align);
|
||||
}
|
||||
|
||||
static vk_pipeline ggml_vk_create_pipeline_from_file(const std::string& path, const std::string& entrypoint, uint32_t parameter_count, uint32_t push_constant_size, std::array<uint32_t, 3> wg_denoms, std::vector<int>&& specialization_constants, uint32_t align) {
|
||||
#ifdef VK_DEBUG
|
||||
std::cerr << "ggml_vk_create_pipeline_from_file(" << path << ", " << entrypoint << ", " << parameter_count << ", " << push_constant_size << ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << align << ")" << std::endl;
|
||||
#endif
|
||||
|
||||
std::vector<char> matmul_shader_contents;
|
||||
if (std::ifstream shader_file{ path, std::ios::binary | std::ios::ate }) {
|
||||
const size_t file_size = shader_file.tellg();
|
||||
shader_file.seekg(0);
|
||||
matmul_shader_contents.resize(file_size, '\0');
|
||||
shader_file.read(matmul_shader_contents.data(), file_size);
|
||||
} else {
|
||||
std::cerr << "ggml_vulkan: Invalid shader path " << path << std::endl;
|
||||
abort();
|
||||
}
|
||||
|
||||
return ggml_vk_create_pipeline(path, matmul_shader_contents.size(), reinterpret_cast<uint32_t *>(matmul_shader_contents.data()), entrypoint, parameter_count, push_constant_size, wg_denoms, std::move(specialization_constants), align);
|
||||
}
|
||||
|
||||
static void ggml_vk_pipeline_allocate_descriptor_sets(vk_pipeline& pipeline, uint32_t n) {
|
||||
#ifdef VK_DEBUG
|
||||
std::cerr << "ggml_vk_pipeline_allocate_descriptor_sets(" << pipeline.name << ", " << n << ")" << std::endl;
|
||||
|
@ -599,6 +634,42 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf) {
|
|||
}
|
||||
}
|
||||
|
||||
static void ggml_vk_generate_shaders() {
|
||||
#ifdef VK_DEBUG
|
||||
std::cerr << "ggml_vk_generate_shaders()" << std::endl;
|
||||
#endif
|
||||
std::cerr << "ggml_vulkan: Generating and compiling shaders to SPIR-V" << std::endl;
|
||||
|
||||
|
||||
// Build dequant q4_0
|
||||
std::stringstream stream;
|
||||
stream << dequant_head;
|
||||
if (vk_device.fp16) {
|
||||
stream << dequant_glsl_fp16_ext;
|
||||
}
|
||||
stream << dequant_q4_0_defines;
|
||||
if (vk_device.fp16) {
|
||||
stream << dequant_output_fp16;
|
||||
} else {
|
||||
stream << dequant_output_fp32;
|
||||
}
|
||||
stream << dequant_body;
|
||||
|
||||
vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline_from_string("dequant_q4_0", stream.str(), "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1);
|
||||
|
||||
// mul mat vec
|
||||
stream.str("");
|
||||
stream.clear();
|
||||
|
||||
stream << mul_mat_vec_head << mul_mat_vec_fp16 << mul_mat_vec_q4_0_defines << mul_mat_vec_body;
|
||||
|
||||
vk_pipeline_dequant_mul_mat_vec_q4_0 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0", stream.str(), "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
|
||||
// Static shaders
|
||||
vk_pipeline_f32_to_f16 = ggml_vk_create_pipeline_from_string("f32_to_f16", f32_to_f16_src, "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1);
|
||||
vk_pipeline_mul_f32 = ggml_vk_create_pipeline_from_string("mul_f32", mul_f32_src, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
}
|
||||
|
||||
void ggml_vk_test_transfer(size_t ne);
|
||||
void ggml_vk_test_matmul_f32(size_t m, size_t n, size_t k, size_t num_it, int split_k, int shader_size);
|
||||
void ggml_vk_test_matmul_f16(size_t m, size_t n, size_t k, size_t num_it, int split_k, int shader_size);
|
||||
|
@ -741,44 +812,40 @@ void ggml_vk_init(void) {
|
|||
auto warptile_s = { 32, 32, 32, 8, 32, 32, 2, 2, 2 };
|
||||
|
||||
// Shaders
|
||||
vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
vk_pipeline_matmul_f32_aligned_l = ggml_vk_create_pipeline("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f32_aligned_m = ggml_vk_create_pipeline("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f32_aligned_s = ggml_vk_create_pipeline("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
vk_pipeline_matmul_f32_aligned_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f32_aligned_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f32_aligned_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
if (vk_device.fp16) {
|
||||
vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
|
||||
vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
vk_pipeline_matmul_f16_f32_aligned_l = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
vk_pipeline_matmul_f16_f32_aligned_l = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
|
||||
vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
|
||||
vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_f16_f32_aligned.spv", "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
|
||||
|
||||
vk_pipeline_dequant_mul_mat_vec_f16 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_f16.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
vk_pipeline_dequant_mul_mat_vec_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_q4_0.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
vk_pipeline_dequant_mul_mat_vec_f16 = ggml_vk_create_pipeline_from_file("vk_shaders/dequant_mul_mat_vec_f16.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
}
|
||||
vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline("vk_shaders/matmul_split_k_reduce.spv", "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1);
|
||||
vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline_from_file("vk_shaders/matmul_split_k_reduce.spv", "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1);
|
||||
|
||||
vk_pipeline_f32_to_f16 = ggml_vk_create_pipeline("vk_shaders/f32_to_f16.spv", "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1);
|
||||
vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline("vk_shaders/dequant_q4_0.spv", "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1);
|
||||
vk_pipeline_dequant_mul_mat_vec_f16_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/dequant_mul_mat_vec_f16_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
vk_pipeline_dequant_mul_mat_vec_q4_0_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/dequant_mul_mat_vec_q4_0_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
|
||||
vk_pipeline_dequant_mul_mat_vec_f16_f32 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_f16_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
vk_pipeline_dequant_mul_mat_vec_q4_0_f32 = ggml_vk_create_pipeline("vk_shaders/dequant_mul_mat_vec_q4_0_f32.spv", "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
|
||||
vk_pipeline_add_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/add_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline_from_file("vk_shaders/add_f16_f32_f16.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
|
||||
vk_pipeline_mul_f32 = ggml_vk_create_pipeline("vk_shaders/mul_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
vk_pipeline_scale_f32 = ggml_vk_create_pipeline_from_file("vk_shaders/scale_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
|
||||
vk_pipeline_add_f32 = ggml_vk_create_pipeline("vk_shaders/add_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline("vk_shaders/add_f16_f32_f16.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
|
||||
vk_pipeline_scale_f32 = ggml_vk_create_pipeline("vk_shaders/scale_f32.spv", "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
|
||||
ggml_vk_generate_shaders();
|
||||
|
||||
// Queues
|
||||
uint32_t queue_index_offset = compute_queue_family_index == transfer_queue_family_index ? 1 : 0;
|
||||
|
|
|
@ -1,73 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : enable
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
|
||||
#define QUANT_K 32
|
||||
#define QUANT_R 2
|
||||
#define BLOCK_SIZE 32
|
||||
|
||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
struct block_q4_0
|
||||
{
|
||||
float16_t d;
|
||||
uint8_t qs[16];
|
||||
};
|
||||
|
||||
layout (binding = 0) readonly buffer A { block_q4_0 x[]; };
|
||||
layout (binding = 1) readonly buffer B { float16_t y[]; };
|
||||
layout (binding = 2) writeonly buffer D { float dst[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int ncols;
|
||||
} p;
|
||||
|
||||
shared float16_t tmp[BLOCK_SIZE];
|
||||
|
||||
void main() {
|
||||
const int block_size = int(gl_WorkGroupSize.x);
|
||||
const int row = int(gl_WorkGroupID.x);
|
||||
const int tid = int(gl_LocalInvocationID.x);
|
||||
|
||||
const int y_offset = QUANT_K/2;
|
||||
|
||||
tmp[tid] = 0.0hf;
|
||||
|
||||
[[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) {
|
||||
const int col = i*block_size + 2*tid;
|
||||
const int ib = (row*p.ncols + col)/QUANT_K; // block index
|
||||
const int iqs = (col%QUANT_K)/QUANT_R; // quant index
|
||||
const int iybs = col - col%QUANT_K; // y block start index
|
||||
|
||||
// dequantize
|
||||
const float16_t d = x[ib].d;
|
||||
|
||||
const uint8_t vui = x[ib].qs[iqs];
|
||||
|
||||
const int8_t vi0 = int8_t(vui & 0xF);
|
||||
const int8_t vi1 = int8_t(vui >> 4);
|
||||
|
||||
float16_t v0 = float16_t(vi0 - 8)*d;
|
||||
float16_t v1 = float16_t(vi1 - 8)*d;
|
||||
|
||||
// matrix multiplication
|
||||
tmp[tid] += v0 * y[iybs + iqs + 0];
|
||||
tmp[tid] += v1 * y[iybs + iqs + y_offset];
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
barrier();
|
||||
[[unroll]] for (int s=block_size/2; s>0; s>>=1) {
|
||||
if (tid < s) {
|
||||
tmp[tid] += tmp[tid + s];
|
||||
}
|
||||
barrier();
|
||||
}
|
||||
if (tid == 0) {
|
||||
dst[row] = float(tmp[0]);
|
||||
}
|
||||
}
|
|
@ -1,53 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_control_flow_attributes : require
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||
|
||||
#define QUANT_K 32
|
||||
#define QUANT_R 2
|
||||
|
||||
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
struct block_q4_0
|
||||
{
|
||||
float16_t d;
|
||||
uint8_t qs[16];
|
||||
};
|
||||
|
||||
layout (binding = 0) readonly buffer A { block_q4_0 x[]; };
|
||||
layout (binding = 1) writeonly buffer D { float16_t y[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int i = int(gl_GlobalInvocationID.x);
|
||||
|
||||
// Transposed
|
||||
const int row = i % (p.K / QUANT_K);
|
||||
const int col = i / (p.K / QUANT_K);
|
||||
|
||||
if (row * QUANT_K >= p.K || col >= p.M) {
|
||||
return;
|
||||
}
|
||||
|
||||
const int stride_a = p.stride_a / QUANT_K;
|
||||
|
||||
const block_q4_0 blk = x[col * stride_a + row];
|
||||
const float16_t d = blk.d;
|
||||
|
||||
[[unroll]] for (int j = 0; j < QUANT_K/2; ++j) {
|
||||
const float16_t x0 = float16_t((blk.qs[j] & 0x0F) - 8);
|
||||
const float16_t x1 = float16_t((blk.qs[j] >> 4) - 8);
|
||||
|
||||
y[col * p.stride_b + row*QUANT_K + j + 0 ] = x0*d;
|
||||
y[col * p.stride_b + row*QUANT_K + j + QUANT_K/2] = x1*d;
|
||||
}
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float16_t data_a[]; };
|
||||
layout (binding = 1) writeonly buffer D { float data_b[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int row = int(gl_GlobalInvocationID.x % p.K);
|
||||
const int col = int(gl_GlobalInvocationID.x / p.K);
|
||||
|
||||
if (row < p.M && col < p.K) {
|
||||
data_b[col * p.stride_b + row] = float(data_a[col * p.stride_a + row]);
|
||||
}
|
||||
}
|
|
@ -1,25 +0,0 @@
|
|||
#version 450
|
||||
|
||||
#extension GL_EXT_shader_16bit_storage : require
|
||||
|
||||
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) readonly buffer A { float data_a[]; };
|
||||
layout (binding = 1) writeonly buffer D { float16_t data_b[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int K;
|
||||
int stride_a;
|
||||
int stride_b;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int row = int(gl_GlobalInvocationID.x % p.K);
|
||||
const int col = int(gl_GlobalInvocationID.x / p.K);
|
||||
|
||||
if (row < p.K && col < p.M) {
|
||||
data_b[col * p.stride_b + row] = float16_t(data_a[col * p.stride_a + row]);
|
||||
}
|
||||
}
|
|
@ -1,31 +0,0 @@
|
|||
#version 450
|
||||
|
||||
layout(local_size_x = 32, local_size_y = 32, local_size_z = 1) in;
|
||||
|
||||
layout (binding = 0) buffer X { float data_x[]; };
|
||||
layout (binding = 1) buffer Y { float data_y[]; };
|
||||
layout (binding = 2) buffer D { float data_d[]; };
|
||||
|
||||
layout (push_constant) uniform parameter
|
||||
{
|
||||
int M;
|
||||
int N;
|
||||
int stride_x;
|
||||
int stride_y;
|
||||
int stride_d;
|
||||
int x_offset;
|
||||
int y_offset;
|
||||
int d_offset;
|
||||
float scale;
|
||||
} p;
|
||||
|
||||
void main() {
|
||||
const int x = int(gl_GlobalInvocationID.x);
|
||||
const int y = int(gl_GlobalInvocationID.y);
|
||||
|
||||
if (x >= p.M || y >= p.N) {
|
||||
return;
|
||||
}
|
||||
|
||||
data_d[p.d_offset + y * p.stride_d + x] = data_x[p.x_offset + y * p.stride_x + x] * data_y[p.y_offset + x];
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue