16 bit unpack
This commit is contained in:
parent
d122d5c987
commit
6b06d16890
2 changed files with 11 additions and 13 deletions
|
@ -15,13 +15,11 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
// 16 threads are used to process each block
|
||||
const uint it_size = gl_WorkGroupSize.x/16;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint itid = tid%16; // 0...16
|
||||
const uint itid = tid%16; // 0...15
|
||||
const uint ix = tid/16;
|
||||
|
||||
const uint step = 4;
|
||||
|
||||
const uint il = itid/step; // 0...3
|
||||
const uint ir = itid - step*il; // 0...7 or 0...3
|
||||
const uint il = itid/4; // 0...3
|
||||
const uint ir = itid - 4*il; // 0...7 or 0...3
|
||||
const uint n = 4;
|
||||
|
||||
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
||||
|
@ -49,12 +47,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
||||
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
||||
|
||||
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
|
||||
uvec4 scale0 = uvec4(unpack8(scale0_u32));
|
||||
uvec4 scale4 = uvec4(unpack8(scale4_u32));
|
||||
uvec4 scale8 = uvec4(unpack8(scale8_u32));
|
||||
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
||||
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
||||
const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
|
||||
uvec2 scale0 = uvec2(unpack8(scale0_u32));
|
||||
uvec2 scale4 = uvec2(unpack8(scale4_u32));
|
||||
uvec2 scale8 = uvec2(unpack8(scale8_u32));
|
||||
|
||||
const uint32_t sc0 = ( scale0.x & 0x3f);
|
||||
const uint32_t sc1 = ( scale0.y & 0x3f);
|
||||
|
|
|
@ -14,7 +14,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
|||
|
||||
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
||||
|
||||
// 16 thread groups are used to process each block
|
||||
// 16 threads are used to process each block
|
||||
const uint it_size = gl_WorkGroupSize.x/16;
|
||||
const uint tid = gl_LocalInvocationID.x;
|
||||
const uint itid = tid%16; // 0...15
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue