cont : args is first argument
This commit is contained in:
parent
4af3a87962
commit
481b05df22
2 changed files with 117 additions and 117 deletions
|
@ -1981,10 +1981,10 @@ static void ggml_metal_encode_node(
|
||||||
};
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||||
|
|
||||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||||
|
@ -2185,10 +2185,10 @@ static void ggml_metal_encode_node(
|
||||||
};
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||||
[encoder setBytes:&args length:sizeof(args) atIndex:3];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||||
|
|
||||||
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 ||
|
||||||
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K ||
|
||||||
|
@ -2503,11 +2503,11 @@ static void ggml_metal_encode_node(
|
||||||
};
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||||
[encoder setBytes:&args length:sizeof(args) atIndex:4];
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:4];
|
||||||
|
|
||||||
const int64_t _ne1 = 1;
|
const int64_t _ne1 = 1;
|
||||||
const int tgz = dst_rows;
|
const int tgz = dst_rows;
|
||||||
|
@ -2752,15 +2752,15 @@ static void ggml_metal_encode_node(
|
||||||
};
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||||
if (id_src2 != nil) {
|
if (id_src2 != nil) {
|
||||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
||||||
} else {
|
} else {
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:2];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
||||||
[encoder setBytes:&args length:sizeof(args) atIndex:4];
|
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
@ -3270,16 +3270,16 @@ static void ggml_metal_encode_node(
|
||||||
};
|
};
|
||||||
|
|
||||||
[encoder setComputePipelineState:pipeline];
|
[encoder setComputePipelineState:pipeline];
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||||
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:3];
|
||||||
if (id_src3) {
|
if (id_src3) {
|
||||||
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:4];
|
||||||
} else {
|
} else {
|
||||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:3];
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:4];
|
||||||
}
|
}
|
||||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
|
||||||
[encoder setBytes:&args length:sizeof(args) atIndex:5];
|
|
||||||
|
|
||||||
if (!use_vec_kernel) {
|
if (!use_vec_kernel) {
|
||||||
// half8x8 kernel
|
// half8x8 kernel
|
||||||
|
|
|
@ -1629,12 +1629,12 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
||||||
// quantizations where the block size is 32. It also does not
|
// quantizations where the block size is 32. It also does not
|
||||||
// guard against the number of rows not being divisible by
|
// guard against the number of rows not being divisible by
|
||||||
// N_DST, so this is another explicit assumption of the implementation.
|
// N_DST, so this is another explicit assumption of the implementation.
|
||||||
template<typename block_q_type, int nr, int nsg, int nw, typename A>
|
template<typename block_q_type, int nr, int nsg, int nw, typename args_t>
|
||||||
void mul_vec_q_n_f32_impl(
|
void mul_vec_q_n_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -1704,57 +1704,57 @@ void mul_vec_q_n_f32_impl(
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mv_q4_0_f32(
|
kernel void kernel_mul_mv_q4_0_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mv_q4_1_f32(
|
kernel void kernel_mul_mv_q4_1_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mv_q5_0_f32(
|
kernel void kernel_mul_mv_q5_0_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_mul_mv_q5_1_f32(
|
kernel void kernel_mul_mv_q5_1_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define NB_Q8_0 8
|
#define NB_Q8_0 8
|
||||||
|
|
||||||
template<typename A>
|
template<typename args_t>
|
||||||
void kernel_mul_mv_q8_0_f32_impl(
|
void kernel_mul_mv_q8_0_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -1823,24 +1823,24 @@ void kernel_mul_mv_q8_0_f32_impl(
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_q8_0_f32")]]
|
[[host_name("kernel_mul_mv_q8_0_f32")]]
|
||||||
kernel void kernel_mul_mv_q8_0_f32(
|
kernel void kernel_mul_mv_q8_0_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
#define N_MV_T_T 4
|
#define N_MV_T_T 4
|
||||||
|
|
||||||
template<typename T0, typename T04, typename T1, typename T14, typename A>
|
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
|
||||||
void kernel_mul_mv_impl(
|
void kernel_mul_mv_impl(
|
||||||
|
args_t args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg) {
|
uint tiisg) {
|
||||||
const int64_t r0 = tgpig.x;
|
const int64_t r0 = tgpig.x;
|
||||||
|
@ -1904,17 +1904,17 @@ void kernel_mul_mv_impl(
|
||||||
|
|
||||||
template<typename T0, typename T04, typename T1, typename T14>
|
template<typename T0, typename T04, typename T1, typename T14>
|
||||||
kernel void kernel_mul_mv(
|
kernel void kernel_mul_mv(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
|
kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
|
||||||
|
args,
|
||||||
src0,
|
src0,
|
||||||
src1,
|
src1,
|
||||||
dst,
|
dst,
|
||||||
args,
|
|
||||||
tgpig,
|
tgpig,
|
||||||
tiisg);
|
tiisg);
|
||||||
}
|
}
|
||||||
|
@ -1931,10 +1931,10 @@ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv<
|
||||||
|
|
||||||
template<typename T, typename T4>
|
template<typename T, typename T4>
|
||||||
kernel void kernel_mul_mv_1row(
|
kernel void kernel_mul_mv_1row(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
@ -1987,10 +1987,10 @@ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kerne
|
||||||
// Assumes row size (ne00) is a multiple of 4
|
// Assumes row size (ne00) is a multiple of 4
|
||||||
template<typename T, typename T4>
|
template<typename T, typename T4>
|
||||||
kernel void kernel_mul_mv_l4(
|
kernel void kernel_mul_mv_l4(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||||
|
|
||||||
|
@ -2069,11 +2069,11 @@ static void rope_yarn_corr_dims(
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
kernel void kernel_rope_norm(
|
kernel void kernel_rope_norm(
|
||||||
|
constant ggml_metal_kargs_rope & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device const char * src2,
|
device const char * src2,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
constant ggml_metal_kargs_rope & args,
|
|
||||||
ushort tiitg[[thread_index_in_threadgroup]],
|
ushort tiitg[[thread_index_in_threadgroup]],
|
||||||
ushort3 tptg [[threads_per_threadgroup]],
|
ushort3 tptg [[threads_per_threadgroup]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||||
|
@ -2122,11 +2122,11 @@ kernel void kernel_rope_norm(
|
||||||
|
|
||||||
template<typename T>
|
template<typename T>
|
||||||
kernel void kernel_rope_neox(
|
kernel void kernel_rope_neox(
|
||||||
|
constant ggml_metal_kargs_rope & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device const char * src2,
|
device const char * src2,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
constant ggml_metal_kargs_rope & args,
|
|
||||||
ushort tiitg[[thread_index_in_threadgroup]],
|
ushort tiitg[[thread_index_in_threadgroup]],
|
||||||
ushort3 tptg [[threads_per_threadgroup]],
|
ushort3 tptg [[threads_per_threadgroup]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||||
|
@ -2563,13 +2563,13 @@ template<
|
||||||
short KV = 8, // key/value processed per each simdgroup
|
short KV = 8, // key/value processed per each simdgroup
|
||||||
short C = 32> // cache items per threadgroup
|
short C = 32> // cache items per threadgroup
|
||||||
kernel void kernel_flash_attn_ext(
|
kernel void kernel_flash_attn_ext(
|
||||||
|
constant ggml_metal_kargs_flash_attn_ext & args,
|
||||||
device const char * q,
|
device const char * q,
|
||||||
device const char * k,
|
device const char * k,
|
||||||
device const char * v,
|
device const char * v,
|
||||||
device const char * mask,
|
device const char * mask,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
constant ggml_metal_kargs_flash_attn_ext & args,
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
threadgroup half * shared [[threadgroup(0)]],
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort3 ntg[[threads_per_threadgroup]],
|
ushort3 ntg[[threads_per_threadgroup]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
@ -3058,13 +3058,13 @@ template<
|
||||||
short Q = 1, // queries per threadgroup
|
short Q = 1, // queries per threadgroup
|
||||||
short C = 32> // cache items per threadgroup
|
short C = 32> // cache items per threadgroup
|
||||||
kernel void kernel_flash_attn_ext_vec(
|
kernel void kernel_flash_attn_ext_vec(
|
||||||
|
constant ggml_metal_kargs_flash_attn_ext & args,
|
||||||
device const char * q,
|
device const char * q,
|
||||||
device const char * k,
|
device const char * k,
|
||||||
device const char * v,
|
device const char * v,
|
||||||
device const char * mask,
|
device const char * mask,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
constant ggml_metal_kargs_flash_attn_ext & args,
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
threadgroup half * shared [[threadgroup(0)]],
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort3 ntg[[threads_per_threadgroup]],
|
ushort3 ntg[[threads_per_threadgroup]],
|
||||||
ushort tiisg[[thread_index_in_simdgroup]],
|
ushort tiisg[[thread_index_in_simdgroup]],
|
||||||
|
@ -3932,12 +3932,12 @@ kernel void kernel_concat(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename A>
|
template<typename args_t>
|
||||||
void kernel_mul_mv_q2_K_f32_impl(
|
void kernel_mul_mv_q2_K_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -4024,23 +4024,23 @@ void kernel_mul_mv_q2_K_f32_impl(
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_q2_K_f32")]]
|
[[host_name("kernel_mul_mv_q2_K_f32")]]
|
||||||
kernel void kernel_mul_mv_q2_K_f32(
|
kernel void kernel_mul_mv_q2_K_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename A>
|
template<typename args_t>
|
||||||
void kernel_mul_mv_q3_K_f32_impl(
|
void kernel_mul_mv_q3_K_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -4184,23 +4184,23 @@ void kernel_mul_mv_q3_K_f32_impl(
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
||||||
kernel void kernel_mul_mv_q3_K_f32(
|
kernel void kernel_mul_mv_q3_K_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename A>
|
template<typename args_t>
|
||||||
void kernel_mul_mv_q4_K_f32_impl(
|
void kernel_mul_mv_q4_K_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -4300,23 +4300,23 @@ void kernel_mul_mv_q4_K_f32_impl(
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
||||||
kernel void kernel_mul_mv_q4_K_f32(
|
kernel void kernel_mul_mv_q4_K_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename A>
|
template<typename args_t>
|
||||||
void kernel_mul_mv_q5_K_f32_impl(
|
void kernel_mul_mv_q5_K_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -4430,23 +4430,23 @@ void kernel_mul_mv_q5_K_f32_impl(
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_q5_K_f32")]]
|
[[host_name("kernel_mul_mv_q5_K_f32")]]
|
||||||
kernel void kernel_mul_mv_q5_K_f32(
|
kernel void kernel_mul_mv_q5_K_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename A>
|
template <typename args_t>
|
||||||
void kernel_mul_mv_q6_K_f32_impl(
|
void kernel_mul_mv_q6_K_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -4519,25 +4519,25 @@ void kernel_mul_mv_q6_K_f32_impl(
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_q6_K_f32")]]
|
[[host_name("kernel_mul_mv_q6_K_f32")]]
|
||||||
kernel void kernel_mul_mv_q6_K_f32(
|
kernel void kernel_mul_mv_q6_K_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ======================= "True" 2-bit
|
// ======================= "True" 2-bit
|
||||||
|
|
||||||
template<typename A>
|
template<typename args_t>
|
||||||
void kernel_mul_mv_iq2_xxs_f32_impl(
|
void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -4627,24 +4627,24 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
|
[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
|
||||||
kernel void kernel_mul_mv_iq2_xxs_f32(
|
kernel void kernel_mul_mv_iq2_xxs_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename A>
|
template<typename args_t>
|
||||||
void kernel_mul_mv_iq2_xs_f32_impl(
|
void kernel_mul_mv_iq2_xs_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -4744,24 +4744,24 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_iq2_xs_f32")]]
|
[[host_name("kernel_mul_mv_iq2_xs_f32")]]
|
||||||
kernel void kernel_mul_mv_iq2_xs_f32(
|
kernel void kernel_mul_mv_iq2_xs_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename A>
|
template <typename args_t>
|
||||||
void kernel_mul_mv_iq3_xxs_f32_impl(
|
void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -4854,24 +4854,24 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
|
[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
|
||||||
kernel void kernel_mul_mv_iq3_xxs_f32(
|
kernel void kernel_mul_mv_iq3_xxs_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename A>
|
template<typename args_t>
|
||||||
void kernel_mul_mv_iq3_s_f32_impl(
|
void kernel_mul_mv_iq3_s_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -4964,24 +4964,24 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_iq3_s_f32")]]
|
[[host_name("kernel_mul_mv_iq3_s_f32")]]
|
||||||
kernel void kernel_mul_mv_iq3_s_f32(
|
kernel void kernel_mul_mv_iq3_s_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename A>
|
template <typename args_t>
|
||||||
void kernel_mul_mv_iq2_s_f32_impl(
|
void kernel_mul_mv_iq2_s_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -5075,24 +5075,24 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_iq2_s_f32")]]
|
[[host_name("kernel_mul_mv_iq2_s_f32")]]
|
||||||
kernel void kernel_mul_mv_iq2_s_f32(
|
kernel void kernel_mul_mv_iq2_s_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename A>
|
template<typename args_t>
|
||||||
void kernel_mul_mv_iq1_s_f32_impl(
|
void kernel_mul_mv_iq1_s_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_value,
|
threadgroup int8_t * shared_value,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -5171,12 +5171,12 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename A>
|
template <typename args_t>
|
||||||
void kernel_mul_mv_iq1_m_f32_impl(
|
void kernel_mul_mv_iq1_m_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_value,
|
threadgroup int8_t * shared_value,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -5264,12 +5264,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename A>
|
template<typename args_t>
|
||||||
void kernel_mul_mv_iq4_nl_f32_impl(
|
void kernel_mul_mv_iq4_nl_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_values_i8,
|
threadgroup int8_t * shared_values_i8,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -5352,12 +5352,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename A>
|
template<typename args_t>
|
||||||
void kernel_mul_mv_iq4_xs_f32_impl(
|
void kernel_mul_mv_iq4_xs_f32_impl(
|
||||||
|
args_t args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
A args,
|
|
||||||
threadgroup int8_t * shared_values_i8,
|
threadgroup int8_t * shared_values_i8,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -5443,56 +5443,56 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
||||||
kernel void kernel_mul_mv_iq1_s_f32(
|
kernel void kernel_mul_mv_iq1_s_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
||||||
kernel void kernel_mul_mv_iq1_m_f32(
|
kernel void kernel_mul_mv_iq1_m_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
||||||
kernel void kernel_mul_mv_iq4_nl_f32(
|
kernel void kernel_mul_mv_iq4_nl_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
|
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
|
||||||
kernel void kernel_mul_mv_iq4_xs_f32(
|
kernel void kernel_mul_mv_iq4_xs_f32(
|
||||||
|
constant ggml_metal_kargs_mul_mv & args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
constant ggml_metal_kargs_mul_mv & args,
|
|
||||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiisg[[thread_index_in_simdgroup]],
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
|
||||||
kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||||
|
@ -5597,10 +5597,10 @@ kernel void kernel_get_rows_i32(
|
||||||
// each block_q contains 16*nl weights
|
// each block_q contains 16*nl weights
|
||||||
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>
|
||||||
kernel void kernel_mul_mm(
|
kernel void kernel_mul_mm(
|
||||||
|
constant ggml_metal_kargs_mul_mm & args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device char * dst,
|
device char * dst,
|
||||||
constant ggml_metal_kargs_mul_mm & args,
|
|
||||||
threadgroup char * shared_memory [[threadgroup(0)]],
|
threadgroup char * shared_memory [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
ushort tiitg[[thread_index_in_threadgroup]],
|
ushort tiitg[[thread_index_in_threadgroup]],
|
||||||
|
@ -6032,18 +6032,18 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
|
||||||
//
|
//
|
||||||
|
|
||||||
typedef void (kernel_mul_mv_impl_t)(
|
typedef void (kernel_mul_mv_impl_t)(
|
||||||
|
ggml_metal_kargs_mul_mv args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
ggml_metal_kargs_mul_mv args,
|
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg);
|
uint tiisg);
|
||||||
|
|
||||||
typedef void (kernel_mul_mv2_impl_t)(
|
typedef void (kernel_mul_mv2_impl_t)(
|
||||||
|
ggml_metal_kargs_mul_mv args,
|
||||||
device const void * src0,
|
device const void * src0,
|
||||||
device const float * src1,
|
device const float * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
ggml_metal_kargs_mul_mv args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
|
@ -6051,41 +6051,41 @@ typedef void (kernel_mul_mv2_impl_t)(
|
||||||
|
|
||||||
template<kernel_mul_mv_impl_t impl_fn>
|
template<kernel_mul_mv_impl_t impl_fn>
|
||||||
void mmv_fn(
|
void mmv_fn(
|
||||||
|
ggml_metal_kargs_mul_mv args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
ggml_metal_kargs_mul_mv args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiitg,
|
uint tiitg,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
uint sgitg) {
|
uint sgitg) {
|
||||||
impl_fn(src0, src1, dst, args, tgpig, tiisg);
|
impl_fn(args, src0, src1, dst, tgpig, tiisg);
|
||||||
}
|
}
|
||||||
|
|
||||||
template<kernel_mul_mv2_impl_t impl_fn>
|
template<kernel_mul_mv2_impl_t impl_fn>
|
||||||
void mmv_fn(
|
void mmv_fn(
|
||||||
|
ggml_metal_kargs_mul_mv args,
|
||||||
device const char * src0,
|
device const char * src0,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
ggml_metal_kargs_mul_mv args,
|
|
||||||
threadgroup int8_t * shared_values,
|
threadgroup int8_t * shared_values,
|
||||||
uint3 tgpig,
|
uint3 tgpig,
|
||||||
uint tiitg,
|
uint tiitg,
|
||||||
uint tiisg,
|
uint tiisg,
|
||||||
uint sgitg) {
|
uint sgitg) {
|
||||||
impl_fn(src0,(const device float *) src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
impl_fn(args, src0,(const device float *) src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||||
}
|
}
|
||||||
|
|
||||||
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
|
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
|
||||||
|
|
||||||
template<mul_mv_impl_fn_t impl_fn>
|
template<mul_mv_impl_fn_t impl_fn>
|
||||||
kernel void kernel_mul_mv_id(
|
kernel void kernel_mul_mv_id(
|
||||||
|
constant ggml_metal_kargs_mul_mv_id & args,
|
||||||
device const char * src0s,
|
device const char * src0s,
|
||||||
device const char * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device float * dst,
|
||||||
device const char * ids,
|
device const char * ids,
|
||||||
constant ggml_metal_kargs_mul_mv_id & args,
|
|
||||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tiitg[[thread_index_in_threadgroup]],
|
uint tiitg[[thread_index_in_threadgroup]],
|
||||||
|
@ -6130,10 +6130,10 @@ kernel void kernel_mul_mv_id(
|
||||||
};
|
};
|
||||||
|
|
||||||
impl_fn(
|
impl_fn(
|
||||||
|
args0,
|
||||||
/* src0 */ src0_cur,
|
/* src0 */ src0_cur,
|
||||||
/* src1 */ src1_cur,
|
/* src1 */ src1_cur,
|
||||||
/* dst */ dst_cur,
|
/* dst */ dst_cur,
|
||||||
args0,
|
|
||||||
shared_values,
|
shared_values,
|
||||||
tgpig,
|
tgpig,
|
||||||
tiitg,
|
tiitg,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue