cont : use char ptr
This commit is contained in:
parent
481b05df22
commit
d2a055059e
1 changed files with 216 additions and 179 deletions
|
@ -1632,9 +1632,9 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre
|
|||
template<typename block_q_type, int nr, int nsg, int nw, typename args_t>
|
||||
void mul_vec_q_n_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -1653,8 +1653,8 @@ void mul_vec_q_n_f32_impl(
|
|||
//const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
//device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0);
|
||||
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
||||
//device const block_q_type * x = (device const block_q_type *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
// pointers to src0 rows
|
||||
device const block_q_type * ax[nr];
|
||||
|
@ -1695,19 +1695,22 @@ void mul_vec_q_n_f32_impl(
|
|||
yb += QK4_0 * 16;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr; ++row) {
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
|
||||
if (tiisg == 0 && first_row + row < args.ne01) {
|
||||
dst[im*args.ne0*args.ne1 + r1*args.ne0 + first_row + row] = tot;
|
||||
dst_f32[first_row + row] = tot;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kernel void kernel_mul_mv_q4_0_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
@ -1716,9 +1719,9 @@ kernel void kernel_mul_mv_q4_0_f32(
|
|||
|
||||
kernel void kernel_mul_mv_q4_1_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
@ -1727,9 +1730,9 @@ kernel void kernel_mul_mv_q4_1_f32(
|
|||
|
||||
kernel void kernel_mul_mv_q5_0_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
@ -1738,9 +1741,9 @@ kernel void kernel_mul_mv_q5_0_f32(
|
|||
|
||||
kernel void kernel_mul_mv_q5_1_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
@ -1752,9 +1755,9 @@ kernel void kernel_mul_mv_q5_1_f32(
|
|||
template<typename args_t>
|
||||
void kernel_mul_mv_q8_0_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -1776,8 +1779,8 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|||
//const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
//device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0);
|
||||
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
||||
//device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
// pointers to src0 rows
|
||||
device const block_q8_0 * ax[nr];
|
||||
|
@ -1813,10 +1816,12 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|||
yb += NB_Q8_0 * nw;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < nr; ++row) {
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
if (tiisg == 0 && first_row + row < args.ne01) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = tot;
|
||||
dst_f32[first_row + row] = tot;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1824,9 +1829,9 @@ void kernel_mul_mv_q8_0_f32_impl(
|
|||
[[host_name("kernel_mul_mv_q8_0_f32")]]
|
||||
kernel void kernel_mul_mv_q8_0_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
@ -1838,9 +1843,9 @@ kernel void kernel_mul_mv_q8_0_f32(
|
|||
template<typename T0, typename T04, typename T1, typename T14, typename args_t>
|
||||
void kernel_mul_mv_impl(
|
||||
args_t args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig,
|
||||
uint tiisg) {
|
||||
const int64_t r0 = tgpig.x;
|
||||
|
@ -1854,6 +1859,8 @@ void kernel_mul_mv_impl(
|
|||
|
||||
device const T0 * x = (device const T0 *) (src0 + offset0);
|
||||
|
||||
device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1;
|
||||
|
||||
if (args.ne00 < 128) {
|
||||
for (int row = 0; row < N_MV_T_T; ++row) {
|
||||
int r1 = rb + row;
|
||||
|
@ -1872,7 +1879,7 @@ void kernel_mul_mv_impl(
|
|||
|
||||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum;
|
||||
dst_f32[r1*args.ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
|
@ -1896,7 +1903,7 @@ void kernel_mul_mv_impl(
|
|||
float all_sum = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]);
|
||||
dst[im*args.ne1*args.ne0 + r1*args.ne0 + r0] = all_sum;
|
||||
dst_f32[r1*args.ne0 + r0] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -1905,9 +1912,9 @@ void kernel_mul_mv_impl(
|
|||
template<typename T0, typename T04, typename T1, typename T14>
|
||||
kernel void kernel_mul_mv(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]]) {
|
||||
kernel_mul_mv_impl<T0, T04, T1, T14, constant ggml_metal_kargs_mul_mv &>(
|
||||
|
@ -3935,9 +3942,9 @@ kernel void kernel_concat(
|
|||
template<typename args_t>
|
||||
void kernel_mul_mv_q2_K_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -3956,8 +3963,8 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0);
|
||||
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
||||
device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
@ -4014,10 +4021,12 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|||
y4 += 4 * QK_K;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum;
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4025,9 +4034,9 @@ void kernel_mul_mv_q2_K_f32_impl(
|
|||
[[host_name("kernel_mul_mv_q2_K_f32")]]
|
||||
kernel void kernel_mul_mv_q2_K_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
@ -4038,9 +4047,9 @@ kernel void kernel_mul_mv_q2_K_f32(
|
|||
template<typename args_t>
|
||||
void kernel_mul_mv_q3_K_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4060,8 +4069,8 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0);
|
||||
device const float * yy = (device const float *) ((device char *) src1 + offset1);
|
||||
device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0);
|
||||
device const float * yy = (device const float *) (src1 + offset1);
|
||||
|
||||
float yl[32];
|
||||
|
||||
|
@ -4175,9 +4184,12 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|||
const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
|
||||
sumf1[row] = simd_sum(sumf);
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
if (tiisg == 0) {
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = sumf1[row];
|
||||
dst_f32[first_row + row] = sumf1[row];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4185,9 +4197,9 @@ void kernel_mul_mv_q3_K_f32_impl(
|
|||
[[host_name("kernel_mul_mv_q3_K_f32")]]
|
||||
kernel void kernel_mul_mv_q3_K_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
@ -4198,9 +4210,9 @@ kernel void kernel_mul_mv_q3_K_f32(
|
|||
template<typename args_t>
|
||||
void kernel_mul_mv_q4_K_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4228,8 +4240,8 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0);
|
||||
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
||||
device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
float yl[16];
|
||||
float yh[16];
|
||||
|
@ -4290,10 +4302,12 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|||
y4 += 4 * QK_K;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum;
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4301,9 +4315,9 @@ void kernel_mul_mv_q4_K_f32_impl(
|
|||
[[host_name("kernel_mul_mv_q4_K_f32")]]
|
||||
kernel void kernel_mul_mv_q4_K_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
@ -4314,9 +4328,9 @@ kernel void kernel_mul_mv_q4_K_f32(
|
|||
template<typename args_t>
|
||||
void kernel_mul_mv_q5_K_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4336,8 +4350,8 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0);
|
||||
device const float * yy = (device const float *) ((device char *) src1 + offset1);
|
||||
device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0);
|
||||
device const float * yy = (device const float *) (src1 + offset1);
|
||||
|
||||
float sumf[2]={0.f};
|
||||
|
||||
|
@ -4420,10 +4434,12 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|||
y1 += 4 * QK_K;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
const float tot = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = tot;
|
||||
dst_f32[first_row + row] = tot;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4431,9 +4447,9 @@ void kernel_mul_mv_q5_K_f32_impl(
|
|||
[[host_name("kernel_mul_mv_q5_K_f32")]]
|
||||
kernel void kernel_mul_mv_q5_K_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
@ -4444,9 +4460,9 @@ kernel void kernel_mul_mv_q5_K_f32(
|
|||
template <typename args_t>
|
||||
void kernel_mul_mv_q6_K_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4471,8 +4487,8 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|||
const uint offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0);
|
||||
device const float * yy = (device const float *) ((device char *) src1 + offset1);
|
||||
device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0);
|
||||
device const float * yy = (device const float *) (src1 + offset1);
|
||||
|
||||
float sumf = 0;
|
||||
|
||||
|
@ -4511,18 +4527,20 @@ void kernel_mul_mv_q6_K_f32_impl(
|
|||
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
const float tot = simd_sum(sumf);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + row] = tot;
|
||||
dst_f32[row] = tot;
|
||||
}
|
||||
}
|
||||
|
||||
[[host_name("kernel_mul_mv_q6_K_f32")]]
|
||||
kernel void kernel_mul_mv_q6_K_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
@ -4535,9 +4553,9 @@ kernel void kernel_mul_mv_q6_K_f32(
|
|||
template<typename args_t>
|
||||
void kernel_mul_mv_iq2_xxs_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4556,8 +4574,8 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0);
|
||||
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
||||
device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
@ -4617,10 +4635,12 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|||
y4 += 32 * 32;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f;
|
||||
dst_f32[first_row + row] = all_sum * 0.25f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4628,9 +4648,9 @@ void kernel_mul_mv_iq2_xxs_f32_impl(
|
|||
[[host_name("kernel_mul_mv_iq2_xxs_f32")]]
|
||||
kernel void kernel_mul_mv_iq2_xxs_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
|
@ -4642,9 +4662,9 @@ kernel void kernel_mul_mv_iq2_xxs_f32(
|
|||
template<typename args_t>
|
||||
void kernel_mul_mv_iq2_xs_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4663,8 +4683,8 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0);
|
||||
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
||||
device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
@ -4734,10 +4754,12 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|||
y4 += 32 * 32;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f;
|
||||
dst_f32[first_row + row] = all_sum * 0.25f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4745,9 +4767,9 @@ void kernel_mul_mv_iq2_xs_f32_impl(
|
|||
[[host_name("kernel_mul_mv_iq2_xs_f32")]]
|
||||
kernel void kernel_mul_mv_iq2_xs_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
|
@ -4759,9 +4781,9 @@ kernel void kernel_mul_mv_iq2_xs_f32(
|
|||
template <typename args_t>
|
||||
void kernel_mul_mv_iq3_xxs_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4780,8 +4802,8 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0);
|
||||
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
||||
device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
@ -4844,10 +4866,12 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|||
y4 += 32 * 32;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.5f;
|
||||
dst_f32[first_row + row] = all_sum * 0.5f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4855,9 +4879,9 @@ void kernel_mul_mv_iq3_xxs_f32_impl(
|
|||
[[host_name("kernel_mul_mv_iq3_xxs_f32")]]
|
||||
kernel void kernel_mul_mv_iq3_xxs_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
|
@ -4869,9 +4893,9 @@ kernel void kernel_mul_mv_iq3_xxs_f32(
|
|||
template<typename args_t>
|
||||
void kernel_mul_mv_iq3_s_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -4890,8 +4914,8 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0);
|
||||
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
||||
device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
@ -4954,10 +4978,12 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|||
y4 += 32 * 32;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum;
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4965,9 +4991,9 @@ void kernel_mul_mv_iq3_s_f32_impl(
|
|||
[[host_name("kernel_mul_mv_iq3_s_f32")]]
|
||||
kernel void kernel_mul_mv_iq3_s_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
|
@ -4979,9 +5005,9 @@ kernel void kernel_mul_mv_iq3_s_f32(
|
|||
template <typename args_t>
|
||||
void kernel_mul_mv_iq2_s_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -5000,8 +5026,8 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0);
|
||||
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
||||
device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
@ -5065,10 +5091,12 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|||
y4 += 32 * 32;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum * 0.25f;
|
||||
dst_f32[first_row + row] = all_sum * 0.25f;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5076,9 +5104,9 @@ void kernel_mul_mv_iq2_s_f32_impl(
|
|||
[[host_name("kernel_mul_mv_iq2_s_f32")]]
|
||||
kernel void kernel_mul_mv_iq2_s_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
|
@ -5090,9 +5118,9 @@ kernel void kernel_mul_mv_iq2_s_f32(
|
|||
template<typename args_t>
|
||||
void kernel_mul_mv_iq1_s_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_value,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -5111,8 +5139,8 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0);
|
||||
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
||||
device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
@ -5163,10 +5191,12 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||
y4 += 32 * 32;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum;
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5174,9 +5204,9 @@ void kernel_mul_mv_iq1_s_f32_impl(
|
|||
template <typename args_t>
|
||||
void kernel_mul_mv_iq1_m_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_value,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -5195,8 +5225,8 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0);
|
||||
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
||||
device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
float yl[32];
|
||||
float sumf[N_DST]={0.f}, all_sum;
|
||||
|
@ -5256,10 +5286,12 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|||
y4 += 32 * 32;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < N_DST; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum;
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5267,9 +5299,9 @@ void kernel_mul_mv_iq1_m_f32_impl(
|
|||
template<typename args_t>
|
||||
void kernel_mul_mv_iq4_nl_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values_i8,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -5288,8 +5320,8 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0);
|
||||
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
||||
device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
const int ix = tiisg/2; // 0...15
|
||||
const int it = tiisg%2; // 0 or 1
|
||||
|
@ -5344,10 +5376,12 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|||
yb += 16 * QK4_NL;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum;
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5355,9 +5389,9 @@ void kernel_mul_mv_iq4_nl_f32_impl(
|
|||
template<typename args_t>
|
||||
void kernel_mul_mv_iq4_xs_f32_impl(
|
||||
args_t args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values_i8,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -5376,8 +5410,8 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|||
const uint offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03;
|
||||
const uint offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13;
|
||||
|
||||
device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0);
|
||||
device const float * y = (device const float *) ((device char *) src1 + offset1);
|
||||
device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0);
|
||||
device const float * y = (device const float *) (src1 + offset1);
|
||||
|
||||
const int ix = tiisg/16; // 0 or 1
|
||||
const int it = tiisg%16; // 0...15
|
||||
|
@ -5433,10 +5467,12 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|||
yb += 2 * QK_K;
|
||||
}
|
||||
|
||||
device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0;
|
||||
|
||||
for (int row = 0; row < 2; ++row) {
|
||||
all_sum = simd_sum(sumf[row]);
|
||||
if (tiisg == 0) {
|
||||
dst[r1*args.ne0 + im*args.ne0*args.ne1 + first_row + row] = all_sum;
|
||||
dst_f32[first_row + row] = all_sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -5444,9 +5480,9 @@ void kernel_mul_mv_iq4_xs_f32_impl(
|
|||
[[host_name("kernel_mul_mv_iq1_s_f32")]]
|
||||
kernel void kernel_mul_mv_iq1_s_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
@ -5457,9 +5493,9 @@ kernel void kernel_mul_mv_iq1_s_f32(
|
|||
[[host_name("kernel_mul_mv_iq1_m_f32")]]
|
||||
kernel void kernel_mul_mv_iq1_m_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||
|
@ -5470,9 +5506,9 @@ kernel void kernel_mul_mv_iq1_m_f32(
|
|||
[[host_name("kernel_mul_mv_iq4_nl_f32")]]
|
||||
kernel void kernel_mul_mv_iq4_nl_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
|
@ -5484,9 +5520,9 @@ kernel void kernel_mul_mv_iq4_nl_f32(
|
|||
[[host_name("kernel_mul_mv_iq4_xs_f32")]]
|
||||
kernel void kernel_mul_mv_iq4_xs_f32(
|
||||
constant ggml_metal_kargs_mul_mv & args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiisg[[thread_index_in_simdgroup]],
|
||||
|
@ -6033,17 +6069,17 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel
|
|||
|
||||
typedef void (kernel_mul_mv_impl_t)(
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig,
|
||||
uint tiisg);
|
||||
|
||||
typedef void (kernel_mul_mv2_impl_t)(
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
device const void * src0,
|
||||
device const float * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiisg,
|
||||
|
@ -6052,9 +6088,9 @@ typedef void (kernel_mul_mv2_impl_t)(
|
|||
template<kernel_mul_mv_impl_t impl_fn>
|
||||
void mmv_fn(
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiitg,
|
||||
|
@ -6066,15 +6102,15 @@ void mmv_fn(
|
|||
template<kernel_mul_mv2_impl_t impl_fn>
|
||||
void mmv_fn(
|
||||
ggml_metal_kargs_mul_mv args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
threadgroup int8_t * shared_values,
|
||||
uint3 tgpig,
|
||||
uint tiitg,
|
||||
uint tiisg,
|
||||
uint sgitg) {
|
||||
impl_fn(args, src0,(const device float *) src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||
impl_fn(args, src0, src1, dst, shared_values, tgpig, tiisg, sgitg);
|
||||
}
|
||||
|
||||
typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_kargs_mul_mv>>) mul_mv_impl_fn_t;
|
||||
|
@ -6082,10 +6118,10 @@ typedef decltype(mmv_fn<kernel_mul_mv_impl<half, half4, half, half4, ggml_metal_
|
|||
template<mul_mv_impl_fn_t impl_fn>
|
||||
kernel void kernel_mul_mv_id(
|
||||
constant ggml_metal_kargs_mul_mv_id & args,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device float * dst,
|
||||
device const char * ids,
|
||||
device const char * src0s,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
device const char * ids,
|
||||
threadgroup int8_t * shared_values [[threadgroup(0)]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
uint tiitg[[thread_index_in_threadgroup]],
|
||||
|
@ -6106,7 +6142,8 @@ kernel void kernel_mul_mv_id(
|
|||
|
||||
device const char * src0_cur = src0s + i02*args.nb02;
|
||||
device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12;
|
||||
device float * dst_cur = dst + i1*args.ne0 + i2*args.ne1*args.ne0;
|
||||
|
||||
device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float);
|
||||
|
||||
ggml_metal_kargs_mul_mv args0 = {
|
||||
/*.ne00 =*/ args.ne00,
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue