fix more formatting and enhance readability
Signed-off-by: Junhee Yoo <junhee.yoo@navercorp.com>
This commit is contained in:
parent
746e79e9a5
commit
3c2b87df4b
1 changed files with 21 additions and 19 deletions
|
@ -2585,7 +2585,6 @@ static void ggml_metal_encode_node(
|
|||
const int32_t ofs1 = src1->nb[is_2D ? 2 : 1] / 4;
|
||||
|
||||
id<MTLComputePipelineState> pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_IM2COL_F32].pipeline;
|
||||
const uint64_t M = pipeline.maxTotalThreadsPerThreadgroup;
|
||||
|
||||
const bool is_gt_mttpt = ((size_t)(N * KH * KW)) > pipeline.maxTotalThreadsPerThreadgroup;
|
||||
|
||||
|
@ -2625,8 +2624,11 @@ static void ggml_metal_encode_node(
|
|||
[encoder setBytes:&KH length:sizeof(int32_t) atIndex:14];
|
||||
[encoder setBytes:&KW length:sizeof(int32_t) atIndex:15];
|
||||
|
||||
const int64_t D = N / M + (N % M > 0 ? 1 : 0);
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(D * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(MIN((uint64_t)N, M), 1, 1)];
|
||||
const uint64_t n_threads = MIN(pipeline.maxTotalThreadsPerThreadgroup, (uint64_t)N);
|
||||
|
||||
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(quotient * CHW, OH, OW) threadsPerThreadgroup:MTLSizeMake(n_threads, 1, 1)];
|
||||
} else {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(IC, OH, OW) threadsPerThreadgroup:MTLSizeMake(N, KH, KW)];
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue