Add support for q4_1, q5_0, q5_1 and q8_0

This commit is contained in:
0cc4m 2023-08-15 15:38:57 +02:00
parent e9be24f9ad
commit 7e88677af4
2 changed files with 255 additions and 142 deletions

View file

@ -12,11 +12,166 @@ const std::string shader_int8_ext = R"(
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
)"; )";
// Type-specific defines
const std::string shader_f16_defines = R"(
#define QUANT_K 32
#define QUANT_R 2
#define A_TYPE float16_t
)";
const std::string shader_q4_0_defines = R"(
#define QUANT_K 32
#define QUANT_R 2
struct block_q4_0
{
float16_t d;
uint8_t qs[16];
};
#define A_TYPE block_q4_0
)";
const std::string shader_q4_1_defines = R"(
#define QUANT_K 32
#define QUANT_R 2
struct block_q4_1
{
float16_t d;
float16_t m;
uint8_t qs[16];
};
#define A_TYPE block_q4_1
)";
const std::string shader_q5_0_defines = R"(
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#define QUANT_K 32
#define QUANT_R 2
struct block_q5_0
{
float16_t d;
uint16_t qh[2];
uint8_t qs[16];
};
#define A_TYPE block_q5_0
)";
const std::string shader_q5_1_defines = R"(
#define QUANT_K 32
#define QUANT_R 2
struct block_q5_1
{
float16_t d;
float16_t m;
uint qh;
uint8_t qs[16];
};
#define A_TYPE block_q5_1
)";
const std::string shader_q8_0_defines = R"(
#define QUANT_K 32
#define QUANT_R 1
struct block_q8_0
{
float16_t d;
int8_t qs[32];
};
#define A_TYPE block_q8_0
)";
// Dequant functions
const std::string shader_f16_dequant_func = R"(
#define DEQUANT_FUNC f16vec2 v = f16vec2(x[ib + 0], x[ib + 1]);
)";
const std::string shader_f16_dequant_func_compat = R"(
#define DEQUANT_FUNC vec2 v = vec2(x[ib + 0], x[ib + 1]);
)";
const std::string shader_q4_0_dequant_func = R"(
#define DEQUANT_FUNC const float16_t d = x[ib].d; \
const uint8_t vui = x[ib].qs[iqs]; \
f16vec2 v = f16vec2(vui & 0xF, vui >> 4); \
v = (v - 8.0hf)*d;
)";
const std::string shader_q4_0_dequant_func_compat = R"(
#define DEQUANT_FUNC const float d = float(x[ib].d); \
const uint vui = uint(x[ib].qs[iqs]); \
vec2 v = vec2(vui & 0xF, vui >> 4); \
v = (v - 8.0f)*d;
)";
const std::string shader_q4_1_dequant_func = R"(
#define DEQUANT_FUNC const float16_t d = x[ib].d; \
const float16_t m = x[ib].m; \
const uint8_t vui = x[ib].qs[iqs]; \
f16vec2 v = f16vec2(vui & 0xF, vui >> 4); \
v = v*d + m;
)";
const std::string shader_q4_1_dequant_func_compat = R"(
#define DEQUANT_FUNC const float d = float(x[ib].d); \
const float m = float(x[ib].m); \
const uint vui = uint(x[ib].qs[iqs]); \
vec2 v = vec2(vui & 0xF, vui >> 4); \
v = v*d + m;
)";
const std::string shader_q5_0_dequant_func = R"(
#define DEQUANT_FUNC const float16_t d = x[ib].d; \
const uint uint_qh = uint(x[ib].qh[1]) << 16 | x[ib].qh[0]; \
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \
const uint8_t vui = x[ib].qs[iqs]; \
f16vec2 v = f16vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
v = (v - 16.0hf) * d;
)";
const std::string shader_q5_0_dequant_func_compat = R"(
#define DEQUANT_FUNC const float d = float(x[ib].d); \
const uint uint_qh = uint(x[ib].qh[1]) << 16 | x[ib].qh[0]; \
const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); \
const uint vui = uint(x[ib].qs[iqs]); \
vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
v = (v - 16.0f) * d;
)";
const std::string shader_q5_1_dequant_func = R"(
#define DEQUANT_FUNC const float16_t d = x[ib].d; \
const float16_t m = x[ib].m; \
const ivec2 qh = ivec2(((x[ib].qh >> iqs) << 4) & 0x10, (x[ib].qh >> (iqs + 12)) & 0x10); \
const uint8_t vui = x[ib].qs[iqs]; \
f16vec2 v = f16vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
v = v*d + m;
)";
const std::string shader_q5_1_dequant_func_compat = R"(
#define DEQUANT_FUNC const float d = float(x[ib].d); \
const float m = float(x[ib].m); \
const ivec2 qh = ivec2(((x[ib].qh >> iqs) << 4) & 0x10, (x[ib].qh >> (iqs + 12)) & 0x10); \
const uint vui = uint(x[ib].qs[iqs]); \
vec2 v = vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); \
v = v*d + m;
)";
const std::string shader_q8_0_dequant_func = R"(
#define DEQUANT_FUNC const float16_t d = x[ib].d; \
f16vec2 v = f16vec2(x[ib].qs[iqs], x[ib].qs[iqs + 1]); \
v = v * d;
)";
const std::string shader_q8_0_dequant_func_compat = R"(
#define DEQUANT_FUNC const float d = float(x[ib].d); \
vec2 v = vec2(int(x[ib].qs[iqs]), int(x[ib].qs[iqs + 1])); \
v = v * d;
)";
// MULMAT // MULMAT
const std::string mulmat_head = R"( const std::string mulmat_head = R"(
#version 450 #version 450
#extension GL_EXT_scalar_block_layout : require
#extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
@ -30,7 +185,7 @@ const std::string mulmat_head = R"(
const std::string mulmat_body = R"( const std::string mulmat_body = R"(
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A { A_TYPE data_a[]; }; layout (binding = 0, scalar) readonly buffer A { A_TYPE data_a[]; };
layout (binding = 1) readonly buffer B { B_TYPE data_b[]; }; layout (binding = 1) readonly buffer B { B_TYPE data_b[]; };
layout (binding = 2) writeonly buffer D { D_TYPE data_d[]; }; layout (binding = 2) writeonly buffer D { D_TYPE data_d[]; };
@ -238,28 +393,16 @@ void main() {
const std::string dequant_head = R"( const std::string dequant_head = R"(
#version 450 #version 450
#extension GL_EXT_scalar_block_layout : require
#extension GL_EXT_control_flow_attributes : require #extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
)"; )";
const std::string dequant_q4_0_defines = R"(
#define QUANT_K 32
#define QUANT_R 2
struct block_q4_0
{
float16_t d;
uint8_t qs[16];
};
#define A_TYPE block_q4_0
)";
const std::string dequant_body = R"( const std::string dequant_body = R"(
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A { A_TYPE x[]; }; layout (binding = 0, scalar) readonly buffer A { A_TYPE x[]; };
layout (binding = 1) writeonly buffer D { D_TYPE y[]; }; layout (binding = 1) writeonly buffer D { D_TYPE y[]; };
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
@ -283,15 +426,16 @@ void main() {
const int stride_a = p.stride_a / QUANT_K; const int stride_a = p.stride_a / QUANT_K;
const int idx = col * stride_a + row; const int ib = col * stride_a + row;
const FLOAT_TYPE d = FLOAT_TYPE(x[idx].d);
[[unroll]] for (int j = 0; j < QUANT_K/2; ++j) { const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
const FLOAT_TYPE x0 = FLOAT_TYPE((x[idx].qs[j] & 0x0F) - 8); const int step = QUANT_R == 1 ? 2 : 1;
const FLOAT_TYPE x1 = FLOAT_TYPE((x[idx].qs[j] >> 4) - 8);
y[col * p.stride_b + row*QUANT_K + j + 0 ] = D_TYPE(x0*d); [[unroll]] for (int iqs = 0; iqs < QUANT_K/QUANT_R; iqs += step) {
y[col * p.stride_b + row*QUANT_K + j + QUANT_K/2] = D_TYPE(x1*d); DEQUANT_FUNC
y[col * p.stride_b + row*QUANT_K + iqs + 0 ] = D_TYPE(v.x);
y[col * p.stride_b + row*QUANT_K + iqs + y_offset] = D_TYPE(v.y);
} }
} }
)"; )";
@ -300,64 +444,16 @@ void main() {
const std::string mul_mat_vec_head = R"( const std::string mul_mat_vec_head = R"(
#version 450 #version 450
#extension GL_EXT_scalar_block_layout : require
#extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require #extension GL_EXT_shader_8bit_storage : require
)"; )";
const std::string mul_mat_vec_f16_defines = R"(
#define QUANT_K 32
#define QUANT_R 2
#define BLOCK_SIZE 32
#define A_TYPE float16_t
)";
const std::string mul_mat_vec_f16_dequant_func = R"(
#define DEQUANT_FUNC float16_t v0 = x[ib + 0]; \
float16_t v1 = x[ib + 1];
)";
const std::string mul_mat_vec_f16_dequant_func_compat = R"(
#define DEQUANT_FUNC float v0 = float(x[ib + 0]); \
float v1 = float(x[ib + 1]);
)";
const std::string mul_mat_vec_q4_0_defines = R"(
#define QUANT_K 32
#define QUANT_R 2
#define BLOCK_SIZE 32
struct block_q4_0
{
float16_t d;
uint8_t qs[16];
};
#define A_TYPE block_q4_0
)";
const std::string mul_mat_vec_q4_0_dequant_func = R"(
#define DEQUANT_FUNC const float16_t d = x[ib].d; \
const uint8_t vui = x[ib].qs[iqs]; \
const int8_t vi0 = int8_t(vui & 0xF); \
const int8_t vi1 = int8_t(vui >> 4); \
float16_t v0 = float16_t(vi0 - 8)*d; \
float16_t v1 = float16_t(vi1 - 8)*d;
)";
const std::string mul_mat_vec_q4_0_dequant_func_compat = R"(
#define DEQUANT_FUNC const float d = float(x[ib].d); \
const uint vui = uint(x[ib].qs[iqs]); \
const int vi0 = int(vui) & 0xF; \
const int vi1 = int(vui) >> 4; \
float v0 = float(vi0 - 8)*d; \
float v1 = float(vi1 - 8)*d;
)";
const std::string mul_mat_vec_body = R"( const std::string mul_mat_vec_body = R"(
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = QUANT_K, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A { A_TYPE x[]; }; layout (binding = 0, scalar) readonly buffer A { A_TYPE x[]; };
layout (binding = 1) readonly buffer B { B_TYPE y[]; }; layout (binding = 1) readonly buffer B { B_TYPE y[]; };
layout (binding = 2) writeonly buffer D { D_TYPE dst[]; }; layout (binding = 2) writeonly buffer D { D_TYPE dst[]; };
@ -366,14 +462,14 @@ layout (push_constant) uniform parameter
int ncols; int ncols;
} p; } p;
shared FLOAT_TYPE tmp[BLOCK_SIZE]; shared FLOAT_TYPE tmp[QUANT_K];
void main() { void main() {
const int block_size = int(gl_WorkGroupSize.x); const int block_size = int(gl_WorkGroupSize.x);
const int row = int(gl_WorkGroupID.x); const int row = int(gl_WorkGroupID.x);
const int tid = int(gl_LocalInvocationID.x); const int tid = int(gl_LocalInvocationID.x);
const int y_offset = QUANT_K/2; const int y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
tmp[tid] = FLOAT_TYPE(0.0f); tmp[tid] = FLOAT_TYPE(0.0f);
@ -386,8 +482,8 @@ void main() {
DEQUANT_FUNC DEQUANT_FUNC
// matrix multiplication // matrix multiplication
tmp[tid] += FLOAT_TYPE(v0) * FLOAT_TYPE(y[iybs + iqs + 0]); tmp[tid] += FLOAT_TYPE(v.x) * FLOAT_TYPE(y[iybs + iqs + 0]);
tmp[tid] += FLOAT_TYPE(v1) * FLOAT_TYPE(y[iybs + iqs + y_offset]); tmp[tid] += FLOAT_TYPE(v.y) * FLOAT_TYPE(y[iybs + iqs + y_offset]);
} }
// sum up partial sums and write back result // sum up partial sums and write back result

View file

@ -52,6 +52,8 @@
#define VK_SUBMIT_BATCH 3 #define VK_SUBMIT_BATCH 3
#define VK_NUM_TYPES 16
typedef void (*ggml_vk_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); typedef void (*ggml_vk_func_t)(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst);
struct vk_buffer { struct vk_buffer {
@ -157,12 +159,12 @@ vk_pipeline vk_pipeline_matmul_f16_aligned_l, vk_pipeline_matmul_f16_aligned_m,
vk_pipeline vk_pipeline_matmul_f16_f32_l, vk_pipeline_matmul_f16_f32_m, vk_pipeline_matmul_f16_f32_s; vk_pipeline vk_pipeline_matmul_f16_f32_l, vk_pipeline_matmul_f16_f32_m, vk_pipeline_matmul_f16_f32_s;
vk_pipeline vk_pipeline_matmul_f16_f32_aligned_l, vk_pipeline_matmul_f16_f32_aligned_m, vk_pipeline_matmul_f16_f32_aligned_s; vk_pipeline vk_pipeline_matmul_f16_f32_aligned_l, vk_pipeline_matmul_f16_f32_aligned_m, vk_pipeline_matmul_f16_f32_aligned_s;
vk_pipeline vk_pipeline_matmul_split_k_reduce; vk_pipeline vk_pipeline_matmul_split_k_reduce;
vk_pipeline vk_pipeline_dequant_mul_mat_vec_f16, vk_pipeline_dequant_mul_mat_vec_q4_0; vk_pipeline vk_pipeline_dequant[VK_NUM_TYPES];
vk_pipeline vk_pipeline_dequant_mul_mat_vec_f16_f32, vk_pipeline_dequant_mul_mat_vec_q4_0_f32; vk_pipeline vk_pipeline_dequant_mul_mat_vec[VK_NUM_TYPES];
vk_pipeline vk_pipeline_dequant_mul_mat_vec_f32[VK_NUM_TYPES];
vk_pipeline vk_pipeline_mul_f32; vk_pipeline vk_pipeline_mul_f32;
vk_pipeline vk_pipeline_add_f32, vk_pipeline_add_f16_f32_f16; vk_pipeline vk_pipeline_add_f32, vk_pipeline_add_f16_f32_f16;
vk_pipeline vk_pipeline_scale_f32; vk_pipeline vk_pipeline_scale_f32;
vk_pipeline vk_pipeline_f32_to_f16, vk_pipeline_dequant_q4_0;
static std::vector<std::tuple<void*, size_t, vk_buffer>> vk_pinned_memory; static std::vector<std::tuple<void*, size_t, vk_buffer>> vk_pinned_memory;
@ -651,6 +653,31 @@ static void ggml_vk_destroy_buffer(vk_buffer& buf) {
} }
} }
static inline bool ggml_vk_build_shader_type_defines(std::stringstream& stream, ggml_type type, bool compat) {
switch(type) {
case GGML_TYPE_F16:
stream << shader_f16_defines << (compat ? shader_f16_dequant_func_compat : shader_f16_dequant_func);
return true;
case GGML_TYPE_Q4_0:
stream << shader_q4_0_defines << (compat ? shader_q4_0_dequant_func_compat : shader_q4_0_dequant_func);
return true;
case GGML_TYPE_Q4_1:
stream << shader_q4_1_defines << (compat ? shader_q4_1_dequant_func_compat : shader_q4_1_dequant_func);
return true;
case GGML_TYPE_Q5_0:
stream << shader_q5_0_defines << (compat ? shader_q5_0_dequant_func_compat : shader_q5_0_dequant_func);
return true;
case GGML_TYPE_Q5_1:
stream << shader_q5_1_defines << (compat ? shader_q5_1_dequant_func_compat : shader_q5_1_dequant_func);
return true;
case GGML_TYPE_Q8_0:
stream << shader_q8_0_defines << (compat ? shader_q8_0_dequant_func_compat : shader_q8_0_dequant_func);
return true;
default:
return false;
}
}
static void ggml_vk_generate_shaders() { static void ggml_vk_generate_shaders() {
#ifdef VK_DEBUG #ifdef VK_DEBUG
std::cerr << "ggml_vk_generate_shaders()" << std::endl; std::cerr << "ggml_vk_generate_shaders()" << std::endl;
@ -705,65 +732,46 @@ static void ggml_vk_generate_shaders() {
vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_m", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type_f16, "B_TYPE", vec_type, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_m", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type_f16, "B_TYPE", vec_type, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_s", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type_f16, "B_TYPE", vec_type, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_s", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type_f16, "B_TYPE", vec_type, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
// Build dequant q4_0 // Build dequant shaders
stream.str(""); vk_pipeline_dequant[GGML_TYPE_F32] = ggml_vk_create_pipeline_from_string("f32_to_f16", f32_to_f16_src, {}, "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1);
stream.clear();
stream << dequant_head << shader_float_type << dequant_q4_0_defines << dequant_body; for (int i = 0; i < VK_NUM_TYPES; i++) {
stream.str("");
stream.clear();
vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline_from_string("dequant_q4_0", stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1); stream << dequant_head << shader_float_type;
if (vk_device.fp16) {
stream << shader_int8_ext;
}
if (!ggml_vk_build_shader_type_defines(stream, (ggml_type)i, !vk_device.fp16)) {
continue;
}
stream << dequant_body;
vk_pipeline_dequant[i] = ggml_vk_create_pipeline_from_string("dequant_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1);
}
// mul mat vec // mul mat vec
stream.str(""); for (int i = 0; i < VK_NUM_TYPES; i++) {
stream.clear(); stream.str("");
stream.clear();
stream << mul_mat_vec_head << shader_float_type; stream << mul_mat_vec_head << shader_float_type;
if (vk_device.fp16) { if (vk_device.fp16) {
stream << shader_int8_ext << mul_mat_vec_q4_0_dequant_func; stream << shader_int8_ext;
} else { }
stream << mul_mat_vec_q4_0_dequant_func_compat;
if (!ggml_vk_build_shader_type_defines(stream, (ggml_type)i, !vk_device.fp16)) {
continue;
}
stream << mul_mat_vec_body;
vk_pipeline_dequant_mul_mat_vec[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)), stream.str(), { "B_TYPE", "float", "D_TYPE", "float16_t" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
vk_pipeline_dequant_mul_mat_vec_f32[i] = ggml_vk_create_pipeline_from_string("mul_mat_vec_" + std::string(ggml_type_name((ggml_type)i)) + "_f32", stream.str(), { "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
} }
stream << mul_mat_vec_q4_0_defines << mul_mat_vec_body;
vk_pipeline_dequant_mul_mat_vec_q4_0 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0", stream.str(), { "D_TYPE", "float", "B_TYPE", "float16_t" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
stream.str("");
stream.clear();
stream << mul_mat_vec_head << shader_float_type;
if (vk_device.fp16) {
stream << shader_int8_ext << mul_mat_vec_q4_0_dequant_func;
} else {
stream << mul_mat_vec_q4_0_dequant_func_compat;
}
stream << mul_mat_vec_q4_0_defines << mul_mat_vec_body;
vk_pipeline_dequant_mul_mat_vec_q4_0_f32 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0_f32", stream.str(), { "D_TYPE", "float", "B_TYPE", "float" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
stream.str("");
stream.clear();
stream << mul_mat_vec_head << shader_float_type;
if (vk_device.fp16) {
stream << shader_int8_ext << mul_mat_vec_f16_dequant_func;
} else {
stream << mul_mat_vec_f16_dequant_func_compat;
}
stream << mul_mat_vec_f16_defines << mul_mat_vec_body;
vk_pipeline_dequant_mul_mat_vec_f16 = ggml_vk_create_pipeline_from_string("mul_mat_vec_f16", stream.str(), { "D_TYPE", "float", "B_TYPE", "float16_t" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
stream.str("");
stream.clear();
stream << mul_mat_vec_head << shader_float_type;
if (vk_device.fp16) {
stream << shader_int8_ext << mul_mat_vec_f16_dequant_func;
} else {
stream << mul_mat_vec_f16_dequant_func_compat;
}
stream << mul_mat_vec_f16_defines << mul_mat_vec_body;
vk_pipeline_dequant_mul_mat_vec_f16_f32 = ggml_vk_create_pipeline_from_string("mul_mat_vec_f16_f32", stream.str(), { "D_TYPE", "float", "B_TYPE", "float" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
// add // add
stream.str(""); stream.str("");
@ -779,7 +787,6 @@ static void ggml_vk_generate_shaders() {
// Static shaders // Static shaders
vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline_from_string("split_k_reduce", mulmat_split_k_reduce_src, {}, "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1); vk_pipeline_matmul_split_k_reduce = ggml_vk_create_pipeline_from_string("split_k_reduce", mulmat_split_k_reduce_src, {}, "main", 1, 3 * sizeof(int), {32, 32, 1}, {}, 1);
vk_pipeline_f32_to_f16 = ggml_vk_create_pipeline_from_string("f32_to_f16", f32_to_f16_src, {}, "main", 2, 4 * sizeof(int), {64, 1, 1}, {}, 1);
vk_pipeline_mul_f32 = ggml_vk_create_pipeline_from_string("mul_f32", mul_f32_src, { "X_TYPE", "float", "Y_TYPE", "float", "D_TYPE", "float" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); vk_pipeline_mul_f32 = ggml_vk_create_pipeline_from_string("mul_f32", mul_f32_src, { "X_TYPE", "float", "Y_TYPE", "float", "D_TYPE", "float" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
vk_pipeline_scale_f32 = ggml_vk_create_pipeline_from_string("scale_f32", scale_src, { "X_TYPE", "float", "D_TYPE", "float" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); vk_pipeline_scale_f32 = ggml_vk_create_pipeline_from_string("scale_f32", scale_src, { "X_TYPE", "float", "D_TYPE", "float" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
@ -994,32 +1001,42 @@ void ggml_vk_init(void) {
#endif #endif
} }
static vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) { static inline vk_pipeline* ggml_vk_get_to_fp16(ggml_type type) {
#ifdef VK_DEBUG #ifdef VK_DEBUG
std::cerr << "ggml_vk_get_to_fp16()" << std::endl; std::cerr << "ggml_vk_get_to_fp16()" << std::endl;
#endif #endif
switch (type) { switch (type) {
case GGML_TYPE_Q4_0:
return &vk_pipeline_dequant_q4_0;
case GGML_TYPE_F32: case GGML_TYPE_F32:
return &vk_pipeline_f32_to_f16; case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
break;
default: default:
return nullptr; return nullptr;
} }
return &vk_pipeline_dequant[type];
} }
static vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type, bool f16_y) { static inline vk_pipeline* ggml_vk_get_dequantize_mul_mat_vec(ggml_type type, bool f16_y) {
#ifdef VK_DEBUG #ifdef VK_DEBUG
std::cerr << "ggml_vk_get_dequantize_mul_mat_vec()" << std::endl; std::cerr << "ggml_vk_get_dequantize_mul_mat_vec()" << std::endl;
#endif #endif
switch (type) { switch (type) {
case GGML_TYPE_Q4_0:
return f16_y ? &vk_pipeline_dequant_mul_mat_vec_q4_0 : &vk_pipeline_dequant_mul_mat_vec_q4_0_f32;
case GGML_TYPE_F16: case GGML_TYPE_F16:
return f16_y ? &vk_pipeline_dequant_mul_mat_vec_f16 : &vk_pipeline_dequant_mul_mat_vec_f16_f32; case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
break;
default: default:
return nullptr; return nullptr;
} }
return f16_y ? &vk_pipeline_dequant_mul_mat_vec[type] : &vk_pipeline_dequant_mul_mat_vec_f32[type];
} }
// buffer pool for vulkan // buffer pool for vulkan