revert low register pressure changes
This commit is contained in:
parent
2bb97fca5e
commit
76a0128bec
1 changed files with 8 additions and 4 deletions
12
ggml-cuda.cu
12
ggml-cuda.cu
|
@ -5327,6 +5327,10 @@ static __global__ void mul_mat_vec_q(
|
||||||
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||||
|
|
||||||
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
|
const int row0 = rows_per_cuda_block*blockIdx.x;
|
||||||
|
const int blocks_per_row_x = ncols_x / qk;
|
||||||
|
const int blocks_per_col_y = nrows_y / QK8_1;
|
||||||
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
|
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
|
||||||
|
|
||||||
// partial sum for each thread
|
// partial sum for each thread
|
||||||
|
@ -5335,18 +5339,18 @@ static __global__ void mul_mat_vec_q(
|
||||||
const block_q_t * x = (const block_q_t *) vx;
|
const block_q_t * x = (const block_q_t *) vx;
|
||||||
const block_q8_1 * y = (const block_q8_1 *) vy;
|
const block_q8_1 * y = (const block_q8_1 *) vy;
|
||||||
|
|
||||||
for (int kbx = (WARP_SIZE*threadIdx.y + threadIdx.x) / (qi/vdr); kbx < (ncols_x / qk); kbx += blocks_per_iter) {
|
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
|
||||||
const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
|
const int kby = kbx * (qk/QK8_1); // y block index that aligns with kbx
|
||||||
|
|
||||||
// x block quant index when casting the quants to int
|
// x block quant index when casting the quants to int
|
||||||
const int kqs = vdr * ((WARP_SIZE*threadIdx.y + threadIdx.x) % (qi/vdr));
|
const int kqs = vdr * (tid % (qi/vdr));
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols_y; ++j) {
|
for (int j = 0; j < ncols_y; ++j) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
for (int i = 0; i < rows_per_cuda_block; ++i) {
|
||||||
tmp[j][i] += vec_dot_q_cuda(
|
tmp[j][i] += vec_dot_q_cuda(
|
||||||
&x[kbx + (rows_per_cuda_block*blockIdx.x + i)*(ncols_x / qk)], &y[j*(nrows_y / QK8_1) + kby], kqs);
|
&x[kbx + (row0 + i)*blocks_per_row_x], &y[j*blocks_per_col_y + kby], kqs);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -5379,7 +5383,7 @@ static __global__ void mul_mat_vec_q(
|
||||||
}
|
}
|
||||||
|
|
||||||
if (threadIdx.x < rows_per_cuda_block) {
|
if (threadIdx.x < rows_per_cuda_block) {
|
||||||
dst[j*nrows_dst + rows_per_cuda_block*blockIdx.x + threadIdx.x] = tmp[j][threadIdx.x];
|
dst[j*nrows_dst + row0 + threadIdx.x] = tmp[j][threadIdx.x];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue