wipwipwiwpip
This commit is contained in:
parent
fc59407efe
commit
ddc59e8e0a
4 changed files with 132 additions and 1 deletions
56
ggml-metal.m
56
ggml-metal.m
|
@ -187,6 +187,7 @@ enum ggml_metal_kernel_type {
|
|||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H128,
|
||||
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H256,
|
||||
GGML_METAL_KERNEL_TYPE_SSM_CONV_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||
|
@ -771,6 +772,8 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
|||
return true;
|
||||
case GGML_OP_FLASH_ATTN_EXT:
|
||||
return ctx->support_simdgroup_mm; // TODO: over-restricted for vec-kernels
|
||||
case GGML_OP_SSM_CONV:
|
||||
return true;
|
||||
case GGML_OP_MUL_MAT:
|
||||
case GGML_OP_MUL_MAT_ID:
|
||||
return ctx->support_simdgroup_reduction &&
|
||||
|
@ -968,6 +971,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
// GGML_METAL_LOG_INFO("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
|
||||
// ggml_is_contiguous(src1), src1->name);
|
||||
//}
|
||||
//if (src2) {
|
||||
// GGML_METAL_LOG_INFO("%s: src2 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne20, ne21, ne22,
|
||||
// ggml_is_contiguous(src2), src2->name);
|
||||
//}
|
||||
//if (dst) {
|
||||
// GGML_METAL_LOG_INFO("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
|
||||
// dst->name);
|
||||
|
@ -2688,6 +2695,55 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||
}
|
||||
} break;
|
||||
case GGML_OP_SSM_CONV:
|
||||
{
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
//pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline;
|
||||
|
||||
//[encoder setComputePipelineState:pipeline];
|
||||
//[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
//[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
//[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
//[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
//[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
//[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
|
||||
//[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
|
||||
//[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5];
|
||||
//[encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6];
|
||||
//[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7];
|
||||
//[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8];
|
||||
//[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9];
|
||||
//[encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10];
|
||||
//[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
||||
//[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12];
|
||||
//[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
||||
//[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
||||
//[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
||||
//[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
||||
//[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
||||
//[encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18];
|
||||
//[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19];
|
||||
//[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20];
|
||||
//[encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21];
|
||||
//[encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22];
|
||||
//[encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23];
|
||||
//[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24];
|
||||
//[encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25];
|
||||
//[encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26];
|
||||
//[encoder setBytes:&offs length:sizeof(offs) atIndex:27];
|
||||
//[encoder setBytes:&nb length:sizeof(nb) atIndex:28];
|
||||
|
||||
//if (bcast_row) {
|
||||
// const int64_t n = ggml_nelements(dst)/4;
|
||||
|
||||
// [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||
//} else {
|
||||
// const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0);
|
||||
|
||||
// [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
//}
|
||||
} break;
|
||||
case GGML_OP_DUP:
|
||||
case GGML_OP_CPY:
|
||||
case GGML_OP_CONT:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue