metal : clean-up
This commit is contained in:
parent
703c6e6528
commit
97eaece7d6
1 changed files with 178 additions and 175 deletions
13
ggml-metal.m
13
ggml-metal.m
|
@ -2560,7 +2560,9 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
|
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
if (ne01 > 1 || (ne00%128 != 0)) {
|
bool use_vec_kernel = false;
|
||||||
|
|
||||||
|
if (ne01 >= 4 || (ne00%128 != 0)) {
|
||||||
switch (ne00) {
|
switch (ne00) {
|
||||||
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||||
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||||
|
@ -2576,6 +2578,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
use_vec_kernel = true;
|
||||||
|
|
||||||
switch (ne00) {
|
switch (ne00) {
|
||||||
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128].pipeline; break;
|
||||||
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256].pipeline; break;
|
||||||
|
@ -2588,7 +2592,6 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: extend if necessary
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
@ -2619,8 +2622,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||||
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||||
|
|
||||||
|
if (!use_vec_kernel) {
|
||||||
// half8x8 kernel
|
// half8x8 kernel
|
||||||
if (ne01 > 1 || (ne00%128 != 0)) {
|
|
||||||
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||||
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||||
|
|
||||||
|
@ -2635,7 +2638,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
||||||
|
|
||||||
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||||
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
[encoder setThreadgroupMemoryLength:GGML_PAD(smem, 16) atIndex:0];
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
} else {
|
} else {
|
||||||
|
@ -2926,7 +2929,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_metal_log_allocated_size(device, size_aligned);
|
//ggml_backend_metal_log_allocated_size(device, size_aligned);
|
||||||
|
|
||||||
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue