WAKE ME UP

This commit is contained in:
JohannesGaessler 2023-05-11 22:47:38 +02:00
parent 8a9d7ce624
commit 4b12881329
2 changed files with 24 additions and 15 deletions

View file

@ -225,8 +225,9 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
}
}
template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, const int ncols) {
template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const void * vx, const void * vy, float * dst, const int ncols) {
const block_q4_0 * x = (const block_q4_0 *) vx;
const block_q8_0 * y = (const block_q8_0 *) vy;
const int row = blockIdx.x;
const int tid = threadIdx.x;
@ -238,21 +239,25 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
const int col = i*block_size + 2*tid;
// dequantize
const float d = x[(row*ncols + col)/QK4_0].d;
const float d0 = x[(row*ncols + col)/QK4_0].d;
const float d1 = y[col/QK8_0].d;
const uint8_t * pp = x[(row*ncols + col)/QK4_0].qs;
const uint8_t * p0 = x[(row*ncols + col)/QK4_0].qs;
const int8_t * p1 = y[col/QK8_0].qs;
const uint8_t vui = pp[((row*ncols + col)%QK4_0)/2];
const uint8_t vui0 = p0[((row*ncols + col)%QK4_0)/2];
const int vi10 = p1[(col + 0)%QK8_0];
const int vi11 = p1[(col + 1)%QK8_0];
const int8_t vi0 = vui & 0xF;
const int8_t vi1 = vui >> 4;
const int vi00 = vui0 & 0xF;
const int vi01 = vui0 >> 4;
const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;
const float v0 = (vi00 - 8)*vi10*d0*d1;
const float v1 = (vi01 - 8)*vi11*d0*d1;
// matrix multiplication
tmp[tid] += v0 * y[col + 0];
tmp[tid] += v1 * y[col + 1];
tmp[tid] += v0;
tmp[tid] += v1;
}
// sum up partial sums and write back result
@ -297,7 +302,7 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
}
static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
static void dequantize_mul_mat_q4_0_cuda(const void * vx, const void * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
// static int block_size = -1;
// if (block_size == -1) {
// int min_grid_size, max_block_size = 1;
@ -634,7 +639,7 @@ static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor *
ggml_cuda_pool_free(d_D, d_size);
}
static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata) {
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2];
@ -696,7 +701,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
// compute
dequantize_mul_mat_q4_0_cuda(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
dequantize_mul_mat_q4_0_cuda(c_Q, wdata + i * QK8_0, c_D, ne00, ne01, cudaStream);
CUDA_CHECK(cudaGetLastError());
} else {
@ -781,11 +786,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
}
else {
ggml_cuda_mul_mat_q_f32(src0, src1, dst);
ggml_cuda_mul_mat_q_f32(src0, src1, dst, wdata);
}
}
else if (ggml_is_quantized(src0->type)) {
ggml_cuda_mul_mat_q_f32(src0, src1, dst);
ggml_cuda_mul_mat_q_f32(src0, src1, dst, wdata);
}
else {
GGML_ASSERT(false);

4
ggml.c
View file

@ -8811,6 +8811,10 @@ static void ggml_compute_forward_mul_mat_q_f32(
#if defined(GGML_USE_CUBLAS)
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
if (ne11 == 1 && ne12 == 1 && ne13 == 1) {
char * wdata = params->wdata;
quantize_row_q_dot((float *)((char *) src1->data), (void *) wdata, ne10);
}
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
}
return;