Simplify q6_k fp16 fix
This commit is contained in:
parent
85c1a63a15
commit
dad1cdb1ef
1 changed files with 12 additions and 16 deletions
|
@ -3,14 +3,10 @@
|
||||||
// Generic
|
// Generic
|
||||||
const std::string shader_f32 = R"(
|
const std::string shader_f32 = R"(
|
||||||
#define FLOAT_TYPE float
|
#define FLOAT_TYPE float
|
||||||
#define INT8_TYPE int
|
|
||||||
#define UINT8_TYPE uint
|
|
||||||
)";
|
)";
|
||||||
const std::string shader_f16 = R"(
|
const std::string shader_f16 = R"(
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||||
#define FLOAT_TYPE float16_t
|
#define FLOAT_TYPE float16_t
|
||||||
#define INT8_TYPE int8_t
|
|
||||||
#define UINT8_TYPE uint8_t
|
|
||||||
)";
|
)";
|
||||||
const std::string shader_int8_ext = R"(
|
const std::string shader_int8_ext = R"(
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||||
|
@ -605,22 +601,22 @@ void main() {
|
||||||
const FLOAT_TYPE d = FLOAT_TYPE(x[ib0 + i].d);
|
const FLOAT_TYPE d = FLOAT_TYPE(x[ib0 + i].d);
|
||||||
|
|
||||||
#if K_QUANTS_PER_ITERATION == 1
|
#if K_QUANTS_PER_ITERATION == 1
|
||||||
FLOAT_TYPE sum = FLOAT_TYPE(y[y_idx + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 0]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x03) << 4)) - 32)
|
FLOAT_TYPE sum = FLOAT_TYPE(y[y_idx + 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x03) << 4)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 16]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x03) << 4)) - 32)
|
+ FLOAT_TYPE(y[y_idx + 16]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 1]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x03) << 4)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 32]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x0c) << 2)) - 32)
|
+ FLOAT_TYPE(y[y_idx + 32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] & 0xF) | ((x[ib0 + i].qh[qh_offset + 0] & 0x0c) << 2)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 48]) & 0xF) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x0c) << 2)) - 32)
|
+ FLOAT_TYPE(y[y_idx + 48]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 3]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] & 0xF) | ((x[ib0 + i].qh[qh_offset + 16] & 0x0c) << 2)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 0]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0x30) >> 0)) - 32)
|
+ FLOAT_TYPE(y[y_idx + 64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 0] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0x30) >> 0)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 16]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0x30) >> 0)) - 32)
|
+ FLOAT_TYPE(y[y_idx + 80]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 5]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 16] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0x30) >> 0)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 32]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 0]) & 0xc0) >> 2)) - 32)
|
+ FLOAT_TYPE(y[y_idx + 96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 32] >> 4) | ((x[ib0 + i].qh[qh_offset + 0] & 0xc0) >> 2)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + 48]) >> 4) | ((UINT8_TYPE(x[ib0 + i].qh[qh_offset + 16]) & 0xc0) >> 2)) - 32);
|
+ FLOAT_TYPE(y[y_idx +112]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 7]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + 48] >> 4) | ((x[ib0 + i].qh[qh_offset + 16] & 0xc0) >> 2)) - 32);
|
||||||
tmp[16 * ix + tid] += sum;
|
tmp[16 * ix + tid] += sum;
|
||||||
#else
|
#else
|
||||||
FLOAT_TYPE sum = 0;
|
FLOAT_TYPE sum = 0;
|
||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+ 0]) & 0xF) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 0) & 3) << 4)) - 32)
|
sum += FLOAT_TYPE(y[y_idx + l+ 0]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 0]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+32]) & 0xF) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 2) & 3) << 4)) - 32)
|
+ FLOAT_TYPE(y[y_idx + l+32]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 2]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((x[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+ 0]) >> 4) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 4) & 3) << 4)) - 32)
|
+ FLOAT_TYPE(y[y_idx + l+64]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 4]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32)
|
||||||
+ FLOAT_TYPE(y[y_idx + l+96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(INT8_TYPE((UINT8_TYPE(x[ib0 + i].ql[ql_offset + l+32]) >> 4) | (((UINT8_TYPE(x[ib0 + i].qh[qh_offset + l]) >> 6) & 3) << 4)) - 32);
|
+ FLOAT_TYPE(y[y_idx + l+96]) * FLOAT_TYPE(x[ib0 + i].scales[s_offset + 6]) * d * FLOAT_TYPE(int8_t((x[ib0 + i].ql[ql_offset + l+32] >> 4) | (((x[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32);
|
||||||
}
|
}
|
||||||
tmp[16 * ix + tid] += sum;
|
tmp[16 * ix + tid] += sum;
|
||||||
#endif
|
#endif
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue