update metal
This commit is contained in:
parent
2479900a1c
commit
93db37e274
3 changed files with 178 additions and 382 deletions
195
ggml-metal.m
195
ggml-metal.m
|
@ -1685,37 +1685,31 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
{
|
||||
//GGML_ASSERT(ne00 == ne10);
|
||||
//GGML_ASSERT(ne03 == ne13);
|
||||
|
||||
GGML_ASSERT(src0t == GGML_TYPE_I32);
|
||||
|
||||
const int n_as = ((int32_t *) dst->op_params)[1];
|
||||
|
||||
// TODO: make this more general
|
||||
GGML_ASSERT(n_as <= 8);
|
||||
const int n_as = src0->ne[2];
|
||||
|
||||
// max size of the src1ids array in the kernel shared buffer
|
||||
GGML_ASSERT(ne11 <= 4096);
|
||||
|
||||
const int64_t ne20 = src2 ? src2->ne[0] : 0;
|
||||
const int64_t ne21 = src2 ? src2->ne[1] : 0;
|
||||
const int64_t ne22 = src2 ? src2->ne[2] : 0;
|
||||
const int64_t ne23 = src2 ? src2->ne[3] : 0; GGML_UNUSED(ne23);
|
||||
// src2 = ids
|
||||
const int64_t ne20 = src2->ne[0]; GGML_UNUSED(ne20);
|
||||
const int64_t ne21 = src2->ne[1];
|
||||
const int64_t ne22 = src2->ne[2]; GGML_UNUSED(ne22);
|
||||
const int64_t ne23 = src2->ne[3]; GGML_UNUSED(ne23);
|
||||
|
||||
const uint64_t nb20 = src2 ? src2->nb[0] : 0; GGML_UNUSED(nb20);
|
||||
const uint64_t nb21 = src2 ? src2->nb[1] : 0;
|
||||
const uint64_t nb22 = src2 ? src2->nb[2] : 0;
|
||||
const uint64_t nb23 = src2 ? src2->nb[3] : 0; GGML_UNUSED(nb23);
|
||||
const uint64_t nb20 = src2->nb[0]; GGML_UNUSED(nb20);
|
||||
const uint64_t nb21 = src2->nb[1];
|
||||
const uint64_t nb22 = src2->nb[2]; GGML_UNUSED(nb22);
|
||||
const uint64_t nb23 = src2->nb[3]; GGML_UNUSED(nb23);
|
||||
|
||||
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
||||
const enum ggml_type src2t = src2->type; GGML_UNUSED(src2t);
|
||||
|
||||
GGML_ASSERT(!ggml_is_transposed(src2));
|
||||
GGML_ASSERT(src2t == GGML_TYPE_I32);
|
||||
|
||||
GGML_ASSERT(!ggml_is_transposed(src0));
|
||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
|
||||
const uint r2 = ne12/ne22;
|
||||
const uint r3 = ne13/ne23;
|
||||
|
||||
// find the break-even point where the matrix-matrix kernel becomes more efficient compared
|
||||
// to the matrix-vector kernel
|
||||
int ne11_mm_min = n_as;
|
||||
|
@ -1723,7 +1717,10 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
const int idx = ((int32_t *) dst->op_params)[0];
|
||||
|
||||
// batch size
|
||||
GGML_ASSERT(ne01 == ne11);
|
||||
GGML_ASSERT(ne21 == ne11); // ?
|
||||
GGML_ASSERT(ne12 == 1 && ne13 == 1); // no broadcasting
|
||||
const uint r2 = 1;
|
||||
const uint r3 = 1;
|
||||
|
||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||
|
@ -1732,20 +1729,20 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
// indirect matrix multiplication
|
||||
// !!!
|
||||
if ([ctx->device supportsFamily:MTLGPUFamilyApple7] &&
|
||||
ne20 % 32 == 0 && ne20 >= 64 &&
|
||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||
ne11 > ne11_mm_min) {
|
||||
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
switch (src2->type) {
|
||||
case GGML_TYPE_F32: GGML_ASSERT(nb21 % 16 == 0); break;
|
||||
case GGML_TYPE_F16: GGML_ASSERT(nb21 % 8 == 0); break;
|
||||
default: break;
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
switch (src2->type) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32 ].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32 ].pipeline; break;
|
||||
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32 ].pipeline; break;
|
||||
|
@ -1774,36 +1771,27 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
||||
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
||||
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:5];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6];
|
||||
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:7];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:9];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:16];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:17];
|
||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:18];
|
||||
// TODO: how to make this an array? read Metal docs
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
||||
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
||||
|
||||
size_t offs_src_cur = 0;
|
||||
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
|
||||
|
||||
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:19 + j];
|
||||
}
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:6];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:9];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:10];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:11];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:12];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:13];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:14];
|
||||
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:15];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:16];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:17];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:18];
|
||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:19];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:GGML_PAD(8192 + 2*ne11, 16) atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne21 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne11 + 31)/32, (ne01 + 63)/64, n_as*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
} else {
|
||||
int nth0 = 32;
|
||||
int nth1 = 1;
|
||||
|
@ -1813,7 +1801,7 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
// use custom matrix x vector kernel
|
||||
switch (src2t) {
|
||||
switch (src0t) {
|
||||
case GGML_TYPE_F32:
|
||||
{
|
||||
GGML_ASSERT(src1t == GGML_TYPE_F32);
|
||||
|
@ -1947,8 +1935,8 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
}
|
||||
};
|
||||
|
||||
if (ggml_is_quantized(src2t)) {
|
||||
GGML_ASSERT(ne20 >= nth0*nth1);
|
||||
if (ggml_is_quantized(src0t)) {
|
||||
GGML_ASSERT(ne00 >= nth0*nth1);
|
||||
}
|
||||
|
||||
const int64_t _ne1 = 1; // kernels needs a reference in constant memory
|
||||
|
@ -1957,75 +1945,66 @@ static enum ggml_status ggml_metal_graph_compute(
|
|||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:3];
|
||||
[encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4];
|
||||
[encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5];
|
||||
[encoder setBytes:&ne22 length:sizeof(ne22) atIndex:6];
|
||||
[encoder setBytes:&nb20 length:sizeof(nb20) atIndex:7];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:8];
|
||||
[encoder setBytes:&nb22 length:sizeof(nb22) atIndex:9];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10];
|
||||
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:11];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17];
|
||||
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:18];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:20];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:21];
|
||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:22];
|
||||
// TODO: how to make this an array? read Metal docs
|
||||
for (int j = 0; j < 8; ++j) {
|
||||
// NOTE: this is done like this to avoid uninitialized kernel arguments when n_as < 8
|
||||
struct ggml_tensor * src_cur = dst->src[2 + (j % n_as)];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
||||
[encoder setBytes:&nb21 length:sizeof(nb21) atIndex:4];
|
||||
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:5];
|
||||
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:6];
|
||||
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:7];
|
||||
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:8];
|
||||
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9];
|
||||
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10];
|
||||
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11];
|
||||
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:12];
|
||||
[encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13];
|
||||
[encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14];
|
||||
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15];
|
||||
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16];
|
||||
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17];
|
||||
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:18];
|
||||
[encoder setBytes:&_ne1 length:sizeof(_ne1) atIndex:19];
|
||||
[encoder setBytes:&nb1 length:sizeof(nb1) atIndex:20];
|
||||
[encoder setBytes:&r2 length:sizeof(r2) atIndex:21];
|
||||
[encoder setBytes:&r3 length:sizeof(r3) atIndex:22];
|
||||
[encoder setBytes:&idx length:sizeof(idx) atIndex:23];
|
||||
|
||||
size_t offs_src_cur = 0;
|
||||
id<MTLBuffer> id_src_cur = ggml_metal_get_buffer(src_cur, &offs_src_cur);
|
||||
|
||||
[encoder setBuffer:id_src_cur offset:offs_src_cur atIndex:23 + j];
|
||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
||||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
||||
src0t == GGML_TYPE_IQ1_S || src0t == GGML_TYPE_IQ1_M || src0t == GGML_TYPE_IQ2_S) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
|
||||
if (src2t == GGML_TYPE_Q4_0 || src2t == GGML_TYPE_Q4_1 || src2t == GGML_TYPE_Q5_0 ||
|
||||
src2t == GGML_TYPE_Q5_1 || src2t == GGML_TYPE_Q8_0 || src2t == GGML_TYPE_Q2_K ||
|
||||
src2t == GGML_TYPE_IQ1_S || src2t == GGML_TYPE_IQ1_M || src2t == GGML_TYPE_IQ2_S) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_IQ2_XXS || src2t == GGML_TYPE_IQ2_XS) {
|
||||
const int mem_size = src2t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
||||
else if (src0t == GGML_TYPE_IQ2_XXS || src0t == GGML_TYPE_IQ2_XS) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ2_XXS ? 256*8+128 : 512*8+128;
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_IQ3_XXS || src2t == GGML_TYPE_IQ3_S) {
|
||||
const int mem_size = src2t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
||||
else if (src0t == GGML_TYPE_IQ3_XXS || src0t == GGML_TYPE_IQ3_S) {
|
||||
const int mem_size = src0t == GGML_TYPE_IQ3_XXS ? 256*4+128 : 512*4;
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 7)/8, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7)/8, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_IQ4_NL || src2t == GGML_TYPE_IQ4_XS) {
|
||||
else if (src0t == GGML_TYPE_IQ4_NL || src0t == GGML_TYPE_IQ4_XS) {
|
||||
const int mem_size = 32*sizeof(float);
|
||||
[encoder setThreadgroupMemoryLength:mem_size atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
else if (src0t == GGML_TYPE_Q4_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_Q3_K) {
|
||||
else if (src0t == GGML_TYPE_Q3_K) {
|
||||
#ifdef GGML_QKK_64
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
#else
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
#endif
|
||||
}
|
||||
else if (src2t == GGML_TYPE_Q5_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 3)/4, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
else if (src0t == GGML_TYPE_Q5_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3)/4, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
else if (src2t == GGML_TYPE_Q6_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 1)/2, _ne1, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
else if (src0t == GGML_TYPE_Q6_K) {
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, _ne1, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
} else {
|
||||
const int64_t ny = (_ne1 + nrows - 1)/nrows;
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne21, ny, ne01*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne21*ne12*ne13) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
|
||||
}
|
||||
}
|
||||
} break;
|
||||
|
|
362
ggml-metal.metal
362
ggml-metal.metal
|
@ -5785,9 +5785,10 @@ kernel void kernel_mul_mm(device const uchar * src0,
|
|||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread half4x4 &)>
|
||||
kernel void kernel_mul_mm_id(
|
||||
device const uchar * ids,
|
||||
device const uchar * src0s,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
device const uchar * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
|
@ -5804,22 +5805,14 @@ kernel void kernel_mul_mm_id(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const uchar * src00,
|
||||
device const uchar * src01,
|
||||
device const uchar * src02,
|
||||
device const uchar * src03,
|
||||
device const uchar * src04,
|
||||
device const uchar * src05,
|
||||
device const uchar * src06,
|
||||
device const uchar * src07,
|
||||
threadgroup uchar * shared_memory [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const uchar * src0s[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
// expert id
|
||||
const int32_t id = tgpig.z/(ne12*ne13);
|
||||
device const uchar * src0 = src0s + id*nb02;
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
|
@ -5834,7 +5827,7 @@ kernel void kernel_mul_mm_id(
|
|||
}
|
||||
|
||||
kernel_mul_mm_id_impl<block_q, nl, dequantize_func>(
|
||||
src0s[id],
|
||||
src0,
|
||||
src1,
|
||||
src1ids,
|
||||
dst,
|
||||
|
@ -5960,9 +5953,10 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mat_mm_t kernel_mul_m
|
|||
//
|
||||
|
||||
typedef void (mat_mm_id_t)(
|
||||
device const uchar * ids,
|
||||
device const uchar * src0s,
|
||||
device const uchar * src1,
|
||||
device float * dst,
|
||||
device const uchar * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne02,
|
||||
|
@ -5979,14 +5973,6 @@ typedef void (mat_mm_id_t)(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const uchar * src00,
|
||||
device const uchar * src01,
|
||||
device const uchar * src02,
|
||||
device const uchar * src03,
|
||||
device const uchar * src04,
|
||||
device const uchar * src05,
|
||||
device const uchar * src06,
|
||||
device const uchar * src07,
|
||||
threadgroup uchar *,
|
||||
uint3, uint, uint);
|
||||
|
||||
|
@ -6022,9 +6008,10 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
|
|||
|
||||
[[host_name("kernel_mul_mv_id_f32_f32")]]
|
||||
kernel void kernel_mul_mv_id_f32_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6045,28 +6032,19 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_f32_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
src1 + bid*nb11,
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6091,9 +6069,10 @@ kernel void kernel_mul_mv_id_f32_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_f16_f32")]]
|
||||
kernel void kernel_mul_mv_id_f16_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6114,28 +6093,19 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_f16_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
src1 + bid*nb11,
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6160,9 +6130,10 @@ kernel void kernel_mul_mv_id_f16_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_q8_0_f32")]]
|
||||
kernel void kernel_mul_mv_id_q8_0_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6183,28 +6154,19 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_q8_0_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6223,9 +6185,10 @@ kernel void kernel_mul_mv_id_q8_0_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_q4_0_f32")]]
|
||||
kernel void kernel_mul_mv_id_q4_0_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6246,28 +6209,19 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6286,9 +6240,10 @@ kernel void kernel_mul_mv_id_q4_0_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_q4_1_f32")]]
|
||||
kernel void kernel_mul_mv_id_q4_1_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6309,28 +6264,19 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6349,9 +6295,10 @@ kernel void kernel_mul_mv_id_q4_1_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_q5_0_f32")]]
|
||||
kernel void kernel_mul_mv_id_q5_0_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6372,28 +6319,19 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6412,9 +6350,10 @@ kernel void kernel_mul_mv_id_q5_0_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_q5_1_f32")]]
|
||||
kernel void kernel_mul_mv_id_q5_1_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6435,28 +6374,19 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6475,9 +6405,10 @@ kernel void kernel_mul_mv_id_q5_1_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_q2_K_f32")]]
|
||||
kernel void kernel_mul_mv_id_q2_K_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6498,28 +6429,19 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_q2_K_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6538,9 +6460,10 @@ kernel void kernel_mul_mv_id_q2_K_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_q3_K_f32")]]
|
||||
kernel void kernel_mul_mv_id_q3_K_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6561,28 +6484,19 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_q3_K_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6601,9 +6515,10 @@ kernel void kernel_mul_mv_id_q3_K_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_q4_K_f32")]]
|
||||
kernel void kernel_mul_mv_id_q4_K_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6624,28 +6539,19 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_q4_K_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6664,9 +6570,10 @@ kernel void kernel_mul_mv_id_q4_K_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_q5_K_f32")]]
|
||||
kernel void kernel_mul_mv_id_q5_K_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6687,28 +6594,19 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_q5_K_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6727,9 +6625,10 @@ kernel void kernel_mul_mv_id_q5_K_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_q6_K_f32")]]
|
||||
kernel void kernel_mul_mv_id_q6_K_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6750,28 +6649,19 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_q6_K_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6790,9 +6680,10 @@ kernel void kernel_mul_mv_id_q6_K_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_iq2_xxs_f32")]]
|
||||
kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6813,29 +6704,20 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_iq2_xxs_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6855,9 +6737,10 @@ kernel void kernel_mul_mv_id_iq2_xxs_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_iq2_xs_f32")]]
|
||||
kernel void kernel_mul_mv_id_iq2_xs_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6878,29 +6761,20 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_iq2_xs_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6920,9 +6794,10 @@ kernel void kernel_mul_mv_id_iq2_xs_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_iq3_xxs_f32")]]
|
||||
kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -6943,29 +6818,20 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_iq3_xxs_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -6985,9 +6851,10 @@ kernel void kernel_mul_mv_id_iq3_xxs_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_iq3_s_f32")]]
|
||||
kernel void kernel_mul_mv_id_iq3_s_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -7008,29 +6875,20 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_iq3_s_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -7050,9 +6908,10 @@ kernel void kernel_mul_mv_id_iq3_s_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_iq2_s_f32")]]
|
||||
kernel void kernel_mul_mv_id_iq2_s_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -7073,29 +6932,20 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_iq2_s_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -7115,9 +6965,10 @@ kernel void kernel_mul_mv_id_iq2_s_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_iq1_s_f32")]]
|
||||
kernel void kernel_mul_mv_id_iq1_s_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -7138,28 +6989,19 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_iq1_s_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -7178,9 +7020,10 @@ kernel void kernel_mul_mv_id_iq1_s_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_iq1_m_f32")]]
|
||||
kernel void kernel_mul_mv_id_iq1_m_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -7201,28 +7044,19 @@ kernel void kernel_mul_mv_id_iq1_m_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_iq1_m_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -7241,9 +7075,10 @@ kernel void kernel_mul_mv_id_iq1_m_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_iq4_nl_f32")]]
|
||||
kernel void kernel_mul_mv_id_iq4_nl_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -7264,29 +7099,20 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
threadgroup float * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
kernel_mul_mv_iq4_nl_f32_impl(
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
@ -7306,9 +7132,10 @@ kernel void kernel_mul_mv_id_iq4_nl_f32(
|
|||
|
||||
[[host_name("kernel_mul_mv_id_iq4_xs_f32")]]
|
||||
kernel void kernel_mul_mv_id_iq4_xs_f32(
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
constant uint64_t & nbi1,
|
||||
constant int64_t & ne00,
|
||||
constant int64_t & ne01,
|
||||
|
@ -7329,33 +7156,24 @@ kernel void kernel_mul_mv_id_iq4_xs_f32(
|
|||
constant uint & r2,
|
||||
constant uint & r3,
|
||||
constant int & idx,
|
||||
device const char * src00,
|
||||
device const char * src01,
|
||||
device const char * src02,
|
||||
device const char * src03,
|
||||
device const char * src04,
|
||||
device const char * src05,
|
||||
device const char * src06,
|
||||
device const char * src07,
|
||||
threadgroup float * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
device const char * src0[8] = {src00, src01, src02, src03, src04, src05, src06, src07};
|
||||
|
||||
const int64_t bid = tgpig.z/(ne12*ne13);
|
||||
|
||||
tgpig.z = tgpig.z%(ne12*ne13);
|
||||
|
||||
const int32_t id = ((device int32_t *) (ids + bid*nbi1))[idx];
|
||||
device const char * src0 = src0s + id*nb02;
|
||||
|
||||
#if QK_K == 64
|
||||
kernel_mul_mv_iq4_nl_f32_impl(
|
||||
#else
|
||||
kernel_mul_mv_iq4_xs_f32_impl(
|
||||
#endif
|
||||
src0[id],
|
||||
src0,
|
||||
(device const float *) (src1 + bid*nb11),
|
||||
dst + bid*ne0,
|
||||
ne00,
|
||||
|
|
3
ggml.c
3
ggml.c
|
@ -11049,8 +11049,7 @@ static void ggml_compute_forward_mul_mat_id(
|
|||
continue;
|
||||
}
|
||||
|
||||
//const struct ggml_tensor * src0_cur = dst->src[cur_a + 2];
|
||||
size_t src0_offset = src0->nb[2]*cur_a;
|
||||
size_t src0_offset = cur_a*src0->nb[2];
|
||||
|
||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t row_size = ggml_row_size(vec_dot_type, ne10);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue