From fe62909618949b4496afe2b918ad0dae0611a3a2 Mon Sep 17 00:00:00 2001 From: slaren Date: Tue, 2 Apr 2024 20:31:01 +0200 Subject: [PATCH] metal : add support for non-pow-2 argsort --- ggml-metal.m | 17 +++++++++++++---- ggml-metal.metal | 48 ++++++++++++++++++++++++++++++++---------------- 2 files changed, 45 insertions(+), 20 deletions(-) diff --git a/ggml-metal.m b/ggml-metal.m index 69ffb39b7..51a5fab3a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -2411,6 +2411,13 @@ static enum ggml_status ggml_metal_graph_compute( enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; + // bitonic sort requires the number of elements to be power of 2 + int64_t ne00_padded = 1; + while (ne00_padded < ne00) { + ne00_padded *= 2; + } + + const int mem_size = ne00_padded*sizeof(int32_t); id pipeline = nil; switch (order) { @@ -2420,11 +2427,13 @@ static enum ggml_status ggml_metal_graph_compute( }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; + [encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3]; + [encoder setThreadgroupMemoryLength:mem_size atIndex:0]; - [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)]; + [encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)]; } break; case GGML_OP_LEAKY_RELU: { diff --git a/ggml-metal.metal b/ggml-metal.metal index a876af365..9a29f57a3 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -13,8 +13,8 @@ using namespace metal; #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 enum ggml_sort_order { - GGML_SORT_ASC, - GGML_SORT_DESC, + GGML_SORT_ORDER_ASC, + GGML_SORT_ORDER_DESC, }; // general-purpose kernel for addition, multiplication and division of two tensors @@ -1973,9 +1973,11 @@ kernel void kernel_timestep_embedding_f32( // bitonic sort implementation following the CUDA kernels as reference typedef void (argsort_t)( - device const float * x, - device int32_t * dst, - constant int64_t & ncols, + device const float * x, + device int32_t * dst, + constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]); @@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32( device const float * x, device int32_t * dst, constant int64_t & ncols, + constant int64_t & ncols_pad, + threadgroup int32_t * shared_values [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], uint3 tpitg[[thread_position_in_threadgroup]]) { // bitonic sort int col = tpitg[0]; int row = tgpig[1]; - if (col >= ncols) return; + if (col >= ncols_pad) return; - device const float * x_row = x + row * ncols; - device int32_t * dst_row = dst + row * ncols; + device const float * x_row = x + row * ncols; + threadgroup int32_t * dst_row = shared_values; // initialize indices - if (col < ncols) { - dst_row[col] = col; - } + dst_row[col] = col; + threadgroup_barrier(mem_flags::mem_threadgroup); - for (int k = 2; k <= ncols; k *= 2) { + for (int k = 2; k <= ncols_pad; k *= 2) { for (int j = k / 2; j > 0; j /= 2) { int ixj = col ^ j; if (ixj > col) { if ((col & k) == 0) { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) { + if (dst_row[col] >= ncols || + (dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] > x_row[dst_row[ixj]] : + x_row[dst_row[col]] < x_row[dst_row[ixj]])) + ) { SWAP(dst_row[col], dst_row[ixj]); } } else { - if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) { + if (dst_row[ixj] >= ncols || + (dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ? + x_row[dst_row[col]] < x_row[dst_row[ixj]] : + x_row[dst_row[col]] > x_row[dst_row[ixj]])) + ) { SWAP(dst_row[col], dst_row[ixj]); } } @@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32( threadgroup_barrier(mem_flags::mem_threadgroup); } } + + // copy the result to dst without the padding + if (col < ncols) { + dst[row * ncols + col] = dst_row[col]; + } } -template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; -template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32; +template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32; kernel void kernel_leaky_relu_f32( device const float * src0,