minor : spaces / formatting
This commit is contained in:
parent
9db276f0c2
commit
7a88522975
1 changed files with 21 additions and 9 deletions
|
@ -414,8 +414,11 @@ kernel void kernel_rms_norm(
|
||||||
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||||
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
float d = qb_curr->d;
|
float d = qb_curr->d;
|
||||||
|
|
||||||
float2 acc = 0.f;
|
float2 acc = 0.f;
|
||||||
|
|
||||||
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
|
||||||
|
|
||||||
for (int i = 0; i < 8; i+=2) {
|
for (int i = 0; i < 8; i+=2) {
|
||||||
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
||||||
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||||
|
@ -432,8 +435,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
|
||||||
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
float d = qb_curr->d;
|
float d = qb_curr->d;
|
||||||
float m = qb_curr->m;
|
float m = qb_curr->m;
|
||||||
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
|
||||||
float2 acc = 0.f;
|
float2 acc = 0.f;
|
||||||
|
|
||||||
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
|
||||||
|
|
||||||
for (int i = 0; i < 8; i+=2) {
|
for (int i = 0; i < 8; i+=2) {
|
||||||
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
|
||||||
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||||
|
@ -449,9 +455,12 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
|
||||||
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
|
||||||
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
float d = qb_curr->d;
|
float d = qb_curr->d;
|
||||||
|
|
||||||
float2 acc = 0.f;
|
float2 acc = 0.f;
|
||||||
|
|
||||||
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
|
||||||
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
||||||
|
|
||||||
for (int i = 0; i < 8; i+=2) {
|
for (int i = 0; i < 8; i+=2) {
|
||||||
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
||||||
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
||||||
|
@ -468,9 +477,12 @@ inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thre
|
||||||
inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
|
inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
|
||||||
float d = qb_curr->d;
|
float d = qb_curr->d;
|
||||||
float m = qb_curr->m;
|
float m = qb_curr->m;
|
||||||
|
|
||||||
float2 acc = 0.f;
|
float2 acc = 0.f;
|
||||||
|
|
||||||
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
|
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
|
||||||
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
|
||||||
|
|
||||||
for (int i = 0; i < 8; i+=2) {
|
for (int i = 0; i < 8; i+=2) {
|
||||||
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
|
||||||
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
+ yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
|
||||||
|
@ -2258,7 +2270,7 @@ void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg
|
||||||
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
||||||
|
|
||||||
// combine the 4-bits from qs with the 5th bit
|
// combine the 4-bits from qs with the 5th bit
|
||||||
const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0);
|
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||||
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||||
|
|
||||||
reg[i/2][2*(i%2)+0] = d * x0 + md;
|
reg[i/2][2*(i%2)+0] = d * x0 + md;
|
||||||
|
@ -2286,7 +2298,7 @@ void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg
|
||||||
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
|
||||||
|
|
||||||
// combine the 4-bits from qs with the 5th bit
|
// combine the 4-bits from qs with the 5th bit
|
||||||
const int32_t x0 = (((qs[i] & mask) >> x_mv) | xh_0);
|
const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
|
||||||
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
|
||||||
|
|
||||||
reg[i/2][2*(i%2)+0] = d * x0 + m;
|
reg[i/2][2*(i%2)+0] = d * x0 + m;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue