make the caches happy
This commit is contained in:
parent
c9463641af
commit
51b5ac507d
2 changed files with 13 additions and 12 deletions
|
@ -40,9 +40,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||||
const f16vec2 d = data_a[ib0 + i].d;
|
|
||||||
const FLOAT_TYPE dall = d.x;
|
|
||||||
const FLOAT_TYPE dmin = d.y;
|
|
||||||
|
|
||||||
sccache[ix][0][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8]), int(v_im*4), 4)); // lower 8 bytes
|
sccache[ix][0][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8]), int(v_im*4), 4)); // lower 8 bytes
|
||||||
sccache[ix][1][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8+8]), int(v_im*4), 4)); // upper 8 bytes
|
sccache[ix][1][itid] = FLOAT_TYPE(bitfieldExtract(uint(data_a[ib0 + i].scales[itid8+8]), int(v_im*4), 4)); // upper 8 bytes
|
||||||
|
@ -54,6 +51,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
|
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
|
||||||
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
||||||
|
|
||||||
|
const f16vec2 d = data_a[ib0 + i].d;
|
||||||
|
const FLOAT_TYPE dall = d.x;
|
||||||
|
const FLOAT_TYPE dmin = d.y;
|
||||||
|
|
||||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
|
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
|
||||||
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
|
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
|
||||||
|
|
|
@ -46,10 +46,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
|
|
||||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||||
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
|
||||||
|
|
||||||
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> v_im4) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (v_im4 + 2*(itid8/4)) & 0x3) << 4)) - 32);
|
const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16));
|
||||||
barrier();
|
const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2));
|
||||||
|
const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2));
|
||||||
|
const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2));
|
||||||
|
const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2));
|
||||||
|
|
||||||
// 0, 1, 16, 17
|
// 0, 1, 16, 17
|
||||||
uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8);
|
uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8);
|
||||||
|
@ -59,14 +61,12 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
|
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
|
||||||
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
||||||
|
|
||||||
const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16));
|
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> v_im4) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (v_im4 + 2*(itid8/4)) & 0x3) << 4)) - 32);
|
||||||
const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2));
|
barrier();
|
||||||
const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2));
|
|
||||||
const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2));
|
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
|
||||||
const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2));
|
|
||||||
|
|
||||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
|
|
||||||
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
|
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
|
||||||
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
|
B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8];
|
||||||
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
|
B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue