WAKE ME UP
This commit is contained in:
parent
8a9d7ce624
commit
4b12881329
2 changed files with 24 additions and 15 deletions
35
ggml-cuda.cu
35
ggml-cuda.cu
|
@ -225,8 +225,9 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
|
|||
}
|
||||
}
|
||||
|
||||
template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const void * vx, const float * y, float * dst, const int ncols) {
|
||||
template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const void * vx, const void * vy, float * dst, const int ncols) {
|
||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||
const block_q8_0 * y = (const block_q8_0 *) vy;
|
||||
|
||||
const int row = blockIdx.x;
|
||||
const int tid = threadIdx.x;
|
||||
|
@ -238,21 +239,25 @@ template <int block_size> static __global__ void dequantize_mul_mat_q4_0(const v
|
|||
const int col = i*block_size + 2*tid;
|
||||
|
||||
// dequantize
|
||||
const float d = x[(row*ncols + col)/QK4_0].d;
|
||||
const float d0 = x[(row*ncols + col)/QK4_0].d;
|
||||
const float d1 = y[col/QK8_0].d;
|
||||
|
||||
const uint8_t * pp = x[(row*ncols + col)/QK4_0].qs;
|
||||
const uint8_t * p0 = x[(row*ncols + col)/QK4_0].qs;
|
||||
const int8_t * p1 = y[col/QK8_0].qs;
|
||||
|
||||
const uint8_t vui = pp[((row*ncols + col)%QK4_0)/2];
|
||||
const uint8_t vui0 = p0[((row*ncols + col)%QK4_0)/2];
|
||||
const int vi10 = p1[(col + 0)%QK8_0];
|
||||
const int vi11 = p1[(col + 1)%QK8_0];
|
||||
|
||||
const int8_t vi0 = vui & 0xF;
|
||||
const int8_t vi1 = vui >> 4;
|
||||
const int vi00 = vui0 & 0xF;
|
||||
const int vi01 = vui0 >> 4;
|
||||
|
||||
const float v0 = (vi0 - 8)*d;
|
||||
const float v1 = (vi1 - 8)*d;
|
||||
const float v0 = (vi00 - 8)*vi10*d0*d1;
|
||||
const float v1 = (vi01 - 8)*vi11*d0*d1;
|
||||
|
||||
// matrix multiplication
|
||||
tmp[tid] += v0 * y[col + 0];
|
||||
tmp[tid] += v1 * y[col + 1];
|
||||
tmp[tid] += v0;
|
||||
tmp[tid] += v1;
|
||||
}
|
||||
|
||||
// sum up partial sums and write back result
|
||||
|
@ -297,7 +302,7 @@ static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStre
|
|||
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
|
||||
}
|
||||
|
||||
static void dequantize_mul_mat_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
static void dequantize_mul_mat_q4_0_cuda(const void * vx, const void * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||
// static int block_size = -1;
|
||||
// if (block_size == -1) {
|
||||
// int min_grid_size, max_block_size = 1;
|
||||
|
@ -634,7 +639,7 @@ static void ggml_cuda_mul_mat_f16(const ggml_tensor * src0, const ggml_tensor *
|
|||
ggml_cuda_pool_free(d_D, d_size);
|
||||
}
|
||||
|
||||
static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, void * wdata) {
|
||||
const int64_t ne00 = src0->ne[0];
|
||||
const int64_t ne01 = src0->ne[1];
|
||||
const int64_t ne02 = src0->ne[2];
|
||||
|
@ -696,7 +701,7 @@ static void ggml_cuda_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor
|
|||
CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0));
|
||||
|
||||
// compute
|
||||
dequantize_mul_mat_q4_0_cuda(c_Q, c_Y, c_D, ne00, ne01, cudaStream);
|
||||
dequantize_mul_mat_q4_0_cuda(c_Q, wdata + i * QK8_0, c_D, ne00, ne01, cudaStream);
|
||||
CUDA_CHECK(cudaGetLastError());
|
||||
|
||||
} else {
|
||||
|
@ -781,11 +786,11 @@ void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_
|
|||
ggml_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize);
|
||||
}
|
||||
else {
|
||||
ggml_cuda_mul_mat_q_f32(src0, src1, dst);
|
||||
ggml_cuda_mul_mat_q_f32(src0, src1, dst, wdata);
|
||||
}
|
||||
}
|
||||
else if (ggml_is_quantized(src0->type)) {
|
||||
ggml_cuda_mul_mat_q_f32(src0, src1, dst);
|
||||
ggml_cuda_mul_mat_q_f32(src0, src1, dst, wdata);
|
||||
}
|
||||
else {
|
||||
GGML_ASSERT(false);
|
||||
|
|
4
ggml.c
4
ggml.c
|
@ -8811,6 +8811,10 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
|||
#if defined(GGML_USE_CUBLAS)
|
||||
if (ggml_cuda_can_mul_mat(src0, src1, dst)) {
|
||||
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
|
||||
if (ne11 == 1 && ne12 == 1 && ne13 == 1) {
|
||||
char * wdata = params->wdata;
|
||||
quantize_row_q_dot((float *)((char *) src1->data), (void *) wdata, ne10);
|
||||
}
|
||||
ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize);
|
||||
}
|
||||
return;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue