metal: use template to reduce size
Revert modifications on block_q4_0 and block_q4_1.
This commit is contained in:
parent
4088df14ca
commit
f3f2e8eee3
1 changed files with 95 additions and 186 deletions
281
ggml-metal.metal
281
ggml-metal.metal
|
@ -8,14 +8,14 @@ using namespace metal;
|
||||||
#define QR4_0 2
|
#define QR4_0 2
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
half d; // delta
|
||||||
uint16_t qs[QK4_0 / 4]; // nibbles / quants
|
uint8_t qs[QK4_0 / 2]; // nibbles / quants
|
||||||
} block_q4_0;
|
} block_q4_0;
|
||||||
|
|
||||||
#define QK4_1 32
|
#define QK4_1 32
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d; // delta
|
half d; // delta
|
||||||
half m; // min
|
half m; // min
|
||||||
uint16_t qs[QK4_1 / 4]; // nibbles / quants
|
uint8_t qs[QK4_1 / 2]; // nibbles / quants
|
||||||
} block_q4_1;
|
} block_q4_1;
|
||||||
|
|
||||||
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
|
static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, int k) {
|
||||||
|
@ -28,16 +28,12 @@ static void dequantize_row_q4_0(device const block_q4_0 * x, device float * y, i
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
const half d = x[i].d;
|
const half d = x[i].d;
|
||||||
|
|
||||||
for (int j = 0; j < qk/4; ++j) {
|
for (int j = 0; j < qk/2; ++j) {
|
||||||
const int x0 = (x[i].qs[j] & 0x000F) - 8;
|
const int x0 = (x[i].qs[j] & 0x0F) - 8;
|
||||||
const int x1 = ((x[i].qs[j] & 0x00F0) >> 4) - 8;
|
const int x1 = (x[i].qs[j] >> 4) - 8;
|
||||||
const int x2 = ((x[i].qs[j] & 0x0F00) >> 8) - 8;
|
|
||||||
const int x3 = ((x[i].qs[j] & 0xF000) >> 12) - 8;
|
|
||||||
|
|
||||||
y[i*qk + 2 * j + 0 ] = x0*d;
|
y[i*qk + j + 0 ] = x0*d;
|
||||||
y[i*qk + 2 * j + qk/2 ] = x1*d;
|
y[i*qk + j + qk/2] = x1*d;
|
||||||
y[i*qk + 2 * j + 1 ] = x2*d;
|
|
||||||
y[i*qk + 2 * j + 1 + qk/2] = x3*d;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -53,16 +49,12 @@ static void dequantize_row_q4_1(device const block_q4_1 * x, device float * y, i
|
||||||
const half d = x[i].d;
|
const half d = x[i].d;
|
||||||
const half m = x[i].m;
|
const half m = x[i].m;
|
||||||
|
|
||||||
for (int j = 0; j < qk/4; ++j) {
|
for (int j = 0; j < qk/2; ++j) {
|
||||||
const int x0 = (x[i].qs[j] & 0x000F);
|
const int x0 = (x[i].qs[j] & 0x0F);
|
||||||
const int x1 = ((x[i].qs[j] & 0x00F0) >> 4);
|
const int x1 = (x[i].qs[j] >> 4);
|
||||||
const int x2 = ((x[i].qs[j] & 0x0F00) >> 8);
|
|
||||||
const int x3 = ((x[i].qs[j] & 0xF000) >> 12);
|
|
||||||
|
|
||||||
y[i*qk + 2 * j + 0 ] = x0*d + m;
|
y[i*qk + j + 0 ] = x0*d + m;
|
||||||
y[i*qk + 2 * j + qk/2 ] = x1*d + m;
|
y[i*qk + j + qk/2] = x1*d + m;
|
||||||
y[i*qk + 2 * j + 1 ] = x2*d + m;
|
|
||||||
y[i*qk + 2 * j + 1 + qk/2] = x3*d + m;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -384,10 +376,91 @@ kernel void kernel_rms_norm(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i])
|
||||||
|
float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) {
|
||||||
|
float d = qb_curr->d;
|
||||||
|
float4 acc = 0.f;
|
||||||
|
device uint16_t * qs = ((device uint16_t *)qb_curr + 1);
|
||||||
|
for (int i = 0; i < 16; i+=2) {
|
||||||
|
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
|
||||||
|
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
|
||||||
|
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||||
|
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
|
||||||
|
}
|
||||||
|
return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f);
|
||||||
|
}
|
||||||
|
|
||||||
|
// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i])
|
||||||
|
float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) {
|
||||||
|
float d = qb_curr->d;
|
||||||
|
float m = qb_curr->m;
|
||||||
|
float4 acc = 0.f;
|
||||||
|
device uint16_t * qs = ((device uint16_t *)qb_curr + 2);
|
||||||
|
for (int i = 0; i < 16; i+=2) {
|
||||||
|
acc[0] += yl[i] * (qs[i / 2] & 0x000F);
|
||||||
|
acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0);
|
||||||
|
acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00);
|
||||||
|
acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000);
|
||||||
|
}
|
||||||
|
return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m;
|
||||||
|
}
|
||||||
|
|
||||||
// putting them in the kernel cause a significant performance penalty
|
// putting them in the kernel cause a significant performance penalty
|
||||||
#define N_DST 4 // each SIMD group works on 4 rows
|
#define N_DST 4 // each SIMD group works on 4 rows
|
||||||
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
|
||||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||||
|
template<typename block_q_type>
|
||||||
|
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
|
||||||
|
int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01,
|
||||||
|
uint2 tgpig, uint tiisg, uint sgitg) {
|
||||||
|
const int nb = ne00/QK4_0;
|
||||||
|
const int r0 = tgpig.x;
|
||||||
|
const int r1 = tgpig.y;
|
||||||
|
device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
|
||||||
|
device const float * y = (device const float *) src1 + r1*ne10;
|
||||||
|
float4 y_curr[8]; // src1 vector cache
|
||||||
|
float sumf[N_DST]={0.f}, all_sum;
|
||||||
|
thread float * yl=(thread float *)y_curr;
|
||||||
|
|
||||||
|
// each thread in a SIMD group deals with 1 block.
|
||||||
|
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
|
||||||
|
float sumy = 0;
|
||||||
|
for (int i = 0; i < QK4_0 / 4; i++) {
|
||||||
|
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i);
|
||||||
|
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int row = 0; row < N_DST; row++) {
|
||||||
|
sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// from now loads two rows every time and 16 blocks per row
|
||||||
|
int ir = tiisg / (N_SIMDWIDTH / 2);
|
||||||
|
int ib = tiisg % (N_SIMDWIDTH / 2);
|
||||||
|
for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) {
|
||||||
|
int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start
|
||||||
|
float sumy = 0;
|
||||||
|
for (int i = 0; i < QK4_0 / 4; i++) {
|
||||||
|
y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i);
|
||||||
|
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int row = 0; row < N_DST; row+=2) {
|
||||||
|
if (nb_start + ib < nb) {
|
||||||
|
sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int row = 0; row < N_DST; ++row) {
|
||||||
|
all_sum = simd_sum(sumf[row]);
|
||||||
|
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
||||||
|
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q4_0_f32(
|
kernel void kernel_mul_mat_q4_0_f32(
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
|
@ -399,89 +472,7 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
const int nb = ne00/QK4_0;
|
mul_vec_q_n_f32<block_q4_0>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
||||||
const int r0 = tgpig.x;
|
|
||||||
const int r1 = tgpig.y;
|
|
||||||
device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
|
|
||||||
device const float * y = (device const float *) src1 + r1*ne10;
|
|
||||||
block_q4_0 qb_curr, qb_next;
|
|
||||||
float4 y_curr[8]; // src1 vector cache
|
|
||||||
float sumf[N_DST]={0.f}, all_sum;
|
|
||||||
thread float * yl=(thread float *)y_curr;
|
|
||||||
|
|
||||||
// bootstrap
|
|
||||||
qb_curr = x[tiisg];
|
|
||||||
// each thread in a SIMD group deals with 1 block.
|
|
||||||
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
|
|
||||||
|
|
||||||
float sumy = 0;
|
|
||||||
for (int i = 0; i < QK4_0 / 4; i++) {
|
|
||||||
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
|
|
||||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
|
||||||
}
|
|
||||||
sumy *= (-8.f);
|
|
||||||
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
|
|
||||||
for (int i = 0; i < 32; i++) {
|
|
||||||
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
|
||||||
// prefetch next x block
|
|
||||||
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
|
|
||||||
|
|
||||||
// calculate
|
|
||||||
float d = qb_curr.d;
|
|
||||||
float acc = sumy;
|
|
||||||
for (int i = 0; i < 16; i+=2) {
|
|
||||||
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
|
|
||||||
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
|
|
||||||
}
|
|
||||||
sumf[row] += d * acc;
|
|
||||||
qb_curr = qb_next;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (nb % N_SIMDWIDTH == 0) {
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
|
||||||
all_sum = simd_sum(sumf[row]);
|
|
||||||
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
|
||||||
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
|
|
||||||
float sumy = 0;
|
|
||||||
for (int i = 0; i < QK4_0 / 4; i++) {
|
|
||||||
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
|
|
||||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
|
||||||
}
|
|
||||||
sumy *= (-8.f);
|
|
||||||
for (int i = 0; i < 32; i++) {
|
|
||||||
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
|
||||||
// prefetch next x block
|
|
||||||
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
|
|
||||||
|
|
||||||
// calculate
|
|
||||||
float d = qb_curr.d;
|
|
||||||
float acc = sumy;
|
|
||||||
for (int i = 0; i < 16; i+=2) {
|
|
||||||
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
|
|
||||||
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
|
|
||||||
}
|
|
||||||
if (tiisg < nb % N_SIMDWIDTH) {
|
|
||||||
sumf[row] += d * acc;
|
|
||||||
}
|
|
||||||
qb_curr = qb_next;
|
|
||||||
|
|
||||||
all_sum = simd_sum(sumf[row]);
|
|
||||||
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
|
||||||
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_q4_1_f32(
|
kernel void kernel_mul_mat_q4_1_f32(
|
||||||
|
@ -495,89 +486,7 @@ kernel void kernel_mul_mat_q4_1_f32(
|
||||||
uint2 tgpig[[threadgroup_position_in_grid]],
|
uint2 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
const int nb = ne00/QK4_0;
|
mul_vec_q_n_f32<block_q4_1>(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg);
|
||||||
const int r0 = tgpig.x;
|
|
||||||
const int r1 = tgpig.y;
|
|
||||||
device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
|
|
||||||
device const float * y = (device const float *) src1 + r1*ne10;
|
|
||||||
block_q4_1 qb_curr, qb_next;
|
|
||||||
float4 y_curr[8]; // src1 vector cache
|
|
||||||
float sumf[N_DST]={0.f}, all_sum;
|
|
||||||
thread float * yl=(thread float *)y_curr;
|
|
||||||
|
|
||||||
// bootstrap
|
|
||||||
qb_curr = x[tiisg];
|
|
||||||
// each thread in a SIMD group deals with 1 block.
|
|
||||||
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
|
|
||||||
|
|
||||||
float sumy = 0;
|
|
||||||
for (int i = 0; i < QK4_0 / 4; i++) {
|
|
||||||
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
|
|
||||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
|
||||||
}
|
|
||||||
// we don't right shift packed 4-bit weights, so we have to devide y by 16/256/4096 to conpensate this.
|
|
||||||
for (int i = 0; i < 32; i++) {
|
|
||||||
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
|
||||||
// prefetch next x block
|
|
||||||
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
|
|
||||||
|
|
||||||
// calculate
|
|
||||||
const float d = qb_curr.d;
|
|
||||||
const float m = qb_curr.m;
|
|
||||||
float acc = 0.f;
|
|
||||||
for (int i = 0; i < 16; i+=2) {
|
|
||||||
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
|
|
||||||
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
|
|
||||||
}
|
|
||||||
sumf[row] += d * acc + m * sumy;
|
|
||||||
qb_curr = qb_next;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (nb % N_SIMDWIDTH == 0) {
|
|
||||||
for (int row = 0; row < N_DST; ++row) {
|
|
||||||
all_sum = simd_sum(sumf[row]);
|
|
||||||
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
|
||||||
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
|
|
||||||
float sumy = 0;
|
|
||||||
for (int i = 0; i < QK4_0 / 4; i++) {
|
|
||||||
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
|
|
||||||
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
|
|
||||||
}
|
|
||||||
for (int i = 0; i < 32; i++) {
|
|
||||||
yl[i] *= pow(1.f/16.f, 2 * (i % 2) + i / 16);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
|
||||||
// prefetch next x block
|
|
||||||
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
|
|
||||||
|
|
||||||
// calculate
|
|
||||||
const float d = qb_curr.d;
|
|
||||||
const float m = qb_curr.m;
|
|
||||||
float acc = 0.f;
|
|
||||||
for (int i = 0; i < 16; i+=2) {
|
|
||||||
acc += yl[i] * (qb_curr.qs[i / 2] & 0x000F) + yl[i + 16] * (qb_curr.qs[i / 2] & 0x00F0);
|
|
||||||
acc += yl[i + 1] * (qb_curr.qs[i / 2] & 0x0F00) + yl[i + 17] * (qb_curr.qs[i / 2] & 0xF000);
|
|
||||||
}
|
|
||||||
if (tiisg < nb % N_SIMDWIDTH) {
|
|
||||||
sumf[row] += d * acc + m * sumy;
|
|
||||||
}
|
|
||||||
qb_curr = qb_next;
|
|
||||||
|
|
||||||
all_sum = simd_sum(sumf[row]);
|
|
||||||
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
|
|
||||||
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mat_f16_f32(
|
kernel void kernel_mul_mat_f16_f32(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue