Minor speed gains for all quantization types

This commit is contained in:
Iwan Kawrakow 2023-09-07 11:18:48 +02:00
parent 15b67a66c2
commit 9a9010609b

View file

@ -1757,29 +1757,34 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
template <typename type4x4> template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1); device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const half d = il ? (xb->d / 16.h) : xb->d; const float d1 = il ? (xb->d / 16.h) : xb->d;
const half m = il ? ( -8.h * 16.h) : -8.h; const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = il ? 0xF000 : 0x0F00; const ushort mask1 = mask0 << 8;
for (int i=0;i<8;i++) { for (int i=0;i<8;i++) {
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d; reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d; reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
} }
} }
template <typename type4x4> template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2); device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const half d = il ? (xb->d / 16.h) : xb->d; const float d1 = il ? (xb->d / 16.h) : xb->d;
const half m = xb->m; const float d2 = d1 / 256.f;
const float m = xb->m;
const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = il ? 0xF000 : 0x0F00; const ushort mask1 = mask0 << 8;
for (int i=0;i<8;i++) { for (int i=0;i<8;i++) {
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m; reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m; reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
} }
} }
@ -1815,7 +1820,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
template <typename type4x4> template <typename type4x4>
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
const float d_all = (float)(xb->d); const half d_all = xb->d;
device const uint8_t * q = (device const uint8_t *)xb->qs; device const uint8_t * q = (device const uint8_t *)xb->qs;
device const uint8_t * h = (device const uint8_t *)xb->hmask; device const uint8_t * h = (device const uint8_t *)xb->hmask;
device const int8_t * scales = (device const int8_t *)xb->scales; device const int8_t * scales = (device const int8_t *)xb->scales;
@ -1828,17 +1833,20 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
((il/4)>0 ? 12 : 3); ((il/4)>0 ? 12 : 3);
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
(scale_2&kmask2) | ((scale_1&kmask1) << 4); : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
const half ml = 4.h * dl;
il = (il/2)%4; il = (il/2) & 3;
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
dl *= coef;
for (int i = 0; i < 16; ++i) { for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef)); reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
} }
#else #else
float kcoef = il&1 ? 1.f/16.f : 1.f; float kcoef = il&1 ? 1.f/16.f : 1.f;
uint16_t kmask = il&1 ? 0xF0 : 0x0F; uint16_t kmask = il&1 ? 0xF0 : 0x0F;
@ -1852,31 +1860,37 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
#endif #endif
} }
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
}
template <typename type4x4> template <typename type4x4>
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
device const uint8_t * q = xb->qs; device const uchar * q = xb->qs;
#if QK_K == 256 #if QK_K == 256
const float d = (float)(xb->d);
const float min = (float)(xb->dmin);
short is = (il/4) * 2; short is = (il/4) * 2;
q = q + (il/4) * 32 + 16 * (il&1); q = q + (il/4) * 32 + 16 * (il&1);
il = il%4; il = il & 3;
const uchar4 sc = get_scale_min_k4(is, xb->scales); const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h; const half d = il < 2 ? xb->d : xb->d / 16.h;
const float ml = il<2 ? min * sc[1] : min * sc[3]; const half min = xb->dmin;
const half dl = d * sc[0];
const half ml = min * sc[1];
#else #else
q = q + 16 * (il&1); q = q + 16 * (il&1);
device const uint8_t * s = xb->scales; device const uint8_t * s = xb->scales;
device const half2 * dh = (device const half2 *)xb->d; device const half2 * dh = (device const half2 *)xb->d;
const float2 d = (float2)dh[0]; const float2 d = (float2)dh[0];
const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h; const float dl = il<2 ? d[0] * (s[0]&0xF) : d[0] * (s[1]&0xF)/16.h;
const float ml = il<2 ? d[1] * (s[0]>>4) : d[1 ]* (s[1]>>4); const float ml = il<2 ? d[1] * (s[0]>>4) : d[1] * (s[1]>>4);
#endif #endif
const ushort mask = il<2 ? 0x0F : 0xF0; const ushort mask = il<2 ? 0x0F : 0xF0;
for (int i = 0; i < 16; ++i) { for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml; reg[i/4][i%4] = dl * (q[i] & mask) - ml;
} }
} }
template <typename type4x4> template <typename type4x4>
@ -1885,19 +1899,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
device const uint8_t * qh = xb->qh; device const uint8_t * qh = xb->qh;
#if QK_K == 256 #if QK_K == 256
const float d = (float)(xb->d);
const float min = (float)(xb->dmin);
short is = (il/4) * 2; short is = (il/4) * 2;
q = q + 32 * (il/4) + 16 * (il&1); q = q + 32 * (il/4) + 16 * (il&1);
qh = qh + 16 * (il&1); qh = qh + 16 * (il&1);
uint8_t ul = 1 << (il/2); uint8_t ul = 1 << (il/2);
il = il%4; il = il & 3;
const uchar4 sc = get_scale_min_k4(is, xb->scales); const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h; const half d = il < 2 ? xb->d : xb->d / 16.h;
const float ml = il<2 ? min * sc[1] : min * sc[3]; const half min = xb->dmin;
const half dl = d * sc[0];
const half ml = min * sc[1];
const ushort mask = il<2 ? 0x0F : 0xF0; const ushort mask = il<2 ? 0x0F : 0xF0;
const float qh_val = il<2 ? 16.f : 256.f; const half qh_val = il<2 ? 16.h : 256.h;
for (int i = 0; i < 16; ++i) { for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
} }
@ -1916,7 +1930,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
template <typename type4x4> template <typename type4x4>
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
const float d_all = (float)(xb->d); const half d_all = xb->d;
device const uint8_t * ql = (device const uint8_t *)xb->ql; device const uint8_t * ql = (device const uint8_t *)xb->ql;
device const uint8_t * qh = (device const uint8_t *)xb->qh; device const uint8_t * qh = (device const uint8_t *)xb->qh;
device const int8_t * scales = (device const int8_t *)xb->scales; device const int8_t * scales = (device const int8_t *)xb->scales;
@ -1924,19 +1938,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
#if QK_K == 256 #if QK_K == 256
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
qh = qh + 32*(il/8) + 16*(il&1); qh = qh + 32*(il/8) + 16*(il&1);
float sc = scales[(il%2) + 2 * ((il/2))]; half sc = scales[(il%2) + 2 * ((il/2))];
il = (il/2)%4; il = (il/2) & 3;
#else #else
ql = ql + 16 * (il&1); ql = ql + 16 * (il&1);
float sc = scales[il]; half sc = scales[il];
#endif #endif
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
const half coef = il>1 ? 1.f/16.h : 1.h;
const half ml = d_all * sc * 32.h;
const half dl = d_all * sc * coef;
for (int i = 0; i < 16; ++i) { for (int i = 0; i < 16; ++i) {
uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
const float coef = il>1 ? 1.f/16.f : 1.f; reg[i/4][i%4] = dl * q - ml;
float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
reg[i/4][i%4] = d_all * sc * q * coef;
} }
} }