mtl : fix kernel signature + roll inner loop

This commit is contained in:
Georgi Gerganov 2023-06-02 19:11:39 +03:00
parent 847bbfe9e6
commit 70c3387726
No known key found for this signature in database
GPG key ID: 449E073F9DC10735
2 changed files with 15 additions and 39 deletions

View file

@ -598,6 +598,7 @@ int llama_mtl_eval(
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
if (src0t == GGML_TYPE_Q4_0) {
//printf("nb = %d\n", ne00/32);
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {

View file

@ -247,8 +247,14 @@ kernel void kernel_mul_mat_q4_0_f32(
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
threadgroup float * sum [[threadgroup(0)]],
@ -256,12 +262,11 @@ kernel void kernel_mul_mat_q4_0_f32(
uint2 tpig[[thread_position_in_grid]],
uint2 tpitg[[thread_position_in_threadgroup]],
uint2 tptg[[threads_per_threadgroup]]) {
const int nb = ne00/QK4_0;
const int64_t r0 = tgpig.x;
const int64_t r1 = tgpig.y;
const int qk = QK4_0;
const int nb = ne00/qk;
device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
device const float * y = (device const float *) src1 + r1*ne10;
@ -272,7 +277,7 @@ kernel void kernel_mul_mat_q4_0_f32(
for (int i = tpitg.x; i < nb; i += tptg.x) {
device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs;
device const float4 * y0p = (device const float4 *) (y + i*qk);
device const float4 * y0p = (device const float4 *) (y + i*QK4_0);
const float d = (float)((x + i)->d);
@ -282,42 +287,12 @@ kernel void kernel_mul_mat_q4_0_f32(
float acc = 0.0f;
{
const int x0 = x0v[0] & 0x0F;
const int x1 = x0v[0] >> 4;
for (int j = 0; j < 4; ++j) {
const int x0 = x0v[j] & 0x0F;
const int x1 = x0v[j] >> 4;
const float y0 = y0v[0];
const float y1 = y1v[0];
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
}
{
const int x0 = x0v[1] & 0x0F;
const int x1 = x0v[1] >> 4;
const float y0 = y0v[1];
const float y1 = y1v[1];
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
}
{
const int x0 = x0v[2] & 0x0F;
const int x1 = x0v[2] >> 4;
const float y0 = y0v[2];
const float y1 = y1v[2];
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
}
{
const int x0 = x0v[3] & 0x0F;
const int x1 = x0v[3] >> 4;
const float y0 = y0v[3];
const float y1 = y1v[3];
const float y0 = y0v[j];
const float y1 = y1v[j];
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
}