Use make_half2 for better compatibility

This commit is contained in:
lijiahao 2023-08-27 11:06:28 +08:00
parent d01f52409f
commit af31f1f00d

View file

@ -4197,6 +4197,12 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
}
#ifdef GGML_CUDA_F16
#define make_dfloat2(x, y) make_half2((x), (y))
#else
#define make_dfloat2(x, y) make_float2((x), (y))
#endif
static __device__ __forceinline__ dfloat2 dfmul2(dfloat2 a, dfloat2 b) {
#ifdef GGML_CUDA_F16
return __hmul2(a, b);
@ -4227,15 +4233,11 @@ static __global__ void dequantize_block_q4_0(const void * __restrict__ vx, float
const uchar2 qs = *(const uchar2 *)(x[ib].qs + iqs*2);
const dfloat d = x[ib].d;
dfloat2 dv0;
dv0.x = (int)(qs.x & 0xf) - 8;
dv0.y = (int)(qs.y & 0xf) - 8;
dfloat2 dv0 = make_dfloat2((int)(qs.x & 0xf) - 8, (int)(qs.y & 0xf) - 8);
const float2 v0 = dfloat22float2(dfmul2(dv0, {d, d}));
*(float2 *)(y + ib*QK4_0 + iqs*2) = v0;
dfloat2 dv1;
dv1.x = (int)(qs.x >> 4) - 8;
dv1.y = (int)(qs.y >> 4) - 8;
dfloat2 dv1 = make_dfloat2((int)(qs.x >> 4) - 8, (int)(qs.y >> 4) - 8);
const float2 v1 = dfloat22float2(dfmul2(dv1, {d, d}));
*(float2 *)(y + ib*QK4_0 + QK4_0/2 + iqs*2) = v1;
}
@ -5755,6 +5757,7 @@ inline void ggml_cuda_op_alibi(
(void) src1;
(void) src0_ddq_i;
(void) src1_ddf_i;
(void) i02;
(void) i1;
}