12% faster PP for Falcon

This commit is contained in:
Iwan Kawrakow 2023-09-10 15:26:06 +02:00
parent 9f778877e3
commit d90b5981d0
2 changed files with 120 additions and 30 deletions

View file

@ -902,22 +902,28 @@ void ggml_metal_graph_compute(
} else { } else {
int nth0 = 32; int nth0 = 32;
int nth1 = 1; int nth1 = 1;
int nrows = 1; //int nrows = 1;
int nx = 1, ny = 1;
// use custom matrix x vector kernel // use custom matrix x vector kernel
switch (src0t) { switch (src0t) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
nth0 = 32; //[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
nth1 = 1; //nth0 = 32;
if (ne11 * ne12 < 4) { //nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; //if (ne11 * ne12 < 4) {
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) { // [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4]; [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
nrows = ne11; nx = ne01;
ny = 1;
nth0 = 32;
} else { } else {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
nrows = 4; nth0 = ne01 >= 32 ? 32 : ne01 >= 16 ? 16 : ne01 >= 8 ? 8 : ne01 >= 4 ? 4 : ne01 >= 2 ? 2 : 1;
nx = (ne01 + nth0 - 1)/nth0;
ny = ne11;
} }
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
@ -1038,8 +1044,13 @@ void ggml_metal_graph_compute(
else if (src0t == GGML_TYPE_Q6_K) { else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else { } else {
int64_t ny = (ne11 + nrows - 1)/nrows; ////printf("f16xf32: %d x %d x %d, %d x %d x %d -> %d\n",(int)ne00,(int)ne01,(int)ne02,
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; //// (int)ne10,(int)ne11,(int)ne12,nrows);
//int64_t ny = (ne11 + nrows - 1)/nrows;
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
//[encoder dispatchThreadgroups:MTLSizeMake(ne10*ne11*ne12, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
//int n = ne01 >= 32 ? 32 : ne01 >= 16 ? 16 : ne01 >= 8 ? 8 : ne01 >= 4 ? 4 : ne01 >= 2 ? 2 : 1;
[encoder dispatchThreadgroups:MTLSizeMake(nx, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, 1, 1)];
} }
} }
} break; } break;

View file

@ -610,9 +610,68 @@ kernel void kernel_mul_mat_f16_f32_1row(
} }
#define N_F16_F32 4 #define N_F16_F32 8
#
kernel void kernel_mul_mat_f16_f32( kernel void kernel_mul_mat_f16_f32(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 ntptg[[threads_per_threadgroup]],
uint tiitg[[thread_index_in_threadgroup]]) {
// :MTLSizeMake(ne01, ne11, ne12)
const int64_t r0 = tgpig.x * ntptg.x + tiitg;
if (r0 >= ne0) {
return;
}
const int64_t r1 = tgpig.y;
const int64_t im = tgpig.z;
//const int64_t im = tiitg/(ne10*ne11);
//const int64_t r1 = (tiitg - im*ne10*ne11)/ne10;
//const int64_t r0 = tiitg - im*ne10*ne11 - r1*ne10;
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
if (ne00 < 16) {
float sumf = 0;
for (int i = 0; i < ne00; ++i) {
sumf += (float) x[i] * (float) y[i];
}
dst[im*ne1*ne0 + r1*ne0 + r0] = sumf;
}
else {
float sumf = 0;
device const half4 * x4 = (device const half4 *) x;
device const float4 * y4 = (device const float4 *) y;
for (int i = 0; i < ne00/4; ++i) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
}
for (int i = 4*(ne00/4); i < ne00; ++i) {
sumf += (float) x[i] * y[i];
}
dst[im*ne1*ne0 + r1*ne0 + r0] = sumf;
}
}
kernel void kernel_mul_mat_f16_f32_old(
device const char * src0, device const char * src0,
device const char * src1, device const char * src1,
device float * dst, device float * dst,
@ -639,7 +698,7 @@ kernel void kernel_mul_mat_f16_f32(
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02); device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
if (ne00 < 128) { if (ne00 < 64) { //128) {
for (int row = 0; row < N_F16_F32; ++row) { for (int row = 0; row < N_F16_F32; ++row) {
int r1 = rb + row; int r1 = rb + row;
if (r1 >= ne11) { if (r1 >= ne11) {
@ -659,27 +718,47 @@ kernel void kernel_mul_mat_f16_f32(
} }
} }
} else { } else {
const int ix = tiisg/N_F16_F32;
const int iy = tiisg%N_F16_F32;
const int r1 = rb + iy < ne11 ? rb + iy : ne11-1;
float sumf[N_F16_F32] = {0.f};
device const half4 * x4 = (device const half4 *)x; device const half4 * x4 = (device const half4 *)x;
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
device const float4 * y4 = (device const float4 *) y;
for (int i = ix; i < ne00/4; i += 32/N_F16_F32) {
for (int k = 0; k < 4; ++k) sumf[iy] += (float) x4[i][k] * y4[i][k];
}
for (int i = 4*(ne00/4)+ix; i < ne00; i += 32/N_F16_F32) {
sumf[iy] += (float) x[i] * y[i];
}
for (int row = 0; row < N_F16_F32; ++row) { for (int row = 0; row < N_F16_F32; ++row) {
int r1 = rb + row; float all_sum = simd_sum(sumf[row]);
if (r1 >= ne11) { if (tiisg == 0 && rb + row < ne11) {
break; dst[im*ne1*ne0 + (rb + row)*ne0 + r0] = all_sum;
}
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
device const float4 * y4 = (device const float4 *) y;
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
} }
} }
//device const half4 * x4 = (device const half4 *)x;
//for (int row = 0; row < N_F16_F32; ++row) {
// int r1 = rb + row;
// if (r1 >= ne11) {
// break;
// }
// device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
// device const float4 * y4 = (device const float4 *) y;
// float sumf = 0;
// for (int i = tiisg; i < ne00/4; i += 32) {
// for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
// }
// float all_sum = simd_sum(sumf);
// if (tiisg == 0) {
// for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
// dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
// }
//}
} }
} }