vulkan: Add VK_NV_cooperative_matrix2 support for mul_mat and flash attention (#10206)
This commit is contained in:
parent
6fe6247831
commit
c9c6e01dae
6 changed files with 1665 additions and 97 deletions
305
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
Normal file
305
ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.comp
Normal file
|
@ -0,0 +1,305 @@
|
|||
|
||||
#include "types.comp"
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 {
|
||||
block_q4_0_packed16 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint shift = (idx & 0x10) >> 2;
|
||||
uint32_t qs = unpack8(uint32_t(bl.block.qs[(idx & 0xE) >> 1]))[idx & 1];
|
||||
qs >>= shift;
|
||||
qs &= 0xF;
|
||||
float16_t ret = (float16_t(qs) - float16_t(8)) * d;
|
||||
return ret;
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 {
|
||||
block_q4_1 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const float16_t m = bl.block.m;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint iqs = idx & 0xF;
|
||||
const uint shift = (idx & 0x10) >> 2;
|
||||
uint32_t qs = bl.block.qs[iqs];
|
||||
qs >>= shift;
|
||||
qs &= 0xF;
|
||||
float16_t ret = float16_t(qs) * d + m;
|
||||
return ret;
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 {
|
||||
block_q5_0 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint iqs = idx & 0xF;
|
||||
|
||||
const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0];
|
||||
const uint qh = ((uint_qh >> idx) << 4) & 0x10;
|
||||
|
||||
const uint shift = (idx & 0x10) >> 2;
|
||||
uint32_t qs = bl.block.qs[iqs];
|
||||
qs >>= shift;
|
||||
qs &= 0xF;
|
||||
|
||||
float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d;
|
||||
return ret;
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 {
|
||||
block_q5_1 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const float16_t m = bl.block.m;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint iqs = idx & 0xF;
|
||||
|
||||
const uint uint_qh = bl.block.qh;
|
||||
const uint qh = ((uint_qh >> idx) << 4) & 0x10;
|
||||
|
||||
const uint shift = (idx & 0x10) >> 2;
|
||||
uint32_t qs = bl.block.qs[iqs];
|
||||
qs >>= shift;
|
||||
qs &= 0xF;
|
||||
|
||||
float16_t ret = float16_t(qs | qh) * d + m;
|
||||
return ret;
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 {
|
||||
block_q8_0_packed16 block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint iqs = idx;
|
||||
|
||||
// Load 16b and select the byte for this element
|
||||
int32_t qs = unpack8(int32_t(bl.block.qs[(iqs & 0x1E) >> 1]))[iqs & 1];
|
||||
float16_t ret = float16_t(qs) * d;
|
||||
return ret;
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K {
|
||||
block_q2_K block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const f16vec2 d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint iqs = idx;
|
||||
|
||||
const uint qsi = (iqs / 128) * 32 + (iqs % 32); // 0..31
|
||||
const uint scalesi = iqs / 16; // 0..15
|
||||
const uint qsshift = ((iqs % 128) / 32) * 2; // 0,2,4,6
|
||||
|
||||
uint32_t qs = bl.block.qs[qsi];
|
||||
const uint scales = bl.block.scales[scalesi];
|
||||
float16_t ret = d.x * float16_t(scales & 0xF) * float16_t((qs >> qsshift) & 3) - d.y * float16_t(scales >> 4);
|
||||
return ret;
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K {
|
||||
block_q3_K block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint iqs = idx;
|
||||
|
||||
const uint n = iqs / 128; // 0,1
|
||||
const uint qsi = n * 32 + (iqs % 32); // 0..63
|
||||
const uint hmi = (iqs % 32); // 0..31
|
||||
const uint j = (iqs % 128) / 8; // 0..15
|
||||
const uint is = iqs / 16; // 0..15
|
||||
const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3
|
||||
const uint qsshift = halfsplit * 2; // 0,2,4,6
|
||||
const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128
|
||||
|
||||
uint32_t scaleidx0 = (is < 8) ? is : (is-8);
|
||||
uint32_t scaleidx0shift = (is < 8) ? 0 : 4;
|
||||
uint32_t scaleidx1 = is + 8 - (is/4)*4;
|
||||
uint32_t scaleidx1shift = (is/4)*2;
|
||||
|
||||
const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4));
|
||||
|
||||
const float16_t dl = bl.block.d * float16_t(us - 32);
|
||||
|
||||
float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4));
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K {
|
||||
block_q4_K block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint iqs = idx;
|
||||
|
||||
const uint n = iqs / 64; // 0,1,2,3
|
||||
const uint b = (iqs % 64) / 32; // 0,1
|
||||
const uint is = (idx & 0xE0) >> 5; // 0..7
|
||||
const uint qsi = n * 32 + (iqs % 32); // 0..127
|
||||
|
||||
const f16vec2 loadd = bl.block.d;
|
||||
|
||||
uint32_t sc;
|
||||
uint32_t mbyte;
|
||||
|
||||
uint32_t scidx0 = (is < 4) ? is : (is + 4);
|
||||
uint32_t scidx1 = (is < 4) ? is : (is - 4);
|
||||
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
|
||||
uint32_t mbidx0 = is + 4;
|
||||
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
|
||||
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
|
||||
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
|
||||
|
||||
sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||
|
||||
const float16_t d = loadd.x * float16_t(sc);
|
||||
const float16_t m = loadd.y * float16_t(mbyte);
|
||||
|
||||
uint32_t dmask = 0xF << (b * 4);
|
||||
|
||||
float16_t ret = d * float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) - m;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K {
|
||||
block_q5_K block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint iqs = idx;
|
||||
|
||||
const uint n = iqs / 64; // 0,1,2,3
|
||||
const uint b = (iqs % 64) / 32; // 0,1
|
||||
const uint is = (idx & 0xE0) >> 5; // 0..7
|
||||
const uint qsi = n * 32 + (iqs % 32); // 0..127
|
||||
const uint qhi = (iqs % 32); // 0..31
|
||||
|
||||
const uint8_t hm = uint8_t(1 << (iqs / 32));
|
||||
|
||||
const f16vec2 loadd = bl.block.d;
|
||||
|
||||
uint32_t sc;
|
||||
uint32_t mbyte;
|
||||
|
||||
uint32_t scidx0 = (is < 4) ? is : (is + 4);
|
||||
uint32_t scidx1 = (is < 4) ? is : (is - 4);
|
||||
uint32_t scidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
uint32_t scidxshift1 = (is < 4) ? 0 : 2;
|
||||
uint32_t mbidx0 = is + 4;
|
||||
uint32_t mbidx1 = (is < 4) ? is + 4 : is;
|
||||
uint32_t mbidxmask0 = (is < 4) ? 0xF : 0xF0;
|
||||
uint32_t mbidxshift0 = (is < 4) ? 0 : 4;
|
||||
uint32_t mbidxmask1 = (is < 4) ? 0x30 : 0xC0;
|
||||
uint32_t mbidxshift1 = (is < 4) ? 0 : 2;
|
||||
|
||||
sc = uint8_t((bl.block.scales[scidx0] & 0xF) | ((bl.block.scales[scidx1] & scidxmask1) >> scidxshift1));
|
||||
mbyte = uint8_t(((bl.block.scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((bl.block.scales[mbidx1] & mbidxmask1) >> mbidxshift1));
|
||||
|
||||
const float16_t d = loadd.x * float16_t(sc);
|
||||
const float16_t m = loadd.y * float16_t(mbyte);
|
||||
|
||||
uint32_t dmask = 0xF << (b * 4);
|
||||
|
||||
float16_t ret = d * (float16_t((bl.block.qs[qsi ] & dmask) >> (b * 4)) + float16_t((bl.block.qh[qhi ] & hm) != 0 ? 16 : 0)) - m;
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
|
||||
block_q6_K block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint iqs = idx;
|
||||
|
||||
const uint n = iqs / 128; // 0,1
|
||||
const uint b = (iqs % 128) / 64; // 0,1
|
||||
const uint is_b = (iqs % 32) / 16; // 0,1
|
||||
const uint qhshift = ((iqs % 128) / 32) * 2;// 0,2,4,6
|
||||
const uint is = 8 * n + qhshift + is_b; // 0..15
|
||||
const uint qsi = n * 64 + (iqs % 64); // 0..127
|
||||
const uint qhi = n * 32 + (iqs % 32); // 0..63
|
||||
|
||||
const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]);
|
||||
|
||||
float16_t ret = dscale * float16_t(int8_t(((bl.block.ql[qsi ] >> (b * 4)) & 0xF) | (((bl.block.qh[qhi ] >> qhshift) & 3) << 4)) - 32);
|
||||
|
||||
return ret;
|
||||
}
|
||||
|
||||
#if defined(DATA_A_IQ4_NL)
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
|
||||
block_iq4_nl block;
|
||||
};
|
||||
|
||||
float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
const float16_t d = bl.block.d;
|
||||
const uint idx = coordInBlock[1];
|
||||
const uint iqs = idx & 0xF;
|
||||
const uint shift = (idx & 0x10) >> 2;
|
||||
uint32_t qs = bl.block.qs[iqs];
|
||||
qs >>= shift;
|
||||
qs &= 0xF;
|
||||
float16_t ret = float16_t(kvalues_iq4nl[qs]) * d;
|
||||
return ret;
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_0)
|
||||
#define dequantFuncA dequantFuncQ4_0
|
||||
#elif defined(DATA_A_Q4_1)
|
||||
#define dequantFuncA dequantFuncQ4_1
|
||||
#elif defined(DATA_A_Q5_0)
|
||||
#define dequantFuncA dequantFuncQ5_0
|
||||
#elif defined(DATA_A_Q5_1)
|
||||
#define dequantFuncA dequantFuncQ5_1
|
||||
#elif defined(DATA_A_Q8_0)
|
||||
#define dequantFuncA dequantFuncQ8_0
|
||||
#elif defined(DATA_A_Q2_K)
|
||||
#define dequantFuncA dequantFuncQ2_K
|
||||
#elif defined(DATA_A_Q3_K)
|
||||
#define dequantFuncA dequantFuncQ3_K
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
#define dequantFuncA dequantFuncQ4_K
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
#define dequantFuncA dequantFuncQ5_K
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
#define dequantFuncA dequantFuncQ6_K
|
||||
#elif defined(DATA_A_IQ4_NL)
|
||||
#define dequantFuncA dequantFuncIQ4_NL
|
||||
#endif
|
Loading…
Add table
Add a link
Reference in a new issue