iq3_xs: make new version work on metal
Performance is very similar to Q3_K_S
This commit is contained in:
parent
1328331db7
commit
1777825550
2 changed files with 21 additions and 26 deletions
|
@ -1555,7 +1555,7 @@ static bool ggml_metal_graph_compute(
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_XS) {
|
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_XS) {
|
||||||
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4+128;
|
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
||||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
|
@ -1874,7 +1874,7 @@ static bool ggml_metal_graph_compute(
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_XS) {
|
else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_XS) {
|
||||||
const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4+128;
|
const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
||||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
|
|
|
@ -2525,11 +2525,13 @@ typedef struct {
|
||||||
} block_iq3_xxs;
|
} block_iq3_xxs;
|
||||||
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
// 98 bytes / block for QK_K = 256, so 3.0625 bpw
|
||||||
|
|
||||||
// 3.3125 bpw
|
// 3.4375 bpw
|
||||||
typedef struct {
|
typedef struct {
|
||||||
half d;
|
half d;
|
||||||
uint8_t qs[3*QK_K/8];
|
uint8_t qs[QK_K/4];
|
||||||
uint8_t qh[QK_K/32];
|
uint8_t qh[QK_K/32];
|
||||||
|
uint8_t signs[QK_K/8];
|
||||||
|
uint8_t scales[QK_K/64];
|
||||||
} block_iq3_xs;
|
} block_iq3_xs;
|
||||||
|
|
||||||
typedef struct {
|
typedef struct {
|
||||||
|
@ -4475,14 +4477,10 @@ void kernel_mul_mv_iq3_xs_f32_impl(
|
||||||
const int nb32 = nb * (QK_K / 32);
|
const int nb32 = nb * (QK_K / 32);
|
||||||
|
|
||||||
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
|
threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values;
|
||||||
threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512);
|
|
||||||
{
|
{
|
||||||
int nval = 8;
|
int nval = 8;
|
||||||
int pos = (32*sgitg + tiisg)*nval;
|
int pos = (32*sgitg + tiisg)*nval;
|
||||||
for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i];
|
for (int i = 0; i < nval; ++i) values[pos + i] = iq3xs_grid[pos + i];
|
||||||
nval = 2;
|
|
||||||
pos = (32*sgitg + tiisg)*nval;
|
|
||||||
for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i];
|
|
||||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -4503,23 +4501,22 @@ void kernel_mul_mv_iq3_xs_f32_impl(
|
||||||
device const block_iq3_xs * xr = x + ibl;
|
device const block_iq3_xs * xr = x + ibl;
|
||||||
device const uint8_t * qs = xr->qs + 8 * ib;
|
device const uint8_t * qs = xr->qs + 8 * ib;
|
||||||
device const uint8_t * qh = xr->qh;
|
device const uint8_t * qh = xr->qh;
|
||||||
device const uint16_t * gas = (device const uint16_t *)(xr->qs + QK_K/4) + 2 * ib;
|
device const uint8_t * sc = xr->scales + (ib/2);
|
||||||
|
device const uint8_t * signs = xr->signs + 4 * ib;
|
||||||
device const half * dh = &xr->d;
|
device const half * dh = &xr->d;
|
||||||
|
|
||||||
for (int row = 0; row < N_DST; row++) {
|
for (int row = 0; row < N_DST; row++) {
|
||||||
|
|
||||||
const float db = dh[0];
|
const float db = dh[0];
|
||||||
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
const float d = db * (0.5f + ((sc[0] >> 4*(ib%2)) & 0xf));
|
||||||
const float d = db * (0.5f + (aux32 >> 28));
|
|
||||||
|
|
||||||
float2 sum = {0};
|
float2 sum = {0};
|
||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[ib] << (8-2*l)) & 256)));
|
const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[2*l+0] | ((qh[ib] << (8-2*l)) & 256)));
|
||||||
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[ib] << (7-2*l)) & 256)));
|
const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[2*l+1] | ((qh[ib] << (7-2*l)) & 256)));
|
||||||
const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127];
|
|
||||||
for (int j = 0; j < 4; ++j) {
|
for (int j = 0; j < 4; ++j) {
|
||||||
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
sum[0] += yl[8*l + j + 0] * grid1[j] * (signs[l] & kmask_iq2xs[j+0] ? -1.f : 1.f);
|
||||||
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
sum[1] += yl[8*l + j + 4] * grid2[j] * (signs[l] & kmask_iq2xs[j+4] ? -1.f : 1.f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
sumf[row] += d * (sum[0] + sum[1]);
|
sumf[row] += d * (sum[0] + sum[1]);
|
||||||
|
@ -4527,7 +4524,8 @@ void kernel_mul_mv_iq3_xs_f32_impl(
|
||||||
dh += nb*sizeof(block_iq3_xs)/2;
|
dh += nb*sizeof(block_iq3_xs)/2;
|
||||||
qs += nb*sizeof(block_iq3_xs);
|
qs += nb*sizeof(block_iq3_xs);
|
||||||
qh += nb*sizeof(block_iq3_xs);
|
qh += nb*sizeof(block_iq3_xs);
|
||||||
gas += nb*sizeof(block_iq3_xs)/2;
|
sc += nb*sizeof(block_iq3_xs);
|
||||||
|
signs += nb*sizeof(block_iq3_xs);
|
||||||
}
|
}
|
||||||
|
|
||||||
y4 += 32 * 32;
|
y4 += 32 * 32;
|
||||||
|
@ -5176,22 +5174,19 @@ void dequantize_iq3_xs(device const block_iq3_xs * xb, short il, thread type4x4
|
||||||
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
// il = 0 or 1. il = 0 processes the first 16 quants in a block of 32, il = 1 the second 16
|
||||||
device const uint8_t * qs = xb->qs + 8*ib32;
|
device const uint8_t * qs = xb->qs + 8*ib32;
|
||||||
device const uint8_t * qh = xb->qh;
|
device const uint8_t * qh = xb->qh;
|
||||||
device const uint16_t * gas = (device const uint16_t *)(xb->qs + QK_K/4) + 2*ib32;
|
device const uint8_t * signs = xb->signs + 4*ib32 + 2*il;
|
||||||
const uint32_t aux32 = gas[0] | (gas[1] << 16);
|
const float dl = d * (0.5f + ((xb->scales[ib32/2] >> 4*(ib32%2)) & 0xf)) * 0.5f;
|
||||||
const float dl = d * (0.5f + (aux32 >> 28)) * 0.5f;
|
|
||||||
constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh[ib32] << (8-4*il)) & 256)));
|
constant uint8_t * grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+0] | ((qh[ib32] << (8-4*il)) & 256)));
|
||||||
constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh[ib32] << (7-4*il)) & 256)));
|
constant uint8_t * grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+1] | ((qh[ib32] << (7-4*il)) & 256)));
|
||||||
uint8_t signs = ksigns_iq2xs[(aux32 >> 14*il) & 127];
|
|
||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
reg[0][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
reg[0][i] = dl * grid1[i] * (signs[0] & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
||||||
reg[1][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
reg[1][i] = dl * grid2[i] * (signs[0] & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
||||||
}
|
}
|
||||||
grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh[ib32] << (6-4*il)) & 256)));
|
grid1 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+2] | ((qh[ib32] << (6-4*il)) & 256)));
|
||||||
grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh[ib32] << (5-4*il)) & 256)));
|
grid2 = (constant uint8_t *)(iq3xs_grid + (qs[4*il+3] | ((qh[ib32] << (5-4*il)) & 256)));
|
||||||
signs = ksigns_iq2xs[(aux32 >> (14*il+7)) & 127];
|
|
||||||
for (int i = 0; i < 4; ++i) {
|
for (int i = 0; i < 4; ++i) {
|
||||||
reg[2][i] = dl * grid1[i] * (signs & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
reg[2][i] = dl * grid1[i] * (signs[1] & kmask_iq2xs[i+0] ? -1.f : 1.f);
|
||||||
reg[3][i] = dl * grid2[i] * (signs & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
reg[3][i] = dl * grid2[i] * (signs[1] & kmask_iq2xs[i+4] ? -1.f : 1.f);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue