Merge branch 'master' into gg/flash-attn
This commit is contained in:
commit
013721df2b
157 changed files with 19090 additions and 15488 deletions
234
ggml-metal.metal
234
ggml-metal.metal
|
@ -4861,6 +4861,114 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||
}
|
||||
}
|
||||
|
||||
void kernel_mul_mv_iq1_m_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
const int nb = ne00/QK_K;
|
||||
const int r0 = tgpig.x;
|
||||
const int r1 = tgpig.y;
|
||||
const int im = tgpig.z;
|
||||
|
||||
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
|
||||
const int ib_row = first_row * nb;
|
||||
|
||||
const uint i12 = im%ne12;
|
||||
const uint i13 = im/ne12;
|
||||
|
||||
const uint offset0 = (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
|
||||
device const block_iq1_m * x = (device const block_iq1_m *) src0 + ib_row + offset0;
|
||||
device const float * y = (device const float *) src1 + r1*ne10 + im*ne00*ne1;
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
||||
const int nb32 = nb * (QK_K / 32);
|
||||
|
||||
const int ix = tiisg;
|
||||
|
||||
device const float * y4 = y + 32 * ix;
|
||||
|
||||
#if QK_K != 64
|
||||
iq1m_scale_t scale;
|
||||
#endif
|
||||
|
||||
for (int ib32 = ix; ib32 < nb32; ib32 += 32) {
|
||||
|
||||
float4 sumy = {0.f};
|
||||
for (int i = 0; i < 8; ++i) {
|
||||
yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0];
|
||||
yl[i+ 8] = y4[i+ 8]; sumy[1] += yl[i+ 8];
|
||||
yl[i+16] = y4[i+16]; sumy[2] += yl[i+16];
|
||||
yl[i+24] = y4[i+24]; sumy[3] += yl[i+24];
|
||||
}
|
||||
|
||||
const int ibl = ib32 / (QK_K / 32);
|
||||
const int ib = ib32 % (QK_K / 32);
|
||||
|
||||
device const block_iq1_m * xr = x + ibl;
|
||||
device const uint8_t * qs = xr->qs + 4 * ib;
|
||||
device const uint8_t * qh = xr->qh + 2 * ib;
|
||||
device const uint16_t * sc = (device const uint16_t *)xr->scales;
|
||||
|
||||
for (int row = 0; row < N_DST; row++) {
|
||||
|
||||
#if QK_K != 64
|
||||
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||
#endif
|
||||
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
||||
constant uint8_t * grid3 = (constant uint8_t *)(iq1s_grid_gpu + (qs[2] | ((qh[1] << 8) & 0x700)));
|
||||
constant uint8_t * grid4 = (constant uint8_t *)(iq1s_grid_gpu + (qs[3] | ((qh[1] << 4) & 0x700)));
|
||||
|
||||
float2 sum = {0.f};
|
||||
for (int j = 0; j < 4; ++j) {
|
||||
sum[0] += yl[j+ 0] * (grid1[j] & 0xf) + yl[j+ 4] * (grid1[j] >> 4)
|
||||
+ yl[j+ 8] * (grid2[j] & 0xf) + yl[j+12] * (grid2[j] >> 4);
|
||||
sum[1] += yl[j+16] * (grid3[j] & 0xf) + yl[j+20] * (grid3[j] >> 4)
|
||||
+ yl[j+24] * (grid4[j] & 0xf) + yl[j+28] * (grid4[j] >> 4);
|
||||
}
|
||||
const float delta1 = sumy[0] * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[1] * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
||||
const float delta2 = sumy[2] * (qh[1] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA) + sumy[3] * (qh[1] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
||||
#if QK_K == 64
|
||||
const float d = (float) *((device const half *)(sc - 1));
|
||||
sumf[row] += d * ((sum[0] + delta1) * (2*((sc[0] >> (8*(ib%2)+0)) & 0xf) + 1) +
|
||||
(sum[1] + delta2) * (2*((sc[0] >> (8*(ib%2)+4)) & 0xf) + 1));
|
||||
#else
|
||||
sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) +
|
||||
(sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1));
|
||||
#endif
|
||||
|
||||
sc += nb*sizeof(block_iq1_m)/2;
|
||||
qs += nb*sizeof(block_iq1_m);
|
||||
qh += nb*sizeof(block_iq1_m);
|
||||
}
|
||||
|
||||
y4 += 32 * 32;
|
||||
}
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
|
@ -5078,6 +5186,34 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|||
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
||||
kernel void kernel_mul_mv_iq1_m_f32(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
||||
kernel void kernel_mul_mv_iq4_nl_f32(
|
||||
device const void * src0,
|
||||
|
@ -5551,6 +5687,38 @@ void dequantize_iq1_s(device const block_iq1_s * xb, short il, thread type4x4 &
|
|||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq1_m(device const block_iq1_m * xb, short il, thread type4x4 & reg) {
|
||||
// il is 0...15 for QK_K = 256 => index of block of 32 is il/2
|
||||
const int ib32 = il/2;
|
||||
il = il%2;
|
||||
device const uint16_t * sc = (device const uint16_t *)xb->scales;
|
||||
#if QK_K == 64
|
||||
const float d = xb->d;
|
||||
#else
|
||||
iq1m_scale_t scale;
|
||||
scale.u16 = (sc[0] >> 12) | ((sc[1] >> 8) & 0x00f0) | ((sc[2] >> 4) & 0x0f00) | (sc[3] & 0xf000);
|
||||
const float d = scale.f16;
|
||||
#endif
|
||||
device const uint8_t * qs = xb->qs + 4*ib32 + 2*il;
|
||||
device const uint8_t * qh = xb->qh + 2*ib32 + il;
|
||||
#if QK_K == 64
|
||||
const float dl = d * (2*((sc[ib32/2] >> (8*(ib32%2)+4*il)) & 0xf) + 1);
|
||||
#else
|
||||
const float dl = d * (2*((sc[ib32/2] >> (6*(ib32%2)+3*il)) & 7) + 1);
|
||||
#endif
|
||||
const float ml1 = dl * (qh[0] & 0x08 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
||||
const float ml2 = dl * (qh[0] & 0x80 ? -1 - IQ1M_DELTA : -1 + IQ1M_DELTA);
|
||||
constant uint8_t * grid1 = (constant uint8_t *)(iq1s_grid_gpu + (qs[0] | ((qh[0] << 8) & 0x700)));
|
||||
constant uint8_t * grid2 = (constant uint8_t *)(iq1s_grid_gpu + (qs[1] | ((qh[0] << 4) & 0x700)));
|
||||
for (int i = 0; i < 4; ++i) {
|
||||
reg[0][i] = dl * (grid1[i] & 0xf) + ml1;
|
||||
reg[1][i] = dl * (grid1[i] >> 4) + ml1;
|
||||
reg[2][i] = dl * (grid2[i] & 0xf) + ml2;
|
||||
reg[3][i] = dl * (grid2[i] >> 4) + ml2;
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_iq4_nl(device const block_iq4_nl * xb, short il, thread type4x4 & reg) {
|
||||
device const uint16_t * q4 = (device const uint16_t *)xb->qs;
|
||||
|
@ -6135,6 +6303,7 @@ template [[host_name("kernel_get_rows_iq3_xxs")]] kernel get_rows_t kernel_get_r
|
|||
template [[host_name("kernel_get_rows_iq3_s")]] kernel get_rows_t kernel_get_rows<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_get_rows_iq2_s")]] kernel get_rows_t kernel_get_rows<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_get_rows_iq1_s")]] kernel get_rows_t kernel_get_rows<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_get_rows_iq1_m")]] kernel get_rows_t kernel_get_rows<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_get_rows_iq4_nl")]] kernel get_rows_t kernel_get_rows<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
#if QK_K == 64
|
||||
template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_t kernel_get_rows<block_iq4_xs, 2, dequantize_iq4_xs>;
|
||||
|
@ -6183,6 +6352,7 @@ template [[host_name("kernel_mul_mm_iq3_xxs_f32")]] kernel mat_mm_t kernel_mul_m
|
|||
template [[host_name("kernel_mul_mm_iq3_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_iq2_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_s_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_iq1_m_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_iq4_nl_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
#if QK_K == 64
|
||||
template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_mm<block_iq4_nl, 2, dequantize_iq4_xs>;
|
||||
|
@ -6243,6 +6413,7 @@ template [[host_name("kernel_mul_mm_id_iq3_xxs_f32")]] kernel mat_mm_id_t kernel
|
|||
template [[host_name("kernel_mul_mm_id_iq3_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq3_s, QK_NL, dequantize_iq3_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq2_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq2_s, QK_NL, dequantize_iq2_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_s_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_s, QK_NL, dequantize_iq1_s>;
|
||||
template [[host_name("kernel_mul_mm_id_iq1_m_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq1_m, QK_NL, dequantize_iq1_m>;
|
||||
template [[host_name("kernel_mul_mm_id_iq4_nl_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_nl, 2, dequantize_iq4_nl>;
|
||||
#if QK_K == 64
|
||||
template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel_mul_mm_id<block_iq4_xs, 2, dequantize_iq4_xs>;
|
||||
|
@ -7410,6 +7581,69 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|||
sgitg);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
|
||||
kernel void kernel_mul_mv_id_iq1_m_f32(
|
||||
device const char * ids,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
constant int64_t & ne02,
|
||||
constant uint64_t & nb00,
|
||||
constant uint64_t & nb01,
|
||||
constant uint64_t & nb02,
|
||||
constant int64_t & ne10,
|
||||
constant int64_t & ne11,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne13,
|
||||
constant uint64_t & nb10,
|
||||
constant uint64_t & nb11,
|
||||
constant uint64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint64_t & nb1,
|
||||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
|
||||
kernel_mul_mv_iq1_m_f32_impl(
|
||||
src0[id],
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
ne01,
|
||||
ne02,
|
||||
ne10,
|
||||
ne12,
|
||||
ne0,
|
||||
ne1,
|
||||
r2,
|
||||
r3,
|
||||
tgpig,
|
||||
tiisg,
|
||||
sgitg);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
||||
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
||||
device const char * ids,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue