metal : works with ne00 % 4 == 0

This commit is contained in:
Georgi Gerganov 2024-02-08 13:26:50 +02:00
parent e68e32548f
commit 845876d012
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
3 changed files with 31 additions and 18 deletions

View file

@ -4872,8 +4872,6 @@ void kernel_mul_mm2_impl(
for (int i00 = 0; i00 < ne00; i00 += 8*NSH1) {
// load NSG1*NSH1 8x8 blocks of src1 to threadgroup memory
{
threadgroup_barrier(mem_flags::mem_threadgroup);
const int nload = MIN(8*NSG1, ne11 - i11) * (8*NSH1);
const size_t offs0 = im*nb12;
@ -4884,11 +4882,17 @@ void kernel_mul_mm2_impl(
device const float4 * p1 = (device const float4 *)(src1 + offs0 + (i11 + ir)*nb11 + (i00 + ic)*nb10);
//float4 tmp0 = *p1;
//tmp0[0] = 1; tmp0[1] = 1; tmp0[2] = 1; tmp0[3] = 1;
if (i00 + ic + 4 <= ne00) {
s14[(8*NSH1*ir + ic)/4] = *p1;
} else {
for (int k = 0; i00 + ic + k < ne00; k++){
s1[8*NSH1*ir + ic + k] = (*p1)[k];
s14[(8*NSH1*ir + ic)/4] = 0.0f;
for (int k = 0; k < 4; k++){
if (i00 + ic + k < ne00) {
s1[8*NSH1*ir + ic + k] = (*p1)[k];
}
}
}
}
@ -4918,11 +4922,16 @@ void kernel_mul_mm2_impl(
dequantize_func(p0, il, tmp0);
//for (int z = 0; z < 16; z++) {
// tmp0[z/4][z%4] = 1;
//}
if (icc + 16 <= ne00) {
s016[(8*NSH0*ir + ic)/16] = tmp0;
} else {
s016[(8*NSH0*ir + ic)/16] = half4x4(0.0h);
for (int k = 0; k < 4; k++){
if (icc + 4*k <= ne00) {
if (icc + 4*k < ne00) {
s04[(8*NSH0*ir + ic)/4 + k] = tmp0[k];
} else {
for (int p = 0; icc + 4*k + p < ne00; p++) {
@ -4953,9 +4962,9 @@ void kernel_mul_mm2_impl(
}
}
}
}
threadgroup_barrier(mem_flags::mem_threadgroup);
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// write the mr to shared memory