metal: fix performance degradation from gqa

Integers are slow on the GPU, and 64-bit divides are extremely slow.
In the context of GQA, we introduce a 64-bit divide that cannot be
optimized out by the compiler, which results in a decrease of ~8% in
inference performance. This commit fixes that issue by calculating a
part of the offset with a 32-bit divide. Naturally, this limits the
size of a single matrix to ~4GB. However, this limitation should
suffice for the near future.
This commit is contained in:
lshzh-ww 2023-08-14 23:10:27 -04:00
parent 5f6de2a2bb
commit bfa455de43
2 changed files with 37 additions and 19 deletions

View file

@ -712,6 +712,7 @@ void ggml_metal_graph_compute(
GGML_ASSERT(ne00 == ne10); GGML_ASSERT(ne00 == ne10);
// GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere // GGML_ASSERT(ne02 == ne12); // Should be checked on individual data types until broadcast is implemented everywhere
uint gqa = ne12/ne02;
GGML_ASSERT(ne03 == ne13); GGML_ASSERT(ne03 == ne13);
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
@ -743,6 +744,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7]; [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:7];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:8];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:9];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:10];
[encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder setThreadgroupMemoryLength:8192 atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11+31)/32, (ne01+63) / 64, ne12) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
} }
@ -845,6 +847,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:14];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:15];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:16];
[encoder setBytes:&gqa length:sizeof(gqa) atIndex:17];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) {

View file

@ -343,14 +343,15 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
// N_DST, so this is another explicit assumption of the implementation. // N_DST, so this is another explicit assumption of the implementation.
template<typename block_q_type, int nr, int nsg, int nw> template<typename block_q_type, int nr, int nsg, int nw>
void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst,
int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, int64_t ne00, int64_t ne01, int64_t ne02, int64_t ne10, int64_t ne12, int64_t ne0, uint gqa,
uint3 tgpig, uint tiisg, uint sgitg) { uint3 tgpig, uint tiisg, uint sgitg) {
const int nb = ne00/QK4_0; const int nb = ne00/QK4_0;
const int r0 = tgpig.x; const int r0 = tgpig.x;
const int r1 = tgpig.y; const int r1 = tgpig.y;
const int im = tgpig.z; const int im = tgpig.z;
const int first_row = (r0 * nsg + sgitg) * nr; const int first_row = (r0 * nsg + sgitg) * nr;
device const block_q_type * x = (device const block_q_type *) src0 + first_row * nb + im/(ne12/ne02)*(ne02/QK4_0); const uint offset0 = first_row * nb + im/gqa*(ne02/QK4_0);
device const block_q_type * x = (device const block_q_type *) src0 + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + im*ne12; device const float * y = (device const float *) src1 + r1*ne10 + im*ne12;
float yl[16]; // src1 vector cache float yl[16]; // src1 vector cache
float sumf[nr]={0.f}; float sumf[nr]={0.f};
@ -397,10 +398,11 @@ kernel void kernel_mul_mat_q4_0_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,tgpig,tiisg,sgitg); mul_vec_q_n_f32<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
} }
kernel void kernel_mul_mat_q4_1_f32( kernel void kernel_mul_mat_q4_1_f32(
@ -413,10 +415,11 @@ kernel void kernel_mul_mat_q4_1_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,tgpig,tiisg,sgitg); mul_vec_q_n_f32<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,gqa,tgpig,tiisg,sgitg);
} }
kernel void kernel_mul_mat_f16_f32( kernel void kernel_mul_mat_f16_f32(
@ -797,6 +800,7 @@ kernel void kernel_mul_mat_q2_K_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -808,7 +812,8 @@ kernel void kernel_mul_mat_q2_K_f32(
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb; const int ib_row = first_row * nb;
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K); const uint offset0 = r2/gqa*(ne02/QK_K);
device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12; device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
float yl[32]; float yl[32];
float sumf[N_DST]={0.f}, all_sum; float sumf[N_DST]={0.f}, all_sum;
@ -938,6 +943,7 @@ kernel void kernel_mul_mat_q3_K_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -946,11 +952,11 @@ kernel void kernel_mul_mat_q3_K_f32(
const int64_t r0 = tgpig.x; const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y; const int64_t r1 = tgpig.y;
const int64_t r2 = tgpig.x; const int64_t r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
const uint offset0 = r2/gqa*(ne02/QK_K);
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + r2/(ne12/ne02)*(ne02/QK_K); device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
float yl[16]; float yl[16];
@ -1054,6 +1060,7 @@ kernel void kernel_mul_mat_q3_K_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -1062,11 +1069,11 @@ kernel void kernel_mul_mat_q3_K_f32(
const int64_t r0 = tgpig.x; const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y; const int64_t r1 = tgpig.y;
const int64_t r2 = tgpig.x; const int64_t r2 = tgpig.z;
const int row = 2 * r0 + sgitg; const int row = 2 * r0 + sgitg;
const uint offset0 = r2/gqa*(ne02/QK_K);
device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + r2/(ne12/ne02)*(ne02/QK_K); device const block_q3_K * x = (device const block_q3_K *) src0 + row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne02;
const int ix = tiisg/4; const int ix = tiisg/4;
const int il = 4 * (tiisg%4);// 0, 4, 8, 12 const int il = 4 * (tiisg%4);// 0, 4, 8, 12
@ -1123,6 +1130,7 @@ kernel void kernel_mul_mat_q4_K_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -1142,7 +1150,8 @@ kernel void kernel_mul_mat_q4_K_f32(
const int r2 = tgpig.z; const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb; const int ib_row = first_row * nb;
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K); const uint offset0 = r2/gqa*(ne02/QK_K);
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12; device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
float yl[16]; float yl[16];
float yh[16]; float yh[16];
@ -1225,6 +1234,7 @@ kernel void kernel_mul_mat_q4_K_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -1238,7 +1248,8 @@ kernel void kernel_mul_mat_q4_K_f32(
const int r2 = tgpig.z; const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST;
const int ib_row = first_row * nb; const int ib_row = first_row * nb;
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + r2/(ne12/ne02)*(ne02/QK_K); const uint offset0 = r2/gqa*(ne02/QK_K);
device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row + offset0;
device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12; device const float * y = (device const float *) src1 + r1*ne10 + r2*ne12;
float yl[8]; float yl[8];
float yh[8]; float yh[8];
@ -1311,6 +1322,7 @@ kernel void kernel_mul_mat_q5_K_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -1322,8 +1334,8 @@ kernel void kernel_mul_mat_q5_K_f32(
const int r2 = tgpig.z; const int r2 = tgpig.z;
const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2;
const uint offset0 = r2/gqa*(ne02/QK_K);
device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + r2/(ne12/ne02)*(ne02/QK_K); device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;
float sumf[2]={0.f}; float sumf[2]={0.f};
@ -1474,6 +1486,7 @@ kernel void kernel_mul_mat_q6_K_f32(
constant int64_t & ne10[[buffer(9)]], constant int64_t & ne10[[buffer(9)]],
constant int64_t & ne12[[buffer(11)]], constant int64_t & ne12[[buffer(11)]],
constant int64_t & ne0[[buffer(15)]], constant int64_t & ne0[[buffer(15)]],
constant uint & gqa[[buffer(17)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
@ -1490,8 +1503,8 @@ kernel void kernel_mul_mat_q6_K_f32(
const int r2 = tgpig.z; const int r2 = tgpig.z;
const int row = 2 * r0 + sgitg; const int row = 2 * r0 + sgitg;
const uint offset0 = r2/gqa*(ne02/QK_K);
device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + r2/(ne12/ne02)*(ne02/QK_K); device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne12;
float sumf = 0; float sumf = 0;
@ -1792,6 +1805,7 @@ kernel void kernel_mul_mm(device const uchar * src0,
constant int64_t & ne12, constant int64_t & ne12,
constant int64_t & ne0, constant int64_t & ne0,
constant int64_t & ne1, constant int64_t & ne1,
constant uint & gqa,
threadgroup uchar * shared_memory [[threadgroup(0)]], threadgroup uchar * shared_memory [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint tiitg[[thread_index_in_threadgroup]], uint tiitg[[thread_index_in_threadgroup]],
@ -1818,7 +1832,8 @@ kernel void kernel_mul_mm(device const uchar * src0,
} }
short il = (tiitg % THREAD_PER_ROW); short il = (tiitg % THREAD_PER_ROW);
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + im/(ne12/ne02)*nb02) + il/nl; uint offset0 = im/gqa*nb02; ushort offset1 = il/nl;
device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1;
device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \ device const float * y = src1 + (r1 * BLOCK_SIZE_N + thread_col) * ne00 \
+ BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne12; + BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL) + im * ne12;
@ -1909,7 +1924,7 @@ template [[host_name("kernel_get_rows_q6_K")]] kernel get_rows_t kernel_get_rows
typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\ typedef void (mat_mm_t)(device const uchar *, device const float *, device float *, constant int64_t &,\
constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \ constant int64_t &, constant int64_t &, constant int64_t &, constant int64_t &, \
constant int64_t &, constant int64_t &, threadgroup uchar *, uint3, uint, uint); constant int64_t &, constant int64_t &, constant uint &, threadgroup uchar *, uint3, uint, uint);
template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>; template [[host_name("kernel_mul_mm_f16_f32")]] kernel mat_mm_t kernel_mul_mm<half4x4, 1, dequantize_f16>;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm<block_q4_0, 2, dequantize_q4_0>;