metal: faster diagonal infinity
Although, to me it looks like one should simply fuse scale + diagnonal infinity + soft_max on the KQtensor.
This commit is contained in:
parent
14148ce3f5
commit
c9e38057f0
2 changed files with 16 additions and 9 deletions
|
@ -838,6 +838,8 @@ void ggml_metal_graph_compute(
|
||||||
case GGML_OP_DIAG_MASK_INF:
|
case GGML_OP_DIAG_MASK_INF:
|
||||||
{
|
{
|
||||||
const int n_past = ((int32_t *)(dst->op_params))[0];
|
const int n_past = ((int32_t *)(dst->op_params))[0];
|
||||||
|
const int64_t n00x01 = ne00*ne01;
|
||||||
|
assert((n00x01*ne02)%8 == 0);
|
||||||
|
|
||||||
if (ne00%8 == 0) {
|
if (ne00%8 == 0) {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
|
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
|
||||||
|
@ -847,7 +849,7 @@ void ggml_metal_graph_compute(
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
|
||||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
|
||||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
|
[encoder setBytes:&n00x01 length:sizeof(n00x01) atIndex:3];
|
||||||
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
|
||||||
|
|
||||||
if (ne00%8 == 0) {
|
if (ne00%8 == 0) {
|
||||||
|
|
|
@ -183,15 +183,13 @@ kernel void kernel_soft_max_4(
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_diag_mask_inf(
|
kernel void kernel_diag_mask_inf(
|
||||||
device const float * src0,
|
device const float4 * src0,
|
||||||
device float * dst,
|
device float4 * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & n00x01,
|
||||||
constant int & n_past,
|
constant int & n_past,
|
||||||
uint3 tpig[[thread_position_in_grid]]) {
|
uint3 tpig[[thread_position_in_grid]]) {
|
||||||
const int64_t i02 = tpig[2];
|
const int64_t i = 2*tpig[0];
|
||||||
const int64_t i01 = tpig[1];
|
|
||||||
const int64_t i00 = tpig[0];
|
|
||||||
|
|
||||||
if (i00 > n_past + i01) {
|
if (i00 > n_past + i01) {
|
||||||
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
dst[i02*ne01*ne00 + i01*ne00 + i00] = -INFINITY;
|
||||||
|
@ -216,6 +214,13 @@ kernel void kernel_diag_mask_inf_8(
|
||||||
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
|
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
|
||||||
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
|
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
|
||||||
const int64_t i00 = i4;
|
const int64_t i00 = i4;
|
||||||
|
dst[i+0] = src0[i+0];
|
||||||
|
dst[i+1] = src0[i+1];
|
||||||
|
int64_t i4 = 4*i;
|
||||||
|
const int64_t i02 = i4/n00x01;
|
||||||
|
i4 -= i02*n00x01;
|
||||||
|
const int64_t i01 = i4/ne00;
|
||||||
|
const int64_t i00 = i4 - i01*ne00;
|
||||||
for (int k = 3; k >= 0; --k) {
|
for (int k = 3; k >= 0; --k) {
|
||||||
if (i00 + 4 + k <= n_past + i01) {
|
if (i00 + 4 + k <= n_past + i01) {
|
||||||
break;
|
break;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue