Speedup dequantize_block_q4_0()

This commit is contained in:
Ivan Komarov 2023-04-29 00:46:25 +03:00
parent 7fc50c051a
commit d67f481144

View file

@ -1,3 +1,4 @@
#include <assert.h>
#include <stdint.h>
#include <stdio.h>
#include <cuda_fp16.h>
@ -7,7 +8,13 @@
typedef uint16_t ggml_fp16_t;
static_assert(sizeof(__half) == sizeof(ggml_fp16_t), "wrong fp16 size");
#define QK4_0 32
#define WARP_SIZE 32
#define THREAD_COUNT 1024
#define WARP_COUNT (THREAD_COUNT / WARP_SIZE)
#define QK4_0 32
#define QK4_0_Q_BLOCKS_PER_WARP 2
#define QK4_0_Q_BLOCKS_PER_THREAD_BLOCK (WARP_COUNT * QK4_0_Q_BLOCKS_PER_WARP)
typedef struct {
float d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
@ -53,26 +60,28 @@ typedef struct {
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
static __global__ void dequantize_block_q4_0(int nb, const void * vx, float * y) {
const block_q4_0 * x = (const block_q4_0 *) vx;
const int i = blockIdx.x;
const unsigned lane_id = threadIdx.x % WARP_SIZE;
const unsigned warp_id = threadIdx.x / WARP_SIZE;
const unsigned start_block_id = blockIdx.x * QK4_0_Q_BLOCKS_PER_THREAD_BLOCK + warp_id * QK4_0_Q_BLOCKS_PER_WARP;
const float d = x[i].d;
if (start_block_id >= nb) {
return;
}
const uint8_t * pp = x[i].qs;
#pragma unroll
for (int i = 0; i < QK4_0_Q_BLOCKS_PER_WARP; ++i) {
const int block_id = start_block_id + i;
const unsigned * int_qs = (unsigned *) x[block_id].qs;
for (int l = 0; l < QK4_0; l += 2) {
const uint8_t vi = pp[l/2];
const unsigned int_id = lane_id / 8;
const unsigned shift = 4*(lane_id % 8);
const unsigned nibble = (int_qs[int_id] >> shift) & 0xf;
const float v = ((int)nibble - 8) * x[block_id].d;
const int8_t vi0 = vi & 0xf;
const int8_t vi1 = vi >> 4;
const float v0 = (vi0 - 8)*d;
const float v1 = (vi1 - 8)*d;
y[i*QK4_0 + l + 0] = v0;
y[i*QK4_0 + l + 1] = v1;
y[block_id*QK4_0 + lane_id] = v;
}
}
@ -197,9 +206,15 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
}
}
static int get_thread_block_count(int nb, int q_blocks_per_thread_block) {
return nb / q_blocks_per_thread_block + (nb % q_blocks_per_thread_block != 0);
}
void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
assert(nb % QK4_0_Q_BLOCKS_PER_WARP == 0);
const int n_thread_blocks = get_thread_block_count(nb, QK4_0_Q_BLOCKS_PER_THREAD_BLOCK);
dequantize_block_q4_0<<<n_thread_blocks, THREAD_COUNT, 0, stream>>>(nb, vx, y);
}
void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {