loop unrolling
This commit is contained in:
parent
a3505fac64
commit
6808800c17
1 changed files with 9 additions and 4 deletions
13
ggml-cuda.cu
13
ggml-cuda.cu
|
@ -2498,12 +2498,17 @@ static __global__ void mul_mat_q(
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= 700 // TODO: actually test this with compute capability 7.X cards
|
||||||
|
#pragma unroll
|
||||||
|
#endif // __CUDA_ARCH__ >= 700
|
||||||
for (int k = 0; k < WARP_SIZE/vdr; ++k) {
|
for (int k = 0; k < WARP_SIZE/vdr; ++k) {
|
||||||
|
#pragma unroll
|
||||||
for (int j = 0; j < WARP_SIZE; j += 8) {
|
for (int j = 0; j < WARP_SIZE; j += 8) {
|
||||||
sum[0][j/8] += vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
|
#pragma unroll
|
||||||
tid_x, tid_y + j, k);
|
for (int i = 0; i < 2*WARP_SIZE; i += WARP_SIZE) {
|
||||||
sum[1][j/8] += vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
|
sum[i/WARP_SIZE][j/8] += vec_dot(tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds,
|
||||||
tid_x + WARP_SIZE, tid_y + j, k);
|
tid_x + i, tid_y + j, k);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue