diff --git a/ggml-cuda/mmq.cuh b/ggml-cuda/mmq.cuh index c454f3f0a..9233c70aa 100644 --- a/ggml-cuda/mmq.cuh +++ b/ggml-cuda/mmq.cuh @@ -194,7 +194,9 @@ static __device__ __forceinline__ void vec_dot_q4_0_q8_1_mma( float dA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -328,7 +330,9 @@ static __device__ __forceinline__ void vec_dot_q4_1_q8_1_mma( half2 dmA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -485,7 +489,9 @@ static __device__ __forceinline__ void vec_dot_q5_0_q8_1_mma( float dA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -635,7 +641,9 @@ static __device__ __forceinline__ void vec_dot_q5_1_q8_1_mma( half2 dmA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -762,7 +770,9 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma( float dA[mma_C::ne/2]; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int l = 0; l < mma_A::ne; ++l) { @@ -886,7 +896,9 @@ static __device__ __forceinline__ void vec_dot_q2_K_q8_1_mma( const half2 * y_ds = (const half2 *) y; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE mma_A A[2]; float dA[mma_C::ne/2][2]; @@ -1071,7 +1083,9 @@ static __device__ __forceinline__ void vec_dot_q3_K_q8_1_mma( const float * y_df = (const float *) y; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE mma_A A[2]; int scA[mma_C::ne/2][2]; @@ -1234,7 +1248,9 @@ static __device__ __forceinline__ void vec_dot_q4_K_q8_1_mma( const half2 * y_ds = (const half2 *) y; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE mma_A A[2]; int scA[mma_C::ne/2][2]; @@ -1419,7 +1435,9 @@ static __device__ __forceinline__ void vec_dot_q5_K_q8_1_mma( const half2 * y_ds = (const half2 *) y; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE mma_A A[2]; int scA[mma_C::ne/2][2]; @@ -1599,7 +1617,9 @@ static __device__ __forceinline__ void vec_dot_q6_K_q8_1_mma( const float * y_df = (const float *) y; const int i0 = threadIdx.y*mma_A::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_A::I == mmq_y, "nwarps*mma_A::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE mma_A A[4]; int scA[mma_C::ne/2][4]; @@ -1702,7 +1722,9 @@ static __device__ __forceinline__ void mmq_write_back_mma(const float * __restri typedef mma_int_C_I16J8 mma_C; const int i0 = threadIdx.y*mma_C::I; +#ifdef INT8_MMA_AVAILABLE static_assert(nwarps*mma_C::I == mmq_y, "nwarps*mma_C::I != mmq_y"); +#endif // INT8_MMA_AVAILABLE #pragma unroll for (int j0 = 0; j0 < mmq_x; j0 += mma_C::J) {