fix more formatting and enhance readability
Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
This commit is contained in:
parent
746e79e9a5
commit
3c2b87df4b
1 changed files with 21 additions and 19 deletions
|
@ -2585,7 +2585,6 @@ static void ggml_metal_encode_node(
|
|||
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
|
||||
const uint64_t M = pipeline.maxTotalThreadsPerThreadgroup;
|
||||
|
||||
const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
|
||||
|
||||
|
@ -2606,27 +2605,30 @@ static void ggml_metal_encode_node(
|
|||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ofs0 length:sizeof( int32_t) atIndex:2];
|
||||
[encoder setBytes:&ofs1 length:sizeof( int32_t) atIndex:3];
|
||||
[encoder setBytes:&IW length:sizeof( int32_t) atIndex:4];
|
||||
[encoder setBytes:&IH length:sizeof( int32_t) atIndex:5];
|
||||
[encoder setBytes:&CHW length:sizeof( int32_t) atIndex:6];
|
||||
[encoder setBytes:&s0 length:sizeof( int32_t) atIndex:7];
|
||||
[encoder setBytes:&s1 length:sizeof( int32_t) atIndex:8];
|
||||
[encoder setBytes:&p0 length:sizeof( int32_t) atIndex:9];
|
||||
[encoder setBytes:&p1 length:sizeof( int32_t) atIndex:10];
|
||||
[encoder setBytes:&d0 length:sizeof( int32_t) atIndex:11];
|
||||
[encoder setBytes:&d1 length:sizeof( int32_t) atIndex:12];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:0];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||
[encoder setBytes:&ofs0 length:sizeof(int32_t) atIndex:2];
|
||||
[encoder setBytes:&ofs1 length:sizeof(int32_t) atIndex:3];
|
||||
[encoder setBytes:&IW length:sizeof(int32_t) atIndex:4];
|
||||
[encoder setBytes:&IH length:sizeof(int32_t) atIndex:5];
|
||||
[encoder setBytes:&CHW length:sizeof(int32_t) atIndex:6];
|
||||
[encoder setBytes:&s0 length:sizeof(int32_t) atIndex:7];
|
||||
[encoder setBytes:&s1 length:sizeof(int32_t) atIndex:8];
|
||||
[encoder setBytes:&p0 length:sizeof(int32_t) atIndex:9];
|
||||
[encoder setBytes:&p1 length:sizeof(int32_t) atIndex:10];
|
||||
[encoder setBytes:&d0 length:sizeof(int32_t) atIndex:11];
|
||||
[encoder setBytes:&d1 length:sizeof(int32_t) atIndex:12];
|
||||
|
||||
if (is_gt_mttpt) {
|
||||
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
|
||||
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
|
||||
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
|
||||
[encoder setBytes:&N length:sizeof(int32_t) atIndex:13];
|
||||
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
|
||||
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
|
||||
|
||||
const int64_t D = N / M + (N % M > 0 ? 1 : 0);
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(D * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(MIN((uint64_t)N, M), 1, 1)];
|
||||
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
|
||||
|
||||
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
||||
} else {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue