More int mult, less float mult, worse performance
This commit is contained in:
parent
d882d1c2fe
commit
e7b9d97bae
1 changed files with 12 additions and 10 deletions
22
ggml-cuda.cu
22
ggml-cuda.cu
|
@ -235,8 +235,8 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
|
||||||
__shared__ float tmp[block_size]; // separate sum for each thread
|
__shared__ float tmp[block_size]; // separate sum for each thread
|
||||||
tmp[tid] = 0;
|
tmp[tid] = 0;
|
||||||
|
|
||||||
for (int i = 0; i < ncols/block_size; i += 2) {
|
for (int i = 0; i < ncols/block_size; i += 4) {
|
||||||
const int col = i*block_size + 2*tid;
|
const int col = i*block_size + 4*tid;
|
||||||
|
|
||||||
// dequantize
|
// dequantize
|
||||||
const float d0 = x[(row*ncols + col)/QK4_0].d;
|
const float d0 = x[(row*ncols + col)/QK4_0].d;
|
||||||
|
@ -245,19 +245,21 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
|
||||||
const uint8_t * p0 = x[(row*ncols + col)/QK4_0].qs;
|
const uint8_t * p0 = x[(row*ncols + col)/QK4_0].qs;
|
||||||
const int8_t * p1 = y[col/QK8_0].qs;
|
const int8_t * p1 = y[col/QK8_0].qs;
|
||||||
|
|
||||||
const uint8_t vui0 = p0[((row*ncols + col)%QK4_0)/2];
|
const uint8_t vui00 = p0[((row*ncols + col)%QK4_0)/2];
|
||||||
|
const uint8_t vui01 = p0[((row*ncols + col + 2)%QK4_0)/2];
|
||||||
const int vi10 = p1[(col + 0)%QK8_0];
|
const int vi10 = p1[(col + 0)%QK8_0];
|
||||||
const int vi11 = p1[(col + 1)%QK8_0];
|
const int vi11 = p1[(col + 1)%QK8_0];
|
||||||
|
const int vi12 = p1[(col + 2)%QK8_0];
|
||||||
|
const int vi13 = p1[(col + 3)%QK8_0];
|
||||||
|
|
||||||
const int vi00 = vui0 & 0xF;
|
const int vi00 = vui00 & 0xF;
|
||||||
const int vi01 = vui0 >> 4;
|
const int vi01 = vui00 >> 4;
|
||||||
|
const int vi02 = vui01 & 0xF;
|
||||||
const float v0 = (vi00 - 8)*vi10*d0*d1;
|
const int vi03 = vui01 >> 4;
|
||||||
const float v1 = (vi01 - 8)*vi11*d0*d1;
|
|
||||||
|
|
||||||
// matrix multiplication
|
// matrix multiplication
|
||||||
tmp[tid] += v0;
|
const int sumi = (vi00 - 8)*vi10 + (vi01 - 8)*vi11 + (vi02 - 8)*vi12 + (vi03 - 8)*vi13;
|
||||||
tmp[tid] += v1;
|
tmp[tid] += sumi*d0*d1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// sum up partial sums and write back result
|
// sum up partial sums and write back result
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue