q3_k optimizations
This commit is contained in:
parent
cc28742ca3
commit
fe71a8c4a1
5 changed files with 24 additions and 13 deletions
|
@ -21,7 +21,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
const uint itid8 = itid%8;
|
const uint itid8 = itid%8;
|
||||||
|
|
||||||
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
|
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||||
const uint v_in = itid - 8*v_im; // 0...15 or 0...7
|
const uint v_in = itid - 8*v_im; // 0...7
|
||||||
|
|
||||||
const uint l0 = 2*v_in; // 0...15
|
const uint l0 = 2*v_in; // 0...15
|
||||||
const uint q_offset = 32*v_im + l0;
|
const uint q_offset = 32*v_im + l0;
|
||||||
|
|
|
@ -21,7 +21,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
const uint itid8 = itid%8;
|
const uint itid8 = itid%8;
|
||||||
|
|
||||||
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
|
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||||
const uint v_in = itid - 8*v_im; // 0...15 or 0...7
|
const uint v_in = itid - 8*v_im; // 0...7
|
||||||
|
|
||||||
const uint8_t m = uint8_t(1 << (4 * v_im));
|
const uint8_t m = uint8_t(1 << (4 * v_im));
|
||||||
|
|
||||||
|
@ -47,6 +47,17 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> 4*v_im) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (4*v_im + 2*(itid8/4)) & 0x3) << 4)) - 32);
|
sccache[ix][v_im][itid8] = FLOAT_TYPE(int8_t(((int8_t(data_a[ib0+i].scales[itid8]) >> 4*v_im) & 0xF) | ((int8_t(data_a[ib0+i].scales[itid8%4+8]) >> (4*v_im + 2*(itid8/4)) & 0x3) << 4)) - 32);
|
||||||
barrier();
|
barrier();
|
||||||
|
|
||||||
|
// 0, 1, 16, 17
|
||||||
|
uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8);
|
||||||
|
qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16;
|
||||||
|
const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
|
||||||
|
const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
|
||||||
|
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
|
||||||
|
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
|
||||||
|
|
||||||
|
const uvec2 hmk0 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in]));
|
||||||
|
const uvec2 hmk16 = uvec2(unpack8(data_a_packed16[ib0 + i].hmask[v_in + 8]));
|
||||||
|
|
||||||
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
|
||||||
|
|
||||||
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
|
B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2];
|
||||||
|
@ -60,14 +71,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
|
|
||||||
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
|
||||||
[[unroll]] for (int l = 0; l < 2; ++l) {
|
[[unroll]] for (int l = 0; l < 2; ++l) {
|
||||||
sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m )) != 0) ? 0 : 4)),
|
sum = fma(FLOAT_TYPE( b0[l]) * sccache[ix][v_im][0], qs_u32_0[l ] - FLOAT_TYPE((( hmk0[l] & (m )) != 0) ? 0 : 4),
|
||||||
fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] ) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m )) != 0) ? 0 : 4)),
|
fma(FLOAT_TYPE( b16[l]) * sccache[ix][v_im][1], qs_u32_0[l+2] - FLOAT_TYPE(((hmk16[l] & (m )) != 0) ? 0 : 4),
|
||||||
fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 1)) != 0) ? 0 : 4)),
|
fma(FLOAT_TYPE( b32[l]) * sccache[ix][v_im][2], qs_u32_2[l ] - FLOAT_TYPE((( hmk0[l] & (m << 1)) != 0) ? 0 : 4),
|
||||||
fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 2) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 1)) != 0) ? 0 : 4)),
|
fma(FLOAT_TYPE( b48[l]) * sccache[ix][v_im][3], qs_u32_2[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 1)) != 0) ? 0 : 4),
|
||||||
fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 2)) != 0) ? 0 : 4)),
|
fma(FLOAT_TYPE( b64[l]) * sccache[ix][v_im][4], qs_u32_4[l ] - FLOAT_TYPE((( hmk0[l] & (m << 2)) != 0) ? 0 : 4),
|
||||||
fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 4) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 2)) != 0) ? 0 : 4)),
|
fma(FLOAT_TYPE( b80[l]) * sccache[ix][v_im][5], qs_u32_4[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 2)) != 0) ? 0 : 4),
|
||||||
fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l ] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l ] & (m << 3)) != 0) ? 0 : 4)),
|
fma(FLOAT_TYPE( b96[l]) * sccache[ix][v_im][6], qs_u32_6[l ] - FLOAT_TYPE((( hmk0[l] & (m << 3)) != 0) ? 0 : 4),
|
||||||
fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], FLOAT_TYPE(((data_a[ib0 + i].qs[q_offset + l+16] >> 6) & 3) - (((data_a[ib0 + i].hmask[l0 + l+16] & (m << 3)) != 0) ? 0 : 4)), sum))))))));
|
fma(FLOAT_TYPE(b112[l]) * sccache[ix][v_im][7], qs_u32_6[l+2] - FLOAT_TYPE(((hmk16[l] & (m << 3)) != 0) ? 0 : 4), sum))))))));
|
||||||
}
|
}
|
||||||
temp[j][n] = fma(d, sum, temp[j][n]);
|
temp[j][n] = fma(d, sum, temp[j][n]);
|
||||||
}
|
}
|
||||||
|
|
|
@ -19,7 +19,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
const uint ix = tid/16;
|
const uint ix = tid/16;
|
||||||
|
|
||||||
const uint il = itid/4; // 0...3
|
const uint il = itid/4; // 0...3
|
||||||
const uint ir = itid - 4*il; // 0...7 or 0...3
|
const uint ir = itid - 4*il; // 0...3
|
||||||
const uint n = 4;
|
const uint n = 4;
|
||||||
|
|
||||||
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
||||||
|
|
|
@ -19,7 +19,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
|
||||||
const uint ix = tid/16;
|
const uint ix = tid/16;
|
||||||
|
|
||||||
const uint il = itid/4; // 0...3
|
const uint il = itid/4; // 0...3
|
||||||
const uint ir = itid - 4*il; // 0...7 or 0...3
|
const uint ir = itid - 4*il; // 0...3
|
||||||
|
|
||||||
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
||||||
const uint v_in = il % 2;
|
const uint v_in = il % 2;
|
||||||
|
|
|
@ -87,7 +87,7 @@ void compute_outputs(const uint first_row, const uint num_rows) {
|
||||||
const uint ix = tid/16;
|
const uint ix = tid/16;
|
||||||
|
|
||||||
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
|
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
|
||||||
const uint v_in = itid - 8*v_im; // 0...15 or 0...7
|
const uint v_in = itid - 8*v_im; // 0...7
|
||||||
|
|
||||||
const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
|
const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28
|
||||||
const uint is = v_in / 4;
|
const uint is = v_in / 4;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue