diff --git a/ggml-vulkan-shaders.hpp b/ggml-vulkan-shaders.hpp index a7cb149ca..fde390b88 100644 --- a/ggml-vulkan-shaders.hpp +++ b/ggml-vulkan-shaders.hpp @@ -3,10 +3,14 @@ // Generic const std::string shader_f32 = R"( #define FLOAT_TYPE float +#define INT8_TYPE int +#define UINT8_TYPE uint )"; const std::string shader_f16 = R"( #extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #define FLOAT_TYPE float16_t +#define INT8_TYPE int8_t +#define UINT8_TYPE uint8_t )"; const std::string shader_int8_ext = R"( #extension GL_EXT_shader_explicit_arithmetic_types_int8 : require @@ -598,25 +602,25 @@ void main() { for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { const int y_idx = i * QUANT_K + y_offset; - const FLOAT_TYPE d = x[ib0 + i].d; + const FLOAT_TYPE d = FLOAT_TYPE(x[ib0 + i].d); #if K_QUANTS_PER_ITERATION == 1 - FLOAT_TYPE sum = FLOAT_TYPE(y[y_idx + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32) - + FLOAT_TYPE(y[y_idx + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32) - + FLOAT_TYPE(y[y_idx + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32) - + FLOAT_TYPE(y[y_idx + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32) - + FLOAT_TYPE(y[y_idx + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32) - + FLOAT_TYPE(y[y_idx + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32) - + FLOAT_TYPE(y[y_idx + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32) - + FLOAT_TYPE(y[y_idx +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32); + FLOAT_TYPE sum = FLOAT_TYPE(y[y_idx + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 0]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x03) << 4)) - 32) + + FLOAT_TYPE(y[y_idx + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 16]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x03) << 4)) - 32) + + FLOAT_TYPE(y[y_idx + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 32]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x0c) << 2)) - 32) + + FLOAT_TYPE(y[y_idx + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 48]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x0c) << 2)) - 32) + + FLOAT_TYPE(y[y_idx + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 0]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x30) >> 0)) - 32) + + FLOAT_TYPE(y[y_idx + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 16]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x30) >> 0)) - 32) + + FLOAT_TYPE(y[y_idx + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 32]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0xc0) >> 2)) - 32) + + FLOAT_TYPE(y[y_idx +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 48]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0xc0) >> 2)) - 32); tmp[16 * ix + tid] += sum; #else FLOAT_TYPE sum = 0; for (int l = 0; l < 4; ++l) { - sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32) - + FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32) - + FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32) - + FLOAT_TYPE(y[y_idx + l+96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32); + sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+ 0]) & 0xF) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 0) & 3) << 4)) - 32) + + FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+32]) & 0xF) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 2) & 3) << 4)) - 32) + + FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+ 0]) >> 4) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 4) & 3) << 4)) - 32) + + FLOAT_TYPE(y[y_idx + l+96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+32]) >> 4) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 6) & 3) << 4)) - 32); } tmp[16 * ix + tid] += sum; #endif diff --git a/ggml-vulkan.cpp b/ggml-vulkan.cpp index 683add489..6030fac48 100644 --- a/ggml-vulkan.cpp +++ b/ggml-vulkan.cpp @@ -788,10 +788,7 @@ static void ggml_vk_generate_shaders() { stream.str(""); stream.clear(); - stream << dequant_head << shader_float_type; - if (vk_device.fp16) { - stream << shader_int8_ext; - } + stream << dequant_head << shader_int8_ext << shader_float_type; if (!ggml_vk_build_shader_type_defines(stream, (ggml_type)i, !vk_device.fp16)) { continue; @@ -814,10 +811,7 @@ static void ggml_vk_generate_shaders() { stream.str(""); stream.clear(); - stream << mul_mat_vec_head << shader_float_type; - if (vk_device.fp16) { - stream << shader_int8_ext; - } + stream << mul_mat_vec_head << shader_int8_ext << shader_float_type; if (!ggml_vk_build_shader_type_defines(stream, (ggml_type)i, !vk_device.fp16)) { continue;