works
This commit is contained in:
parent
cde81f91ca
commit
28e2762324
1 changed files with 56 additions and 32 deletions
88
ggml-cuda.cu
88
ggml-cuda.cu
|
@ -54,11 +54,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||||
typedef half dfloat; // dequantize float
|
typedef half dfloat; // dequantize float
|
||||||
typedef half2 dfloat2;
|
typedef half2 dfloat2;
|
||||||
|
|
||||||
static __device__ __forceinline__ void dadd_inplace(dfloat2 & a, const dfloat b) {
|
static __device__ __forceinline__ void dadd_scalar_inplace(dfloat2 & a, const dfloat b) {
|
||||||
a = __hadd2(a, {b, b});
|
a = __hadd2(a, {b, b});
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ void dsub_inplace(dfloat2 & a, const dfloat b) {
|
static __device__ __forceinline__ void dadd_vector_inplace(dfloat2 & a, const dfloat2 b) {
|
||||||
|
a = __hadd2(a, b);
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void dsub_scalar_inplace(dfloat2 & a, const dfloat b) {
|
||||||
a = __hsub2(a, {b, b});
|
a = __hsub2(a, {b, b});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -75,12 +79,17 @@ static __device__ __forceinline__ void dmul_vector_inplace(dfloat2 & a, const df
|
||||||
typedef float dfloat; // dequantize float
|
typedef float dfloat; // dequantize float
|
||||||
typedef float2 dfloat2;
|
typedef float2 dfloat2;
|
||||||
|
|
||||||
static __device__ __forceinline__ void dadd_inplace(dfloat2 & a, const dfloat b) {
|
static __device__ __forceinline__ void dadd_scalar_inplace(dfloat2 & a, const dfloat b) {
|
||||||
a.x += b;
|
a.x += b;
|
||||||
a.y += b;
|
a.y += b;
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ void dsub_inplace(dfloat2 & a, const dfloat b) {
|
static __device__ __forceinline__ void dadd_vector_inplace(dfloat2 & a, const dfloat2 b) {
|
||||||
|
a.x += b.x;
|
||||||
|
a.y += b.y;
|
||||||
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ void dsub_scalar_inplace(dfloat2 & a, const dfloat b) {
|
||||||
a.x -= b;
|
a.x -= b;
|
||||||
a.y -= b;
|
a.y -= b;
|
||||||
}
|
}
|
||||||
|
@ -290,7 +299,7 @@ static __device__ void dequantize_q4_0(const void * vx, const int ib, const int
|
||||||
v.x = vui & 0xF;
|
v.x = vui & 0xF;
|
||||||
v.y = vui >> 4;
|
v.y = vui >> 4;
|
||||||
|
|
||||||
dsub_inplace(v, 8.0f);
|
dsub_scalar_inplace(v, 8.0f);
|
||||||
dmul_scalar_inplace(v, d);
|
dmul_scalar_inplace(v, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -306,7 +315,7 @@ static __device__ void dequantize_q4_1(const void * vx, const int ib, const int
|
||||||
v.y = vui >> 4;
|
v.y = vui >> 4;
|
||||||
|
|
||||||
dmul_scalar_inplace(v, d);
|
dmul_scalar_inplace(v, d);
|
||||||
dadd_inplace(v, m);
|
dadd_scalar_inplace(v, m);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
|
@ -323,7 +332,7 @@ static __device__ void dequantize_q5_0(const void * vx, const int ib, const int
|
||||||
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
|
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
|
||||||
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
|
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
|
||||||
|
|
||||||
dsub_inplace(v, 16.0f);
|
dsub_scalar_inplace(v, 16.0f);
|
||||||
dmul_scalar_inplace(v, d);
|
dmul_scalar_inplace(v, d);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -343,7 +352,7 @@ static __device__ void dequantize_q5_1(const void * vx, const int ib, const int
|
||||||
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
|
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
|
||||||
|
|
||||||
dmul_scalar_inplace(v, d);
|
dmul_scalar_inplace(v, d);
|
||||||
dadd_inplace(v, m);
|
dadd_scalar_inplace(v, m);
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
|
@ -891,8 +900,9 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
|
||||||
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
const half * x = (const half *) vx;
|
const half * x = (const half *) vx;
|
||||||
|
|
||||||
v.x = __half2float(x[ib + iqs + 0]);
|
// automatic half -> float type cast if dfloat == float
|
||||||
v.y = __half2float(x[ib + iqs + 1]);
|
v.x = x[ib + iqs + 0];
|
||||||
|
v.y = x[ib + iqs + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
|
@ -917,7 +927,7 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
static __global__ void dequantize_mul_mat_vec(const void * vx, const half * y, float * dst, const int ncols, const int nrows) {
|
static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
|
||||||
// qk = quantized weights per x block
|
// qk = quantized weights per x block
|
||||||
// qr = number of quantized weights per data value in x block
|
// qr = number of quantized weights per data value in x block
|
||||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||||
|
@ -932,7 +942,7 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const half * y, f
|
||||||
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
|
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
|
||||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||||
|
|
||||||
dfloat tmp = 0.0f; // partial sum for thread in warp
|
dfloat2 tmp = {0.0f, 0.0f}; // partial sum for thread in warp
|
||||||
|
|
||||||
for (int i = 0; i < ncols; i += iter_stride) {
|
for (int i = 0; i < ncols; i += iter_stride) {
|
||||||
const int col = i + vals_per_iter*tid;
|
const int col = i + vals_per_iter*tid;
|
||||||
|
@ -951,10 +961,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const half * y, f
|
||||||
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
||||||
|
|
||||||
// matrix multiplication
|
// matrix multiplication
|
||||||
dfloat2 yi = {y[iybs + iqs + j/qr + 0], y[iybs + iqs + j/qr + y_offset]};
|
const dfloat2 yi = {y[iybs + iqs + j/qr + 0], y[iybs + iqs + j/qr + y_offset]};
|
||||||
tmp += v.x * yi.x;
|
|
||||||
tmp += v.y * yi.y;
|
|
||||||
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
||||||
|
|
||||||
|
dmul_vector_inplace(v, yi);
|
||||||
|
dadd_vector_inplace(tmp, v);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -962,11 +973,16 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const half * y, f
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
|
||||||
|
#else
|
||||||
|
tmp.x += __shfl_xor_sync(0xffffffff, tmp.x, mask, 32);
|
||||||
|
tmp.y += __shfl_xor_sync(0xffffffff, tmp.y, mask, 32);
|
||||||
|
#endif // GGML_CUDA_DMMV_F16
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
dst[row] = tmp;
|
dst[row] = tmp.x + tmp.y;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1261,7 +1277,7 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu
|
||||||
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const half * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1270,7 +1286,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const half * y, fl
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const half * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1279,7 +1295,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const half * y, fl
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const half * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1288,7 +1304,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const half * y, fl
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const half * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1297,7 +1313,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const half * y, fl
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const half * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1347,7 +1363,7 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
|
||||||
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void convert_mul_mat_vec_f16_cuda(const void * vx, const half * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1762,34 +1778,40 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t nrows = i01_high - i01_low;
|
const int64_t nrows = i01_high - i01_low;
|
||||||
|
|
||||||
|
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
|
||||||
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
size_t ash;
|
size_t ash;
|
||||||
half * src1_ddh_i = nullptr;
|
dfloat * src1_dfloat = nullptr; // dfloat == half
|
||||||
|
|
||||||
bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
|
bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
|
||||||
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
|
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
|
||||||
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
|
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
|
||||||
|
|
||||||
if (src1_convert_f16) {
|
if (src1_convert_f16) {
|
||||||
src1_ddh_i = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
|
src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
|
||||||
ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_ddh_i, ne00,
|
ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
|
||||||
ne00, 1, sizeof(float), 0, 0,
|
ne00, 1, sizeof(float), 0, 0,
|
||||||
ne00, 1, sizeof(half), 0, 0, cudaStream_main);
|
ne00, 1, sizeof(half), 0, 0, cudaStream_main);
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
|
||||||
|
#endif // GGML_CUDA_DMMV_F16
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_ddh_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_ddh_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_ddh_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_ddh_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddh_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
|
@ -1807,7 +1829,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddh_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -1815,9 +1837,11 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
}
|
}
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
if (src1_convert_f16) {
|
if (src1_convert_f16) {
|
||||||
ggml_cuda_pool_free(src1_ddh_i, ash);
|
ggml_cuda_pool_free(src1_dfloat, ash);
|
||||||
}
|
}
|
||||||
|
#endif // GGML_CUDA_DMMV_F16
|
||||||
|
|
||||||
(void) src1;
|
(void) src1;
|
||||||
(void) dst;
|
(void) dst;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue