diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index 9233c70aa..c9019242b 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -180,6 +180,7 @@ template static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; @@ -194,9 +195,7 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( float dA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -239,6 +238,10 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( sum[(j0/B.J)*C.ne + l] += dA[l/2]*__low2float(dsB[l%2])*C.x[l]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q4_1( @@ -317,6 +320,7 @@ template static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; @@ -330,9 +334,7 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( half2 dmA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -376,6 +378,10 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q5_0( @@ -475,6 +481,7 @@ template static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; @@ -489,9 +496,7 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( float dA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -533,6 +538,10 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( sum[(j0/B.J)*C.ne + l] += dA[l/2]*dB[l%2]*C.x[l]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q5_1( @@ -628,6 +637,7 @@ template static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; @@ -641,9 +651,7 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( half2 dmA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -686,6 +694,10 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( sum[(j0/B.J)*C.ne + l] += __low2float(dmA_dsB)*C.x[l] + __high2float(dmA_dsB); } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q8_0( @@ -756,6 +768,7 @@ template static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE GGML_UNUSED(x_sc); typedef mma_int_A_I16K8 mma_A; @@ -770,9 +783,7 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( float dA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -814,6 +825,10 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( sum[(j0/B.J)*C.ne + l] += C.x[l]*dA[l/2]*dB[l%2]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q2_K( @@ -887,6 +902,7 @@ template static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; typedef mma_int_B_J8K4 mma_B; @@ -896,9 +912,7 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const half2 * y_ds = (const half2 *) y; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE mma_A A[2]; float dA[mma_C::ne/2][2]; @@ -961,6 +975,10 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*dA[l/2][0] + C[1].x[l]*dA[l/2][1] + mA[l/2][0]*sB[l%2][0] + mA[l/2][1]*sB[l%2][1])*dB[l%2]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q3_K( @@ -1073,6 +1091,7 @@ template static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; typedef mma_int_B_J8K4 mma_B; @@ -1083,9 +1102,7 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( const float * y_df = (const float *) y; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE mma_A A[2]; int scA[mma_C::ne/2][2]; @@ -1150,6 +1167,10 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( sum[(j0/mma_B::J)*mma_C::ne + l] += (C[0].x[l]*scA[l/2][0] + C[1].x[l]*scA[l/2][1])*dA[l/2]*dB[l%2]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q4_K( @@ -1239,6 +1260,7 @@ template static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; @@ -1248,9 +1270,7 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( const half2 * y_ds = (const half2 *) y; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE mma_A A[2]; int scA[mma_C::ne/2][2]; @@ -1324,6 +1344,10 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q5_K( @@ -1426,6 +1450,7 @@ template static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K8 mma_A; typedef mma_int_B_J8K8 mma_B; @@ -1435,9 +1460,7 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( const half2 * y_ds = (const half2 *) y; const int i0 = threadIdx.y*mma_A::I; -#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); -#endif // INT8_MMA_AVAILABLE mma_A A[2]; int scA[mma_C::ne/2][2]; @@ -1511,6 +1534,10 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( sum[(j0/mma_B::J)*mma_C::ne + l] += __low2float(dmA[l/2])*tmpd[l] - __high2float(dmA[l/2])*tmpm[l]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template static __device__ __forceinline__ void load_tiles_q6_K( @@ -1607,6 +1634,7 @@ template static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const int * __restrict__ x_qs, const half2 * __restrict__ x_dm, const int * __restrict__ x_sc, const int * __restrict__ y, float * __restrict__ sum, const int & k0) { +#ifdef INT8_MMA_AVAILABLE typedef mma_int_A_I16K4 mma_A; typedef mma_int_B_J8K4 mma_B; @@ -1692,6 +1720,10 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( sum[(j0/mma_B::J)*mma_C::ne + l] += tmp[l]*dA[l/2]; } } +#else + GGML_UNUSED(x_qs); GGML_UNUSED(x_dm); GGML_UNUSED(x_sc); GGML_UNUSED(y); GGML_UNUSED(sum); GGML_UNUSED(k0); + NO_DEVICE_CODE; +#endif // INT8_MMA_AVAILABLE } template