mtl : fix kernel signature + roll inner loop
This commit is contained in:
parent
847bbfe9e6
commit
70c3387726
2 changed files with 15 additions and 39 deletions
|
@ -598,6 +598,7 @@ int llama_mtl_eval(
|
||||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0) {
|
if (src0t == GGML_TYPE_Q4_0) {
|
||||||
|
//printf("nb = %d\n", ne00/32);
|
||||||
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -247,8 +247,14 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
constant int64_t & ne10,
|
constant int64_t & ne10,
|
||||||
constant int64_t & ne11,
|
constant int64_t & ne11,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
constant int64_t & ne0,
|
constant int64_t & ne0,
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
threadgroup float * sum [[threadgroup(0)]],
|
threadgroup float * sum [[threadgroup(0)]],
|
||||||
|
@ -256,12 +262,11 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||||
uint2 tpig[[thread_position_in_grid]],
|
uint2 tpig[[thread_position_in_grid]],
|
||||||
uint2 tpitg[[thread_position_in_threadgroup]],
|
uint2 tpitg[[thread_position_in_threadgroup]],
|
||||||
uint2 tptg[[threads_per_threadgroup]]) {
|
uint2 tptg[[threads_per_threadgroup]]) {
|
||||||
|
const int nb = ne00/QK4_0;
|
||||||
|
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
const int64_t r1 = tgpig.y;
|
const int64_t r1 = tgpig.y;
|
||||||
|
|
||||||
const int qk = QK4_0;
|
|
||||||
const int nb = ne00/qk;
|
|
||||||
|
|
||||||
device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
|
device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
|
||||||
device const float * y = (device const float *) src1 + r1*ne10;
|
device const float * y = (device const float *) src1 + r1*ne10;
|
||||||
|
|
||||||
|
@ -272,7 +277,7 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||||
|
|
||||||
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
for (int i = tpitg.x; i < nb; i += tptg.x) {
|
||||||
device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs;
|
device const uchar4 * x0p = (device const uchar4 *) (x + i)->qs;
|
||||||
device const float4 * y0p = (device const float4 *) (y + i*qk);
|
device const float4 * y0p = (device const float4 *) (y + i*QK4_0);
|
||||||
|
|
||||||
const float d = (float)((x + i)->d);
|
const float d = (float)((x + i)->d);
|
||||||
|
|
||||||
|
@ -282,42 +287,12 @@ kernel void kernel_mul_mat_q4_0_f32(
|
||||||
|
|
||||||
float acc = 0.0f;
|
float acc = 0.0f;
|
||||||
|
|
||||||
{
|
for (int j = 0; j < 4; ++j) {
|
||||||
const int x0 = x0v[0] & 0x0F;
|
const int x0 = x0v[j] & 0x0F;
|
||||||
const int x1 = x0v[0] >> 4;
|
const int x1 = x0v[j] >> 4;
|
||||||
|
|
||||||
const float y0 = y0v[0];
|
const float y0 = y0v[j];
|
||||||
const float y1 = y1v[0];
|
const float y1 = y1v[j];
|
||||||
|
|
||||||
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
const int x0 = x0v[1] & 0x0F;
|
|
||||||
const int x1 = x0v[1] >> 4;
|
|
||||||
|
|
||||||
const float y0 = y0v[1];
|
|
||||||
const float y1 = y1v[1];
|
|
||||||
|
|
||||||
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
const int x0 = x0v[2] & 0x0F;
|
|
||||||
const int x1 = x0v[2] >> 4;
|
|
||||||
|
|
||||||
const float y0 = y0v[2];
|
|
||||||
const float y1 = y1v[2];
|
|
||||||
|
|
||||||
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
|
||||||
}
|
|
||||||
|
|
||||||
{
|
|
||||||
const int x0 = x0v[3] & 0x0F;
|
|
||||||
const int x1 = x0v[3] >> 4;
|
|
||||||
|
|
||||||
const float y0 = y0v[3];
|
|
||||||
const float y1 = y1v[3];
|
|
||||||
|
|
||||||
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
acc += (x0 - 8)*y0 + (x1 - 8)*y1;
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue