metal : block_q_n_dot_y for block_q5_0 (broken)
This commit is contained in:
parent
9c3e05d524
commit
7ebd4acb0a
1 changed files with 15 additions and 5 deletions
|
@ -447,7 +447,17 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
||||||
// we assume that the yl's have been multiplied with the appropriate scale factor
|
// we assume that the yl's have been multiplied with the appropriate scale factor
|
||||||
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||||
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
// TODO
|
float d = qb_curr->d;
|
||||||
|
float2 acc = 0.f;
|
||||||
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
|
||||||
|
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
||||||
|
for (int i = 0; i < 8; i+=2) {
|
||||||
|
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x0010))
|
||||||
|
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x1000));
|
||||||
|
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 4 ) & 0x0010))
|
||||||
|
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 12) & 0x1000));
|
||||||
|
}
|
||||||
|
return d * (sumy * -16.f + acc[0] + acc[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||||
|
@ -2225,13 +2235,13 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
|
||||||
|
|
||||||
const int x_mv = (il ? 4 : 0);
|
const int x_mv = (il ? 4 : 0);
|
||||||
|
|
||||||
const int gh_mv = (il ? 12 : 0);
|
const int qh_mv = (il ? 12 : 0);
|
||||||
const int gh_bk = (il ? 0 : 4);
|
const int qh_bk = (il ? 0 : 4);
|
||||||
|
|
||||||
for (int i = 0; i < 8; i++) {
|
for (int i = 0; i < 8; i++) {
|
||||||
// extract the 5-th bits for x0 and x1
|
// extract the 5-th bits for x0 and x1
|
||||||
const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
|
const uint8_t xh_0 = ((qh >> (qh_mv + 2*i )) << qh_bk) & 0x10;
|
||||||
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
const uint8_t xh_1 = ((qh >> (qh_mv + 2*i+1)) << qh_bk) & 0x10;
|
||||||
|
|
||||||
// combine the 4-bits from qs with the 5th bit
|
// combine the 4-bits from qs with the 5th bit
|
||||||
const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0);
|
const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue