CUDA: refactor mmq, dmmv, mmvq (#7716)
* CUDA: refactor mmq, dmmv, mmvq * fix out-of-bounds write * struct for qk, qr, qi * fix cmake build * mmq_type_traits
This commit is contained in:
parent
2b3389677a
commit
7d1a378b8f
112 changed files with 1783 additions and 1767 deletions
|
@ -566,9 +566,9 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
|
||||
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq + kbx;
|
||||
|
||||
int v[VDR_Q4_0_Q8_1_MMVQ];
|
||||
int u[2*VDR_Q4_0_Q8_1_MMVQ];
|
||||
|
@ -585,9 +585,9 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
|
|||
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
|
||||
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq + kbx;
|
||||
|
||||
int v[VDR_Q4_1_Q8_1_MMVQ];
|
||||
int u[2*VDR_Q4_1_Q8_1_MMVQ];
|
||||
|
@ -603,9 +603,9 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
|
||||
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq + kbx;
|
||||
|
||||
int vl[VDR_Q5_0_Q8_1_MMVQ];
|
||||
int vh[VDR_Q5_0_Q8_1_MMVQ];
|
||||
|
@ -623,9 +623,9 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
|
||||
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq + kbx;
|
||||
|
||||
int vl[VDR_Q5_1_Q8_1_MMVQ];
|
||||
int vh[VDR_Q5_1_Q8_1_MMVQ];
|
||||
|
@ -643,9 +643,9 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
|
||||
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq + kbx;
|
||||
|
||||
int v[VDR_Q8_0_Q8_1_MMVQ];
|
||||
int u[VDR_Q8_0_Q8_1_MMVQ];
|
||||
|
@ -660,9 +660,9 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
const block_q2_K * bq2_K = (const block_q2_K *) vbq;
|
||||
const block_q2_K * bq2_K = (const block_q2_K *) vbq + kbx;
|
||||
|
||||
const int bq8_offset = QR2_K * (iqs / QI8_1);
|
||||
const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
|
||||
|
@ -683,9 +683,9 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
const block_q3_K * bq3_K = (const block_q3_K *) vbq;
|
||||
const block_q3_K * bq3_K = (const block_q3_K *) vbq + kbx;
|
||||
|
||||
const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
|
||||
const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
|
||||
|
@ -710,9 +710,9 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
|
||||
const block_q4_K * bq4_K = (const block_q4_K *) vbq + kbx;
|
||||
|
||||
int v[2];
|
||||
int u[2*QR4_K];
|
||||
|
@ -756,9 +756,9 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
const block_q5_K * bq5_K = (const block_q5_K *) vbq;
|
||||
const block_q5_K * bq5_K = (const block_q5_K *) vbq + kbx;
|
||||
|
||||
int vl[2];
|
||||
int vh[2];
|
||||
|
@ -802,9 +802,9 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
const block_q6_K * bq6_K = (const block_q6_K *) vbq;
|
||||
const block_q6_K * bq6_K = (const block_q6_K *) vbq + kbx;
|
||||
|
||||
const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
|
||||
const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
|
||||
|
@ -828,8 +828,8 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq;
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
const block_iq2_xxs * bq2 = (const block_iq2_xxs *) vbq + kbx;
|
||||
|
||||
#if QR2_XXS == 8
|
||||
const int ib32 = iqs;
|
||||
|
@ -872,9 +872,9 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq;
|
||||
const block_iq2_xs * bq2 = (const block_iq2_xs *) vbq + kbx;
|
||||
|
||||
const int ib32 = iqs;
|
||||
const uint16_t * q2 = bq2->qs + 4*ib32;
|
||||
|
@ -911,9 +911,9 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1(
|
|||
|
||||
// TODO
|
||||
static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
const block_iq2_s * bq2 = (const block_iq2_s *) vbq;
|
||||
const block_iq2_s * bq2 = (const block_iq2_s *) vbq + kbx;
|
||||
|
||||
const int ib32 = iqs;
|
||||
const int8_t * q8 = bq8_1[ib32].qs;
|
||||
|
@ -951,9 +951,9 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq;
|
||||
const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq + kbx;
|
||||
|
||||
const int ib32 = iqs;
|
||||
const uint8_t * q3 = bq2->qs + 8*ib32;
|
||||
|
@ -981,9 +981,9 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1(
|
|||
|
||||
// TODO: don't use lookup table for signs
|
||||
static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
const block_iq3_s * bq2 = (const block_iq3_s *) vbq;
|
||||
const block_iq3_s * bq2 = (const block_iq3_s *) vbq + kbx;
|
||||
|
||||
const int ib32 = iqs;
|
||||
const uint8_t * qs = bq2->qs + 8*ib32;
|
||||
|
@ -1008,8 +1008,8 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const block_iq1_s * bq1 = (const block_iq1_s *) vbq;
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
const block_iq1_s * bq1 = (const block_iq1_s *) vbq + kbx;
|
||||
|
||||
const int ib32 = iqs;
|
||||
int sumi = 0;
|
||||
|
@ -1039,8 +1039,8 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq1_m_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const block_iq1_m * bq1 = (const block_iq1_m *) vbq;
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
const block_iq1_m * bq1 = (const block_iq1_m *) vbq + kbx;
|
||||
|
||||
const int ib32 = iqs;
|
||||
int sumi[2] = {0, 0};
|
||||
|
@ -1094,9 +1094,9 @@ static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4
|
|||
#endif
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
const block_iq4_nl * bq = (const block_iq4_nl *) vbq;
|
||||
const block_iq4_nl * bq = (const block_iq4_nl *) vbq + kbx;
|
||||
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
const uint16_t * q4 = (const uint16_t *)bq->qs + 2*iqs;
|
||||
|
@ -1128,10 +1128,10 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1(
|
|||
}
|
||||
|
||||
static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) {
|
||||
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs) {
|
||||
|
||||
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
|
||||
const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq;
|
||||
const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq + kbx;
|
||||
const uint8_t * values = (const uint8_t *)kvalues_iq4nl;
|
||||
|
||||
// iqs is 0...7
|
||||
|
@ -1149,6 +1149,6 @@ static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1(
|
|||
}
|
||||
return d * (sumi1 + sumi2);
|
||||
#else
|
||||
return vec_dot_iq4_xs_q8_1(vbq, bq8_1, iqs);
|
||||
return vec_dot_iq4_xs_q8_1(vbq, bq8_1, kbx, iqs);
|
||||
#endif
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue