metal: add back faster diagonal infinity

This time more carefully
This commit is contained in:
Iwan Kawrakow 2023-09-08 18:07:29 +02:00
parent f895d8c774
commit 9f778877e3

View file

@ -234,6 +234,34 @@ kernel void kernel_diag_mask_inf_8(
}
}
kernel void kernel_diag_mask_inf_8(
device const float4 * src0,
device float4 * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int & n_past,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i = 2*tpig[0];
dst[i+0] = src0[i+0];
dst[i+1] = src0[i+1];
int64_t i4 = 4*i;
const int64_t i02 = i4/(ne00*ne01);
i4 -= i02*ne00*ne01;
const int64_t i01 = i4/ne00;
const int64_t i00 = i4 - i01*ne00;
for (int k = 3; k >= 0; --k) {
if (i00 + 4 + k <= n_past + i01) {
break;
}
dst[i+1][k] = -INFINITY;
if (i00 + k > n_past + i01) {
dst[i][k] = -INFINITY;
}
}
}
kernel void kernel_norm(
device const void * src0,
device float * dst,