metal : add support for non-pow-2 argsort
This commit is contained in:
parent
c704c778f6
commit
fe62909618
2 changed files with 45 additions and 20 deletions
11
ggml-metal.m
11
ggml-metal.m
|
@ -2411,6 +2411,13 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
|
|
||||||
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0];
|
||||||
|
|
||||||
|
// bitonic sort requires the number of elements to be power of 2
|
||||||
|
int64_t ne00_padded = 1;
|
||||||
|
while (ne00_padded < ne00) {
|
||||||
|
ne00_padded *= 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int mem_size = ne00_padded*sizeof(int32_t);
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
switch (order) {
|
switch (order) {
|
||||||
|
@ -2423,8 +2430,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
|
||||||
|
[encoder setBytes:&ne00_padded length:sizeof( int64_t) atIndex:3];
|
||||||
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(1, nrows, 1) threadsPerThreadgroup:MTLSizeMake(ne00_padded, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
{
|
{
|
||||||
|
|
|
@ -13,8 +13,8 @@ using namespace metal;
|
||||||
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
|
||||||
|
|
||||||
enum ggml_sort_order {
|
enum ggml_sort_order {
|
||||||
GGML_SORT_ASC,
|
GGML_SORT_ORDER_ASC,
|
||||||
GGML_SORT_DESC,
|
GGML_SORT_ORDER_DESC,
|
||||||
};
|
};
|
||||||
|
|
||||||
// general-purpose kernel for addition, multiplication and division of two tensors
|
// general-purpose kernel for addition, multiplication and division of two tensors
|
||||||
|
@ -1976,6 +1976,8 @@ typedef void (argsort_t)(
|
||||||
device const float * x,
|
device const float * x,
|
||||||
device int32_t * dst,
|
device int32_t * dst,
|
||||||
constant int64_t & ncols,
|
constant int64_t & ncols,
|
||||||
|
constant int64_t & ncols_pad,
|
||||||
|
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]]);
|
uint3 tpitg[[thread_position_in_threadgroup]]);
|
||||||
|
|
||||||
|
@ -1984,33 +1986,42 @@ kernel void kernel_argsort_f32_i32(
|
||||||
device const float * x,
|
device const float * x,
|
||||||
device int32_t * dst,
|
device int32_t * dst,
|
||||||
constant int64_t & ncols,
|
constant int64_t & ncols,
|
||||||
|
constant int64_t & ncols_pad,
|
||||||
|
threadgroup int32_t * shared_values [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
uint3 tpitg[[thread_position_in_threadgroup]]) {
|
||||||
// bitonic sort
|
// bitonic sort
|
||||||
int col = tpitg[0];
|
int col = tpitg[0];
|
||||||
int row = tgpig[1];
|
int row = tgpig[1];
|
||||||
|
|
||||||
if (col >= ncols) return;
|
if (col >= ncols_pad) return;
|
||||||
|
|
||||||
device const float * x_row = x + row * ncols;
|
device const float * x_row = x + row * ncols;
|
||||||
device int32_t * dst_row = dst + row * ncols;
|
threadgroup int32_t * dst_row = shared_values;
|
||||||
|
|
||||||
// initialize indices
|
// initialize indices
|
||||||
if (col < ncols) {
|
|
||||||
dst_row[col] = col;
|
dst_row[col] = col;
|
||||||
}
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
for (int k = 2; k <= ncols; k *= 2) {
|
for (int k = 2; k <= ncols_pad; k *= 2) {
|
||||||
for (int j = k / 2; j > 0; j /= 2) {
|
for (int j = k / 2; j > 0; j /= 2) {
|
||||||
int ixj = col ^ j;
|
int ixj = col ^ j;
|
||||||
if (ixj > col) {
|
if (ixj > col) {
|
||||||
if ((col & k) == 0) {
|
if ((col & k) == 0) {
|
||||||
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] > x_row[dst_row[ixj]] : x_row[dst_row[col]] < x_row[dst_row[ixj]]) {
|
if (dst_row[col] >= ncols ||
|
||||||
|
(dst_row[ixj] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||||
|
x_row[dst_row[col]] > x_row[dst_row[ixj]] :
|
||||||
|
x_row[dst_row[col]] < x_row[dst_row[ixj]]))
|
||||||
|
) {
|
||||||
SWAP(dst_row[col], dst_row[ixj]);
|
SWAP(dst_row[col], dst_row[ixj]);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
if (order == GGML_SORT_ASC ? x_row[dst_row[col]] < x_row[dst_row[ixj]] : x_row[dst_row[col]] > x_row[dst_row[ixj]]) {
|
if (dst_row[ixj] >= ncols ||
|
||||||
|
(dst_row[col] < ncols && (order == GGML_SORT_ORDER_ASC ?
|
||||||
|
x_row[dst_row[col]] < x_row[dst_row[ixj]] :
|
||||||
|
x_row[dst_row[col]] > x_row[dst_row[ixj]]))
|
||||||
|
) {
|
||||||
SWAP(dst_row[col], dst_row[ixj]);
|
SWAP(dst_row[col], dst_row[ixj]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -2018,10 +2029,15 @@ kernel void kernel_argsort_f32_i32(
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// copy the result to dst without the padding
|
||||||
|
if (col < ncols) {
|
||||||
|
dst[row * ncols + col] = dst_row[col];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ASC>;
|
template [[host_name("kernel_argsort_f32_i32_asc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_ASC>;
|
||||||
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_DESC>;
|
template [[host_name("kernel_argsort_f32_i32_desc")]] kernel argsort_t kernel_argsort_f32_i32<GGML_SORT_ORDER_DESC>;
|
||||||
|
|
||||||
kernel void kernel_leaky_relu_f32(
|
kernel void kernel_leaky_relu_f32(
|
||||||
device const float * src0,
|
device const float * src0,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue