metal : minor refactor
This commit is contained in:
parent
bb9d36be78
commit
78416536c8
1 changed files with 39 additions and 34 deletions
|
@ -1064,17 +1064,18 @@ kernel void kernel_group_norm(
|
|||
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||
float d = qb_curr->d;
|
||||
|
||||
float2 acc = 0.f;
|
||||
float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
|
||||
|
||||
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
||||
device const uint16_t * qs = ((device const uint16_t *) qb_curr + 1 + il/2);
|
||||
|
||||
for (int i = 0; i < 8; i+=2) {
|
||||
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
||||
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
|
||||
+ yl[i + 9] * (qs[i / 2] & 0xF000);
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
|
||||
acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||
acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
|
||||
acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
|
||||
}
|
||||
return d * (sumy * -8.f + acc[0] + acc[1]);
|
||||
|
||||
return d * (sumy * -8.f + acc[0] + acc[1] + acc[2] + acc[3]);
|
||||
}
|
||||
|
||||
// function for calculate inner product between half a q4_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||
|
@ -1085,17 +1086,18 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|||
float d = qb_curr->d;
|
||||
float m = qb_curr->m;
|
||||
|
||||
float2 acc = 0.f;
|
||||
float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
|
||||
|
||||
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
||||
device const uint16_t * qs = ((device const uint16_t *) qb_curr + 2 + il/2);
|
||||
|
||||
for (int i = 0; i < 8; i+=2) {
|
||||
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
||||
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||
acc[1] += yl[i + 8] * (qs[i / 2] & 0x00F0)
|
||||
+ yl[i + 9] * (qs[i / 2] & 0xF000);
|
||||
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F);
|
||||
acc[1] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||
acc[2] += yl[i + 8] * (qs[i / 2] & 0x00F0);
|
||||
acc[3] += yl[i + 9] * (qs[i / 2] & 0xF000);
|
||||
}
|
||||
return d * (acc[0] + acc[1]) + sumy * m;
|
||||
|
||||
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
|
||||
}
|
||||
|
||||
// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||
|
@ -1105,18 +1107,19 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
|||
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||
float d = qb_curr->d;
|
||||
|
||||
float2 acc = 0.f;
|
||||
float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
|
||||
|
||||
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
|
||||
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
||||
|
||||
for (int i = 0; i < 8; i+=2) {
|
||||
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
||||
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
||||
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
||||
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
||||
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
|
||||
acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
||||
acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
|
||||
acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
||||
}
|
||||
return d * (sumy * -16.f + acc[0] + acc[1]);
|
||||
|
||||
return d * (sumy * -16.f + acc[0] + acc[1] + acc[2] + acc[3]);
|
||||
}
|
||||
|
||||
// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
|
||||
|
@ -1127,18 +1130,19 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|||
float d = qb_curr->d;
|
||||
float m = qb_curr->m;
|
||||
|
||||
float2 acc = 0.f;
|
||||
float acc[4] = { 0.0f, 0.0f, 0.0f, 0.0f };
|
||||
|
||||
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
|
||||
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
||||
|
||||
for (int i = 0; i < 8; i+=2) {
|
||||
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
||||
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
||||
acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
|
||||
+ yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
||||
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010));
|
||||
acc[1] += yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
||||
acc[2] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100));
|
||||
acc[3] += yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
|
||||
}
|
||||
return d * (acc[0] + acc[1]) + sumy * m;
|
||||
|
||||
return d * (acc[0] + acc[1] + acc[2] + acc[3]) + sumy * m;
|
||||
}
|
||||
|
||||
// putting them in the kernel cause a significant performance penalty
|
||||
|
@ -1208,14 +1212,15 @@ void mul_vec_q_n_f32_impl(
|
|||
// each thread in a SIMD group deals with half a block.
|
||||
for (int ib = ix; ib < nb; ib += nw/2) {
|
||||
float sumy = 0;
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
sumy += yb[i] + yb[i+1];
|
||||
yl[i+0] = yb[i+ 0];
|
||||
yl[i+1] = yb[i+ 1]/256.f;
|
||||
|
||||
sumy += yb[i+16] + yb[i+17];
|
||||
yl[i+8] = yb[i+16]/16.f;
|
||||
yl[i+9] = yb[i+17]/4096.f;
|
||||
for (int i = 0; i < 8; i += 2) {
|
||||
sumy += yb[i + 0] + yb[i + 1];
|
||||
yl[i + 0] = yb[i + 0];
|
||||
yl[i + 1] = yb[i + 1]/256.f;
|
||||
|
||||
sumy += yb[i + 16] + yb[i + 17];
|
||||
yl[i + 8] = yb[i + 16]/16.f;
|
||||
yl[i + 9] = yb[i + 17]/4096.f;
|
||||
}
|
||||
|
||||
for (int row = 0; row < nr; row++) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue