Reverting the diag infinity change

It does work for PP, but somehow it fails for TG.
Need to look more into it.
This commit is contained in:
Iwan Kawrakow 2023-09-08 17:14:04 +02:00
parent 9f353f0536
commit f895d8c774
2 changed files with 8 additions and 8 deletions

View file

@ -838,8 +838,6 @@ void ggml_metal_graph_compute(
case GGML_OP_DIAG_MASK_INF:
{
const int n_past = ((int32_t *)(dst->op_params))[0];
const int64_t n00x01 = ne00*ne01;
assert((n00x01*ne02)%8 == 0);
if (ne00%8 == 0) {
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
@ -849,7 +847,7 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&n00x01 length:sizeof(n00x01) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
if (ne00%8 == 0) {

View file

@ -183,13 +183,15 @@ kernel void kernel_soft_max_4(
}
kernel void kernel_diag_mask_inf(
device const float4 * src0,
device float4 * dst,
device const float * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & n00x01,
constant int64_t & ne01,
constant int & n_past,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i = 2*tpig[0];
const int64_t i02 = tpig[2];
const int64_t i01 = tpig[1];
const int64_t i00 = tpig[0];
if (i00 > n_past + i01) {
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;