cont : pass by reference
This commit is contained in:
parent
07bc7610ad
commit
4af3a87962
1 changed files with 57 additions and 42 deletions
|
@ -1629,12 +1629,12 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|||
// quantizations where the block size is 32. It also does not
|
||||
// guard against the number of rows not being divisible by
|
||||
// N_DST, so this is another explicit assumption of the implementation.
|
||||
template<typename block_q_type, int nr, int nsg, int nw>
|
||||
template<typename block_q_type, int nr, int nsg, int nw, typename A>
|
||||
void mul_vec_q_n_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -1711,7 +1711,7 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
mul_vec_q_n_f32_impl<block_q4_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mv_q4_1_f32(
|
||||
|
@ -1720,9 +1720,9 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|||
device float * dst,
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32_impl<block_q4_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mv_q5_0_f32(
|
||||
|
@ -1733,7 +1733,7 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
mul_vec_q_n_f32_impl<block_q5_0, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mv_q5_1_f32(
|
||||
|
@ -1744,16 +1744,17 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
mul_vec_q_n_f32_impl<block_q5_1, N_DST, N_SIMDGROUP, N_SIMDWIDTH, constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
#define NB_Q8_0 8
|
||||
|
||||
template<typename A>
|
||||
void kernel_mul_mv_q8_0_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -1829,17 +1830,17 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
kernel_mul_mv_q8_0_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_q8_0_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
#define N_MV_T_T 4
|
||||
|
||||
template<typename T0, typename T04, typename T1, typename T14>
|
||||
template<typename T0, typename T04, typename T1, typename T14, typename A>
|
||||
void kernel_mul_mv_impl(
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
uint3 tgpig,
|
||||
uint tiisg) {
|
||||
const int64_t r0 = tgpig.x;
|
||||
|
@ -1909,7 +1910,7 @@ kernel void kernel_mul_mv(
|
|||
constant ggml_metal_kargs_mul_mv & args,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
kernel_mul_mv_impl<T0, T04, T1, T14>(
|
||||
kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
|
||||
src0,
|
||||
src1,
|
||||
dst,
|
||||
|
@ -3931,11 +3932,12 @@ kernel void kernel_concat(
|
|||
}
|
||||
}
|
||||
|
||||
template<typename A>
|
||||
void kernel_mul_mv_q2_K_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4030,14 +4032,15 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_q2_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<typename A>
|
||||
void kernel_mul_mv_q3_K_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4189,14 +4192,15 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_q3_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<typename A>
|
||||
void kernel_mul_mv_q4_K_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4304,14 +4308,15 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_q4_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<typename A>
|
||||
void kernel_mul_mv_q5_K_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4433,14 +4438,15 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_q5_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template <typename A>
|
||||
void kernel_mul_mv_q6_K_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4521,16 +4527,17 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_q6_K_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
// ======================= "True" 2-bit
|
||||
|
||||
template<typename A>
|
||||
void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4629,14 +4636,15 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
|
|||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq2_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<typename A>
|
||||
void kernel_mul_mv_iq2_xs_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4745,14 +4753,15 @@ kernel void kernel_mul_mv_iq2_xs_f32(
|
|||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq2_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template <typename A>
|
||||
void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4854,14 +4863,15 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq3_xxs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<typename A>
|
||||
void kernel_mul_mv_iq3_s_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4963,14 +4973,15 @@ kernel void kernel_mul_mv_iq3_s_f32(
|
|||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq3_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template <typename A>
|
||||
void kernel_mul_mv_iq2_s_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -5073,14 +5084,15 @@ kernel void kernel_mul_mv_iq2_s_f32(
|
|||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq2_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<typename A>
|
||||
void kernel_mul_mv_iq1_s_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_value,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -5159,11 +5171,12 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||
}
|
||||
}
|
||||
|
||||
template <typename A>
|
||||
void kernel_mul_mv_iq1_m_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_value,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -5251,11 +5264,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|||
}
|
||||
}
|
||||
|
||||
template<typename A>
|
||||
void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_values_i8,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -5338,11 +5352,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|||
}
|
||||
}
|
||||
|
||||
template<typename A>
|
||||
void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
A args,
|
||||
threadgroup int8_t * shared_values_i8,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -5436,7 +5451,7 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq1_s_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
||||
|
@ -5449,7 +5464,7 @@ kernel void kernel_mul_mv_iq1_m_f32(
|
|||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq1_m_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, nullptr, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
||||
|
@ -5463,7 +5478,7 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
|||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq4_nl_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
|
||||
|
@ -5477,7 +5492,7 @@ kernel void kernel_mul_mv_iq4_xs_f32(
|
|||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
||||
kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
kernel_mul_mv_iq4_xs_f32_impl<constant ggml_metal_kargs_mul_mv &>(src0, src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
template<typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread float4x4 &)>
|
||||
|
@ -6062,7 +6077,7 @@ void mmv_fn(
|
|||
impl_fn(src0,(const device float *) src1, dst, args, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4>>) mul_mv_impl_fn_t;
|
||||
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
|
||||
|
||||
template<mul_mv_impl_fn_t impl_fn>
|
||||
kernel void kernel_mul_mv_id(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue