metal : trying bs = 512 performance (wip)
This commit is contained in:
parent
e8b00e2941
commit
5a668ea000
3 changed files with 44 additions and 15 deletions
|
@ -4784,10 +4784,10 @@ void kernel_mul_mm_impl(
|
|||
}
|
||||
}
|
||||
|
||||
#define NSG0 1
|
||||
#define NSH0 16
|
||||
#define NSG1 1
|
||||
#define NSH1 64
|
||||
#define NSG0 4
|
||||
#define NSH0 4
|
||||
#define NSG1 2
|
||||
#define NSH1 4
|
||||
|
||||
// each block_q contains 16*nl weights
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
|
@ -4870,6 +4870,8 @@ void kernel_mul_mm2_impl(
|
|||
}
|
||||
|
||||
for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) {
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
// load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory
|
||||
{
|
||||
const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1);
|
||||
|
@ -4896,10 +4898,10 @@ void kernel_mul_mm2_impl(
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
|
||||
for (int b0 = 0; b0 < NSH1/NSH0; ++b0) {
|
||||
// load NSG0*NSH0 8x8 blocks of src0 to threadgroup memory
|
||||
{
|
||||
|
@ -4945,6 +4947,7 @@ void kernel_mul_mm2_impl(
|
|||
|
||||
simdgroup_barrier(mem_flags::mem_none);
|
||||
|
||||
#if 0
|
||||
#pragma unroll(NSH0)
|
||||
for (int k = 0; k < NSH0; ++k) {
|
||||
for (int j = 0; j < NSG0; ++j) {
|
||||
|
@ -4961,9 +4964,22 @@ void kernel_mul_mm2_impl(
|
|||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
#else
|
||||
#pragma unroll(NSH0)
|
||||
for (int k = 0; k < NSH0; ++k) {
|
||||
for (int i = 0; i < NSG1; ++i) {
|
||||
simdgroup_load(m1[i], s1 + (8*i)*(8*NSH1) + 8*NSH0*b0 + 8*k, 8*NSH1, 0, true);
|
||||
}
|
||||
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
for (int j = 0; j < NSG0; ++j) {
|
||||
simdgroup_load(m0[j], s0 + (8*j)*(8*NSH0) + 8*k, 8*NSH0);
|
||||
for (int i = 0; i < NSG1; ++i) {
|
||||
simdgroup_multiply_accumulate(mr[j][i], m0[j], m1[i], mr[j][i]);
|
||||
}
|
||||
}
|
||||
}
|
||||
#endif
|
||||
}
|
||||
}
|
||||
|
||||
// write the mr to shared memory
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue