From 76f9d41dd6bcc9012363ad7cdad8b899529d91ac Mon Sep 17 00:00:00 2001 From: Georgi Gerganov Date: Sun, 31 Dec 2023 18:14:02 +0200 Subject: [PATCH] metal : optimizing ggml_mul_mat_id (wip) --- ggml-metal.m | 31 ++++--- ggml-metal.metal | 220 +++++++++++++++++++++++++++++++++++++++-------- llama.cpp | 3 +- 3 files changed, 207 insertions(+), 47 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index cd9d00456..4167b6b14 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -1657,6 +1657,10 @@ void ggml_metal_graph_compute( } }; + if (ggml_is_quantized(src0t)) { + GGML_ASSERT(ne00 >= nth0*nth1); + } + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; @@ -1715,6 +1719,9 @@ void ggml_metal_graph_compute( // TODO: make this more general GGML_ASSERT(n_as <= 8); + // max size of the src1ids array in the kernel stack + GGML_ASSERT(ne11 <= 512); + struct ggml_tensor * src2 = gf->nodes[i]->src[2]; const int64_t ne20 = src2 ? src2->ne[0] : 0; @@ -1732,9 +1739,6 @@ void ggml_metal_graph_compute( GGML_ASSERT(!ggml_is_transposed(src2)); GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(ne20 % 32 == 0); - // !!!!!!!!! TODO: this assert is probably required but not sure! - //GGML_ASSERT(ne20 >= 64); GGML_ASSERT(src1t == GGML_TYPE_F32); const uint r2 = ne12/ne22; @@ -1742,22 +1746,22 @@ void ggml_metal_graph_compute( // find the break-even point where the matrix-matrix kernel becomes more efficient compared // to the matrix-vector kernel - int ne11_mm_min = 1; + int ne11_mm_min = n_as; const int idx = ((int32_t *) dst->op_params)[0]; // batch size GGML_ASSERT(ne01 == ne11); - const int64_t _ne1 = 1; // kernel_mul_mm_impl needs a reference in constant memory - // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel // !!! // TODO: for now, always use mat-vec kernels until we figure out how to improve the // indirect matrix multiplication // !!! - if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && _ne1 > ne11_mm_min) { + if ([ctx->device supportsFamily:MTLGPUFamilyApple7] && + ne20 % 32 == 0 && ne20 >= 64 && + ne11 > ne11_mm_min) { switch (src2->type) { case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f32_f32]; break; case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_mul_mm_id_f16_f32]; break; @@ -1787,7 +1791,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:14]; + [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; [encoder setBytes:&r2 length:sizeof(r2) atIndex:16]; [encoder setBytes:&r3 length:sizeof(r3) atIndex:17]; @@ -1805,8 +1809,7 @@ void ggml_metal_graph_compute( [encoder setThreadgroupMemoryLength:8192 atIndex:0]; - // TODO: processing one row at a time (ne11 -> 1) is not efficient - [encoder dispatchThreadgroups:MTLSizeMake( (_ne1 + 31)/32, (ne21 + 63)/64, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { int nth0 = 32; int nth1 = 1; @@ -1889,11 +1892,17 @@ void ggml_metal_graph_compute( } break; default: { - GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src0t); + GGML_METAL_LOG_ERROR("Asserting on type %d\n", (int)src2t); GGML_ASSERT(false && "not implemented"); } }; + if (ggml_is_quantized(src2t)) { + GGML_ASSERT(ne20 >= nth0*nth1); + } + + const int64_t _ne1 = 1; // kernels needs a reference in constant memory + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; diff --git a/ggml-metal.metal b/ggml-metal.metal index 1d5b8f6f4..9e9cb40e3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -846,7 +846,7 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre #define N_SIMDGROUP 2 // number of SIMD groups in a thread group //Note: This is a template, but strictly speaking it only applies to // quantizations where the block size is 32. It also does not -// giard against the number of rows not being divisible by +// guard against the number of rows not being divisible by // N_DST, so this is another explicit assumption of the implementation. template void mul_vec_q_n_f32_impl( @@ -3973,6 +3973,146 @@ void kernel_mul_mm_impl(device const uchar * src0, } } +// same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in src1ids +// TODO: simplify and optimize constants +template +void kernel_mul_mm_id_impl( + device const uchar * src0, + device const uchar * src1, + thread short * src1ids, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne02, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant int64_t & ne12, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant int64_t & ne0, + int64_t ne1, + constant uint & r2, + constant uint & r3, + threadgroup uchar * shared_memory, + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiitg[[thread_index_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + threadgroup half * sa = (threadgroup half *)(shared_memory); + threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + + const uint r0 = tgpig.y; + const uint r1 = tgpig.x; + const uint im = tgpig.z; + + if (r1 * BLOCK_SIZE_N >= ne1) return; + + // if this block is of 64x32 shape or smaller + short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (ne1 - r1 * BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1 * BLOCK_SIZE_N) : BLOCK_SIZE_N; + + // a thread shouldn't load data outside of the matrix + short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; + short thread_col = ((short)tiitg/THREAD_PER_COL) < n_cols ? ((short)tiitg/THREAD_PER_COL) : n_cols - 1; + + simdgroup_half8x8 ma[4]; + simdgroup_float8x8 mb[2]; + simdgroup_float8x8 c_res[8]; + for (int i = 0; i < 8; i++){ + c_res[i] = make_filled_simdgroup_matrix(0.f); + } + + short il = (tiitg % THREAD_PER_ROW); + + const uint i12 = im%ne12; + const uint i13 = im/ne12; + + uint offset0 = (i12/r2)*nb02 + (i13/r3)*(nb02*ne02); + ushort offset1 = il/nl; + + device const block_q * x = (device const block_q *)(src0 + (r0 * BLOCK_SIZE_M + thread_row) * nb01 + offset0) + offset1; + device const float * y = (device const float *)(src1 + + nb12 * im + + nb11 * src1ids[r1 * BLOCK_SIZE_N + thread_col] + + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + // load data and store to threadgroup memory + half4x4 temp_a; + dequantize_func(x, il, temp_a); + threadgroup_barrier(mem_flags::mem_threadgroup); + + #pragma unroll(16) + for (int i = 0; i < 16; i++) { + *(sa + SG_MAT_SIZE * ((tiitg / THREAD_PER_ROW / 8) \ + + (tiitg % THREAD_PER_ROW) * 16 + (i / 8) * 8) \ + + (tiitg / THREAD_PER_ROW) % 8 + (i & 7) * 8) = temp_a[i/4][i%4]; + } + + *(threadgroup float2x4 *)(sb + (tiitg % THREAD_PER_COL) * 8 * 32 + 8 * (tiitg / THREAD_PER_COL)) = *((device float2x4 *)y); + + il = (il + 2 < nl) ? il + 2 : il % 2; + x = (il < 2) ? x + (2+nl-1)/nl : x; + y += BLOCK_SIZE_K; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // load matrices from threadgroup memory and conduct outer products + threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); + threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + + #pragma unroll(4) + for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) + for (int i = 0; i < 4; i++) { + simdgroup_load(ma[i],lsma + SG_MAT_SIZE * i); + } + simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) + for (int i = 0; i < 2; i++) { + simdgroup_load(mb[i],lsmb + SG_MAT_SIZE * i); + } + + lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; + lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + + #pragma unroll(8) + for (int i = 0; i < 8; i++){ + simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + } + } + } + + // TODO: this branch is invalid - need to fix it + //if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { + // device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ + // + im*ne1*ne0; + // for (int i = 0; i < 8; i++) { + // simdgroup_store(c_res[i], C + 8 * (i%4) + ne0*src1ids[8*(i/4) + BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)], ne0); + // } + //} else { + { + // block is smaller than 64x32, we should avoid writing data outside of the matrix + threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; + for (int i = 0; i < 8; i++) { + simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + device float * C = dst + (BLOCK_SIZE_M * r0) + im*ne1*ne0; + if (sgitg == 0) { + for (int i = 0; i < n_rows; i++) { + for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { + *(C + i + src1ids[j + r1*BLOCK_SIZE_N] * ne0) = *(temp_str + i + j * BLOCK_SIZE_M); + } + } + } + } +} + template kernel void kernel_mul_mm(device const uchar * src0, device const uchar * src1, @@ -4019,7 +4159,7 @@ template( - src0[id], - src1 + bid*nb11, - (device float *) (dst + bid*nb1), + for (int64_t i1 = 0; i1 < ne1; i1++) { + if (((device int32_t *) (ids + i1*nbi1))[idx] == id) { + src1ids[_ne1++] = i1; + } + } + + kernel_mul_mm_id_impl( + src0s[id], + src1, + src1ids, + dst, ne00, ne02, nb01, @@ -4069,7 +4219,7 @@ kernel void kernel_mul_mm_id( nb11, nb12, ne0, - ne1, + _ne1, r2, r3, shared_memory, @@ -4158,7 +4308,7 @@ template [[host_name("kernel_mul_mm_q6_K_f32")]] kernel mat_mm_t kernel_mul_mm( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4471,7 +4621,7 @@ kernel void kernel_mul_mv_id_q4_0_f32( kernel void kernel_mul_mv_id_q4_1_f32( device const char * ids, device const char * src1, - device uchar * dst, + device float * dst, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -4515,7 +4665,7 @@ kernel void kernel_mul_mv_id_q4_1_f32( mul_vec_q_n_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4534,7 +4684,7 @@ kernel void kernel_mul_mv_id_q4_1_f32( kernel void kernel_mul_mv_id_q5_0_f32( device const char * ids, device const char * src1, - device uchar * dst, + device float * dst, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -4578,7 +4728,7 @@ kernel void kernel_mul_mv_id_q5_0_f32( mul_vec_q_n_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4597,7 +4747,7 @@ kernel void kernel_mul_mv_id_q5_0_f32( kernel void kernel_mul_mv_id_q5_1_f32( device const char * ids, device const char * src1, - device uchar * dst, + device float * dst, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -4641,7 +4791,7 @@ kernel void kernel_mul_mv_id_q5_1_f32( mul_vec_q_n_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4660,7 +4810,7 @@ kernel void kernel_mul_mv_id_q5_1_f32( kernel void kernel_mul_mv_id_q2_K_f32( device const char * ids, device const char * src1, - device uchar * dst, + device float * dst, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -4704,7 +4854,7 @@ kernel void kernel_mul_mv_id_q2_K_f32( kernel_mul_mv_q2_K_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4723,7 +4873,7 @@ kernel void kernel_mul_mv_id_q2_K_f32( kernel void kernel_mul_mv_id_q3_K_f32( device const char * ids, device const char * src1, - device uchar * dst, + device float * dst, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -4767,7 +4917,7 @@ kernel void kernel_mul_mv_id_q3_K_f32( kernel_mul_mv_q3_K_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4786,7 +4936,7 @@ kernel void kernel_mul_mv_id_q3_K_f32( kernel void kernel_mul_mv_id_q4_K_f32( device const char * ids, device const char * src1, - device uchar * dst, + device float * dst, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -4830,7 +4980,7 @@ kernel void kernel_mul_mv_id_q4_K_f32( kernel_mul_mv_q4_K_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4849,7 +4999,7 @@ kernel void kernel_mul_mv_id_q4_K_f32( kernel void kernel_mul_mv_id_q5_K_f32( device const char * ids, device const char * src1, - device uchar * dst, + device float * dst, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -4893,7 +5043,7 @@ kernel void kernel_mul_mv_id_q5_K_f32( kernel_mul_mv_q5_K_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, @@ -4912,7 +5062,7 @@ kernel void kernel_mul_mv_id_q5_K_f32( kernel void kernel_mul_mv_id_q6_K_f32( device const char * ids, device const char * src1, - device uchar * dst, + device float * dst, constant uint64_t & nbi1, constant int64_t & ne00, constant int64_t & ne01, @@ -4956,7 +5106,7 @@ kernel void kernel_mul_mv_id_q6_K_f32( kernel_mul_mv_q6_K_f32_impl( src0[id], (device const float *) (src1 + bid*nb11), - (device float *) ( dst + bid*nb1), + dst + bid*ne0, ne00, ne01, ne02, diff --git a/llama.cpp b/llama.cpp index a833d4c15..d4b7e95a0 100644 --- a/llama.cpp +++ b/llama.cpp @@ -9852,8 +9852,9 @@ int llama_model_meta_val_str_by_index(const struct llama_model * model, int i, c } int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) { - return snprintf(buf, buf_size, "%s %s %s", + return snprintf(buf, buf_size, "%s %s%s %s", llama_model_arch_name(model->arch).c_str(), + model->hparams.n_expert > 0 ? (std::to_string(model->hparams.n_expert) + "x").c_str() : "", llama_model_type_name(model->type), llama_model_ftype_name(model->ftype).c_str()); }