Fix q6_k for GPUs without fp16 support
This commit is contained in:
parent
39bd512dd1
commit
85c1a63a15
2 changed files with 19 additions and 21 deletions
|
@ -3,10 +3,14 @@
|
||||||
// Generic
|
// Generic
|
||||||
const std::string shader_f32 = R"(
|
const std::string shader_f32 = R"(
|
||||||
#define FLOAT_TYPE float
|
#define FLOAT_TYPE float
|
||||||
|
#define INT8_TYPE int
|
||||||
|
#define UINT8_TYPE uint
|
||||||
)";
|
)";
|
||||||
const std::string shader_f16 = R"(
|
const std::string shader_f16 = R"(
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||||
#define FLOAT_TYPE float16_t
|
#define FLOAT_TYPE float16_t
|
||||||
|
#define INT8_TYPE int8_t
|
||||||
|
#define UINT8_TYPE uint8_t
|
||||||
)";
|
)";
|
||||||
const std::string shader_int8_ext = R"(
|
const std::string shader_int8_ext = R"(
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||||
|
@ -598,25 +602,25 @@ void main() {
|
||||||
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||||
const int y_idx = i * QUANT_K + y_offset;
|
const int y_idx = i * QUANT_K + y_offset;
|
||||||
|
|
||||||
const FLOAT_TYPE d = x[ib0 + i].d;
|
const FLOAT_TYPE d = FLOAT_TYPE(x[ib0 + i].d);
|
||||||
|
|
||||||
#if K_QUANTS_PER_ITERATION == 1
|
#if K_QUANTS_PER_ITERATION == 1
|
||||||
FLOAT_TYPE sum = FLOAT_TYPE(y[y_idx + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32)
|
FLOAT_TYPE sum = FLOAT_TYPE(y[y_idx + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 0]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x03) << 4)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
|
+ FLOAT_TYPE(y[y_idx + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 16]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x03) << 4)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32)
|
+ FLOAT_TYPE(y[y_idx + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 32]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x0c) << 2)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
|
+ FLOAT_TYPE(y[y_idx + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 48]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x0c) << 2)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32)
|
+ FLOAT_TYPE(y[y_idx + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 0]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x30) >> 0)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
|
+ FLOAT_TYPE(y[y_idx + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 16]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x30) >> 0)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32)
|
+ FLOAT_TYPE(y[y_idx + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 32]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0xc0) >> 2)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
|
+ FLOAT_TYPE(y[y_idx +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 48]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0xc0) >> 2)) - 32);
|
||||||
tmp[16 * ix + tid] += sum;
|
tmp[16 * ix + tid] += sum;
|
||||||
#else
|
#else
|
||||||
FLOAT_TYPE sum = 0;
|
FLOAT_TYPE sum = 0;
|
||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
|
sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+ 0]) & 0xF) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 0) & 3) << 4)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
|
+ FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+32]) & 0xF) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 2) & 3) << 4)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
|
+ FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+ 0]) >> 4) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 4) & 3) << 4)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + l+96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
|
+ FLOAT_TYPE(y[y_idx + l+96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+32]) >> 4) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 6) & 3) << 4)) - 32);
|
||||||
}
|
}
|
||||||
tmp[16 * ix + tid] += sum;
|
tmp[16 * ix + tid] += sum;
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -788,10 +788,7 @@ static void ggml_vk_generate_shaders() {
|
||||||
stream.str("");
|
stream.str("");
|
||||||
stream.clear();
|
stream.clear();
|
||||||
|
|
||||||
stream << dequant_head << shader_float_type;
|
stream << dequant_head << shader_int8_ext << shader_float_type;
|
||||||
if (vk_device.fp16) {
|
|
||||||
stream << shader_int8_ext;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!ggml_vk_build_shader_type_defines(stream, (ggml_type)i, !vk_device.fp16)) {
|
if (!ggml_vk_build_shader_type_defines(stream, (ggml_type)i, !vk_device.fp16)) {
|
||||||
continue;
|
continue;
|
||||||
|
@ -814,10 +811,7 @@ static void ggml_vk_generate_shaders() {
|
||||||
stream.str("");
|
stream.str("");
|
||||||
stream.clear();
|
stream.clear();
|
||||||
|
|
||||||
stream << mul_mat_vec_head << shader_float_type;
|
stream << mul_mat_vec_head << shader_int8_ext << shader_float_type;
|
||||||
if (vk_device.fp16) {
|
|
||||||
stream << shader_int8_ext;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!ggml_vk_build_shader_type_defines(stream, (ggml_type)i, !vk_device.fp16)) {
|
if (!ggml_vk_build_shader_type_defines(stream, (ggml_type)i, !vk_device.fp16)) {
|
||||||
continue;
|
continue;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue