diff --git a/ggml-cuda.cu b/ggml-cuda.cu index e8a1e77cb..7c856e9ee 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -80,192 +80,227 @@ typedef struct { } block_q8_0; static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); -static __global__ void dequantize_block_q4_0(const void * vx, float * y) { +static __global__ void dequantize_block_q4_0(const void * vx, float * y, int k) { const block_q4_0 * x = (const block_q4_0 *) vx; - const int i = blockIdx.x; + const int i = blockIdx.x*blockDim.x + threadIdx.x; - const float d = x[i].d; + if (i < k) { + const float d = x[i].d; - const uint8_t * pp = x[i].qs; + const uint8_t * pp = x[i].qs; - for (int l = 0; l < QK4_0; l += 2) { - const uint8_t vi = pp[l/2]; + for (int l = 0; l < QK4_0; l += 2) { + const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; - const float v0 = (vi0 - 8)*d; - const float v1 = (vi1 - 8)*d; + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; - y[i*QK4_0 + l + 0] = v0; - y[i*QK4_0 + l + 1] = v1; + y[i*QK4_0 + l + 0] = v0; + y[i*QK4_0 + l + 1] = v1; + } } } -static __global__ void dequantize_block_q4_1(const void * vx, float * y) { +static __global__ void dequantize_block_q4_1(const void * vx, float * y, int k) { const block_q4_1 * x = (const block_q4_1 *) vx; - const int i = blockIdx.x; + const int i = blockIdx.x*blockDim.x + threadIdx.x; - const float d = x[i].d; - const float m = x[i].m; + if (i < k) { + const float d = x[i].d; + const float m = x[i].m; - const uint8_t * pp = x[i].qs; + const uint8_t * pp = x[i].qs; - for (int l = 0; l < QK4_1; l += 2) { - const uint8_t vi = pp[l/2]; + for (int l = 0; l < QK4_1; l += 2) { + const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; - const float v0 = vi0*d + m; - const float v1 = vi1*d + m; + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; - y[i*QK4_1 + l + 0] = v0; - y[i*QK4_1 + l + 1] = v1; + y[i*QK4_1 + l + 0] = v0; + y[i*QK4_1 + l + 1] = v1; + } } } -static __global__ void dequantize_block_q4_2(const void * vx, float * y) { +static __global__ void dequantize_block_q4_2(const void * vx, float * y, int k) { const block_q4_2 * x = (const block_q4_2 *) vx; - const int i = blockIdx.x; + const int i = blockIdx.x*blockDim.x + threadIdx.x; - const float d = x[i].d; + if (i < k) { + const float d = x[i].d; - const uint8_t * pp = x[i].qs; + const uint8_t * pp = x[i].qs; - for (int l = 0; l < QK4_2; l += 2) { - const uint8_t vi = pp[l/2]; + for (int l = 0; l < QK4_2; l += 2) { + const uint8_t vi = pp[l/2]; - const int8_t vi0 = vi & 0xf; - const int8_t vi1 = vi >> 4; + const int8_t vi0 = vi & 0xf; + const int8_t vi1 = vi >> 4; - const float v0 = (vi0 - 8)*d; - const float v1 = (vi1 - 8)*d; + const float v0 = (vi0 - 8)*d; + const float v1 = (vi1 - 8)*d; - y[i*QK4_2 + l + 0] = v0; - y[i*QK4_2 + l + 1] = v1; + y[i*QK4_2 + l + 0] = v0; + y[i*QK4_2 + l + 1] = v1; + } } } -static __global__ void dequantize_block_q5_0(const void * vx, float * y) { +static __global__ void dequantize_block_q5_0(const void * vx, float * y, int k) { const block_q5_0 * x = (const block_q5_0 *) vx; - const int i = blockIdx.x; + const int i = blockIdx.x*blockDim.x + threadIdx.x; - const float d = x[i].d; + if (i < k) { + const float d = x[i].d; - const uint8_t * pp = x[i].qs; + const uint8_t * pp = x[i].qs; - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); - for (int l = 0; l < QK5_0; l += 2) { - const uint8_t vi = pp[l/2]; + for (int l = 0; l < QK5_0; l += 2) { + const uint8_t vi = pp[l/2]; - const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; - const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; - const int8_t vi0 = ((vi & 0xf) | vh0); - const int8_t vi1 = ((vi >> 4) | vh1); + const int8_t vi0 = ((vi & 0xf) | vh0); + const int8_t vi1 = ((vi >> 4) | vh1); - const float v0 = (vi0 - 16)*d; - const float v1 = (vi1 - 16)*d; + const float v0 = (vi0 - 16)*d; + const float v1 = (vi1 - 16)*d; - y[i*QK5_0 + l + 0] = v0; - y[i*QK5_0 + l + 1] = v1; + y[i*QK5_0 + l + 0] = v0; + y[i*QK5_0 + l + 1] = v1; + } } } -static __global__ void dequantize_block_q5_1(const void * vx, float * y) { +static __global__ void dequantize_block_q5_1(const void * vx, float * y, int k) { const block_q5_1 * x = (const block_q5_1 *) vx; - const int i = blockIdx.x; + const int i = blockIdx.x*blockDim.x + threadIdx.x; - const float d = x[i].d; - const float m = x[i].m; + if (i < k) { + const float d = x[i].d; + const float m = x[i].m; - const uint8_t * pp = x[i].qs; + const uint8_t * pp = x[i].qs; - uint32_t qh; - memcpy(&qh, x[i].qh, sizeof(qh)); + uint32_t qh; + memcpy(&qh, x[i].qh, sizeof(qh)); - for (int l = 0; l < QK5_1; l += 2) { - const uint8_t vi = pp[l/2]; + for (int l = 0; l < QK5_1; l += 2) { + const uint8_t vi = pp[l/2]; - const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; - const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; + const int8_t vh0 = ((qh & (1 << (l + 0))) >> (l + 0)) << 4; + const int8_t vh1 = ((qh & (1 << (l + 1))) >> (l + 1)) << 4; - const int8_t vi0 = (vi & 0xf) | vh0; - const int8_t vi1 = (vi >> 4) | vh1; + const int8_t vi0 = (vi & 0xf) | vh0; + const int8_t vi1 = (vi >> 4) | vh1; - const float v0 = vi0*d + m; - const float v1 = vi1*d + m; + const float v0 = vi0*d + m; + const float v1 = vi1*d + m; - y[i*QK5_1 + l + 0] = v0; - y[i*QK5_1 + l + 1] = v1; + y[i*QK5_1 + l + 0] = v0; + y[i*QK5_1 + l + 1] = v1; + } } } -static __global__ void dequantize_block_q8_0(const void * vx, float * y) { +static __global__ void dequantize_block_q8_0(const void * vx, float * y, int k) { const block_q8_0 * x = (const block_q8_0 *) vx; - const int i = blockIdx.x; + const int i = blockIdx.x*blockDim.x + threadIdx.x; - const float d = x[i].d; + if (i < k) { + const float d = x[i].d; - const int8_t * pp = x[i].qs; + const int8_t * pp = x[i].qs; - for (int l = 0; l < QK8_0; l++) { - const int8_t vi = pp[l]; + for (int l = 0; l < QK8_0; l++) { + const int8_t vi = pp[l]; - y[i*QK8_0 + l] = vi*d; + y[i*QK8_0 + l] = vi*d; + } } } static void dequantize_row_q4_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_0; - dequantize_block_q4_0<<>>(vx, y); + int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_0, 0, 0)); + int grid_size = (nb + block_size - 1) / block_size; // Round up. + dequantize_block_q4_0<<>>(vx, y, k); } static void dequantize_row_q4_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_1; - dequantize_block_q4_1<<>>(vx, y); + int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_1, 0, 0)); + int grid_size = (nb + block_size - 1) / block_size; // Round up. + dequantize_block_q4_1<<>>(vx, y, k); } static void dequantize_row_q4_2_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK4_2; - dequantize_block_q4_2<<>>(vx, y); + int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q4_2, 0, 0)); + int grid_size = (nb + block_size - 1) / block_size; // Round up. + dequantize_block_q4_2<<>>(vx, y, k); } static void dequantize_row_q5_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK5_0; - dequantize_block_q5_0<<>>(vx, y); + int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_0, 0, 0)); + int grid_size = (nb + block_size - 1) / block_size; // Round up. + dequantize_block_q5_0<<>>(vx, y, k); } static void dequantize_row_q5_1_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK5_1; - dequantize_block_q5_1<<>>(vx, y); + int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q5_1, 0, 0)); + int grid_size = (nb + block_size - 1) / block_size; // Round up. + dequantize_block_q5_1<<>>(vx, y, k); } static void dequantize_row_q8_0_cuda(const void * vx, float * y, int k, cudaStream_t stream) { const int nb = k / QK8_0; - dequantize_block_q8_0<<>>(vx, y); + int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, dequantize_block_q8_0, 0, 0)); + int grid_size = (nb + block_size - 1) / block_size; // Round up. + dequantize_block_q8_0<<>>(vx, y, k); } // TODO: optimize -static __global__ void convert_fp16_to_fp32(const void * vx, float * y) { +static __global__ void convert_fp16_to_fp32(const void * vx, float * y, int k) { const half * x = (const half *) vx; const int i = blockIdx.x; - y[i] = __half2float(x[i]); + if (i < k) { + y[i] = __half2float(x[i]); + } } static void convert_fp16_to_fp32_cuda(const void * x, float * y, int k, cudaStream_t stream) { - convert_fp16_to_fp32<<>>(x, y); + int min_grid_size, block_size = 1; // Initialize to suppress compiler warning. + CUDA_CHECK(cudaOccupancyMaxPotentialBlockSize(&min_grid_size, &block_size, convert_fp16_to_fp32, 0, 0)); + int grid_size = (k + block_size - 1) / block_size; // Round up. + convert_fp16_to_fp32<<>>(x, y, k); } static to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {