revert and update
This commit is contained in:
parent
be5295bd51
commit
ac3973bfc3
14 changed files with 681 additions and 195 deletions
|
@ -2,6 +2,15 @@
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#include "types.comp"
|
||||||
|
|
||||||
|
#if defined(A_TYPE_PACKED16)
|
||||||
|
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
|
||||||
|
#endif
|
||||||
|
#if defined(A_TYPE_PACKED32)
|
||||||
|
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
|
||||||
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_F32)
|
#if defined(DATA_A_F32)
|
||||||
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
|
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);
|
||||||
|
@ -20,6 +29,11 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||||
return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
|
return (vec2(vui & 0xF, vui >> 4) - 8.0f) * d;
|
||||||
}
|
}
|
||||||
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||||
|
const float d = float(data_a_packed16[a_offset + ib].d);
|
||||||
|
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||||
|
return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) - 8.0f) * d;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_1)
|
#if defined(DATA_A_Q4_1)
|
||||||
|
@ -29,6 +43,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||||
return vec2(vui & 0xF, vui >> 4) * d + m;
|
return vec2(vui & 0xF, vui >> 4) * d + m;
|
||||||
}
|
}
|
||||||
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||||
|
const float d = float(data_a_packed16[a_offset + ib].d);
|
||||||
|
const float m = float(data_a_packed16[a_offset + ib].m);
|
||||||
|
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||||
|
return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, (vui >> 12) & 0xF) * d + m;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q5_0)
|
#if defined(DATA_A_Q5_0)
|
||||||
|
@ -39,6 +59,14 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||||
return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
|
return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f) * d;
|
||||||
}
|
}
|
||||||
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||||
|
const float d = float(data_a_packed16[a_offset + ib].d);
|
||||||
|
const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0];
|
||||||
|
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
|
||||||
|
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
|
||||||
|
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||||
|
return (vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) - 16.0f) * d;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q5_1)
|
#if defined(DATA_A_Q5_1)
|
||||||
|
@ -50,6 +78,15 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||||
return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
|
return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) * d + m;
|
||||||
}
|
}
|
||||||
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||||
|
const float d = float(data_a_packed16[a_offset + ib].d);
|
||||||
|
const float m = float(data_a_packed16[a_offset + ib].m);
|
||||||
|
const uint uint_qh = data_a_packed16[a_offset + ib].qh;
|
||||||
|
const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10);
|
||||||
|
const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10);
|
||||||
|
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||||
|
return vec4(((vui >> 0) & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * d + m;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q8_0)
|
#if defined(DATA_A_Q8_0)
|
||||||
|
@ -57,6 +94,12 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
const float d = float(data_a[a_offset + ib].d);
|
const float d = float(data_a[a_offset + ib].d);
|
||||||
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
|
return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])) * d;
|
||||||
}
|
}
|
||||||
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||||
|
const float d = float(data_a_packed16[a_offset + ib].d);
|
||||||
|
uint32_t v0 = data_a_packed16[a_offset + ib].qs[iqs/2];
|
||||||
|
uint32_t v1 = data_a_packed16[a_offset + ib].qs[iqs/2 + 1];
|
||||||
|
return vec4(int8_t(v0 & 0xFF), int8_t((v0 >> 8) & 0xFF), int8_t(v1 & 0xFF), int8_t((v1 >> 8) & 0xFF)) * d;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_IQ4_NL)
|
#if defined(DATA_A_IQ4_NL)
|
||||||
|
@ -65,4 +108,9 @@ vec2 dequantize(uint ib, uint iqs, uint a_offset) {
|
||||||
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
|
||||||
return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
|
return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]) * d;
|
||||||
}
|
}
|
||||||
|
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
|
||||||
|
const float d = float(data_a_packed16[a_offset + ib].d);
|
||||||
|
const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]);
|
||||||
|
return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[(vui >> 12) & 0xF]) * d;
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
|
@ -10,6 +10,8 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
|
||||||
void main() {
|
void main() {
|
||||||
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
|
const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64;
|
||||||
|
|
||||||
|
init_iq4nl_shmem();
|
||||||
|
|
||||||
const uint tid = gl_LocalInvocationID.x % 64;
|
const uint tid = gl_LocalInvocationID.x % 64;
|
||||||
const uint il = tid/32;
|
const uint il = tid/32;
|
||||||
const uint ir = tid%32;
|
const uint ir = tid%32;
|
||||||
|
|
|
@ -12,6 +12,10 @@ void main() {
|
||||||
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
|
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
|
||||||
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
|
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
|
||||||
|
|
||||||
|
#if defined(DATA_A_IQ4_NL)
|
||||||
|
init_iq4nl_shmem();
|
||||||
|
#endif
|
||||||
|
|
||||||
if (i00 >= p.ne00) {
|
if (i00 >= p.ne00) {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,49 +1,177 @@
|
||||||
#version 450
|
#version 450
|
||||||
|
|
||||||
#extension GL_KHR_shader_subgroup_arithmetic: enable
|
|
||||||
|
|
||||||
#ifdef FLOAT16
|
#ifdef FLOAT16
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||||
#endif
|
#endif
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||||
|
|
||||||
#include "mul_mat_vec_base.comp"
|
#include "mul_mat_vec_base.comp"
|
||||||
|
|
||||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||||
|
layout (constant_id = 1) const uint NUM_ROWS = 1;
|
||||||
|
|
||||||
void main() {
|
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
|
||||||
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
#define K_PER_ITER 8
|
||||||
const uint tid = gl_LocalInvocationID.x;
|
#else
|
||||||
|
#define K_PER_ITER 2
|
||||||
|
#endif
|
||||||
|
|
||||||
// There are not enough cols to use all threads
|
|
||||||
if (tid >= p.ncols) {
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
const uint block_size = min(p.ncols, BLOCK_SIZE);
|
uint a_offset, b_offset, d_offset, y_offset;
|
||||||
|
|
||||||
uint a_offset, b_offset, d_offset;
|
shared FLOAT_TYPE tmpsh[NUM_ROWS][BLOCK_SIZE];
|
||||||
get_offsets(a_offset, b_offset, d_offset);
|
|
||||||
|
|
||||||
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
void iter(inout FLOAT_TYPE temp[NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
|
||||||
|
{
|
||||||
FLOAT_TYPE tmp = FLOAT_TYPE(0.0f);
|
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
|
||||||
|
|
||||||
[[unroll]] for (uint i = 0; i < p.ncols/block_size; i += 2) {
|
|
||||||
const uint col = i*block_size + 2*tid;
|
|
||||||
const uint ib = (row*p.ncols + col)/QUANT_K; // block index
|
|
||||||
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
|
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
|
||||||
const uint iybs = col - col%QUANT_K; // y block start index
|
const uint iybs = col - col%QUANT_K; // y block start index
|
||||||
|
|
||||||
vec2 v = dequantize(ib, iqs, a_offset / QUANT_K);
|
#if K_PER_ITER == 8
|
||||||
|
#if QUANT_R == 2
|
||||||
|
B_TYPE_VEC4 bv02 = data_b_v4[(b_offset + iybs + iqs) / 4];
|
||||||
|
B_TYPE_VEC4 bv13 = data_b_v4[(b_offset + iybs + iqs + y_offset) / 4];
|
||||||
|
FLOAT_TYPE b0 = FLOAT_TYPE(bv02.x);
|
||||||
|
FLOAT_TYPE b1 = FLOAT_TYPE(bv13.x);
|
||||||
|
FLOAT_TYPE b2 = FLOAT_TYPE(bv02.y);
|
||||||
|
FLOAT_TYPE b3 = FLOAT_TYPE(bv13.y);
|
||||||
|
FLOAT_TYPE b4 = FLOAT_TYPE(bv02.z);
|
||||||
|
FLOAT_TYPE b5 = FLOAT_TYPE(bv13.z);
|
||||||
|
FLOAT_TYPE b6 = FLOAT_TYPE(bv02.w);
|
||||||
|
FLOAT_TYPE b7 = FLOAT_TYPE(bv13.w);
|
||||||
|
#else
|
||||||
|
B_TYPE_VEC4 bv0 = data_b_v4[(b_offset + iybs + iqs) / 4];
|
||||||
|
B_TYPE_VEC4 bv1 = data_b_v4[(b_offset + iybs + iqs) / 4 + 1];
|
||||||
|
FLOAT_TYPE b0 = FLOAT_TYPE(bv0.x);
|
||||||
|
FLOAT_TYPE b1 = FLOAT_TYPE(bv0.y);
|
||||||
|
FLOAT_TYPE b2 = FLOAT_TYPE(bv0.z);
|
||||||
|
FLOAT_TYPE b3 = FLOAT_TYPE(bv0.w);
|
||||||
|
FLOAT_TYPE b4 = FLOAT_TYPE(bv1.x);
|
||||||
|
FLOAT_TYPE b5 = FLOAT_TYPE(bv1.y);
|
||||||
|
FLOAT_TYPE b6 = FLOAT_TYPE(bv1.z);
|
||||||
|
FLOAT_TYPE b7 = FLOAT_TYPE(bv1.w);
|
||||||
|
#endif
|
||||||
|
#else
|
||||||
|
// Check if the second of the pair of elements is OOB, and don't fetch B or
|
||||||
|
// accumulate it. We still fetch a pair of elements for A, which is fine for
|
||||||
|
// quantized formats since they'll be within the same block. We should
|
||||||
|
// probably skip fetching the second element for F16/F32, but as of now we
|
||||||
|
// still do.
|
||||||
|
const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
|
||||||
|
|
||||||
|
FLOAT_TYPE b0 = 0, b1 = 0;
|
||||||
|
b0 = FLOAT_TYPE(data_b[b_offset + iybs + iqs]);
|
||||||
|
if (!OOB) {
|
||||||
|
b1 = FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
|
const uint ib = ((first_row + n)*p.ncols + col)/QUANT_K; // block index
|
||||||
|
|
||||||
|
#if K_PER_ITER == 8
|
||||||
|
const vec4 v = dequantize4(ib, iqs, a_offset);
|
||||||
|
const vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
|
||||||
|
|
||||||
// matrix multiplication
|
// matrix multiplication
|
||||||
tmp = fma(FLOAT_TYPE(v.x), FLOAT_TYPE(data_b[b_offset + iybs + iqs]), fma(FLOAT_TYPE(v.y), FLOAT_TYPE(data_b[b_offset + iybs + iqs + y_offset]), tmp));
|
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
|
||||||
|
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
|
||||||
|
temp[n] = fma(FLOAT_TYPE(v.z), b2, temp[n]);
|
||||||
|
temp[n] = fma(FLOAT_TYPE(v.w), b3, temp[n]);
|
||||||
|
temp[n] = fma(FLOAT_TYPE(v2.x), b4, temp[n]);
|
||||||
|
temp[n] = fma(FLOAT_TYPE(v2.y), b5, temp[n]);
|
||||||
|
temp[n] = fma(FLOAT_TYPE(v2.z), b6, temp[n]);
|
||||||
|
temp[n] = fma(FLOAT_TYPE(v2.w), b7, temp[n]);
|
||||||
|
#else
|
||||||
|
const vec2 v = dequantize(ib, iqs, a_offset);
|
||||||
|
|
||||||
|
// matrix multiplication
|
||||||
|
temp[n] = fma(FLOAT_TYPE(v.x), b0, temp[n]);
|
||||||
|
if (!OOB) {
|
||||||
|
temp[n] = fma(FLOAT_TYPE(v.y), b1, temp[n]);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
|
|
||||||
|
get_offsets(a_offset, b_offset, d_offset);
|
||||||
|
a_offset /= QUANT_K;
|
||||||
|
|
||||||
|
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
|
||||||
|
|
||||||
|
FLOAT_TYPE temp[NUM_ROWS];
|
||||||
|
|
||||||
|
for (uint i = 0; i < NUM_ROWS; ++i) {
|
||||||
|
temp[i] = FLOAT_TYPE(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
|
||||||
|
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
|
||||||
|
num_iters++;
|
||||||
|
}
|
||||||
|
int unroll_count = 4;
|
||||||
|
uint unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||||
|
|
||||||
|
uint i = 0;
|
||||||
|
while (i < unrolled_iters) {
|
||||||
|
// Manually partially unroll the loop
|
||||||
|
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||||
|
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
unroll_count = 2;
|
||||||
|
unrolled_iters = num_iters & ~(unroll_count - 1);
|
||||||
|
while (i < unrolled_iters) {
|
||||||
|
// Manually partially unroll the loop
|
||||||
|
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
|
||||||
|
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
while (i < num_iters) {
|
||||||
|
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
|
||||||
|
i++;
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
tmp = subgroupAdd(tmp);
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
if (tid == 0)
|
tmpsh[n][tid] = temp[n];
|
||||||
data_d[d_offset + row] = D_TYPE(tmp);
|
}
|
||||||
|
barrier();
|
||||||
|
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
|
tmpsh[n][tid] += tmpsh[n][tid + s];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
if (tid == 0) {
|
||||||
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
|
data_d[d_offset + first_row + n] = D_TYPE(tmpsh[n][0]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
|
||||||
|
|
||||||
|
#if defined(DATA_A_IQ4_NL)
|
||||||
|
init_iq4nl_shmem();
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// do NUM_ROWS at a time, unless there aren't enough remaining rows
|
||||||
|
if (first_row + NUM_ROWS <= p.stride_d) {
|
||||||
|
compute_outputs(first_row, NUM_ROWS);
|
||||||
|
} else {
|
||||||
|
if (first_row >= p.stride_d) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
compute_outputs(first_row, p.stride_d - first_row);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,6 +12,9 @@
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
||||||
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
||||||
|
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
|
||||||
|
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
||||||
|
|
||||||
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
layout (binding = 3) readonly buffer IDS {int data_ids[];};
|
||||||
|
|
|
@ -9,6 +9,10 @@ shared FLOAT_TYPE tmp[32];
|
||||||
void main() {
|
void main() {
|
||||||
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
||||||
|
|
||||||
|
if (row >= p.stride_d) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
uint a_offset, b_offset, d_offset;
|
uint a_offset, b_offset, d_offset;
|
||||||
get_offsets(a_offset, b_offset, d_offset);
|
get_offsets(a_offset, b_offset, d_offset);
|
||||||
|
|
||||||
|
|
|
@ -9,6 +9,10 @@ shared FLOAT_TYPE tmp[32];
|
||||||
void main() {
|
void main() {
|
||||||
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
||||||
|
|
||||||
|
if (row >= p.stride_d) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
uint a_offset, b_offset, d_offset;
|
uint a_offset, b_offset, d_offset;
|
||||||
get_offsets(a_offset, b_offset, d_offset);
|
get_offsets(a_offset, b_offset, d_offset);
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,21 @@
|
||||||
#version 450
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||||
|
|
||||||
#include "mul_mat_vec_base.comp"
|
#include "mul_mat_vec_base.comp"
|
||||||
|
|
||||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
shared FLOAT_TYPE tmp[32];
|
shared FLOAT_TYPE tmp[32];
|
||||||
|
|
||||||
|
// This shader assumes K_QUANTS_PER_ITERATION == 2 for alignment of loads
|
||||||
void main() {
|
void main() {
|
||||||
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
||||||
|
|
||||||
|
if (row >= p.stride_d) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
uint a_offset, b_offset, d_offset;
|
uint a_offset, b_offset, d_offset;
|
||||||
get_offsets(a_offset, b_offset, d_offset);
|
get_offsets(a_offset, b_offset, d_offset);
|
||||||
|
|
||||||
|
@ -31,79 +38,81 @@ void main() {
|
||||||
const uint q_offset = 32*v_im + l0;
|
const uint q_offset = 32*v_im + l0;
|
||||||
const uint y_offset = 64*v_im + l0;
|
const uint y_offset = 64*v_im + l0;
|
||||||
|
|
||||||
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
||||||
|
|
||||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||||
const uint y1_idx = i * QUANT_K + y_offset;
|
const uint y1_idx = i * QUANT_K + y_offset;
|
||||||
const uint y2_idx = y1_idx + 128;
|
const uint y2_idx = y1_idx + 128;
|
||||||
|
|
||||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
|
f16vec2 d = data_a[ib0 + i].d;
|
||||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
|
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||||
|
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||||
|
|
||||||
const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
|
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||||
const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
|
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||||
const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
|
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
|
||||||
const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
|
uvec4 scale0 = uvec4(unpack8(scale0_u32));
|
||||||
const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
|
uvec4 scale4 = uvec4(unpack8(scale4_u32));
|
||||||
const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
|
uvec4 scale8 = uvec4(unpack8(scale8_u32));
|
||||||
const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
|
|
||||||
const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
|
|
||||||
|
|
||||||
#if K_QUANTS_PER_ITERATION == 2
|
const uint32_t sc0 = ( scale0.x & 0x3f);
|
||||||
const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
|
const uint32_t sc1 = ( scale0.y & 0x3f);
|
||||||
const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
|
const uint32_t sc2 = ( scale4.x & 0x3f);
|
||||||
const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] & 0xf);
|
const uint32_t sc3 = ( scale4.y & 0x3f);
|
||||||
const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] & 0xf);
|
const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
|
||||||
const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
|
const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
|
||||||
const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
|
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
|
||||||
const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 2] >> 4);
|
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
|
||||||
const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 3] >> 4);
|
|
||||||
const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
|
|
||||||
const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
|
|
||||||
const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] & 0xf);
|
|
||||||
const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] & 0xf);
|
|
||||||
const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
|
|
||||||
const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
|
|
||||||
const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 66] >> 4);
|
|
||||||
const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 67] >> 4);
|
|
||||||
|
|
||||||
const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx]), q4_0, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), q4_1, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 3]) * q4_3)));
|
uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4];
|
||||||
const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_4, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), q4_5, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), q4_6, FLOAT_TYPE(data_b[b_offset + y1_idx + 35]) * q4_7)));
|
uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16];
|
||||||
const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx]), q4_8, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), q4_9, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), q4_10, FLOAT_TYPE(data_b[b_offset + y2_idx + 3]) * q4_11)));
|
|
||||||
const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_12, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), q4_13, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), q4_14, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * q4_15)));
|
uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F;
|
||||||
|
uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F;
|
||||||
|
uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F;
|
||||||
|
uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F;
|
||||||
|
|
||||||
|
uvec4 qs0_lo4 = uvec4(unpack8(qs0_u32_lo4));
|
||||||
|
uvec4 qs64_lo4 = uvec4(unpack8(qs64_u32_lo4));
|
||||||
|
uvec4 qs0_hi4 = uvec4(unpack8(qs0_u32_hi4));
|
||||||
|
uvec4 qs64_hi4 = uvec4(unpack8(qs64_u32_hi4));
|
||||||
|
|
||||||
|
const uint32_t q4_0 = qs0_lo4.x;
|
||||||
|
const uint32_t q4_1 = qs0_lo4.y;
|
||||||
|
const uint32_t q4_2 = qs0_lo4.z;
|
||||||
|
const uint32_t q4_3 = qs0_lo4.w;
|
||||||
|
const uint32_t q4_4 = qs0_hi4.x;
|
||||||
|
const uint32_t q4_5 = qs0_hi4.y;
|
||||||
|
const uint32_t q4_6 = qs0_hi4.z;
|
||||||
|
const uint32_t q4_7 = qs0_hi4.w;
|
||||||
|
const uint32_t q4_8 = qs64_lo4.x;
|
||||||
|
const uint32_t q4_9 = qs64_lo4.y;
|
||||||
|
const uint32_t q4_10 = qs64_lo4.z;
|
||||||
|
const uint32_t q4_11 = qs64_lo4.w;
|
||||||
|
const uint32_t q4_12 = qs64_hi4.x;
|
||||||
|
const uint32_t q4_13 = qs64_hi4.y;
|
||||||
|
const uint32_t q4_14 = qs64_hi4.z;
|
||||||
|
const uint32_t q4_15 = qs64_hi4.w;
|
||||||
|
|
||||||
|
B_TYPE_VEC4 by10 = data_b_v4[(b_offset + y1_idx) / 4];
|
||||||
|
B_TYPE_VEC4 by132 = data_b_v4[(b_offset + y1_idx) / 4 + 8];
|
||||||
|
B_TYPE_VEC4 by20 = data_b_v4[(b_offset + y2_idx) / 4];
|
||||||
|
B_TYPE_VEC4 by232 = data_b_v4[(b_offset + y2_idx) / 4 + 8];
|
||||||
|
|
||||||
|
const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3)));
|
||||||
|
const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7)));
|
||||||
|
const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11)));
|
||||||
|
const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15)));
|
||||||
const FLOAT_TYPE smin =
|
const FLOAT_TYPE smin =
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
|
fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7,
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), sc7,
|
fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7,
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 2]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 34]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 2]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 34]), sc7,
|
fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7,
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 3]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 35]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 3]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 35]) * sc7)))))))))))))));
|
fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7)))))))))))))));
|
||||||
const uint tmp_idx = 16 * ix + tid;
|
temp = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp));
|
||||||
tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
|
|
||||||
#else
|
|
||||||
const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
|
|
||||||
const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
|
|
||||||
const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
|
|
||||||
const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
|
|
||||||
const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
|
|
||||||
const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
|
|
||||||
const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
|
|
||||||
const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
|
|
||||||
|
|
||||||
const FLOAT_TYPE sx = fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), q4_0, FLOAT_TYPE(data_b[b_offset + y1_idx + 1]) * q4_1);
|
|
||||||
const FLOAT_TYPE sy = fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), q4_2, FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) * q4_3);
|
|
||||||
const FLOAT_TYPE sz = fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), q4_4, FLOAT_TYPE(data_b[b_offset + y2_idx + 1]) * q4_5);
|
|
||||||
const FLOAT_TYPE sw = fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), q4_6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * q4_7);
|
|
||||||
const FLOAT_TYPE smin =
|
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), sc6, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), sc7,
|
|
||||||
+ fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), sc2, fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), sc3, fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), sc6, FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) * sc7)))))));
|
|
||||||
|
|
||||||
tmp[16 * ix + tid] += FLOAT_TYPE(dall * (sx * FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f) + sy * FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f) +
|
|
||||||
sz * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)) + sw * FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))) - dmin * smin);
|
|
||||||
const uint tmp_idx = 16 * ix + tid;
|
|
||||||
tmp[tmp_idx] = fma(dall, (fma(sx, FLOAT_TYPE(data_a[ib0 + i].scales[v_im] & 0x3f), fma(sy, FLOAT_TYPE(data_a[ib0 + i].scales[v_im + 1] & 0x3f),
|
|
||||||
fma(sz, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 4] & 0x0f) | ((data_a[ib0 + i].scales[v_im] & 0xc0) >> 2)), fma(sw, FLOAT_TYPE((data_a[ib0 + i].scales[v_im + 5] & 0x0f) | ((data_a[ib0 + i].scales[v_im + 1] & 0xc0) >> 2))))))), fma(-dmin, smin, tmp[tmp_idx]));
|
|
||||||
#endif
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tmp[gl_LocalInvocationID.x] = temp;
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
barrier();
|
barrier();
|
||||||
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
|
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
|
||||||
|
|
|
@ -1,5 +1,7 @@
|
||||||
#version 450
|
#version 450
|
||||||
|
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||||
|
|
||||||
#include "mul_mat_vec_base.comp"
|
#include "mul_mat_vec_base.comp"
|
||||||
|
|
||||||
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
@ -9,6 +11,10 @@ shared FLOAT_TYPE tmp[32];
|
||||||
void main() {
|
void main() {
|
||||||
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
||||||
|
|
||||||
|
if (row >= p.stride_d) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
uint a_offset, b_offset, d_offset;
|
uint a_offset, b_offset, d_offset;
|
||||||
get_offsets(a_offset, b_offset, d_offset);
|
get_offsets(a_offset, b_offset, d_offset);
|
||||||
|
|
||||||
|
@ -31,70 +37,106 @@ void main() {
|
||||||
const uint8_t hm1 = uint8_t(1 << (2*v_im));
|
const uint8_t hm1 = uint8_t(1 << (2*v_im));
|
||||||
const uint8_t hm2 = uint8_t(hm1 << 4);
|
const uint8_t hm2 = uint8_t(hm1 << 4);
|
||||||
|
|
||||||
tmp[16 * ix + tid] = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
||||||
|
|
||||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
|
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
|
||||||
const uint y1_idx = i * QUANT_K + y_offset;
|
const uint y1_idx = i * QUANT_K + y_offset;
|
||||||
const uint y2_idx = y1_idx + 128;
|
const uint y2_idx = y1_idx + 128;
|
||||||
|
|
||||||
const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib0 + i].d.x);
|
f16vec2 d = data_a[ib0 + i].d;
|
||||||
const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib0 + i].d.y);
|
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||||
|
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||||
|
|
||||||
const uint8_t sc0 = uint8_t( data_a[ib0 + i].scales[v_im * 2 ] & 0x3f);
|
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||||
const uint8_t sc1 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 1] & 0x3f);
|
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||||
const uint8_t sc2 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 4] & 0x3f);
|
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
|
||||||
const uint8_t sc3 = uint8_t( data_a[ib0 + i].scales[v_im * 2 + 5] & 0x3f);
|
uvec4 scale0 = uvec4(unpack8(scale0_u32));
|
||||||
const uint8_t sc4 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 8] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 ] & 0xc0) >> 2));
|
uvec4 scale4 = uvec4(unpack8(scale4_u32));
|
||||||
const uint8_t sc5 = uint8_t(( data_a[ib0 + i].scales[v_im * 2 + 9] & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 1] & 0xc0) >> 2));
|
uvec4 scale8 = uvec4(unpack8(scale8_u32));
|
||||||
const uint8_t sc6 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 8] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 4] & 0xc0) >> 2));
|
|
||||||
const uint8_t sc7 = uint8_t(((data_a[ib0 + i].scales[v_im * 2 + 9] >> 4) & 0x0f) | ((data_a[ib0 + i].scales[v_im * 2 + 5] & 0xc0) >> 2));
|
|
||||||
|
|
||||||
const uint8_t q4_0 = uint8_t(data_a[ib0 + i].qs[q_offset ] & 0xf);
|
const uint32_t sc0 = ( scale0.x & 0x3f);
|
||||||
const uint8_t q4_1 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] & 0xf);
|
const uint32_t sc1 = ( scale0.y & 0x3f);
|
||||||
const uint8_t q4_2 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] & 0xf);
|
const uint32_t sc2 = ( scale4.x & 0x3f);
|
||||||
const uint8_t q4_3 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] & 0xf);
|
const uint32_t sc3 = ( scale4.y & 0x3f);
|
||||||
const uint8_t q4_4 = uint8_t(data_a[ib0 + i].qs[q_offset ] >> 4);
|
const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
|
||||||
const uint8_t q4_5 = uint8_t(data_a[ib0 + i].qs[q_offset + 1] >> 4);
|
const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
|
||||||
const uint8_t q4_6 = uint8_t(data_a[ib0 + i].qs[q_offset + 16] >> 4);
|
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
|
||||||
const uint8_t q4_7 = uint8_t(data_a[ib0 + i].qs[q_offset + 17] >> 4);
|
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
|
||||||
const uint8_t q4_8 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] & 0xf);
|
|
||||||
const uint8_t q4_9 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] & 0xf);
|
uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
|
||||||
const uint8_t q4_10 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] & 0xf);
|
uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
|
||||||
const uint8_t q4_11 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] & 0xf);
|
|
||||||
const uint8_t q4_12 = uint8_t(data_a[ib0 + i].qs[q_offset + 64] >> 4);
|
uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
|
||||||
const uint8_t q4_13 = uint8_t(data_a[ib0 + i].qs[q_offset + 65] >> 4);
|
uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
|
||||||
const uint8_t q4_14 = uint8_t(data_a[ib0 + i].qs[q_offset + 80] >> 4);
|
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
|
||||||
const uint8_t q4_15 = uint8_t(data_a[ib0 + i].qs[q_offset + 81] >> 4);
|
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
|
||||||
|
|
||||||
|
uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
|
||||||
|
uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
|
||||||
|
uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
|
||||||
|
uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4));
|
||||||
|
|
||||||
|
const uint32_t q4_0 = qs0_16_lo4.x;
|
||||||
|
const uint32_t q4_1 = qs0_16_lo4.y;
|
||||||
|
const uint32_t q4_2 = qs0_16_lo4.z;
|
||||||
|
const uint32_t q4_3 = qs0_16_lo4.w;
|
||||||
|
const uint32_t q4_4 = qs0_16_hi4.x;
|
||||||
|
const uint32_t q4_5 = qs0_16_hi4.y;
|
||||||
|
const uint32_t q4_6 = qs0_16_hi4.z;
|
||||||
|
const uint32_t q4_7 = qs0_16_hi4.w;
|
||||||
|
const uint32_t q4_8 = qs64_80_lo4.x;
|
||||||
|
const uint32_t q4_9 = qs64_80_lo4.y;
|
||||||
|
const uint32_t q4_10 = qs64_80_lo4.z;
|
||||||
|
const uint32_t q4_11 = qs64_80_lo4.w;
|
||||||
|
const uint32_t q4_12 = qs64_80_hi4.x;
|
||||||
|
const uint32_t q4_13 = qs64_80_hi4.y;
|
||||||
|
const uint32_t q4_14 = qs64_80_hi4.z;
|
||||||
|
const uint32_t q4_15 = qs64_80_hi4.w;
|
||||||
|
|
||||||
|
B_TYPE_VEC2 by10 = data_b_v2[(b_offset + y1_idx) / 2];
|
||||||
|
B_TYPE_VEC2 by116 = data_b_v2[(b_offset + y1_idx) / 2 + 8];
|
||||||
|
B_TYPE_VEC2 by132 = data_b_v2[(b_offset + y1_idx) / 2 + 16];
|
||||||
|
B_TYPE_VEC2 by148 = data_b_v2[(b_offset + y1_idx) / 2 + 24];
|
||||||
|
B_TYPE_VEC2 by20 = data_b_v2[(b_offset + y2_idx) / 2];
|
||||||
|
B_TYPE_VEC2 by216 = data_b_v2[(b_offset + y2_idx) / 2 + 8];
|
||||||
|
B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
|
||||||
|
B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
|
||||||
|
|
||||||
|
uint32_t qh0 = data_a_packed16[ib0 + i].qh[l0 / 2];
|
||||||
|
uint32_t qh1 = qh0 >> 8;
|
||||||
|
uint32_t qh16 = data_a_packed16[ib0 + i].qh[l0 / 2 + 8];
|
||||||
|
uint32_t qh17 = qh16 >> 8;
|
||||||
|
|
||||||
const FLOAT_TYPE sx =
|
const FLOAT_TYPE sx =
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]), (q4_0 + (((data_a[ib0 + i].qh[l0 ] & hm1) != 0) ? 16 : 0)),
|
fma(FLOAT_TYPE(by10.x), (q4_0 + (((qh0 & hm1) != 0) ? 16 : 0)),
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 1]), (q4_1 + (((data_a[ib0 + i].qh[l0 + 1] & hm1) != 0) ? 16 : 0)),
|
fma(FLOAT_TYPE(by10.y), (q4_1 + (((qh1 & hm1) != 0) ? 16 : 0)),
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 16]), (q4_2 + (((data_a[ib0 + i].qh[l0 + 16] & hm1) != 0) ? 16 : 0)),
|
fma(FLOAT_TYPE(by116.x), (q4_2 + (((qh16 & hm1) != 0) ? 16 : 0)),
|
||||||
FLOAT_TYPE(data_b[b_offset + y1_idx + 17]) * (q4_3 + (((data_a[ib0 + i].qh[l0 + 17] & hm1) != 0) ? 16 : 0)))));
|
FLOAT_TYPE(by116.y) * (q4_3 + (((qh17 & hm1) != 0) ? 16 : 0)))));
|
||||||
const FLOAT_TYPE sy =
|
const FLOAT_TYPE sy =
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]), (q4_4 + (((data_a[ib0 + i].qh[l0 ] & (hm1 << 1)) != 0) ? 16 : 0)),
|
fma(FLOAT_TYPE(by132.x), (q4_4 + (((qh0 & (hm1 << 1)) != 0) ? 16 : 0)),
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 33]), (q4_5 + (((data_a[ib0 + i].qh[l0 + 1] & (hm1 << 1)) != 0) ? 16 : 0)),
|
fma(FLOAT_TYPE(by132.y), (q4_5 + (((qh1 & (hm1 << 1)) != 0) ? 16 : 0)),
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 48]), (q4_6 + (((data_a[ib0 + i].qh[l0 + 16] & (hm1 << 1)) != 0) ? 16 : 0)),
|
fma(FLOAT_TYPE(by148.x), (q4_6 + (((qh16 & (hm1 << 1)) != 0) ? 16 : 0)),
|
||||||
FLOAT_TYPE(data_b[b_offset + y1_idx + 49]) * (q4_7 + (((data_a[ib0 + i].qh[l0 + 17] & (hm1 << 1)) != 0) ? 16 : 0)))));
|
FLOAT_TYPE(by148.y) * (q4_7 + (((qh17 & (hm1 << 1)) != 0) ? 16 : 0)))));
|
||||||
const FLOAT_TYPE sz =
|
const FLOAT_TYPE sz =
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]), (q4_8 + (((data_a[ib0 + i].qh[l0 ] & hm2) != 0) ? 16 : 0)),
|
fma(FLOAT_TYPE(by20.x), (q4_8 + (((qh0 & hm2) != 0) ? 16 : 0)),
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 1]), (q4_9 + (((data_a[ib0 + i].qh[l0 + 1] & hm2) != 0) ? 16 : 0)),
|
fma(FLOAT_TYPE(by20.y), (q4_9 + (((qh1 & hm2) != 0) ? 16 : 0)),
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 16]), (q4_10 + (((data_a[ib0 + i].qh[l0 + 16] & hm2) != 0) ? 16 : 0)),
|
fma(FLOAT_TYPE(by216.x), (q4_10 + (((qh16 & hm2) != 0) ? 16 : 0)),
|
||||||
FLOAT_TYPE(data_b[b_offset + y2_idx + 17]) * (q4_11 + (((data_a[ib0 + i].qh[l0 + 17] & hm2) != 0) ? 16 : 0)))));
|
FLOAT_TYPE(by216.y) * (q4_11 + (((qh17 & hm2) != 0) ? 16 : 0)))));
|
||||||
const FLOAT_TYPE sw =
|
const FLOAT_TYPE sw =
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]), (q4_12 + (((data_a[ib0 + i].qh[l0 ] & (hm2 << 1)) != 0) ? 16 : 0)),
|
fma(FLOAT_TYPE(by232.x), (q4_12 + (((qh0 & (hm2 << 1)) != 0) ? 16 : 0)),
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 33]), (q4_13 + (((data_a[ib0 + i].qh[l0 + 1] & (hm2 << 1)) != 0) ? 16 : 0)),
|
fma(FLOAT_TYPE(by232.y), (q4_13 + (((qh1 & (hm2 << 1)) != 0) ? 16 : 0)),
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y2_idx + 48]), (q4_14 + (((data_a[ib0 + i].qh[l0 + 16] & (hm2 << 1)) != 0) ? 16 : 0)),
|
fma(FLOAT_TYPE(by248.x), (q4_14 + (((qh16 & (hm2 << 1)) != 0) ? 16 : 0)),
|
||||||
FLOAT_TYPE(data_b[b_offset + y2_idx + 49]) * (q4_15 + (((data_a[ib0 + i].qh[l0 + 17] & (hm2 << 1)) != 0) ? 16 : 0)))));
|
FLOAT_TYPE(by248.y) * (q4_15 + (((qh17 & (hm2 << 1)) != 0) ? 16 : 0)))));
|
||||||
const FLOAT_TYPE smin =
|
const FLOAT_TYPE smin =
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 17]), sc2,
|
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y1_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y1_idx + 49]), sc3,
|
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y2_idx ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 1 ]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 16]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 17]), sc6,
|
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
|
||||||
(FLOAT_TYPE(data_b[b_offset + y2_idx + 32]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 33]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 48]) + FLOAT_TYPE(data_b[b_offset + y2_idx + 49])) * sc7)));
|
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
|
||||||
const uint tmp_idx = 16 * ix + tid;
|
temp = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp));
|
||||||
tmp[tmp_idx] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, tmp[tmp_idx]));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
tmp[gl_LocalInvocationID.x] = temp;
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
barrier();
|
barrier();
|
||||||
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
|
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
|
||||||
|
|
|
@ -1,58 +1,110 @@
|
||||||
#version 450
|
#version 450
|
||||||
|
|
||||||
#extension GL_KHR_shader_subgroup_arithmetic: enable
|
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||||
|
|
||||||
#include "mul_mat_vec_base.comp"
|
#include "mul_mat_vec_base.comp"
|
||||||
|
|
||||||
layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in;
|
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
|
shared FLOAT_TYPE tmp[32];
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
||||||
|
|
||||||
|
if (row >= p.stride_d) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
uint a_offset, b_offset, d_offset;
|
uint a_offset, b_offset, d_offset;
|
||||||
get_offsets(a_offset, b_offset, d_offset);
|
get_offsets(a_offset, b_offset, d_offset);
|
||||||
|
|
||||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||||
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
||||||
|
|
||||||
const uint tid_64 = gl_LocalInvocationID.x;
|
const uint tid = gl_LocalInvocationID.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
|
||||||
const uint tid_group = tid_64/32;
|
const uint ix = gl_LocalInvocationID.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
|
||||||
|
|
||||||
const uint tid = (tid_64%32)/2; // 0...31 or 0...16
|
const uint step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
|
||||||
const uint ix = (tid_64%32)%2; // 0 or 0, 1
|
|
||||||
|
|
||||||
const uint loop_start = 0 + tid_group*2;
|
|
||||||
const uint loop_end = 2 + tid_group*2;
|
|
||||||
|
|
||||||
const uint step = 16/2; // 16 or 8
|
|
||||||
|
|
||||||
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
const uint v_im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||||
const uint v_in = tid - step*v_im; // 0...15 or 0...7
|
const uint v_in = tid - step*v_im; // 0...15 or 0...7
|
||||||
|
|
||||||
|
#if K_QUANTS_PER_ITERATION == 1
|
||||||
|
const uint l0 = v_in; // 0...15
|
||||||
|
const uint is = 0;
|
||||||
|
#else
|
||||||
const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
|
const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
|
||||||
const uint is = v_in / 4;
|
const uint is = v_in / 4;
|
||||||
|
#endif
|
||||||
|
|
||||||
const uint ql_offset = 64*v_im + l0;
|
const uint ql_offset = 64*v_im + l0;
|
||||||
const uint qh_offset = 32*v_im + l0;
|
const uint qh_offset = 32*v_im + l0;
|
||||||
const uint s_offset = 8*v_im + is;
|
const uint s_offset = 8*v_im + is;
|
||||||
const uint y_offset = 128*v_im + l0;
|
const uint y_offset = 128*v_im + l0;
|
||||||
|
|
||||||
FLOAT_TYPE tmp = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
||||||
|
|
||||||
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
|
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
|
||||||
const uint y_idx = i * QUANT_K + y_offset;
|
const uint y_idx = i * QUANT_K + y_offset;
|
||||||
|
|
||||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
||||||
|
|
||||||
[[unroll]] for (uint l = loop_start; l < loop_end; ++l) {
|
FLOAT_TYPE scales[4];
|
||||||
tmp = fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+ 0]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 0) & 3) << 4)) - 32),
|
scales[0] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 0]);
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+32]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] & 0xF) | (((data_a[ib0 + i].qh[qh_offset + l] >> 2) & 3) << 4)) - 32),
|
scales[1] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 2]);
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+64]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+ 0] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 4) & 3) << 4)) - 32),
|
scales[2] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 4]);
|
||||||
fma(FLOAT_TYPE(data_b[b_offset + y_idx + l+96]) * FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]) * d, FLOAT_TYPE(int8_t((data_a[ib0 + i].ql[ql_offset + l+32] >> 4) | (((data_a[ib0 + i].qh[qh_offset + l] >> 6) & 3) << 4)) - 32), tmp))));
|
scales[3] = FLOAT_TYPE(data_a[ib0 + i].scales[s_offset + 6]);
|
||||||
|
|
||||||
|
uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16);
|
||||||
|
uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16);
|
||||||
|
|
||||||
|
uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F;
|
||||||
|
uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F;
|
||||||
|
uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F;
|
||||||
|
uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F;
|
||||||
|
|
||||||
|
uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16);
|
||||||
|
uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4;
|
||||||
|
uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2;
|
||||||
|
uint32_t qh4_u32 = (qh_u32 & 0x30303030) << 0;
|
||||||
|
uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2;
|
||||||
|
|
||||||
|
uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32;
|
||||||
|
uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32;
|
||||||
|
uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32;
|
||||||
|
uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32;
|
||||||
|
|
||||||
|
uvec4 q0 = uvec4(unpack8(q0_u32));
|
||||||
|
uvec4 q1 = uvec4(unpack8(q1_u32));
|
||||||
|
uvec4 q2 = uvec4(unpack8(q2_u32));
|
||||||
|
uvec4 q3 = uvec4(unpack8(q3_u32));
|
||||||
|
|
||||||
|
B_TYPE_VEC4 by0 = data_b_v4[(b_offset + y_idx) / 4];
|
||||||
|
B_TYPE_VEC4 by32 = data_b_v4[(b_offset + y_idx) / 4 + 8];
|
||||||
|
B_TYPE_VEC4 by64 = data_b_v4[(b_offset + y_idx) / 4 + 16];
|
||||||
|
B_TYPE_VEC4 by96 = data_b_v4[(b_offset + y_idx) / 4 + 24];
|
||||||
|
|
||||||
|
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||||
|
[[unroll]] for (int l = 0; l < 4; ++l) {
|
||||||
|
sum = fma(FLOAT_TYPE(by0[l]) * scales[0], FLOAT_TYPE(int8_t(q0[l]) - 32),
|
||||||
|
fma(FLOAT_TYPE(by32[l]) * scales[1], FLOAT_TYPE(int8_t(q1[l]) - 32),
|
||||||
|
fma(FLOAT_TYPE(by64[l]) * scales[2], FLOAT_TYPE(int8_t(q2[l]) - 32),
|
||||||
|
fma(FLOAT_TYPE(by96[l]) * scales[3], FLOAT_TYPE(int8_t(q3[l]) - 32), sum))));
|
||||||
}
|
}
|
||||||
|
temp += sum * d;
|
||||||
}
|
}
|
||||||
|
|
||||||
tmp = subgroupAdd(tmp);
|
tmp[gl_LocalInvocationID.x] = temp;
|
||||||
if (tid == 0)
|
|
||||||
data_d[d_offset + row] = D_TYPE(tmp);
|
// sum up partial sums and write back result
|
||||||
|
barrier();
|
||||||
|
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
|
||||||
|
if (tid < s) {
|
||||||
|
tmp[tid] += tmp[tid + s];
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
|
if (tid == 0) {
|
||||||
|
data_d[d_offset + row] = D_TYPE(tmp[0]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -75,6 +75,10 @@ shared u16vec2 row_ids[3072];
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
void main() {
|
void main() {
|
||||||
|
#if defined(DATA_A_IQ4_NL)
|
||||||
|
init_iq4nl_shmem();
|
||||||
|
#endif
|
||||||
|
|
||||||
#ifdef MUL_MAT_ID
|
#ifdef MUL_MAT_ID
|
||||||
const uint expert_idx = gl_GlobalInvocationID.z;
|
const uint expert_idx = gl_GlobalInvocationID.z;
|
||||||
#else
|
#else
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
#version 450
|
#version 450
|
||||||
|
|
||||||
#extension GL_EXT_shader_16bit_storage : require
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
||||||
|
#extension GL_EXT_control_flow_attributes : enable
|
||||||
|
|
||||||
layout (push_constant) uniform parameter
|
layout (push_constant) uniform parameter
|
||||||
{
|
{
|
||||||
|
@ -11,14 +12,13 @@ layout (push_constant) uniform parameter
|
||||||
float m0;
|
float m0;
|
||||||
float m1;
|
float m1;
|
||||||
uint n_head_log2;
|
uint n_head_log2;
|
||||||
|
uint nrows_x;
|
||||||
} p;
|
} p;
|
||||||
|
|
||||||
#include "types.comp"
|
#include "types.comp"
|
||||||
|
|
||||||
#extension GL_EXT_control_flow_attributes : enable
|
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
|
||||||
#define BLOCK_SIZE 512
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||||
|
|
||||||
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
|
||||||
|
|
||||||
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
|
||||||
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
|
layout (binding = 1) readonly buffer Y {B_TYPE data_b[];};
|
||||||
|
@ -26,11 +26,18 @@ layout (binding = 2) buffer D {D_TYPE data_d[];};
|
||||||
|
|
||||||
shared FLOAT_TYPE vals[BLOCK_SIZE];
|
shared FLOAT_TYPE vals[BLOCK_SIZE];
|
||||||
|
|
||||||
void main() {
|
// num_iters is the number of BLOCK_SIZE loop iterations we need to iterate
|
||||||
|
// over all the columns. The main function tries to pass a constant here,
|
||||||
|
// as if it were a template function, to allow unrolling.
|
||||||
|
void soft_max(uint num_iters) {
|
||||||
const uint tid = gl_LocalInvocationID.x;
|
const uint tid = gl_LocalInvocationID.x;
|
||||||
const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
|
||||||
const uint rowy = rowx % p.KY;
|
const uint rowy = rowx % p.KY;
|
||||||
|
|
||||||
|
if (rowx >= p.nrows_x) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
float slope = 1.0f;
|
float slope = 1.0f;
|
||||||
|
|
||||||
// ALiBi
|
// ALiBi
|
||||||
|
@ -46,19 +53,39 @@ void main() {
|
||||||
// Find max
|
// Find max
|
||||||
FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
|
FLOAT_TYPE max_val = uintBitsToFloat(0xFF800000);
|
||||||
|
|
||||||
[[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
|
// Cache values while we compute the max, so we don't need to read them
|
||||||
|
// again when we're ready to compute exp(x-max).
|
||||||
|
const uint DATA_CACHE_SIZE = 16;
|
||||||
|
FLOAT_TYPE data_cache[DATA_CACHE_SIZE];
|
||||||
|
|
||||||
|
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
|
||||||
const uint col = col0 + tid;
|
const uint col = col0 + tid;
|
||||||
|
|
||||||
if (col >= p.KX) {
|
FLOAT_TYPE a = FLOAT_TYPE(0);
|
||||||
break;
|
if (col < p.KX) {
|
||||||
|
a = data_a[rowx * p.KX + col];
|
||||||
}
|
}
|
||||||
|
|
||||||
max_val = max(max_val, FLOAT_TYPE(data_a[rowx * p.KX + col]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)));
|
FLOAT_TYPE b = FLOAT_TYPE(0);
|
||||||
|
if (p.KY > 0 && col < p.KX) {
|
||||||
|
b = data_b[rowy * p.KX + col];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
FLOAT_TYPE v = a * p.scale + slope * b;
|
||||||
|
|
||||||
|
if (col < p.KX) {
|
||||||
|
max_val = max(max_val, v);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (idx < DATA_CACHE_SIZE) {
|
||||||
|
data_cache[idx] = v;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce across the workgroup
|
||||||
vals[tid] = max_val;
|
vals[tid] = max_val;
|
||||||
|
|
||||||
barrier();
|
barrier();
|
||||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||||
if (tid < s) {
|
if (tid < s) {
|
||||||
vals[tid] = max(vals[tid], vals[tid + s]);
|
vals[tid] = max(vals[tid], vals[tid + s]);
|
||||||
}
|
}
|
||||||
|
@ -68,39 +95,80 @@ void main() {
|
||||||
max_val = vals[0];
|
max_val = vals[0];
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
// Sum up values
|
FLOAT_TYPE sum = FLOAT_TYPE(0.0f);
|
||||||
vals[tid] = FLOAT_TYPE(0.0f);
|
|
||||||
|
|
||||||
[[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
|
// Compute sum{exp(x - max)}
|
||||||
|
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
|
||||||
const uint col = col0 + tid;
|
const uint col = col0 + tid;
|
||||||
|
|
||||||
if (col >= p.KX) {
|
if (col >= p.KX) {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// compute exp(a*scale+b*slope), add it to sum, and cache the new value
|
||||||
|
// in data_cache if possible.
|
||||||
const uint i = rowx * p.KX + col;
|
const uint i = rowx * p.KX + col;
|
||||||
const FLOAT_TYPE val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
|
FLOAT_TYPE val;
|
||||||
vals[tid] += val;
|
if (idx < DATA_CACHE_SIZE) {
|
||||||
|
val = exp(data_cache[idx] - max_val);
|
||||||
|
} else {
|
||||||
|
val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy * p.KX + col]) : FLOAT_TYPE(0.0f)) - max_val);
|
||||||
|
}
|
||||||
|
sum += val;
|
||||||
|
if (idx < DATA_CACHE_SIZE) {
|
||||||
|
data_cache[idx] = val;
|
||||||
|
} else {
|
||||||
data_d[i] = D_TYPE(val);
|
data_d[i] = D_TYPE(val);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce across the workgroup
|
||||||
|
vals[tid] = sum;
|
||||||
barrier();
|
barrier();
|
||||||
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
[[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
||||||
if (tid < s) {
|
if (tid < s) {
|
||||||
vals[tid] += vals[tid + s];
|
vals[tid] += vals[tid + s];
|
||||||
}
|
}
|
||||||
barrier();
|
barrier();
|
||||||
}
|
}
|
||||||
|
sum = vals[0];
|
||||||
|
|
||||||
const D_TYPE divisor = D_TYPE(vals[0]);
|
FLOAT_TYPE rcpdivisor = 1.0/sum;
|
||||||
|
|
||||||
[[unroll]] for (uint col0 = 0; col0 < p.KX; col0 += BLOCK_SIZE) {
|
[[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) {
|
||||||
const uint col = col0 + tid;
|
const uint col = col0 + tid;
|
||||||
|
|
||||||
if (col >= p.KX) {
|
if (col >= p.KX) {
|
||||||
break;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
data_d[rowx*p.KX + col] /= divisor;
|
if (idx < DATA_CACHE_SIZE) {
|
||||||
|
data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor);
|
||||||
|
} else {
|
||||||
|
data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void main() {
|
||||||
|
// instantiate the soft_max function for several different
|
||||||
|
// dimensions, to allow loop unrolling
|
||||||
|
uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE;
|
||||||
|
if (num_blocks > 32) {
|
||||||
|
soft_max(num_blocks);
|
||||||
|
} else if (num_blocks > 16) {
|
||||||
|
soft_max(32);
|
||||||
|
} else if (num_blocks > 8) {
|
||||||
|
soft_max(16);
|
||||||
|
} else if (num_blocks > 4) {
|
||||||
|
soft_max(8);
|
||||||
|
} else if (num_blocks == 4) {
|
||||||
|
soft_max(4);
|
||||||
|
} else if (num_blocks == 3) {
|
||||||
|
soft_max(3);
|
||||||
|
} else if (num_blocks == 2) {
|
||||||
|
soft_max(2);
|
||||||
|
} else if (num_blocks == 1) {
|
||||||
|
soft_max(1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
#if !defined(DATA_A_F32) && !defined(DATA_A_F16)
|
|
||||||
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
#if !defined(GGML_TYPES_COMP)
|
||||||
#endif
|
#define GGML_TYPES_COMP
|
||||||
|
|
||||||
|
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
||||||
|
|
||||||
#if defined(DATA_A_F32)
|
#if defined(DATA_A_F32)
|
||||||
#define QUANT_K 1
|
#define QUANT_K 1
|
||||||
|
@ -38,8 +40,14 @@ struct block_q4_0
|
||||||
float16_t d;
|
float16_t d;
|
||||||
uint8_t qs[16];
|
uint8_t qs[16];
|
||||||
};
|
};
|
||||||
|
struct block_q4_0_packed16
|
||||||
|
{
|
||||||
|
float16_t d;
|
||||||
|
uint16_t qs[16/2];
|
||||||
|
};
|
||||||
|
|
||||||
#define A_TYPE block_q4_0
|
#define A_TYPE block_q4_0
|
||||||
|
#define A_TYPE_PACKED16 block_q4_0_packed16
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_1)
|
#if defined(DATA_A_Q4_1)
|
||||||
|
@ -54,7 +62,15 @@ struct block_q4_1
|
||||||
uint8_t qs[16];
|
uint8_t qs[16];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct block_q4_1_packed16
|
||||||
|
{
|
||||||
|
float16_t d;
|
||||||
|
float16_t m;
|
||||||
|
uint16_t qs[16/2];
|
||||||
|
};
|
||||||
|
|
||||||
#define A_TYPE block_q4_1
|
#define A_TYPE block_q4_1
|
||||||
|
#define A_TYPE_PACKED16 block_q4_1_packed16
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q5_0)
|
#if defined(DATA_A_Q5_0)
|
||||||
|
@ -70,7 +86,15 @@ struct block_q5_0
|
||||||
uint8_t qs[16];
|
uint8_t qs[16];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct block_q5_0_packed16
|
||||||
|
{
|
||||||
|
float16_t d;
|
||||||
|
uint16_t qh[2];
|
||||||
|
uint16_t qs[16/2];
|
||||||
|
};
|
||||||
|
|
||||||
#define A_TYPE block_q5_0
|
#define A_TYPE block_q5_0
|
||||||
|
#define A_TYPE_PACKED16 block_q5_0_packed16
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q5_1)
|
#if defined(DATA_A_Q5_1)
|
||||||
|
@ -87,7 +111,16 @@ struct block_q5_1
|
||||||
uint8_t qs[16];
|
uint8_t qs[16];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct block_q5_1_packed16
|
||||||
|
{
|
||||||
|
float16_t d;
|
||||||
|
float16_t m;
|
||||||
|
uint qh;
|
||||||
|
uint16_t qs[16/2];
|
||||||
|
};
|
||||||
|
|
||||||
#define A_TYPE block_q5_1
|
#define A_TYPE block_q5_1
|
||||||
|
#define A_TYPE_PACKED16 block_q5_1_packed16
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q8_0)
|
#if defined(DATA_A_Q8_0)
|
||||||
|
@ -100,8 +133,14 @@ struct block_q8_0
|
||||||
float16_t d;
|
float16_t d;
|
||||||
int8_t qs[32];
|
int8_t qs[32];
|
||||||
};
|
};
|
||||||
|
struct block_q8_0_packed16
|
||||||
|
{
|
||||||
|
float16_t d;
|
||||||
|
uint16_t qs[32/2];
|
||||||
|
};
|
||||||
|
|
||||||
#define A_TYPE block_q8_0
|
#define A_TYPE block_q8_0
|
||||||
|
#define A_TYPE_PACKED16 block_q8_0_packed16
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// K-quants
|
// K-quants
|
||||||
|
@ -116,7 +155,23 @@ struct block_q2_K
|
||||||
f16vec2 d;
|
f16vec2 d;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct block_q2_K_packed16
|
||||||
|
{
|
||||||
|
uint16_t scales[QUANT_K/16/2];
|
||||||
|
uint16_t qs[QUANT_K/4/2];
|
||||||
|
f16vec2 d;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct block_q2_K_packed32
|
||||||
|
{
|
||||||
|
uint32_t scales[QUANT_K/16/4];
|
||||||
|
uint32_t qs[QUANT_K/4/4];
|
||||||
|
f16vec2 d;
|
||||||
|
};
|
||||||
|
|
||||||
#define A_TYPE block_q2_K
|
#define A_TYPE block_q2_K
|
||||||
|
#define A_TYPE_PACKED16 block_q2_K_packed16
|
||||||
|
#define A_TYPE_PACKED32 block_q2_K_packed32
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q3_K)
|
#if defined(DATA_A_Q3_K)
|
||||||
|
@ -131,7 +186,16 @@ struct block_q3_K
|
||||||
float16_t d;
|
float16_t d;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct block_q3_K_packed16
|
||||||
|
{
|
||||||
|
uint16_t hmask[QUANT_K/8/2];
|
||||||
|
uint16_t qs[QUANT_K/4/2];
|
||||||
|
uint16_t scales[12/2];
|
||||||
|
float16_t d;
|
||||||
|
};
|
||||||
|
|
||||||
#define A_TYPE block_q3_K
|
#define A_TYPE block_q3_K
|
||||||
|
#define A_TYPE_PACKED16 block_q3_K_packed16
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q4_K)
|
#if defined(DATA_A_Q4_K)
|
||||||
|
@ -145,7 +209,23 @@ struct block_q4_K
|
||||||
uint8_t qs[QUANT_K/2];
|
uint8_t qs[QUANT_K/2];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct block_q4_K_packed16
|
||||||
|
{
|
||||||
|
f16vec2 d;
|
||||||
|
uint16_t scales[3*QUANT_K/64/2];
|
||||||
|
uint16_t qs[QUANT_K/2/2];
|
||||||
|
};
|
||||||
|
|
||||||
|
struct block_q4_K_packed32
|
||||||
|
{
|
||||||
|
f16vec2 d;
|
||||||
|
uint32_t scales[3*QUANT_K/64/4];
|
||||||
|
uint32_t qs[QUANT_K/2/4];
|
||||||
|
};
|
||||||
|
|
||||||
#define A_TYPE block_q4_K
|
#define A_TYPE block_q4_K
|
||||||
|
#define A_TYPE_PACKED16 block_q4_K_packed16
|
||||||
|
#define A_TYPE_PACKED32 block_q4_K_packed32
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q5_K)
|
#if defined(DATA_A_Q5_K)
|
||||||
|
@ -160,7 +240,16 @@ struct block_q5_K
|
||||||
uint8_t qs[QUANT_K/2];
|
uint8_t qs[QUANT_K/2];
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct block_q5_K_packed16
|
||||||
|
{
|
||||||
|
f16vec2 d;
|
||||||
|
uint16_t scales[12/2];
|
||||||
|
uint16_t qh[QUANT_K/8/2];
|
||||||
|
uint16_t qs[QUANT_K/2/2];
|
||||||
|
};
|
||||||
|
|
||||||
#define A_TYPE block_q5_K
|
#define A_TYPE block_q5_K
|
||||||
|
#define A_TYPE_PACKED16 block_q5_K_packed16
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
#if defined(DATA_A_Q6_K)
|
#if defined(DATA_A_Q6_K)
|
||||||
|
@ -175,7 +264,16 @@ struct block_q6_K
|
||||||
float16_t d;
|
float16_t d;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
struct block_q6_K_packed16
|
||||||
|
{
|
||||||
|
uint16_t ql[QUANT_K/2/2];
|
||||||
|
uint16_t qh[QUANT_K/4/2];
|
||||||
|
int8_t scales[QUANT_K/16];
|
||||||
|
float16_t d;
|
||||||
|
};
|
||||||
|
|
||||||
#define A_TYPE block_q6_K
|
#define A_TYPE block_q6_K
|
||||||
|
#define A_TYPE_PACKED16 block_q6_K_packed16
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
// IQuants
|
// IQuants
|
||||||
|
@ -191,10 +289,30 @@ struct block_iq4_nl
|
||||||
uint8_t qs[QUANT_K/2];
|
uint8_t qs[QUANT_K/2];
|
||||||
};
|
};
|
||||||
|
|
||||||
#define A_TYPE block_iq4_nl
|
struct block_iq4_nl_packed16
|
||||||
|
{
|
||||||
|
float16_t d;
|
||||||
|
uint16_t qs[QUANT_K/2/2];
|
||||||
|
};
|
||||||
|
|
||||||
const int8_t kvalues_iq4nl[16] = {
|
#define A_TYPE block_iq4_nl
|
||||||
|
#define A_TYPE_PACKED16 block_iq4_nl_packed16
|
||||||
|
|
||||||
|
const int8_t kvalues_iq4nl_const[16] = {
|
||||||
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
|
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
|
||||||
int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
|
int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
|
||||||
};
|
};
|
||||||
|
|
||||||
|
shared FLOAT_TYPE kvalues_iq4nl[16];
|
||||||
|
|
||||||
|
void init_iq4nl_shmem()
|
||||||
|
{
|
||||||
|
// copy the table into shared memory and sync
|
||||||
|
if (gl_LocalInvocationIndex.x < 16) {
|
||||||
|
kvalues_iq4nl[gl_LocalInvocationIndex.x] = FLOAT_TYPE(kvalues_iq4nl_const[gl_LocalInvocationIndex.x]);
|
||||||
|
}
|
||||||
|
barrier();
|
||||||
|
}
|
||||||
#endif
|
#endif
|
||||||
|
|
||||||
|
#endif // !defined(GGML_TYPES_COMP)
|
||||||
|
|
|
@ -317,10 +317,10 @@ void process_shaders() {
|
||||||
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
std::string data_a_key = "DATA_A_" + to_uppercase(tname);
|
||||||
std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
|
std::string shader = (string_ends_with(tname, "_k")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp";
|
||||||
|
|
||||||
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||||
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}));
|
string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}}));
|
||||||
|
|
||||||
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
|
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
|
||||||
|
|
||||||
// Dequant shaders
|
// Dequant shaders
|
||||||
if (tname != "f16") {
|
if (tname != "f16") {
|
||||||
|
@ -331,11 +331,11 @@ void process_shaders() {
|
||||||
shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
|
shader = (tname == "f32" || tname == "f16") ? "get_rows.comp" : "get_rows_quant.comp";
|
||||||
|
|
||||||
if (tname == "f16") {
|
if (tname == "f16") {
|
||||||
string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
|
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}));
|
||||||
} else {
|
} else {
|
||||||
string_to_spv("get_rows_" + tname, shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}});
|
string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}}));
|
||||||
}
|
}
|
||||||
string_to_spv("get_rows_" + tname + "_f32", shader, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}});
|
string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}}));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue