mtl : working mul_mat q4

This commit is contained in:
Georgi Gerganov 2023-05-30 23:06:49 +03:00
parent 29bec00ba0
commit b2fd06c6aa
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 9 additions and 6 deletions

View file

@ -389,11 +389,12 @@ int llama_mtl_eval(
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:4]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:5]; [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:5];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:6]; [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:6];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:7]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:7];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:8];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ne00, ne01, ne10, ne11, ne0, ne1); printf("mul_mat: %lldx%lld * %lldx%lld -> %lldx%lld\n", ne00, ne01, ne10, ne11, ne0, ne1);
@ -446,7 +447,7 @@ int llama_mtl_eval(
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4]; [encoder setBytes:&eps length:sizeof( float) atIndex:4];
const int64_t nrows = ggml_nrows(gf->nodes[i]->src0); const int64_t nrows = ggml_nrows(gf->nodes[i]->src0);

View file

@ -127,6 +127,7 @@ kernel void kernel_mul_mat_q4_0(
constant int64_t & ne11, constant int64_t & ne11,
constant int64_t & ne0, constant int64_t & ne0,
constant int64_t & ne1, constant int64_t & ne1,
threadgroup float * sum [[threadgroup(0)]],
uint2 tgpig[[threadgroup_position_in_grid]], uint2 tgpig[[threadgroup_position_in_grid]],
uint2 tpig[[thread_position_in_grid]], uint2 tpig[[thread_position_in_grid]],
uint2 tpitg[[thread_position_in_threadgroup]], uint2 tpitg[[thread_position_in_threadgroup]],
@ -140,10 +141,9 @@ kernel void kernel_mul_mat_q4_0(
device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb; device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
device const float * y = (device const float *) src1 + r1*ne10; device const float * y = (device const float *) src1 + r1*ne10;
threadgroup float sum[32]; // TODO: should be equal to threadgroup size
sum[tpitg.x] = 0.0f; sum[tpitg.x] = 0.0f;
for (int i = 0; i < nb; i += tptg.x) { for (int i = tpitg.x; i < nb; i += tptg.x) {
//device const uint4 * x0p = (device const uint4 *) (x + i)->qs; //device const uint4 * x0p = (device const uint4 *) (x + i)->qs;
//device const float4 * y0p = (device const float4 *) (y + i*qk); //device const float4 * y0p = (device const float4 *) (y + i*qk);
@ -206,5 +206,7 @@ kernel void kernel_mul_mat_q4_0(
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
} }
dst[r1*ne0 + r0] = sum[0]; if (tpitg.x == 0) {
dst[r1*ne0 + r0] = sum[0];
}
} }