vulkan: Add GLSL structure aliases for quant types to allow larger loads

In Vulkan it's not possible to cast pointer types, so instead you have to
declare an aliased binding for the memory with a different type. This
commit adds aliases for the quant formats using 16b ints, and in a few
places where the struct size is a multiple of 4 also using 32b ints.
Currently only q4_k's aliases are used, but others will be used in
subsequent commits.
This commit is contained in:
Jeff Bolz 2024-11-17 23:34:45 -06:00
parent 000a03bb5b
commit 6c3ad9342d
5 changed files with 130 additions and 31 deletions

View file

@ -2,6 +2,15 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#endif
#include "types.comp"
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
#if defined(DATA_A_F32)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);

View file

@ -12,6 +12,9 @@
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};

View file

@ -8,26 +8,6 @@ layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE tmp[32];
// Declare aliased versions of A and B bindings that can use 16b/32b loads for
// the quantized values, and vec4 loads for B.
struct block_q4_K_u32
{
f16vec2 d;
uint32_t scales[3*QUANT_K/64/4];
uint32_t qs[QUANT_K/2/4];
};
struct block_q4_K_u16
{
f16vec2 d;
uint16_t scales[3*QUANT_K/64/2];
uint16_t qs[QUANT_K/2/2];
};
layout (binding = 0) readonly buffer A_u32 {block_q4_K_u32 data_a_u32[];};
layout (binding = 0) readonly buffer A_u16 {block_q4_K_u16 data_a_u16[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
// This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads
void main() {
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
@ -68,9 +48,9 @@ void main() {
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
uint32_t scale0_u32 = data_a_u16[ib0 + i].scales[v_im ];
uint32_t scale4_u32 = data_a_u16[ib0 + i].scales[v_im + 2];
uint32_t scale8_u32 = data_a_u16[ib0 + i].scales[v_im + 4];
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
uvec4 scale0 = uvec4(unpack8(scale0_u32));
uvec4 scale4 = uvec4(unpack8(scale4_u32));
uvec4 scale8 = uvec4(unpack8(scale8_u32));
@ -84,8 +64,8 @@ void main() {
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
uint32_t qs0_u32 = data_a_u32[ib0 + i].qs[q_offset / 4];
uint32_t qs64_u32 = data_a_u32[ib0 + i].qs[q_offset / 4 + 16];
uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;

View file

@ -1,6 +1,8 @@
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#endif
#if !defined(GGML_TYPES_COMP)
#define GGML_TYPES_COMP
#extension GL_EXT_shader_explicit_arithmetic_types : require
#if defined(DATA_A_F32)
#define QUANT_K 1
@ -38,8 +40,14 @@ struct block_q4_0
float16_t d;
uint8_t qs[16];
};
struct block_q4_0_packed16
{
float16_t d;
uint16_t qs[16/2];
};
#define A_TYPE block_q4_0
#define A_TYPE_PACKED16 block_q4_0_packed16
#endif
#if defined(DATA_A_Q4_1)
@ -54,7 +62,15 @@ struct block_q4_1
uint8_t qs[16];
};
struct block_q4_1_packed16
{
float16_t d;
float16_t m;
uint16_t qs[16/2];
};
#define A_TYPE block_q4_1
#define A_TYPE_PACKED16 block_q4_1_packed16
#endif
#if defined(DATA_A_Q5_0)
@ -70,7 +86,15 @@ struct block_q5_0
uint8_t qs[16];
};
struct block_q5_0_packed16
{
float16_t d;
uint16_t qh[2];
uint16_t qs[16/2];
};
#define A_TYPE block_q5_0
#define A_TYPE_PACKED16 block_q5_0_packed16
#endif
#if defined(DATA_A_Q5_1)
@ -87,7 +111,16 @@ struct block_q5_1
uint8_t qs[16];
};
struct block_q5_1_packed16
{
float16_t d;
float16_t m;
uint qh;
uint16_t qs[16/2];
};
#define A_TYPE block_q5_1
#define A_TYPE_PACKED16 block_q5_1_packed16
#endif
#if defined(DATA_A_Q8_0)
@ -100,8 +133,14 @@ struct block_q8_0
float16_t d;
int8_t qs[32];
};
struct block_q8_0_packed16
{
float16_t d;
uint16_t qs[32/2];
};
#define A_TYPE block_q8_0
#define A_TYPE_PACKED16 block_q8_0_packed16
#endif
// K-quants
@ -116,7 +155,23 @@ struct block_q2_K
f16vec2 d;
};
struct block_q2_K_packed16
{
uint16_t scales[QUANT_K/16/2];
uint16_t qs[QUANT_K/4/2];
f16vec2 d;
};
struct block_q2_K_packed32
{
uint32_t scales[QUANT_K/16/4];
uint32_t qs[QUANT_K/4/4];
f16vec2 d;
};
#define A_TYPE block_q2_K
#define A_TYPE_PACKED16 block_q2_K_packed16
#define A_TYPE_PACKED32 block_q2_K_packed32
#endif
#if defined(DATA_A_Q3_K)
@ -131,7 +186,16 @@ struct block_q3_K
float16_t d;
};
struct block_q3_K_packed16
{
uint16_t hmask[QUANT_K/8/2];
uint16_t qs[QUANT_K/4/2];
uint16_t scales[12/2];
float16_t d;
};
#define A_TYPE block_q3_K
#define A_TYPE_PACKED16 block_q3_K_packed16
#endif
#if defined(DATA_A_Q4_K)
@ -145,7 +209,23 @@ struct block_q4_K
uint8_t qs[QUANT_K/2];
};
struct block_q4_K_packed16
{
f16vec2 d;
uint16_t scales[3*QUANT_K/64/2];
uint16_t qs[QUANT_K/2/2];
};
struct block_q4_K_packed32
{
f16vec2 d;
uint32_t scales[3*QUANT_K/64/4];
uint32_t qs[QUANT_K/2/4];
};
#define A_TYPE block_q4_K
#define A_TYPE_PACKED16 block_q4_K_packed16
#define A_TYPE_PACKED32 block_q4_K_packed32
#endif
#if defined(DATA_A_Q5_K)
@ -160,7 +240,16 @@ struct block_q5_K
uint8_t qs[QUANT_K/2];
};
struct block_q5_K_packed16
{
f16vec2 d;
uint16_t scales[12/2];
uint16_t qh[QUANT_K/8/2];
uint16_t qs[QUANT_K/2/2];
};
#define A_TYPE block_q5_K
#define A_TYPE_PACKED16 block_q5_K_packed16
#endif
#if defined(DATA_A_Q6_K)
@ -175,7 +264,16 @@ struct block_q6_K
float16_t d;
};
struct block_q6_K_packed16
{
uint16_t ql[QUANT_K/2/2];
uint16_t qh[QUANT_K/4/2];
int8_t scales[QUANT_K/16];
float16_t d;
};
#define A_TYPE block_q6_K
#define A_TYPE_PACKED16 block_q6_K_packed16
#endif
// IQuants
@ -191,10 +289,19 @@ struct block_iq4_nl
uint8_t qs[QUANT_K/2];
};
struct block_iq4_nl_packed16
{
float16_t d;
uint16_t qs[QUANT_K/2/2];
};
#define A_TYPE block_iq4_nl
#define A_TYPE_PACKED16 block_iq4_nl_packed16
const int8_t kvalues_iq4nl[16] = {
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
};
#endif
#endif // !defined(GGML_TYPES_COMP)

View file

@ -317,10 +317,10 @@ void process_shaders() {
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
// Dequant shaders
if (tname != "f16") {