12% faster PP for Falcon
This commit is contained in:
parent
9f778877e3
commit
d90b5981d0
2 changed files with 120 additions and 30 deletions
31
ggml-metal.m
31
ggml-metal.m
|
@ -902,22 +902,28 @@ void ggml_metal_graph_compute(
|
||||||
} else {
|
} else {
|
||||||
int nth0 = 32;
|
int nth0 = 32;
|
||||||
int nth1 = 1;
|
int nth1 = 1;
|
||||||
int nrows = 1;
|
//int nrows = 1;
|
||||||
|
int nx = 1, ny = 1;
|
||||||
|
|
||||||
// use custom matrix x vector kernel
|
// use custom matrix x vector kernel
|
||||||
switch (src0t) {
|
switch (src0t) {
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
{
|
{
|
||||||
nth0 = 32;
|
//[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||||
nth1 = 1;
|
//nth0 = 32;
|
||||||
if (ne11 * ne12 < 4) {
|
//nth1 = 1;
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
//if (ne11 * ne12 < 4) {
|
||||||
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
// [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
|
||||||
|
if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
|
||||||
nrows = ne11;
|
nx = ne01;
|
||||||
|
ny = 1;
|
||||||
|
nth0 = 32;
|
||||||
} else {
|
} else {
|
||||||
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
|
||||||
nrows = 4;
|
nth0 = ne01 >= 32 ? 32 : ne01 >= 16 ? 16 : ne01 >= 8 ? 8 : ne01 >= 4 ? 4 : ne01 >= 2 ? 2 : 1;
|
||||||
|
nx = (ne01 + nth0 - 1)/nth0;
|
||||||
|
ny = ne11;
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
|
@ -1038,8 +1044,13 @@ void ggml_metal_graph_compute(
|
||||||
else if (src0t == GGML_TYPE_Q6_K) {
|
else if (src0t == GGML_TYPE_Q6_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
int64_t ny = (ne11 + nrows - 1)/nrows;
|
////printf("f16xf32: %d x %d x %d, %d x %d x %d -> %d\n",(int)ne00,(int)ne01,(int)ne02,
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
//// (int)ne10,(int)ne11,(int)ne12,nrows);
|
||||||
|
//int64_t ny = (ne11 + nrows - 1)/nrows;
|
||||||
|
//[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
|
//[encoder dispatchThreadgroups:MTLSizeMake(ne10*ne11*ne12, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
|
//int n = ne01 >= 32 ? 32 : ne01 >= 16 ? 16 : ne01 >= 8 ? 8 : ne01 >= 4 ? 4 : ne01 >= 2 ? 2 : 1;
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake(nx, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, 1, 1)];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
|
119
ggml-metal.metal
119
ggml-metal.metal
|
@ -610,9 +610,68 @@ kernel void kernel_mul_mat_f16_f32_1row(
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#define N_F16_F32 4
|
#define N_F16_F32 8
|
||||||
|
#
|
||||||
kernel void kernel_mul_mat_f16_f32(
|
kernel void kernel_mul_mat_f16_f32(
|
||||||
|
device const char * src0,
|
||||||
|
device const char * src1,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 ntptg[[threads_per_threadgroup]],
|
||||||
|
uint tiitg[[thread_index_in_threadgroup]]) {
|
||||||
|
|
||||||
|
// :MTLSizeMake(ne01, ne11, ne12)
|
||||||
|
const int64_t r0 = tgpig.x * ntptg.x + tiitg;
|
||||||
|
if (r0 >= ne0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
const int64_t r1 = tgpig.y;
|
||||||
|
const int64_t im = tgpig.z;
|
||||||
|
|
||||||
|
//const int64_t im = tiitg/(ne10*ne11);
|
||||||
|
//const int64_t r1 = (tiitg - im*ne10*ne11)/ne10;
|
||||||
|
//const int64_t r0 = tiitg - im*ne10*ne11 - r1*ne10;
|
||||||
|
|
||||||
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||||
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
|
||||||
|
if (ne00 < 16) {
|
||||||
|
float sumf = 0;
|
||||||
|
for (int i = 0; i < ne00; ++i) {
|
||||||
|
sumf += (float) x[i] * (float) y[i];
|
||||||
|
}
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = sumf;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
float sumf = 0;
|
||||||
|
device const half4 * x4 = (device const half4 *) x;
|
||||||
|
device const float4 * y4 = (device const float4 *) y;
|
||||||
|
for (int i = 0; i < ne00/4; ++i) {
|
||||||
|
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
||||||
|
}
|
||||||
|
for (int i = 4*(ne00/4); i < ne00; ++i) {
|
||||||
|
sumf += (float) x[i] * y[i];
|
||||||
|
}
|
||||||
|
dst[im*ne1*ne0 + r1*ne0 + r0] = sumf;
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kernel void kernel_mul_mat_f16_f32_old(
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
|
@ -639,7 +698,7 @@ kernel void kernel_mul_mat_f16_f32(
|
||||||
|
|
||||||
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
device const half * x = (device const half *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
|
||||||
|
|
||||||
if (ne00 < 128) {
|
if (ne00 < 64) { //128) {
|
||||||
for (int row = 0; row < N_F16_F32; ++row) {
|
for (int row = 0; row < N_F16_F32; ++row) {
|
||||||
int r1 = rb + row;
|
int r1 = rb + row;
|
||||||
if (r1 >= ne11) {
|
if (r1 >= ne11) {
|
||||||
|
@ -659,27 +718,47 @@ kernel void kernel_mul_mat_f16_f32(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
const int ix = tiisg/N_F16_F32;
|
||||||
|
const int iy = tiisg%N_F16_F32;
|
||||||
|
const int r1 = rb + iy < ne11 ? rb + iy : ne11-1;
|
||||||
|
float sumf[N_F16_F32] = {0.f};
|
||||||
device const half4 * x4 = (device const half4 *)x;
|
device const half4 * x4 = (device const half4 *)x;
|
||||||
|
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
device const float4 * y4 = (device const float4 *) y;
|
||||||
|
for (int i = ix; i < ne00/4; i += 32/N_F16_F32) {
|
||||||
|
for (int k = 0; k < 4; ++k) sumf[iy] += (float) x4[i][k] * y4[i][k];
|
||||||
|
}
|
||||||
|
for (int i = 4*(ne00/4)+ix; i < ne00; i += 32/N_F16_F32) {
|
||||||
|
sumf[iy] += (float) x[i] * y[i];
|
||||||
|
}
|
||||||
for (int row = 0; row < N_F16_F32; ++row) {
|
for (int row = 0; row < N_F16_F32; ++row) {
|
||||||
int r1 = rb + row;
|
float all_sum = simd_sum(sumf[row]);
|
||||||
if (r1 >= ne11) {
|
if (tiisg == 0 && rb + row < ne11) {
|
||||||
break;
|
dst[im*ne1*ne0 + (rb + row)*ne0 + r0] = all_sum;
|
||||||
}
|
|
||||||
|
|
||||||
device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
|
||||||
device const float4 * y4 = (device const float4 *) y;
|
|
||||||
|
|
||||||
float sumf = 0;
|
|
||||||
for (int i = tiisg; i < ne00/4; i += 32) {
|
|
||||||
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
|
||||||
}
|
|
||||||
|
|
||||||
float all_sum = simd_sum(sumf);
|
|
||||||
if (tiisg == 0) {
|
|
||||||
for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
|
||||||
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//device const half4 * x4 = (device const half4 *)x;
|
||||||
|
//for (int row = 0; row < N_F16_F32; ++row) {
|
||||||
|
// int r1 = rb + row;
|
||||||
|
// if (r1 >= ne11) {
|
||||||
|
// break;
|
||||||
|
// }
|
||||||
|
|
||||||
|
// device const float * y = (device const float *) (src1 + r1*nb11 + im*nb12);
|
||||||
|
// device const float4 * y4 = (device const float4 *) y;
|
||||||
|
|
||||||
|
// float sumf = 0;
|
||||||
|
// for (int i = tiisg; i < ne00/4; i += 32) {
|
||||||
|
// for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
|
||||||
|
// }
|
||||||
|
|
||||||
|
// float all_sum = simd_sum(sumf);
|
||||||
|
// if (tiisg == 0) {
|
||||||
|
// for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) x[i] * y[i];
|
||||||
|
// dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
|
||||||
|
// }
|
||||||
|
//}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue