faster q8_1 loading

This commit is contained in:
JohannesGaessler 2023-07-28 14:21:25 +02:00
parent b53e713883
commit a3505fac64

View file

@ -2462,8 +2462,10 @@ static __global__ void mul_mat_q(
allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc);
const int blocks_per_tile_y_col = qr*WARP_SIZE/QI8_1;
__shared__ int tile_y_qs[(WARP_SIZE) * (qr*WARP_SIZE)];
__shared__ half2 tile_y_ds[(WARP_SIZE) * (qr*WARP_SIZE/QI8_1)];
__shared__ half2 tile_y_ds[(WARP_SIZE) * blocks_per_tile_y_col];
float sum[2][4] = {0.0f};
@ -2474,8 +2476,6 @@ static __global__ void mul_mat_q(
i + tid_y, tid_x, blocks_per_row_x);
}
const int iqsy = sizeof(int) * (tid_x % QI8_1);
for (int ir = 0; ir < qr; ++ir) {
const int kqs = ir*WARP_SIZE + tid_x;
const int kby = kqs / QI8_1;
@ -2485,11 +2485,17 @@ static __global__ void mul_mat_q(
const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby];
tile_y_qs[(tid_y + i) * (qr*WARP_SIZE) + kqs] = *((int *) &by0->qs[iqsy]);
tile_y_ds[(tid_y + i) * (qr*WARP_SIZE/QI8_1) + kby] = by0->ds;
tile_y_qs[(tid_y + i) * (qr*WARP_SIZE) + kqs] = get_int_from_int8_aligned(by0->qs, tid_x % QI8_1);
}
}
for (int ids0 = 0; ids0 < WARP_SIZE; ids0 += 8 * (WARP_SIZE/blocks_per_tile_y_col)) {
const int ids = (ids0 + tid_y * (WARP_SIZE/blocks_per_tile_y_col) + tid_x / blocks_per_tile_y_col) % WARP_SIZE;
const int kby = tid_x % blocks_per_tile_y_col;
const int col_y_eff = min(col_y_0 + ids, ncols_y-1);
tile_y_ds[ids * (qr*WARP_SIZE/QI8_1) + kby] = y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kby].ds;
}
__syncthreads();
for (int k = 0; k < WARP_SIZE/vdr; ++k) {