diff --git a/ggml-metal.metal b/ggml-metal.metal index ff289908a..c40a71a6b 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -234,6 +234,34 @@ kernel void kernel_diag_mask_inf_8( } } +kernel void kernel_diag_mask_inf_8( + device const float4 * src0, + device float4 * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int & n_past, + uint3 tpig[[thread_position_in_grid]]) { + + const int64_t i = 2*tpig[0]; + + dst[i+0] = src0[i+0]; + dst[i+1] = src0[i+1]; + int64_t i4 = 4*i; + const int64_t i02 = i4/(ne00*ne01); + i4 -= i02*ne00*ne01; + const int64_t i01 = i4/ne00; + const int64_t i00 = i4 - i01*ne00; + for (int k = 3; k >= 0; --k) { + if (i00 + 4 + k <= n_past + i01) { + break; + } + dst[i+1][k] = -INFINITY; + if (i00 + k > n_past + i01) { + dst[i][k] = -INFINITY; + } + } +} + kernel void kernel_norm( device const void * src0, device float * dst,