More GPU threads for CUDA kernels

This commit is contained in:
JohannesGaessler 2023-05-06 12:17:45 +02:00
parent 173d0e6419
commit 50148408b5

View file

@ -80,11 +80,12 @@ typedef struct {
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding");
static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
static __global__ void dequantize_block_q4_0(const void * vx, float * y, int k) {
const block_q4_0 * x = (const block_q4_0 *) vx;
const int i = blockIdx.x;
const int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i < k) {
const float d = x[i].d;
const uint8_t * pp = x[i].qs;
@ -101,13 +102,15 @@ static __global__ void dequantize_block_q4_0(const void * vx, float * y) {
y[i*QK4_0 + l + 0] = v0;
y[i*QK4_0 + l + 1] = v1;
}
}
}
static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
static __global__ void dequantize_block_q4_1(const void * vx, float * y, int k) {
const block_q4_1 * x = (const block_q4_1 *) vx;
const int i = blockIdx.x;
const int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i < k) {
const float d = x[i].d;
const float m = x[i].m;
@ -125,13 +128,15 @@ static __global__ void dequantize_block_q4_1(const void * vx, float * y) {
y[i*QK4_1 + l + 0] = v0;
y[i*QK4_1 + l + 1] = v1;
}
}
}
static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
static __global__ void dequantize_block_q4_2(const void * vx, float * y, int k) {
const block_q4_2 * x = (const block_q4_2 *) vx;
const int i = blockIdx.x;
const int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i < k) {
const float d = x[i].d;
const uint8_t * pp = x[i].qs;
@ -148,13 +153,15 @@ static __global__ void dequantize_block_q4_2(const void * vx, float * y) {
y[i*QK4_2 + l + 0] = v0;
y[i*QK4_2 + l + 1] = v1;
}
}
}
static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
static __global__ void dequantize_block_q5_0(const void * vx, float * y, int k) {
const block_q5_0 * x = (const block_q5_0 *) vx;
const int i = blockIdx.x;
const int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i < k) {
const float d = x[i].d;
const uint8_t * pp = x[i].qs;
@ -177,13 +184,15 @@ static __global__ void dequantize_block_q5_0(const void * vx, float * y) {
y[i*QK5_0 + l + 0] = v0;
y[i*QK5_0 + l + 1] = v1;
}
}
}
static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
static __global__ void dequantize_block_q5_1(const void * vx, float * y, int k) {
const block_q5_1 * x = (const block_q5_1 *) vx;
const int i = blockIdx.x;
const int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i < k) {
const float d = x[i].d;
const float m = x[i].m;
@ -207,13 +216,15 @@ static __global__ void dequantize_block_q5_1(const void * vx, float * y) {
y[i*QK5_1 + l + 0] = v0;
y[i*QK5_1 + l + 1] = v1;
}
}
}
static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
static __global__ void dequantize_block_q8_0(const void * vx, float * y, int k) {
const block_q8_0 * x = (const block_q8_0 *) vx;
const int i = blockIdx.x;
const int i = blockIdx.x*blockDim.x + threadIdx.x;
if (i < k) {
const float d = x[i].d;
const int8_t * pp = x[i].qs;
@ -223,49 +234,73 @@ static __global__ void dequantize_block_q8_0(const void * vx, float * y) {
y[i*QK8_0 + l] = vi*d;
}
}
}
static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_0;
dequantize_block_q4_0<<<nb, 1, 0, stream>>>(vx, y);
int min_grid_size, block_size = 1; // Initialize to suppress compiler warning.
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_0, 0, 0));
int grid_size = (nb + block_size - 1) / block_size; // Round up.
dequantize_block_q4_0<<<grid_size, block_size, 0, stream>>>(vx, y, k);
}
static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_1;
dequantize_block_q4_1<<<nb, 1, 0, stream>>>(vx, y);
int min_grid_size, block_size = 1; // Initialize to suppress compiler warning.
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_1, 0, 0));
int grid_size = (nb + block_size - 1) / block_size; // Round up.
dequantize_block_q4_1<<<grid_size, block_size, 0, stream>>>(vx, y, k);
}
static void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK4_2;
dequantize_block_q4_2<<<nb, 1, 0, stream>>>(vx, y);
int min_grid_size, block_size = 1; // Initialize to suppress compiler warning.
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_2, 0, 0));
int grid_size = (nb + block_size - 1) / block_size; // Round up.
dequantize_block_q4_2<<<grid_size, block_size, 0, stream>>>(vx, y, k);
}
static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK5_0;
dequantize_block_q5_0<<<nb, 1, 0, stream>>>(vx, y);
int min_grid_size, block_size = 1; // Initialize to suppress compiler warning.
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_0, 0, 0));
int grid_size = (nb + block_size - 1) / block_size; // Round up.
dequantize_block_q5_0<<<grid_size, block_size, 0, stream>>>(vx, y, k);
}
static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK5_1;
dequantize_block_q5_1<<<nb, 1, 0, stream>>>(vx, y);
int min_grid_size, block_size = 1; // Initialize to suppress compiler warning.
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_1, 0, 0));
int grid_size = (nb + block_size - 1) / block_size; // Round up.
dequantize_block_q5_1<<<grid_size, block_size, 0, stream>>>(vx, y, k);
}
static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) {
const int nb = k / QK8_0;
dequantize_block_q8_0<<<nb, 1, 0, stream>>>(vx, y);
int min_grid_size, block_size = 1; // Initialize to suppress compiler warning.
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q8_0, 0, 0));
int grid_size = (nb + block_size - 1) / block_size; // Round up.
dequantize_block_q8_0<<<grid_size, block_size, 0, stream>>>(vx, y, k);
}
// TODO: optimize
static __global__ void convert_fp16_to_fp32(const void * vx, float * y) {
static __global__ void convert_fp16_to_fp32(const void * vx, float * y, int k) {
const half * x = (const half *) vx;
const int i = blockIdx.x;
if (i < k) {
y[i] = __half2float(x[i]);
}
}
static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) {
convert_fp16_to_fp32<<<k, 1, 0, stream>>>(x, y);
int min_grid_size, block_size = 1; // Initialize to suppress compiler warning.
CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, convert_fp16_to_fp32, 0, 0));
int grid_size = (k + block_size - 1) / block_size; // Round up.
convert_fp16_to_fp32<<<grid_size, block_size, 0, stream>>>(x, y, k);
}
static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {