Convert vector to f16 for dmmv
This commit is contained in:
parent
90cc59d6ab
commit
98b8596b2a
2 changed files with 33 additions and 16 deletions
47
ggml-cuda.cu
47
ggml-cuda.cu
|
@ -870,7 +870,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
|
static __global__ void dequantize_mul_mat_vec(const void * vx, const half * y, float * dst, const int ncols, const int nrows) {
|
||||||
// qk = quantized weights per x block
|
// qk = quantized weights per x block
|
||||||
// qr = number of quantized weights per data value in x block
|
// qr = number of quantized weights per data value in x block
|
||||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||||
|
@ -904,8 +904,8 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
||||||
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
||||||
|
|
||||||
// matrix multiplication
|
// matrix multiplication
|
||||||
tmp += v0 * y[iybs + iqs + j/qr + 0];
|
tmp += v0 * __half2float(y[iybs + iqs + j/qr + 0]);
|
||||||
tmp += v1 * y[iybs + iqs + j/qr + y_offset];
|
tmp += v1 * __half2float(y[iybs + iqs + j/qr + y_offset]);
|
||||||
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1213,7 +1213,7 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu
|
||||||
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const half * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1222,7 +1222,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, f
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const half * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1231,7 +1231,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, f
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const half * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1240,7 +1240,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, f
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const half * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1249,7 +1249,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, f
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const half * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1299,7 +1299,7 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
|
||||||
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void convert_mul_mat_vec_f16_cuda(const void * vx, const half * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1714,21 +1714,34 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t nrows = i01_high - i01_low;
|
const int64_t nrows = i01_high - i01_low;
|
||||||
|
|
||||||
|
size_t ash;
|
||||||
|
half * src1_ddh_i = nullptr;
|
||||||
|
bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
|
||||||
|
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
|
||||||
|
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
|
||||||
|
|
||||||
|
if (src1_convert_f16) {
|
||||||
|
src1_ddh_i = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
|
||||||
|
ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_ddh_i, ne00,
|
||||||
|
ne00, 1, sizeof(float), 0, 0,
|
||||||
|
ne00, 1, sizeof(half), 0, 0, cudaStream_main);
|
||||||
|
}
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_ddh_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_ddh_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_ddh_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_ddh_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddh_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
|
@ -1746,7 +1759,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddh_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -1754,6 +1767,10 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
}
|
}
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
if (src1_convert_f16) {
|
||||||
|
ggml_cuda_pool_free(src1_ddh_i, ash);
|
||||||
|
}
|
||||||
|
|
||||||
(void) src1;
|
(void) src1;
|
||||||
(void) dst;
|
(void) dst;
|
||||||
(void) src0_ddf_i;
|
(void) src0_ddf_i;
|
||||||
|
|
|
@ -1615,7 +1615,7 @@ static bool llama_eval_internal(
|
||||||
model.layers[il].w1,
|
model.layers[il].w1,
|
||||||
cur);
|
cur);
|
||||||
offload_func(cur);
|
offload_func(cur);
|
||||||
ggml_set_name(cur, "result_w2");
|
ggml_set_name(cur, "result_w1");
|
||||||
|
|
||||||
// SILU activation
|
// SILU activation
|
||||||
cur = ggml_silu(ctx0, cur);
|
cur = ggml_silu(ctx0, cur);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue