diff --git a/ggml-metal.m b/ggml-metal.m index da0fda8a1..cc4a236d5 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -65,6 +65,7 @@ struct ggml_metal_context { GGML_METAL_DECL_KERNEL(soft_max); GGML_METAL_DECL_KERNEL(soft_max_4); GGML_METAL_DECL_KERNEL(diag_mask_inf); + GGML_METAL_DECL_KERNEL(diag_mask_inf_8); GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_1); @@ -211,6 +212,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(soft_max); GGML_METAL_ADD_KERNEL(soft_max_4); GGML_METAL_ADD_KERNEL(diag_mask_inf); + GGML_METAL_ADD_KERNEL(diag_mask_inf_8); GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_1); @@ -278,7 +280,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) { GGML_METAL_DEL_KERNEL(gelu); GGML_METAL_DEL_KERNEL(soft_max); GGML_METAL_DEL_KERNEL(soft_max_4); - GGML_METAL_DEL_KERNEL(diag_mask_inf); + GGML_METAL_DEL_KERNEL(diag_mask_inf_8); GGML_METAL_DEL_KERNEL(get_rows_f16); GGML_METAL_DEL_KERNEL(get_rows_q4_0); GGML_METAL_DEL_KERNEL(get_rows_q4_1); @@ -819,14 +821,23 @@ void ggml_metal_graph_compute( { const int n_past = ((int32_t *)(dst->op_params))[0]; - [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; + if (ne00%8 == 0) { + [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8]; + } else { + [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; + } [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&n_past length:sizeof(int) atIndex:4]; - [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + if (ne00%8 == 0) { + [encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } + else { + [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; + } } break; case GGML_OP_MUL_MAT: { diff --git a/ggml-metal.metal b/ggml-metal.metal index 7c9e03ee1..c42320996 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -200,6 +200,34 @@ kernel void kernel_diag_mask_inf( } } +kernel void kernel_diag_mask_inf_8( + device const float4 * src0, + device float4 * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + + const int64_t i = 2*tpig[0]; + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int64_t i4 = 4*i; + const int64_t i02 = i4/(ne00*ne01); + i4 -= i02*ne00*ne01; + const int64_t i01 = i4/ne00; + const int64_t i00 = i4 - i01*ne00; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + dst[i+1][k] = -INFINITY; + if (i00 + k > n_past + i01) { + dst[i][k] = -INFINITY; + } + } +} + kernel void kernel_norm( device const void * src0, device float * dst,