GGML_CUDA_MMQ_Y
This commit is contained in:
parent
abed446346
commit
3c09e11c97
1 changed files with 54 additions and 49 deletions
103
ggml-cuda.cu
103
ggml-cuda.cu
|
@ -262,6 +262,10 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
|
|||
#define CUDA_QUANTIZE_BLOCK_SIZE 256
|
||||
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
|
||||
|
||||
#ifndef GGML_CUDA_MMQ_Y
|
||||
#define GGML_CUDA_MMQ_Y 128
|
||||
#endif
|
||||
|
||||
// dmmv = dequantize_mul_mat_vec
|
||||
#ifndef GGML_CUDA_DMMV_X
|
||||
#define GGML_CUDA_DMMV_X 32
|
||||
|
@ -1393,8 +1397,8 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
|
|||
|
||||
static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||
|
||||
__shared__ int tile_x_qs[(2*WARP_SIZE) * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_d[(2*WARP_SIZE) * (WARP_SIZE/QI4_0)];
|
||||
__shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_0)];
|
||||
|
||||
*x_ql = tile_x_qs;
|
||||
*x_dm = tile_x_d;
|
||||
|
@ -1467,8 +1471,8 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
|
|||
|
||||
static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||
|
||||
__shared__ int tile_x_qs[(2*WARP_SIZE) * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE/QI4_1)];
|
||||
__shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_1)];
|
||||
|
||||
*x_ql = tile_x_qs;
|
||||
*x_dm = tile_x_dm;
|
||||
|
@ -1539,9 +1543,9 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
|
|||
|
||||
static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||
|
||||
__shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)];
|
||||
__shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE/QI5_0)];
|
||||
__shared__ half2 tile_x_d[(2*WARP_SIZE) * (WARP_SIZE/QI5_0)];
|
||||
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
||||
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0)];
|
||||
__shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_0)];
|
||||
|
||||
*x_ql = tile_x_ql;
|
||||
*x_qh = tile_x_qh;
|
||||
|
@ -1625,9 +1629,9 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
|
|||
|
||||
static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||
|
||||
__shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)];
|
||||
__shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE/QI5_1)];
|
||||
__shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE/QI5_1)];
|
||||
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
||||
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1)];
|
||||
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI5_1)];
|
||||
|
||||
*x_ql = tile_x_ql;
|
||||
*x_qh = tile_x_qh;
|
||||
|
@ -1688,8 +1692,8 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
|
|||
|
||||
static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||
|
||||
__shared__ int tile_x_qs[(2*WARP_SIZE) * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_d[(2*WARP_SIZE) * (WARP_SIZE/QI8_0)];
|
||||
__shared__ int tile_x_qs[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_d[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI8_0)];
|
||||
|
||||
*x_ql = tile_x_qs;
|
||||
*x_dm = tile_x_d;
|
||||
|
@ -1772,9 +1776,9 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
|
|||
|
||||
static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||
|
||||
__shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE / QI2_K)];
|
||||
__shared__ int tile_x_sc[(2*WARP_SIZE) * (WARP_SIZE / 4)];
|
||||
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE / QI2_K)];
|
||||
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE / 4)];
|
||||
|
||||
*x_ql = tile_x_ql;
|
||||
*x_dm = tile_x_dm;
|
||||
|
@ -1799,7 +1803,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat(
|
|||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
||||
|
||||
__builtin_assume(i < 2*WARP_SIZE);
|
||||
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
||||
__builtin_assume(j < WARP_SIZE);
|
||||
__builtin_assume(k < WARP_SIZE);
|
||||
|
||||
|
@ -1888,10 +1892,10 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
|
|||
|
||||
static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||
|
||||
__shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE / QI2_K)];
|
||||
__shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE / 2)];
|
||||
__shared__ int tile_x_sc[(2*WARP_SIZE) * (WARP_SIZE / 4)];
|
||||
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE / QI2_K)];
|
||||
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE / 2)];
|
||||
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE / 4)];
|
||||
|
||||
*x_ql = tile_x_ql;
|
||||
*x_dm = tile_x_dm;
|
||||
|
@ -2064,9 +2068,9 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
|
|||
|
||||
static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||
|
||||
__shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE/QI4_K)];
|
||||
__shared__ int tile_x_sc[(2*WARP_SIZE) * (3*WARP_SIZE/32)];
|
||||
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_K)];
|
||||
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (3*WARP_SIZE/32)];
|
||||
|
||||
*x_ql = tile_x_ql;
|
||||
*x_dm = tile_x_dm;
|
||||
|
@ -2091,7 +2095,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat(
|
|||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
||||
|
||||
__builtin_assume(i < 2*WARP_SIZE);
|
||||
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
||||
__builtin_assume(j < WARP_SIZE);
|
||||
__builtin_assume(k < WARP_SIZE);
|
||||
|
||||
|
@ -2256,10 +2260,10 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
|
|||
|
||||
static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||
|
||||
__shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE/QI4_K)];
|
||||
__shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE/4)];
|
||||
__shared__ int tile_x_sc[(2*WARP_SIZE) * (3*WARP_SIZE/32)];
|
||||
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI4_K)];
|
||||
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/4)];
|
||||
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (3*WARP_SIZE/32)];
|
||||
|
||||
*x_ql = tile_x_ql;
|
||||
*x_dm = tile_x_dm;
|
||||
|
@ -2383,10 +2387,10 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
|
|||
|
||||
static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
|
||||
|
||||
__shared__ int tile_x_ql[(2*WARP_SIZE) * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_dm[(2*WARP_SIZE) * (WARP_SIZE/QI6_K)];
|
||||
__shared__ int tile_x_qh[(2*WARP_SIZE) * (WARP_SIZE/2)];
|
||||
__shared__ int tile_x_sc[(2*WARP_SIZE) * (WARP_SIZE/8)];
|
||||
__shared__ int tile_x_ql[GGML_CUDA_MMQ_Y * (WARP_SIZE + 1)];
|
||||
__shared__ half2 tile_x_dm[GGML_CUDA_MMQ_Y * (WARP_SIZE/QI6_K)];
|
||||
__shared__ int tile_x_qh[GGML_CUDA_MMQ_Y * (WARP_SIZE/2)];
|
||||
__shared__ int tile_x_sc[GGML_CUDA_MMQ_Y * (WARP_SIZE/8)];
|
||||
|
||||
*x_ql = tile_x_ql;
|
||||
*x_dm = tile_x_dm;
|
||||
|
@ -2413,7 +2417,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat(
|
|||
const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc,
|
||||
const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) {
|
||||
|
||||
__builtin_assume(i < 2*WARP_SIZE);
|
||||
__builtin_assume(i < GGML_CUDA_MMQ_Y);
|
||||
__builtin_assume(j < WARP_SIZE);
|
||||
__builtin_assume(k < WARP_SIZE);
|
||||
|
||||
|
@ -2459,7 +2463,7 @@ static __global__ void mul_mat_q(
|
|||
const int tid_x = threadIdx.x;
|
||||
const int tid_y = threadIdx.y;
|
||||
|
||||
const int row_dst_0 = 2*blockIdx.x*WARP_SIZE;
|
||||
const int row_dst_0 = blockIdx.x*GGML_CUDA_MMQ_Y;
|
||||
const int & row_x_0 = row_dst_0;
|
||||
const int row_dst = row_dst_0 + tid_x;
|
||||
|
||||
|
@ -2478,11 +2482,11 @@ static __global__ void mul_mat_q(
|
|||
__shared__ int tile_y_qs[(WARP_SIZE) * (qr*WARP_SIZE)];
|
||||
__shared__ half2 tile_y_ds[(WARP_SIZE) * blocks_per_tile_y_col];
|
||||
|
||||
float sum[2][4] = {0.0f};
|
||||
float sum[GGML_CUDA_MMQ_Y/WARP_SIZE][4] = {0.0f};
|
||||
|
||||
for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) {
|
||||
|
||||
for (int i = 0; i < 2*WARP_SIZE; i += 8) {
|
||||
for (int i = 0; i < GGML_CUDA_MMQ_Y; i += 8) {
|
||||
load_tiles(x + row_x_0*blocks_per_row_x + ib0, tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc,
|
||||
i + tid_y, tid_x, blocks_per_row_x);
|
||||
}
|
||||
|
@ -2516,7 +2520,7 @@ static __global__ void mul_mat_q(
|
|||
#pragma unroll
|
||||
for (int j = 0; j < WARP_SIZE; j += 8) {
|
||||
#pragma unroll
|
||||
for (int i = 0; i < 2*WARP_SIZE; i += WARP_SIZE) {
|
||||
for (int i = 0; i < GGML_CUDA_MMQ_Y; i += WARP_SIZE) {
|
||||
sum[i/WARP_SIZE][j/8] += vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
|
||||
tid_x + i, tid_y + j, k);
|
||||
}
|
||||
|
@ -2538,8 +2542,9 @@ static __global__ void mul_mat_q(
|
|||
return;
|
||||
}
|
||||
|
||||
dst[col_dst*nrows_dst + row_dst] = sum[0][j/8];
|
||||
dst[col_dst*nrows_dst + row_dst + WARP_SIZE] = sum[1][j/8];
|
||||
for (int i = 0; i < GGML_CUDA_MMQ_Y; i += WARP_SIZE) {
|
||||
dst[col_dst*nrows_dst + row_dst + i] = sum[i/WARP_SIZE][j/8];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -3239,7 +3244,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda(
|
|||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
|
||||
const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
|
||||
|
@ -3251,7 +3256,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda(
|
|||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
|
||||
const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
|
||||
|
@ -3263,7 +3268,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda(
|
|||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
|
||||
const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
|
||||
|
@ -3275,7 +3280,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda(
|
|||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
|
||||
const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
|
||||
|
@ -3287,7 +3292,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda(
|
|||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
|
||||
const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
|
||||
|
@ -3299,7 +3304,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda(
|
|||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
|
||||
const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
|
||||
|
@ -3311,7 +3316,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda(
|
|||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
|
||||
const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
|
||||
|
@ -3323,7 +3328,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda(
|
|||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
|
||||
const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
|
||||
|
@ -3335,7 +3340,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
|
|||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
|
||||
const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
|
||||
|
@ -3347,7 +3352,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda(
|
|||
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
|
||||
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
|
||||
|
||||
const int block_num_x = (nrows_x + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
|
||||
const int block_num_x = (nrows_x + GGML_CUDA_MMQ_Y - 1) / GGML_CUDA_MMQ_Y;
|
||||
const int block_num_y = (ncols_y + WARP_SIZE - 1) / WARP_SIZE;
|
||||
const dim3 block_nums(block_num_x, block_num_y, 1);
|
||||
const dim3 block_dims(WARP_SIZE, WARP_SIZE/4, 1);
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue