metal : trying bs = 512 performance (wip)

This commit is contained in:
Georgi Gerganov 2024-02-12 19:21:57 +02:00
parent e8b00e2941
commit 5a668ea000
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 44 additions and 15 deletions

View file

@ -4784,10 +4784,10 @@ void kernel_mul_mm_impl(
}
}
#define NSG0 1
#define NSH0 16
#define NSG1 1
#define NSH1 64
#define NSG0 4
#define NSH0 4
#define NSG1 2
#define NSH1 4
// each block_q contains 16*nl weights
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
@ -4870,6 +4870,8 @@ void kernel_mul_mm2_impl(
}
for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) {
threadgroup_barrier(mem_flags::mem_threadgroup);
// load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory
{
const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1);
@ -4896,10 +4898,10 @@ void kernel_mul_mm2_impl(
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int b0 = 0; b0 < NSH1/NSH0; ++b0) {
// load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory
{
@ -4945,6 +4947,7 @@ void kernel_mul_mm2_impl(
simdgroup_barrier(mem_flags::mem_none);
#if 0
#pragma unroll(NSH0)
for (int k = 0; k < NSH0; ++k) {
for (int j = 0; j < NSG0; ++j) {
@ -4961,9 +4964,22 @@ void kernel_mul_mm2_impl(
}
}
}
}
#else
#pragma unroll(NSH0)
for (int k = 0; k < NSH0; ++k) {
for (int i = 0; i < NSG1; ++i) {
simdgroup_load(m1[i], s1 + (8*i)*(8*NSH1) + 8*NSH0*b0 + 8*k, 8*NSH1, 0, true);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
for (int j = 0; j < NSG0; ++j) {
simdgroup_load(m0[j], s0 + (8*j)*(8*NSH0) + 8*k, 8*NSH0);
for (int i = 0; i < NSG1; ++i) {
simdgroup_multiply_accumulate(mr[j][i], m0[j], m1[i], mr[j][i]);
}
}
}
#endif
}
}
// write the mr to shared memory