metal: add back faster diagonal infinity

This time more carefully
This commit is contained in:
Iwan Kawrakow 2023-09-08 18:07:29 +02:00
parent 4fc615e827
commit 7331d1e0b4
2 changed files with 42 additions and 3 deletions

View file

@ -65,6 +65,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(soft_max); GGML_METAL_DECL_KERNEL(soft_max);
GGML_METAL_DECL_KERNEL(soft_max_4); GGML_METAL_DECL_KERNEL(soft_max_4);
GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1); GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@ -211,6 +212,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(soft_max); GGML_METAL_ADD_KERNEL(soft_max);
GGML_METAL_ADD_KERNEL(soft_max_4); GGML_METAL_ADD_KERNEL(soft_max_4);
GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1); GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@ -278,7 +280,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(gelu); GGML_METAL_DEL_KERNEL(gelu);
GGML_METAL_DEL_KERNEL(soft_max); GGML_METAL_DEL_KERNEL(soft_max);
GGML_METAL_DEL_KERNEL(soft_max_4); GGML_METAL_DEL_KERNEL(soft_max_4);
GGML_METAL_DEL_KERNEL(diag_mask_inf); GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
GGML_METAL_DEL_KERNEL(get_rows_f16); GGML_METAL_DEL_KERNEL(get_rows_f16);
GGML_METAL_DEL_KERNEL(get_rows_q4_0); GGML_METAL_DEL_KERNEL(get_rows_q4_0);
GGML_METAL_DEL_KERNEL(get_rows_q4_1); GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@ -819,14 +821,23 @@ void ggml_metal_graph_compute(
{ {
const int n_past = ((int32_t *)(dst->op_params))[0]; const int n_past = ((int32_t *)(dst->op_params))[0];
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; if (ne00%8 == 0) {
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
} else {
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&n_past length:sizeof(int) atIndex:4]; [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; if (ne00%8 == 0) {
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
}
else {
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
}
} break; } break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {

View file

@ -200,6 +200,34 @@ kernel void kernel_diag_mask_inf(
} }
} }
kernel void kernel_diag_mask_inf_8(
device const float4 * src0,
device float4 * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int & n_past,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i = 2*tpig[0];
dst[i+0] = src0[i+0];
dst[i+1] = src0[i+1];
int64_t i4 = 4*i;
const int64_t i02 = i4/(ne00*ne01);
i4 -= i02*ne00*ne01;
const int64_t i01 = i4/ne00;
const int64_t i00 = i4 - i01*ne00;
for (int k = 3; k >= 0; --k) {
if (i00 + 4 + k <= n_past + i01) {
break;
}
dst[i+1][k] = -INFINITY;
if (i00 + k > n_past + i01) {
dst[i][k] = -INFINITY;
}
}
}
kernel void kernel_norm( kernel void kernel_norm(
device const void * src0, device const void * src0,
device float * dst, device float * dst,