iq2_xxs: slighty faster dot product
TG-128 is now 50.9 t/s
This commit is contained in:
parent
1c96aa0d7f
commit
e211fadc8a
2 changed files with 24 additions and 4 deletions
12
ggml-metal.m
12
ggml-metal.m
|
@ -1708,10 +1708,14 @@ bool ggml_metal_graph_compute(
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 ||
|
||||||
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 ||
|
||||||
src0t == GGML_TYPE_IQ2_XXS ||
|
//src0t == GGML_TYPE_IQ2_XXS ||
|
||||||
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
src0t == GGML_TYPE_Q2_K) { // || src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
|
else if (src0t == GGML_TYPE_IQ2_XXS) {
|
||||||
|
[encoder setThreadgroupMemoryLength:256*8 atIndex:0];
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
|
}
|
||||||
else if (src0t == GGML_TYPE_Q4_K) {
|
else if (src0t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
|
@ -1972,10 +1976,14 @@ bool ggml_metal_graph_compute(
|
||||||
|
|
||||||
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 ||
|
||||||
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
src2t == GGML_TYPE_Q5_0 || src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 ||
|
||||||
src2t == GGML_TYPE_IQ2_XXS ||
|
//src2t == GGML_TYPE_IQ2_XXS ||
|
||||||
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
src2t == GGML_TYPE_Q2_K) { // || src2t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
|
else if (src0t == GGML_TYPE_IQ2_XXS) {
|
||||||
|
[encoder setThreadgroupMemoryLength:256*8 atIndex:0];
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, ne11, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
|
}
|
||||||
else if (src2t == GGML_TYPE_Q4_K) {
|
else if (src2t == GGML_TYPE_Q4_K) {
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||||
}
|
}
|
||||||
|
|
|
@ -3569,6 +3569,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
@ -3594,6 +3595,14 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||||
|
|
||||||
const int nb32 = nb * (QK_K / 32);
|
const int nb32 = nb * (QK_K / 32);
|
||||||
|
|
||||||
|
threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values;
|
||||||
|
{
|
||||||
|
const int nval = 4;
|
||||||
|
const int pos = (32*sgitg + tiisg)*nval;
|
||||||
|
for (int i = 0; i < nval; ++i) values[pos + i] = kgrid_iq2xxs[pos + i];
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
}
|
||||||
|
|
||||||
#if QK_K == 256
|
#if QK_K == 256
|
||||||
const int ix = tiisg;
|
const int ix = tiisg;
|
||||||
|
|
||||||
|
@ -3621,7 +3630,7 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||||
|
|
||||||
float sum = 0;
|
float sum = 0;
|
||||||
for (int l = 0; l < 4; ++l) {
|
for (int l = 0; l < 4; ++l) {
|
||||||
constant uint8_t * grid = (constant uint8_t *)(kgrid_iq2xxs + aux8[l]);
|
const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]);
|
||||||
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
|
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*l) & 127];
|
||||||
for (int j = 0; j < 8; ++j) {
|
for (int j = 0; j < 8; ++j) {
|
||||||
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f);
|
||||||
|
@ -3668,11 +3677,12 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
|
||||||
constant int64_t & ne1,
|
constant int64_t & ne1,
|
||||||
constant uint & r2,
|
constant uint & r2,
|
||||||
constant uint & r3,
|
constant uint & r3,
|
||||||
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, tgpig, tiisg, sgitg);
|
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
//============================= templates and their specializations =============================
|
//============================= templates and their specializations =============================
|
||||||
|
@ -5403,6 +5413,7 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
||||||
device const char * src05,
|
device const char * src05,
|
||||||
device const char * src06,
|
device const char * src06,
|
||||||
device const char * src07,
|
device const char * src07,
|
||||||
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
@ -5428,6 +5439,7 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
||||||
ne1,
|
ne1,
|
||||||
r2,
|
r2,
|
||||||
r3,
|
r3,
|
||||||
|
shared_values,
|
||||||
tgpig,
|
tgpig,
|
||||||
tiisg,
|
tiisg,
|
||||||
sgitg);
|
sgitg);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue