Fix q6_k for GPUs without fp16 support

This commit is contained in:
0cc4m 2023-10-08 11:52:44 +02:00
parent 39bd512dd1
commit 85c1a63a15
2 changed files with 19 additions and 21 deletions

View file

@ -3,10 +3,14 @@
// Generic
const std::string shader_f32 = R"(
#define FLOAT_TYPE float
#define INT8_TYPE int
#define UINT8_TYPE uint
)";
const std::string shader_f16 = R"(
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
#define FLOAT_TYPE float16_t
#define INT8_TYPE int8_t
#define UINT8_TYPE uint8_t
)";
const std::string shader_int8_ext = R"(
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
@ -598,25 +602,25 @@ void main() {
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const int y_idx = i * QUANT_K + y_offset;
const FLOAT_TYPE d = x[ib0 + i].d;
const FLOAT_TYPE d = FLOAT_TYPE(x[ib0 + i].d);
#if K_QUANTS_PER_ITERATION == 1
FLOAT_TYPE sum = FLOAT_TYPE(y[y_idx + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32)
+ FLOAT_TYPE(y[y_idx + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
+ FLOAT_TYPE(y[y_idx + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32)
+ FLOAT_TYPE(y[y_idx + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
+ FLOAT_TYPE(y[y_idx + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32)
+ FLOAT_TYPE(y[y_idx +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
FLOAT_TYPE sum = FLOAT_TYPE(y[y_idx + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 0]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x03) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 16]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x03) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 32]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x0c) << 2)) - 32)
+ FLOAT_TYPE(y[y_idx + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 48]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x0c) << 2)) - 32)
+ FLOAT_TYPE(y[y_idx + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 0]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x30) >> 0)) - 32)
+ FLOAT_TYPE(y[y_idx + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 16]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x30) >> 0)) - 32)
+ FLOAT_TYPE(y[y_idx + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 32]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0xc0) >> 2)) - 32)
+ FLOAT_TYPE(y[y_idx +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 48]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0xc0) >> 2)) - 32);
tmp[16 * ix + tid] += sum;
#else
FLOAT_TYPE sum = 0;
for (int l = 0; l < 4; ++l) {
sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+ 0]) & 0xF) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 0) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+32]) & 0xF) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 2) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+ 0]) >> 4) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 4) & 3) << 4)) - 32)
+ FLOAT_TYPE(y[y_idx + l+96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+32]) >> 4) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 6) & 3) << 4)) - 32);
}
tmp[16 * ix + tid] += sum;
#endif

View file

@ -788,10 +788,7 @@ static void ggml_vk_generate_shaders() {
stream.str("");
stream.clear();
stream << dequant_head << shader_float_type;
if (vk_device.fp16) {
stream << shader_int8_ext;
}
stream << dequant_head << shader_int8_ext << shader_float_type;
if (!ggml_vk_build_shader_type_defines(stream, (ggml_type)i, !vk_device.fp16)) {
continue;
@ -814,10 +811,7 @@ static void ggml_vk_generate_shaders() {
stream.str("");
stream.clear();
stream << mul_mat_vec_head << shader_float_type;
if (vk_device.fp16) {
stream << shader_int8_ext;
}
stream << mul_mat_vec_head << shader_int8_ext << shader_float_type;
if (!ggml_vk_build_shader_type_defines(stream, (ggml_type)i, !vk_device.fp16)) {
continue;