metal : relax conditions on fast matrix multiplication kernel (#3168)
* metal : relax conditions on fast matrix multiplication kernel * metal : revert the concurrnecy change because it was wrong * llama : remove experimental stuff
This commit is contained in:
parent
76164fe2e6
commit
a51b687657
4 changed files with 100 additions and 51 deletions
|
@ -38,7 +38,7 @@ kernel void kernel_add_row(
|
|||
device const float4 * src0,
|
||||
device const float4 * src1,
|
||||
device float4 * dst,
|
||||
constant int64_t & nb,
|
||||
constant int64_t & nb,
|
||||
uint tpig[[thread_position_in_grid]]) {
|
||||
dst[tpig] = src0[tpig] + src1[tpig % nb];
|
||||
}
|
||||
|
@ -1321,7 +1321,6 @@ kernel void kernel_mul_mat_q3_K_f32(
|
|||
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
#else
|
||||
kernel void kernel_mul_mat_q3_K_f32(
|
||||
|
@ -1865,6 +1864,15 @@ kernel void kernel_mul_mat_q6_K_f32(
|
|||
|
||||
//============================= templates and their specializations =============================
|
||||
|
||||
// NOTE: this is not dequantizing - we are simply fitting the template
|
||||
template <typename type4x4>
|
||||
void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) {
|
||||
float4x4 temp = *(((device float4x4 *)src));
|
||||
for (int i = 0; i < 16; i++){
|
||||
reg[i/4][i%4] = temp[i/4][i%4];
|
||||
}
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg) {
|
||||
half4x4 temp = *(((device half4x4 *)src));
|
||||
|
@ -1875,7 +1883,6 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
|
|||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
|
||||
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 1);
|
||||
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
||||
const float d2 = d1 / 256.f;
|
||||
|
@ -1887,12 +1894,10 @@ void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg
|
|||
reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
|
||||
reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
template <typename type4x4>
|
||||
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
|
||||
|
||||
device const uint16_t * qs = ((device const uint16_t *)xb + 2);
|
||||
const float d1 = il ? (xb->d / 16.h) : xb->d;
|
||||
const float d2 = d1 / 256.f;
|
||||
|
@ -1964,7 +1969,6 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
|
|||
for (int i = 0; i < 16; ++i) {
|
||||
reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
|
||||
}
|
||||
|
||||
#else
|
||||
float kcoef = il&1 ? 1.f/16.f : 1.f;
|
||||
uint16_t kmask = il&1 ? 0xF0 : 0x0F;
|
||||
|
@ -2110,22 +2114,25 @@ kernel void kernel_get_rows(
|
|||
// each block_q contains 16*nl weights
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
kernel void kernel_mul_mm(device const uchar * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & nb01,
|
||||
constant int64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & gqa,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & nb01,
|
||||
constant int64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & nb10,
|
||||
constant int64_t & nb11,
|
||||
constant int64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & gqa,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
threadgroup half * sa = ((threadgroup half *)shared_memory);
|
||||
threadgroup half * sa = (threadgroup half *)(shared_memory);
|
||||
threadgroup float * sb = (threadgroup float *)(shared_memory + 4096);
|
||||
|
||||
const uint r0 = tgpig.y;
|
||||
|
@ -2138,7 +2145,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||
short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1;
|
||||
short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1;
|
||||
|
||||
simdgroup_half8x8 ma[4];
|
||||
simdgroup_half8x8 ma[4];
|
||||
simdgroup_float8x8 mb[2];
|
||||
simdgroup_float8x8 c_res[8];
|
||||
for (int i = 0; i < 8; i++){
|
||||
|
@ -2146,10 +2153,15 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||
}
|
||||
|
||||
short il = (tiitg % THREAD_PER_ROW);
|
||||
uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
|
||||
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
||||
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
|
||||
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne00 * ne1;
|
||||
|
||||
uint offset0 = im/gqa*nb02;
|
||||
ushort offset1 = il/nl;
|
||||
|
||||
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
|
||||
device const float * y = (device const float *)(src1
|
||||
+ nb12 * im
|
||||
+ nb11 * (r1 * BLOCK_SIZE_N + thread_col)
|
||||
+ nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL)));
|
||||
|
||||
for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) {
|
||||
//load data and store to threadgroup memory
|
||||
|
@ -2229,6 +2241,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||
typedef void (get_rows_t)(device const void *, device const int *, device float *, constant int64_t &, \
|
||||
constant uint64_t &, constant uint64_t &, uint, uint, uint);
|
||||
|
||||
template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows<float4x4, 1, dequantize_f32>;
|
||||
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows<block_q4_1, 2, dequantize_q4_1>;
|
||||
|
@ -2239,14 +2252,27 @@ template [[host_name("kernel_get_rows_q4_K")]] kernel get_rows_t kernel_get_rows
|
|||
template [[host_name("kernel_get_rows_q5_K")]] kernel get_rows_t kernel_get_rows<block_q5_K, QK_NL, dequantize_q5_K>;
|
||||
template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows<block_q6_K, QK_NL, dequantize_q6_K>;
|
||||
|
||||
typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
|
||||
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
|
||||
constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
|
||||
typedef void (mat_mm_t)(
|
||||
device const uchar * src0,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
constant int64_t & nb01,
|
||||
constant int64_t & nb02,
|
||||
constant int64_t & ne12,
|
||||
constant int64_t & nb10,
|
||||
constant int64_t & nb11,
|
||||
constant int64_t & nb12,
|
||||
constant int64_t & ne0,
|
||||
constant int64_t & ne1,
|
||||
constant uint & gqa,
|
||||
threadgroup uchar *, uint3, uint, uint);
|
||||
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
|
||||
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;
|
||||
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_1, 2, dequantize_q4_1>;
|
||||
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q8_0, 2, dequantize_q8_0>;
|
||||
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q2_K, QK_NL, dequantize_q2_K>;
|
||||
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q3_K, QK_NL, dequantize_q3_K>;
|
||||
template [[host_name("kernel_mul_mm_q4_K_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_K, QK_NL, dequantize_q4_K>;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue