Fix fp32 fallback if device doesn't support fp16, add force disable env var GGML_VULKAN_DISABLE_F16

This commit is contained in:
0cc4m 2023-08-14 11:07:55 +02:00
parent 01d22a4a10
commit e9be24f9ad
2 changed files with 141 additions and 89 deletions

View file

@ -11,38 +11,27 @@ const std::string shader_f16 = R"(
const std::string shader_int8_ext = R"( const std::string shader_int8_ext = R"(
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
)"; )";
const std::string shader_output_f16 = R"(
#define OUT_TYPE float16_t
)";
const std::string shader_output_f32 = R"(
#define OUT_TYPE float
)";
// MULMAT // MULMAT
const std::string mulmat_head = R"( const std::string mulmat_head = R"(
#version 450 #version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#define WARP 32 #define WARP 32
#extension GL_EXT_control_flow_attributes : enable #ifndef LOAD_VEC
#define LOAD_VEC 1
#endif
)"; )";
const std::string mulmat_body = R"( const std::string mulmat_body = R"(
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#ifdef ALIGNED_INPUT
#define LOAD_VEC 8
layout (binding = 0) readonly buffer A { A_TYPE data_a[]; }; layout (binding = 0) readonly buffer A { A_TYPE data_a[]; };
layout (binding = 1) readonly buffer B { B_TYPE data_b[]; }; layout (binding = 1) readonly buffer B { B_TYPE data_b[]; };
#else
#define LOAD_VEC 1
layout (binding = 0) readonly buffer A { A_TYPE data_a[]; };
layout (binding = 1) readonly buffer B { B_TYPE data_b[]; };
#endif
layout (binding = 2) writeonly buffer D { D_TYPE data_d[]; }; layout (binding = 2) writeonly buffer D { D_TYPE data_d[]; };
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
@ -107,16 +96,22 @@ void main() {
[[unroll]] for (int block = start_k; block < end_k; block += BK) { [[unroll]] for (int block = start_k; block < end_k; block += BK) {
[[unroll]] for (int l = 0; l < BM; l += loadstride) { [[unroll]] for (int l = 0; l < BM; l += loadstride) {
#ifdef ALIGNED_INPUT #if LOAD_VEC == 8
A_TYPE tmp = data_a[pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr]; const int idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(tmp[0].x); buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx][0].x);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(tmp[0].y); buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx][0].y);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(tmp[0].z); buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx][0].z);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(tmp[0].w); buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx][0].w);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(tmp[1].x); buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(data_a[idx][1].x);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(tmp[1].y); buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(data_a[idx][1].y);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(tmp[1].z); buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_a[idx][1].z);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(tmp[1].w); buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_a[idx][1].w);
#elif LOAD_VEC == 4
const int idx = pos_a + (loadc + l) * p.stride_a / LOAD_VEC + loadr;
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_a[idx].x);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_a[idx].y);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_a[idx].z);
buf_a[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_a[idx].w);
#else #else
if (ir * BM + loadc + l < p.M && block + loadr < p.K) { if (ir * BM + loadc + l < p.M && block + loadr < p.K) {
buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_a[pos_a + (loadc + l) * p.stride_a + loadr]); buf_a[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_a[pos_a + (loadc + l) * p.stride_a + loadr]);
@ -126,16 +121,22 @@ void main() {
#endif #endif
} }
[[unroll]] for (int l = 0; l < BN; l += loadstride) { [[unroll]] for (int l = 0; l < BN; l += loadstride) {
#ifdef ALIGNED_INPUT #if LOAD_VEC == 8
B_TYPE tmp = data_b[pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr]; const int idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(tmp[0].x); buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx][0].x);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(tmp[0].y); buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx][0].y);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(tmp[0].z); buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx][0].z);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(tmp[0].w); buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx][0].w);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(tmp[1].x); buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 4] = FLOAT_TYPE(data_b[idx][1].x);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(tmp[1].y); buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 5] = FLOAT_TYPE(data_b[idx][1].y);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(tmp[1].z); buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 6] = FLOAT_TYPE(data_b[idx][1].z);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(tmp[1].w); buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 7] = FLOAT_TYPE(data_b[idx][1].w);
#elif LOAD_VEC == 4
const int idx = pos_b + (loadc + l) * p.stride_b / LOAD_VEC + loadr;
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 0] = FLOAT_TYPE(data_b[idx].x);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 1] = FLOAT_TYPE(data_b[idx].y);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 2] = FLOAT_TYPE(data_b[idx].z);
buf_b[(loadc + l) * (BK+1) + loadr * LOAD_VEC + 3] = FLOAT_TYPE(data_b[idx].w);
#else #else
if (ic * BN + loadc + l < p.N && block + loadr < p.K) { if (ic * BN + loadc + l < p.N && block + loadr < p.K) {
buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_b[pos_b + (loadc + l) * p.stride_b + loadr]); buf_b[(loadc + l) * (BK+1) + loadr] = FLOAT_TYPE(data_b[pos_b + (loadc + l) * p.stride_b + loadr]);
@ -259,7 +260,7 @@ const std::string dequant_body = R"(
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A { A_TYPE x[]; }; layout (binding = 0) readonly buffer A { A_TYPE x[]; };
layout (binding = 1) writeonly buffer D { OUT_TYPE y[]; }; layout (binding = 1) writeonly buffer D { D_TYPE y[]; };
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
{ {
@ -282,15 +283,15 @@ void main() {
const int stride_a = p.stride_a / QUANT_K; const int stride_a = p.stride_a / QUANT_K;
const A_TYPE blk = x[col * stride_a + row]; const int idx = col * stride_a + row;
const OUT_TYPE d = blk.d; const FLOAT_TYPE d = FLOAT_TYPE(x[idx].d);
[[unroll]] for (int j = 0; j < QUANT_K/2; ++j) { [[unroll]] for (int j = 0; j < QUANT_K/2; ++j) {
const OUT_TYPE x0 = OUT_TYPE((blk.qs[j] & 0x0F) - 8); const FLOAT_TYPE x0 = FLOAT_TYPE((x[idx].qs[j] & 0x0F) - 8);
const OUT_TYPE x1 = OUT_TYPE((blk.qs[j] >> 4) - 8); const FLOAT_TYPE x1 = FLOAT_TYPE((x[idx].qs[j] >> 4) - 8);
y[col * p.stride_b + row*QUANT_K + j + 0 ] = x0*d; y[col * p.stride_b + row*QUANT_K + j + 0 ] = D_TYPE(x0*d);
y[col * p.stride_b + row*QUANT_K + j + QUANT_K/2] = x1*d; y[col * p.stride_b + row*QUANT_K + j + QUANT_K/2] = D_TYPE(x1*d);
} }
} }
)"; )";
@ -304,25 +305,24 @@ const std::string mul_mat_vec_head = R"(
#extension GL_EXT_shader_8bit_storage : require #extension GL_EXT_shader_8bit_storage : require
)"; )";
const std::string mul_mat_vec_b_type_f32 = R"(
#define B_TYPE float
)";
const std::string mul_mat_vec_b_type_f16 = R"(
#define B_TYPE float16_t
)";
const std::string mul_mat_vec_f16_defines = R"( const std::string mul_mat_vec_f16_defines = R"(
#define QUANT_K 32 #define QUANT_K 32
#define QUANT_R 2 #define QUANT_R 2
#define BLOCK_SIZE 32 #define BLOCK_SIZE 32
#define A_TYPE float16_t #define A_TYPE float16_t
)";
const std::string mul_mat_vec_f16_dequant_func = R"(
#define DEQUANT_FUNC float16_t v0 = x[ib + 0]; \ #define DEQUANT_FUNC float16_t v0 = x[ib + 0]; \
float16_t v1 = x[ib + 1]; float16_t v1 = x[ib + 1];
)"; )";
const std::string mul_mat_vec_f16_dequant_func_compat = R"(
#define DEQUANT_FUNC float v0 = float(x[ib + 0]); \
float v1 = float(x[ib + 1]);
)";
const std::string mul_mat_vec_q4_0_defines = R"( const std::string mul_mat_vec_q4_0_defines = R"(
#define QUANT_K 32 #define QUANT_K 32
#define QUANT_R 2 #define QUANT_R 2
@ -334,7 +334,9 @@ struct block_q4_0
uint8_t qs[16]; uint8_t qs[16];
}; };
#define A_TYPE block_q4_0 #define A_TYPE block_q4_0
)";
const std::string mul_mat_vec_q4_0_dequant_func = R"(
#define DEQUANT_FUNC const float16_t d = x[ib].d; \ #define DEQUANT_FUNC const float16_t d = x[ib].d; \
const uint8_t vui = x[ib].qs[iqs]; \ const uint8_t vui = x[ib].qs[iqs]; \
const int8_t vi0 = int8_t(vui & 0xF); \ const int8_t vi0 = int8_t(vui & 0xF); \
@ -343,12 +345,21 @@ float16_t v0 = float16_t(vi0 - 8)*d; \
float16_t v1 = float16_t(vi1 - 8)*d; float16_t v1 = float16_t(vi1 - 8)*d;
)"; )";
const std::string mul_mat_vec_q4_0_dequant_func_compat = R"(
#define DEQUANT_FUNC const float d = float(x[ib].d); \
const uint vui = uint(x[ib].qs[iqs]); \
const int vi0 = int(vui) & 0xF; \
const int vi1 = int(vui) >> 4; \
float v0 = float(vi0 - 8)*d; \
float v1 = float(vi1 - 8)*d;
)";
const std::string mul_mat_vec_body = R"( const std::string mul_mat_vec_body = R"(
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A { A_TYPE x[]; }; layout (binding = 0) readonly buffer A { A_TYPE x[]; };
layout (binding = 1) readonly buffer B { B_TYPE y[]; }; layout (binding = 1) readonly buffer B { B_TYPE y[]; };
layout (binding = 2) writeonly buffer D { OUT_TYPE dst[]; }; layout (binding = 2) writeonly buffer D { D_TYPE dst[]; };
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
{ {
@ -364,7 +375,7 @@ void main() {
const int y_offset = QUANT_K/2; const int y_offset = QUANT_K/2;
tmp[tid] = 0.0hf; tmp[tid] = FLOAT_TYPE(0.0f);
[[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) { [[unroll]] for (int i = 0; i < p.ncols/block_size; i += 2) {
const int col = i*block_size + 2*tid; const int col = i*block_size + 2*tid;
@ -375,8 +386,8 @@ void main() {
DEQUANT_FUNC DEQUANT_FUNC
// matrix multiplication // matrix multiplication
tmp[tid] += FLOAT_TYPE(v0 * y[iybs + iqs + 0]); tmp[tid] += FLOAT_TYPE(v0) * FLOAT_TYPE(y[iybs + iqs + 0]);
tmp[tid] += FLOAT_TYPE(v1 * y[iybs + iqs + y_offset]); tmp[tid] += FLOAT_TYPE(v1) * FLOAT_TYPE(y[iybs + iqs + y_offset]);
} }
// sum up partial sums and write back result // sum up partial sums and write back result
@ -388,7 +399,7 @@ void main() {
barrier(); barrier();
} }
if (tid == 0) { if (tid == 0) {
dst[row] = OUT_TYPE(tmp[0]); dst[row] = D_TYPE(tmp[0]);
} }
} }
)"; )";
@ -460,6 +471,8 @@ void main() {
// ADD // ADD
const std::string add_head = R"( const std::string add_head = R"(
#version 450 #version 450
#extension GL_EXT_shader_16bit_storage : require
)"; )";
const std::string add_body = R"( const std::string add_body = R"(
@ -489,7 +502,7 @@ void main() {
return; return;
} }
data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(data_x[p.x_offset + y * p.stride_x + x]) + D_TYPE(data_y[p.y_offset + x]); data_d[p.d_offset + y * p.stride_d + x] = D_TYPE(FLOAT_TYPE(data_x[p.x_offset + y * p.stride_x + x]) + FLOAT_TYPE(data_y[p.y_offset + x]));
} }
)"; )";

View file

@ -662,85 +662,119 @@ static void ggml_vk_generate_shaders() {
auto warptile_m = { 128, 64, 64, 16, 32, 32, 2, 4, 2 }; auto warptile_m = { 128, 64, 64, 16, 32, 32, 2, 4, 2 };
auto warptile_s = { 32, 32, 32, 8, 32, 32, 2, 2, 2 }; auto warptile_s = { 32, 32, 32, 8, 32, 32, 2, 2, 2 };
std::string shader_float_type;
std::string load_vec;
std::string vec_type_f16;
std::string vec_type;
if (vk_device.fp16) {
shader_float_type = shader_f16;
load_vec = "8";
vec_type_f16 = "f16mat2x4";
vec_type = "mat2x4";
} else {
shader_float_type = shader_f32;
load_vec = "4";
vec_type_f16 = "f16vec4";
vec_type = "vec4";
}
std::stringstream stream; std::stringstream stream;
stream << mulmat_head << shader_f32 << mulmat_body; stream << mulmat_head << shader_float_type << mulmat_body;
vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline_from_string("matmul_f32_l", stream.str(), { "A_TYPE", "float", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); vk_pipeline_matmul_f32_l = ggml_vk_create_pipeline_from_string("matmul_f32_l", stream.str(), { "A_TYPE", "float", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline_from_string("matmul_f32_m", stream.str(), { "A_TYPE", "float", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); vk_pipeline_matmul_f32_m = ggml_vk_create_pipeline_from_string("matmul_f32_m", stream.str(), { "A_TYPE", "float", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline_from_string("matmul_f32_s", stream.str(), { "A_TYPE", "float", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); vk_pipeline_matmul_f32_s = ggml_vk_create_pipeline_from_string("matmul_f32_s", stream.str(), { "A_TYPE", "float", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
vk_pipeline_matmul_f32_aligned_l = ggml_vk_create_pipeline_from_string("matmul_f32_aligned_l", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); vk_pipeline_matmul_f32_aligned_l = ggml_vk_create_pipeline_from_string("matmul_f32_aligned_l", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type, "B_TYPE", vec_type, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f32_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f32_aligned_m", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); vk_pipeline_matmul_f32_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f32_aligned_m", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type, "B_TYPE", vec_type, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f32_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f32_aligned_s", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); vk_pipeline_matmul_f32_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f32_aligned_s", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type, "B_TYPE", vec_type, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
stream.str(""); stream.str("");
stream.clear(); stream.clear();
stream << mulmat_head << shader_f16 << mulmat_body; stream << mulmat_head << shader_float_type << mulmat_body;
vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline_from_string("matmul_f16_l", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float16_t", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); vk_pipeline_matmul_f16_l = ggml_vk_create_pipeline_from_string("matmul_f16_l", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float16_t", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline_from_string("matmul_f16_m", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float16_t", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); vk_pipeline_matmul_f16_m = ggml_vk_create_pipeline_from_string("matmul_f16_m", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float16_t", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline_from_string("matmul_f16_s", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float16_t", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); vk_pipeline_matmul_f16_s = ggml_vk_create_pipeline_from_string("matmul_f16_s", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float16_t", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline_from_string("matmul_f16_aligned_l", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "f16mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); vk_pipeline_matmul_f16_aligned_l = ggml_vk_create_pipeline_from_string("matmul_f16_aligned_l", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type_f16, "B_TYPE", vec_type_f16, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f16_aligned_m", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "f16mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); vk_pipeline_matmul_f16_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f16_aligned_m", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type_f16, "B_TYPE", vec_type_f16, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f16_aligned_s", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "f16mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); vk_pipeline_matmul_f16_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f16_aligned_s", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type_f16, "B_TYPE", vec_type_f16, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline_from_string("matmul_f16_f32_l", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); vk_pipeline_matmul_f16_f32_l = ggml_vk_create_pipeline_from_string("matmul_f16_f32_l", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline_from_string("matmul_f16_f32_m", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); vk_pipeline_matmul_f16_f32_m = ggml_vk_create_pipeline_from_string("matmul_f16_f32_m", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline_from_string("matmul_f16_f32_s", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); vk_pipeline_matmul_f16_f32_s = ggml_vk_create_pipeline_from_string("matmul_f16_f32_s", stream.str(), { "A_TYPE", "float16_t", "B_TYPE", "float", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
vk_pipeline_matmul_f16_f32_aligned_l = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_l", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128); vk_pipeline_matmul_f16_f32_aligned_l = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_l", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type_f16, "B_TYPE", vec_type, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), {128, 128, 1}, warptile_l, 128);
vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_m", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64); vk_pipeline_matmul_f16_f32_aligned_m = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_m", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type_f16, "B_TYPE", vec_type, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 64, 64, 1}, warptile_m, 64);
vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_s", stream.str(), { "ALIGNED_INPUT", "", "A_TYPE", "f16mat2x4", "B_TYPE", "mat2x4", "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32); vk_pipeline_matmul_f16_f32_aligned_s = ggml_vk_create_pipeline_from_string("matmul_f16_f32_aligned_s", stream.str(), { "LOAD_VEC", load_vec, "A_TYPE", vec_type_f16, "B_TYPE", vec_type, "D_TYPE", "float" }, "main", 3, 7 * sizeof(int), { 32, 32, 1}, warptile_s, 32);
// Build dequant q4_0 // Build dequant q4_0
stream.str(""); stream.str("");
stream.clear(); stream.clear();
stream << dequant_head; stream << dequant_head << shader_float_type << dequant_q4_0_defines << dequant_body;
if (vk_device.fp16) {
stream << shader_f16 << shader_output_f16;
} else {
stream << shader_output_f32;
}
stream << dequant_q4_0_defines << dequant_body;
vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline_from_string("dequant_q4_0", stream.str(), {}, "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1); vk_pipeline_dequant_q4_0 = ggml_vk_create_pipeline_from_string("dequant_q4_0", stream.str(), { "D_TYPE", "float16_t" }, "main", 2, 4 * sizeof(int), {256*32, 1, 1}, {}, 1);
// mul mat vec // mul mat vec
stream.str(""); stream.str("");
stream.clear(); stream.clear();
stream << mul_mat_vec_head << shader_f16 << shader_int8_ext << shader_output_f32 << mul_mat_vec_b_type_f16 << mul_mat_vec_q4_0_defines << mul_mat_vec_body; stream << mul_mat_vec_head << shader_float_type;
if (vk_device.fp16) {
stream << shader_int8_ext << mul_mat_vec_q4_0_dequant_func;
} else {
stream << mul_mat_vec_q4_0_dequant_func_compat;
}
stream << mul_mat_vec_q4_0_defines << mul_mat_vec_body;
vk_pipeline_dequant_mul_mat_vec_q4_0 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0", stream.str(), {}, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); vk_pipeline_dequant_mul_mat_vec_q4_0 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0", stream.str(), { "D_TYPE", "float", "B_TYPE", "float16_t" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
stream.str(""); stream.str("");
stream.clear(); stream.clear();
stream << mul_mat_vec_head << shader_f16 << shader_int8_ext << shader_output_f32 << mul_mat_vec_b_type_f32 << mul_mat_vec_q4_0_defines << mul_mat_vec_body; stream << mul_mat_vec_head << shader_float_type;
if (vk_device.fp16) {
stream << shader_int8_ext << mul_mat_vec_q4_0_dequant_func;
} else {
stream << mul_mat_vec_q4_0_dequant_func_compat;
}
stream << mul_mat_vec_q4_0_defines << mul_mat_vec_body;
vk_pipeline_dequant_mul_mat_vec_q4_0_f32 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0_f32", stream.str(), {}, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); vk_pipeline_dequant_mul_mat_vec_q4_0_f32 = ggml_vk_create_pipeline_from_string("mul_mat_vec_q4_0_f32", stream.str(), { "D_TYPE", "float", "B_TYPE", "float" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
stream.str(""); stream.str("");
stream.clear(); stream.clear();
stream << mul_mat_vec_head << shader_f16 << shader_output_f32 << mul_mat_vec_b_type_f16 << mul_mat_vec_f16_defines << mul_mat_vec_body; stream << mul_mat_vec_head << shader_float_type;
if (vk_device.fp16) {
stream << shader_int8_ext << mul_mat_vec_f16_dequant_func;
} else {
stream << mul_mat_vec_f16_dequant_func_compat;
}
stream << mul_mat_vec_f16_defines << mul_mat_vec_body;
vk_pipeline_dequant_mul_mat_vec_f16 = ggml_vk_create_pipeline_from_string("mul_mat_vec_f16", stream.str(), {}, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); vk_pipeline_dequant_mul_mat_vec_f16 = ggml_vk_create_pipeline_from_string("mul_mat_vec_f16", stream.str(), { "D_TYPE", "float", "B_TYPE", "float16_t" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
stream.str(""); stream.str("");
stream.clear(); stream.clear();
stream << mul_mat_vec_head << shader_f16 << shader_output_f32 << mul_mat_vec_b_type_f32 << mul_mat_vec_f16_defines << mul_mat_vec_body; stream << mul_mat_vec_head << shader_float_type;
vk_pipeline_dequant_mul_mat_vec_f16_f32 = ggml_vk_create_pipeline_from_string("mul_mat_vec_f16_f32", stream.str(), {}, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1); if (vk_device.fp16) {
stream << shader_int8_ext << mul_mat_vec_f16_dequant_func;
} else {
stream << mul_mat_vec_f16_dequant_func_compat;
}
stream << mul_mat_vec_f16_defines << mul_mat_vec_body;
vk_pipeline_dequant_mul_mat_vec_f16_f32 = ggml_vk_create_pipeline_from_string("mul_mat_vec_f16_f32", stream.str(), { "D_TYPE", "float", "B_TYPE", "float" }, "main", 3, 1 * sizeof(int), {1, 1, 1}, {}, 1);
// add // add
stream.str(""); stream.str("");
stream.clear(); stream.clear();
stream << add_head << add_body; stream << add_head << shader_float_type << add_body;
vk_pipeline_add_f32 = ggml_vk_create_pipeline_from_string("add_f32", stream.str(), { "X_TYPE", "float", "Y_TYPE", "float", "D_TYPE", "float" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); vk_pipeline_add_f32 = ggml_vk_create_pipeline_from_string("add_f32", stream.str(), { "X_TYPE", "float", "Y_TYPE", "float", "D_TYPE", "float" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
stream.str(""); stream.str("");
stream.clear(); stream.clear();
stream << add_head << shader_f16 << add_body; stream << add_head << shader_float_type << add_body;
vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline_from_string("add_f16_f32_f16", stream.str(), { "X_TYPE", "float16_t", "Y_TYPE", "float", "D_TYPE", "float16_t" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1); vk_pipeline_add_f16_f32_f16 = ggml_vk_create_pipeline_from_string("add_f16_f32_f16", stream.str(), { "X_TYPE", "float16_t", "Y_TYPE", "float", "D_TYPE", "float16_t" }, "main", 3, sizeof(vk_op_push_constants), {32, 32, 1}, {}, 1);
// Static shaders // Static shaders
@ -761,7 +795,7 @@ void ggml_vk_init(void) {
#ifdef VK_DEBUG #ifdef VK_DEBUG
std::cerr << "ggml_vk_init()" << std::endl; std::cerr << "ggml_vk_init()" << std::endl;
#endif #endif
char* GGML_VULKAN_DEVICE = getenv("GGML_VULKAN_DEVICE"); const char* GGML_VULKAN_DEVICE = getenv("GGML_VULKAN_DEVICE");
int dev_num = (GGML_VULKAN_DEVICE == NULL ? 0 : atoi(GGML_VULKAN_DEVICE)); int dev_num = (GGML_VULKAN_DEVICE == NULL ? 0 : atoi(GGML_VULKAN_DEVICE));
vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION }; vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, VK_API_VERSION };
@ -806,7 +840,10 @@ void ggml_vk_init(void) {
} }
} }
vk_device.fp16 = fp16_storage && fp16_compute; const char* GGML_VULKAN_DISABLE_F16 = getenv("GGML_VULKAN_DISABLE_F16");
bool force_disable_f16 = GGML_VULKAN_DISABLE_F16 != NULL;
vk_device.fp16 = !force_disable_f16 && fp16_storage && fp16_compute;
std::vector<vk::QueueFamilyProperties> queue_family_props = vk_device.physical_device.getQueueFamilyProperties(); std::vector<vk::QueueFamilyProperties> queue_family_props = vk_device.physical_device.getQueueFamilyProperties();
@ -875,6 +912,8 @@ void ggml_vk_init(void) {
if (vk_device.fp16) { if (vk_device.fp16) {
std::cerr << "ggml_vulkan: 16-bit enabled" << std::endl; std::cerr << "ggml_vulkan: 16-bit enabled" << std::endl;
device_extensions.push_back("VK_KHR_shader_float16_int8"); device_extensions.push_back("VK_KHR_shader_float16_int8");
} else if (force_disable_f16) {
std::cerr << "ggml_vulkan: 16-bit force-disabled" << std::endl;
} }
device_create_info = { device_create_info = {
vk::DeviceCreateFlags(), vk::DeviceCreateFlags(),