iq3_xxs: CUDA dequantize works

This commit is contained in:
Iwan Kawrakow 2024-01-27 11:34:52 +02:00
parent 8524d277ec
commit bf9349c610

View file

@ -505,6 +505,14 @@ typedef struct {
} block_iq2_xs;
static_assert(sizeof(block_iq2_xs) == sizeof(ggml_fp16_t) + QK_K/8*sizeof(uint16_t) + QK_K/32, "wrong iq2_xs block size/padding");
#define QR3_XXS 8
#define QI3_XXS (QK_K / (4*QR3_XXS))
typedef struct {
half d;
uint8_t qs[3*(QK_K/8)];
} block_iq3_xxs;
static_assert(sizeof(block_iq3_xxs) == sizeof(ggml_fp16_t) + 3*(QK_K/8), "wrong iq3_xxs block size/padding");
#define WARP_SIZE 32
#define MATRIX_ROW_PADDING 512 // last row of quant. matrices is a multiple of this to avoid out-of-bounds memory accesses
@ -1613,6 +1621,41 @@ static const __device__ uint64_t iq2xs_grid[512] = {
0x2b2b2b2b082b2b08, 0x2b2b2b2b082b2b2b, 0x2b2b2b2b2b190819, 0x2b2b2b2b2b2b2b2b,
};
static const __device__ uint32_t iq3xxs_grid[256] = {
0x04040404, 0x04040414, 0x04040c0c, 0x04040c30, 0x04041404, 0x04041414, 0x0404450c, 0x0404451d,
0x04044530, 0x04044545, 0x040c040c, 0x040c0445, 0x040c0c04, 0x040c0c14, 0x040c140c, 0x040c1d04,
0x040c1d14, 0x040c3014, 0x04140404, 0x04140414, 0x04140c0c, 0x04141404, 0x04141d1d, 0x04143045,
0x04143a04, 0x04144545, 0x041d0430, 0x041d0c04, 0x041d2630, 0x041d4526, 0x04260c0c, 0x04262604,
0x04263a14, 0x04264545, 0x0430141d, 0x04301445, 0x04302645, 0x04303026, 0x04304504, 0x043a4530,
0x043a4545, 0x0445041d, 0x04450445, 0x04450c04, 0x04451430, 0x04451d04, 0x04452645, 0x04453014,
0x0c04040c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c041d04, 0x0c041d14, 0x0c042645, 0x0c043004,
0x0c043026, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, 0x0c0c1404, 0x0c0c1d30, 0x0c14040c, 0x0c140c04,
0x0c14140c, 0x0c141445, 0x0c14260c, 0x0c144514, 0x0c1d301d, 0x0c26041d, 0x0c260445, 0x0c261430,
0x0c300404, 0x0c301404, 0x0c302614, 0x0c30451d, 0x0c3a0430, 0x0c3a3004, 0x0c451414, 0x0c452626,
0x0c45450c, 0x0c45451d, 0x14040404, 0x14040414, 0x14040426, 0x14040c0c, 0x14040c45, 0x14041404,
0x14041d1d, 0x1404450c, 0x140c040c, 0x140c0c04, 0x140c0c14, 0x140c140c, 0x140c4526, 0x14140404,
0x14140426, 0x14143a3a, 0x141d0c04, 0x141d0c3a, 0x141d1d04, 0x141d1d14, 0x141d3004, 0x14260c0c,
0x14261d45, 0x14262626, 0x1426450c, 0x14264530, 0x14264545, 0x1430141d, 0x1430303a, 0x143a0414,
0x14450c04, 0x14451d04, 0x14451d3a, 0x14453014, 0x1445303a, 0x1d040c04, 0x1d040c14, 0x1d041430,
0x1d043004, 0x1d04303a, 0x1d0c0404, 0x1d0c0c1d, 0x1d0c1d0c, 0x1d140445, 0x1d142630, 0x1d143014,
0x1d1d0414, 0x1d1d1426, 0x1d1d3045, 0x1d1d451d, 0x1d260430, 0x1d300404, 0x1d300c45, 0x1d301404,
0x1d30300c, 0x1d3a3026, 0x1d450426, 0x1d45043a, 0x1d451d1d, 0x1d454545, 0x26042614, 0x26042626,
0x2604451d, 0x26044530, 0x26044545, 0x260c0430, 0x26141414, 0x26141d45, 0x26142604, 0x26144530,
0x261d0c04, 0x261d4504, 0x26262604, 0x26262626, 0x2630041d, 0x2630141d, 0x26301430, 0x26303a45,
0x26304514, 0x263a1d0c, 0x263a4530, 0x2645040c, 0x26451445, 0x26453014, 0x2645303a, 0x26454504,
0x3004041d, 0x30040445, 0x3004140c, 0x30041d3a, 0x30043004, 0x300c0404, 0x300c1426, 0x300c3030,
0x300c450c, 0x3014261d, 0x30143a45, 0x301d0414, 0x301d0426, 0x301d0c45, 0x301d1426, 0x301d3030,
0x301d3a14, 0x30261d14, 0x30264526, 0x3026453a, 0x30300404, 0x3030301d, 0x30303030, 0x30303a04,
0x303a0430, 0x303a2645, 0x30451414, 0x30451426, 0x30452604, 0x30452626, 0x3045451d, 0x3a0c3a1d,
0x3a0c453a, 0x3a141414, 0x3a143a04, 0x3a1d1d3a, 0x3a262604, 0x3a263045, 0x3a300c14, 0x3a300c3a,
0x3a3a1404, 0x3a3a1d30, 0x3a3a300c, 0x3a45041d, 0x3a450445, 0x3a451445, 0x45040430, 0x45040c04,
0x45040c14, 0x45041d04, 0x45041d14, 0x45041d26, 0x45042645, 0x45043004, 0x45043014, 0x45043a30,
0x45043a45, 0x45044504, 0x4514040c, 0x45140c26, 0x45141445, 0x4514260c, 0x45142630, 0x45142645,
0x45143a30, 0x45143a45, 0x45144514, 0x451d1404, 0x451d1d1d, 0x4526040c, 0x45260445, 0x45261430,
0x45263014, 0x45263a30, 0x45264504, 0x45300426, 0x45301d45, 0x45302626, 0x4530451d, 0x45304545,
0x453a1d14, 0x453a303a, 0x45450404, 0x45450c30, 0x45452604, 0x4545301d, 0x4545450c, 0x45454530,
};
static const __device__ uint8_t ksigns_iq2xs[128] = {
0, 129, 130, 3, 132, 5, 6, 135, 136, 9, 10, 139, 12, 141, 142, 15,
144, 17, 18, 147, 20, 149, 150, 23, 24, 153, 154, 27, 156, 29, 30, 159,
@ -1690,6 +1733,34 @@ static __global__ void dequantize_block_iq2_xs(const void * __restrict__ vx, dst
}
template<typename dst_t>
static __global__ void dequantize_block_iq3_xxs(const void * __restrict__ vx, dst_t * __restrict__ yy) {
const int i = blockIdx.x;
const block_iq3_xxs * x = (const block_iq3_xxs *) vx;
const int tid = threadIdx.x;
#if QK_K == 256
const int il = tid/8; // 0...3
const int ib = tid%8; // 0...7
dst_t * y = yy + i*QK_K + 32*ib + 8*il;
const uint8_t * q3 = x[i].qs + 8*ib;
const uint16_t * gas = (const uint16_t *)(x[i].qs + QK_K/4) + 2*ib;
const uint8_t * grid1 = (const uint8_t *)(iq3xxs_grid + q3[2*il+0]);
const uint8_t * grid2 = (const uint8_t *)(iq3xxs_grid + q3[2*il+1]);
const uint32_t aux32 = gas[0] | (gas[1] << 16);
const float d = (float)x[i].d * (0.5f + (aux32 >> 28)) * 0.5f;
const uint8_t signs = ksigns_iq2xs[(aux32 >> 7*il) & 127];
for (int j = 0; j < 4; ++j) {
y[j+0] = d * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f);
y[j+4] = d * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f);
}
#else
assert(false);
#endif
}
static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
@ -6381,6 +6452,12 @@ static void dequantize_row_iq2_xs_cuda(const void * vx, dst_t * y, const int k,
dequantize_block_iq2_xs<<<nb, 32, 0, stream>>>(vx, y);
}
template<typename dst_t>
static void dequantize_row_iq3_xxs_cuda(const void * vx, dst_t * y, const int k, cudaStream_t stream) {
const int nb = k / QK_K;
dequantize_block_iq3_xxs<<<nb, 32, 0, stream>>>(vx, y);
}
template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
@ -6418,6 +6495,8 @@ static to_fp16_cuda_t ggml_get_to_fp16_cuda(ggml_type type) {
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ3_XXS:
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_F32:
return convert_unary_cuda<float>;
default:
@ -6451,6 +6530,8 @@ static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
return dequantize_row_iq2_xxs_cuda;
case GGML_TYPE_IQ2_XS:
return dequantize_row_iq2_xs_cuda;
case GGML_TYPE_IQ3_XXS:
return dequantize_row_iq3_xxs_cuda;
case GGML_TYPE_F16:
return convert_unary_cuda<half>;
default:
@ -8213,6 +8294,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
return max_compute_capability >= CC_RDNA2 ? 128 : 64;
default:
GGML_ASSERT(false);
@ -8235,6 +8317,7 @@ static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUD
case GGML_TYPE_Q5_K:
case GGML_TYPE_IQ2_XXS:
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ3_XXS:
return max_compute_capability >= CC_VOLTA ? 128 : 64;
case GGML_TYPE_Q6_K:
return 64;
@ -10931,7 +11014,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
return false;
}
ggml_type a_type = a->type;
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS) {
if (a_type == GGML_TYPE_IQ2_XXS || a_type == GGML_TYPE_IQ2_XS || a_type == GGML_TYPE_IQ3_XXS) {
if (b->ne[1] == 1 && ggml_nrows(b) > 1) {
return false;
}