CUDA: refactor mmq, dmmv, mmvq (#7716)
* CUDA: refactor mmq, dmmv, mmvq * fix out-of-bounds write * struct for qk, qr, qi * fix cmake build * mmq_type_traits
This commit is contained in:
		
							parent
							
								
									2b3389677a
								
							
						
					
					
						commit
						7d1a378b8f
					
				
					 112 changed files with 1783 additions and 1767 deletions
				
			
		|  | @ -1,9 +1,47 @@ | |||
| #include "mmvq.cuh" | ||||
| #include "vecdotq.cuh" | ||||
| 
 | ||||
| typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs); | ||||
| typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & kbx, const int & iqs); | ||||
| 
 | ||||
| template <int ncols_y, int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda> | ||||
| static constexpr __device__ vec_dot_q_cuda_t get_vec_dot_q_cuda(ggml_type type) { | ||||
|     return type == GGML_TYPE_Q4_0 ? vec_dot_q4_0_q8_1 : | ||||
|         type == GGML_TYPE_Q4_1 ? vec_dot_q4_1_q8_1 : | ||||
|         type == GGML_TYPE_Q5_0 ? vec_dot_q5_0_q8_1 : | ||||
|         type == GGML_TYPE_Q5_1 ? vec_dot_q5_1_q8_1 : | ||||
|         type == GGML_TYPE_Q8_0 ? vec_dot_q8_0_q8_1 : | ||||
|         type == GGML_TYPE_Q2_K ? vec_dot_q2_K_q8_1 : | ||||
|         type == GGML_TYPE_Q3_K ? vec_dot_q3_K_q8_1 : | ||||
|         type == GGML_TYPE_Q4_K ? vec_dot_q4_K_q8_1 : | ||||
|         type == GGML_TYPE_Q5_K ? vec_dot_q5_K_q8_1 : | ||||
|         type == GGML_TYPE_Q6_K ? vec_dot_q6_K_q8_1 : | ||||
|         type == GGML_TYPE_IQ2_XXS ? vec_dot_iq2_xxs_q8_1 : | ||||
|         type == GGML_TYPE_IQ2_XS ? vec_dot_iq2_xs_q8_1 : | ||||
|         type == GGML_TYPE_IQ2_S ? vec_dot_iq2_s_q8_1 : | ||||
|         type == GGML_TYPE_IQ3_XXS ? vec_dot_iq3_xxs_q8_1 : | ||||
|         type == GGML_TYPE_IQ1_S ? vec_dot_iq1_s_q8_1 : | ||||
|         type == GGML_TYPE_IQ1_M ? vec_dot_iq1_m_q8_1 : | ||||
|         type == GGML_TYPE_IQ4_NL ? vec_dot_iq4_nl_q8_1 : | ||||
|         type == GGML_TYPE_IQ4_XS ? vec_dot_iq4_xs_q8_1 : | ||||
|         type == GGML_TYPE_IQ3_S ? vec_dot_iq3_s_q8_1 : | ||||
|         nullptr; | ||||
| } | ||||
| 
 | ||||
| static constexpr __device__ int get_vdr_mmvq(ggml_type type) { | ||||
|     return type == GGML_TYPE_Q4_0 ? VDR_Q4_0_Q8_1_MMVQ : | ||||
|         type == GGML_TYPE_Q4_1 ? VDR_Q4_1_Q8_1_MMVQ : | ||||
|         type == GGML_TYPE_Q5_0 ? VDR_Q5_0_Q8_1_MMVQ : | ||||
|         type == GGML_TYPE_Q5_1 ? VDR_Q5_1_Q8_1_MMVQ : | ||||
|         type == GGML_TYPE_Q8_0 ? VDR_Q8_0_Q8_1_MMVQ : | ||||
|         type == GGML_TYPE_Q2_K ? VDR_Q2_K_Q8_1_MMVQ : | ||||
|         type == GGML_TYPE_Q3_K ? VDR_Q3_K_Q8_1_MMVQ : | ||||
|         type == GGML_TYPE_Q4_K ? VDR_Q4_K_Q8_1_MMVQ : | ||||
|         type == GGML_TYPE_Q5_K ? VDR_Q5_K_Q8_1_MMVQ : | ||||
|         type == GGML_TYPE_Q6_K ? VDR_Q6_K_Q8_1_MMVQ : | ||||
|         type == GGML_TYPE_IQ4_NL ? VDR_Q4_K_Q8_1_MMVQ : | ||||
|         1; | ||||
| } | ||||
| 
 | ||||
| template <ggml_type type, int ncols_y> | ||||
| #if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) | ||||
| // tell the compiler to use as many registers as it wants, see nwarps definition below | ||||
| __launch_bounds__((ncols_y <= 4 ? 4 : 2)*WARP_SIZE, 1) | ||||
|  | @ -12,6 +50,12 @@ static __global__ void mul_mat_vec_q( | |||
|     const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) { | ||||
| 
 | ||||
|     constexpr int qk  = ggml_cuda_type_traits<type>::qk; | ||||
|     constexpr int qi  = ggml_cuda_type_traits<type>::qi; | ||||
|     constexpr int vdr = get_vdr_mmvq(type); | ||||
| 
 | ||||
|     constexpr vec_dot_q_cuda_t vec_dot_q_cuda = get_vec_dot_q_cuda(type); | ||||
| 
 | ||||
| #if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3)) | ||||
|     constexpr int nwarps              = 1; | ||||
|     constexpr int rows_per_cuda_block = 1; | ||||
|  | @ -29,7 +73,6 @@ static __global__ void mul_mat_vec_q( | |||
| // partial sum for each thread | ||||
|     float tmp[ncols_y][rows_per_cuda_block] = {0.0f}; | ||||
| 
 | ||||
|     const block_q_t  * x = (const block_q_t  *) vx; | ||||
|     const block_q8_1 * y = (const block_q8_1 *) vy; | ||||
| 
 | ||||
|     for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { | ||||
|  | @ -42,8 +85,7 @@ static __global__ void mul_mat_vec_q( | |||
|         for (int j = 0; j < ncols_y; ++j) { | ||||
| #pragma unroll | ||||
|             for (int i = 0; i < rows_per_cuda_block; ++i) { | ||||
|                 tmp[j][i] += vec_dot_q_cuda( | ||||
|                     &x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs); | ||||
|                 tmp[j][i] += vec_dot_q_cuda(vx, &y[j*blocks_per_col_y + kby], (row0 + i)*blocks_per_row_x + kbx, kqs); | ||||
|             } | ||||
|         } | ||||
|     } | ||||
|  | @ -81,12 +123,12 @@ static __global__ void mul_mat_vec_q( | |||
|     } | ||||
| } | ||||
| 
 | ||||
| template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot> | ||||
| template <ggml_type type> | ||||
| static void mul_mat_vec_q_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     GGML_ASSERT(ncols_x % qk == 0); | ||||
|     GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); | ||||
|     GGML_ASSERT(ncols_y <= MMVQ_MAX_BATCH_SIZE); | ||||
| 
 | ||||
|     int id = ggml_cuda_get_device(); | ||||
|  | @ -124,36 +166,28 @@ static void mul_mat_vec_q_cuda( | |||
| 
 | ||||
|     switch (ncols_y) { | ||||
|         case 1: | ||||
|             mul_mat_vec_q<1, qk, qi, block_q_t, vdr, vec_dot> | ||||
|                 <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             mul_mat_vec_q<type, 1><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             break; | ||||
|         case 2: | ||||
|             mul_mat_vec_q<2, qk, qi, block_q_t, vdr, vec_dot> | ||||
|                 <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             mul_mat_vec_q<type, 2><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             break; | ||||
|         case 3: | ||||
|             mul_mat_vec_q<3, qk, qi, block_q_t, vdr, vec_dot> | ||||
|                 <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             mul_mat_vec_q<type, 3><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             break; | ||||
|         case 4: | ||||
|             mul_mat_vec_q<4, qk, qi, block_q_t, vdr, vec_dot> | ||||
|                 <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             mul_mat_vec_q<type, 4><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             break; | ||||
|         case 5: | ||||
|             mul_mat_vec_q<5, qk, qi, block_q_t, vdr, vec_dot> | ||||
|                 <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             mul_mat_vec_q<type, 5><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             break; | ||||
|         case 6: | ||||
|             mul_mat_vec_q<6, qk, qi, block_q_t, vdr, vec_dot> | ||||
|                 <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             mul_mat_vec_q<type, 6><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             break; | ||||
|         case 7: | ||||
|             mul_mat_vec_q<7, qk, qi, block_q_t, vdr, vec_dot> | ||||
|                 <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             mul_mat_vec_q<type, 7><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             break; | ||||
|         case 8: | ||||
|             mul_mat_vec_q<8, qk, qi, block_q_t, vdr, vec_dot> | ||||
|                 <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             mul_mat_vec_q<type, 8><<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols_x, nrows_x, nrows_y, nrows_dst); | ||||
|             break; | ||||
|         default: | ||||
|             GGML_ASSERT(false); | ||||
|  | @ -165,152 +199,133 @@ static void mul_mat_vec_q4_0_q8_1_cuda( | |||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK4_0, QI4_0, block_q4_0, VDR_Q4_0_Q8_1_MMVQ, vec_dot_q4_0_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_Q4_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_q4_1_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK4_1, QI4_1, block_q4_1, VDR_Q4_1_Q8_1_MMVQ, vec_dot_q4_1_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_Q4_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_q5_0_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK5_0, QI5_0, block_q5_0, VDR_Q5_0_Q8_1_MMVQ, vec_dot_q5_0_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_Q5_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_q5_1_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK5_1, QI5_1, block_q5_1, VDR_Q5_1_Q8_1_MMVQ, vec_dot_q5_1_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_Q5_1>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_q8_0_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK8_0, QI8_0, block_q8_0, VDR_Q8_0_Q8_1_MMVQ, vec_dot_q8_0_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_Q8_0>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_q2_K_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK_K, QI2_K, block_q2_K, VDR_Q2_K_Q8_1_MMVQ, vec_dot_q2_K_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_Q2_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_q3_K_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK_K, QI3_K, block_q3_K, VDR_Q3_K_Q8_1_MMVQ, vec_dot_q3_K_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_Q3_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_q4_K_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK_K, QI4_K, block_q4_K, VDR_Q4_K_Q8_1_MMVQ, vec_dot_q4_K_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_Q4_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_q5_K_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK_K, QI5_K, block_q5_K, VDR_Q5_K_Q8_1_MMVQ, vec_dot_q5_K_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_Q5_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_q6_K_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK_K, QI6_K, block_q6_K, VDR_Q6_K_Q8_1_MMVQ, vec_dot_q6_K_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_Q6_K>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_iq2_xxs_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK_K, QI2_XXS, block_iq2_xxs, 1, vec_dot_iq2_xxs_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_iq2_xs_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK_K, QI2_XS, block_iq2_xs, 1, vec_dot_iq2_xs_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_IQ2_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_iq2_s_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK_K, QI2_S, block_iq2_s, 1, vec_dot_iq2_s_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_IQ2_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_iq3_xxs_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK_K, QI3_XXS, block_iq3_xxs, 1, vec_dot_iq3_xxs_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_IQ3_XXS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_iq1_s_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_s, 1, vec_dot_iq1_s_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_IQ1_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_iq1_m_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK_K, QI1_S, block_iq1_m, 1, vec_dot_iq1_m_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_IQ1_M>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_iq4_nl_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK4_NL, QI4_NL, block_iq4_nl, VDR_Q4_0_Q8_1_MMVQ, vec_dot_iq4_nl_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_IQ4_NL>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_iq4_xs_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK_K, QI4_XS, block_iq4_xs, 1, vec_dot_iq4_xs_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_IQ4_XS>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| static void mul_mat_vec_iq3_s_q8_1_cuda( | ||||
|     const void * vx, const void * vy, float * dst, | ||||
|     const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int nrows_dst, cudaStream_t stream) { | ||||
| 
 | ||||
|     mul_mat_vec_q_cuda<QK_K, QI3_XS, block_iq3_s, 1, vec_dot_iq3_s_q8_1> | ||||
|         (vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
|     mul_mat_vec_q_cuda<GGML_TYPE_IQ3_S>(vx, vy, dst, ncols_x, nrows_x, nrows_y, ncols_y, nrows_dst, stream); | ||||
| } | ||||
| 
 | ||||
| void ggml_cuda_op_mul_mat_vec_q( | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue