CUDA: added int8 tensor core matrix multiplication, 4279 t/s

This commit is contained in:
JohannesGaessler 2024-01-03 20:58:58 +01:00
parent 6f9939d119
commit 979a9bf1be

View file

@ -104,6 +104,8 @@
#include <cuda.h>
#include <cublas_v2.h>
#include <cuda_fp16.h>
#include <mma.h>
#include <cuda/barrier>
#if CUDART_VERSION < 11020
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
@ -122,13 +124,17 @@
#define CUDART_HMAX 11070 // CUDA 11.7, min. ver. for which __hmax and __hmax2 are known to work (may be higher than needed)
#define CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700
#define CC_OFFSET_AMD 1000000
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
#define CC_PASCAL 600
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#define CC_VOLTA 700 // minimum compute capability for mma, i.e. tensor cores
#define CC_TURING 750
#define CC_AMPERE 800
#define CC_ADA_LOVELACE 890
#define CC_HOPPER 900
#define CC_OFFSET_AMD 1000000
#define CC_RDNA1 (CC_OFFSET_AMD + 1010)
#define CC_RDNA2 (CC_OFFSET_AMD + 1030)
#define CC_RDNA3 (CC_OFFSET_AMD + 1100)
#define GGML_CUDA_MAX_NODES 8192
@ -574,12 +580,14 @@ static std::array<float, GGML_CUDA_MAX_DEVICES> g_default_tensor_split = {};
struct cuda_device_capabilities {
int cc; // compute capability
size_t smpb; // max. shared memory per block
int nsm; // number of streaming multiprocessors
size_t smem; // shared memory per SM
size_t smempb; // max. shared memory per block
bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory
};
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, false, 0} };
static cuda_device_capabilities g_device_caps[GGML_CUDA_MAX_DEVICES] = { {0, 0, 0, 0, false, 0} };
static cublasHandle_t g_cublas_handles[GGML_CUDA_MAX_DEVICES] = {nullptr};
@ -2290,6 +2298,192 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
reinterpret_cast<half&>(y[ib].ds.y) = sum;
}
template <int block_size, int kx_template>
static __global__ void convert_q8_0_to_i8(
const void * __restrict__ vx, int * __restrict__ y_qs_low, float * __restrict__ y_d, const int kx_par) {
const int kx = kx_template == 0 ? kx_par : kx_template;
const int nint = kx*sizeof(block_q8_0)/(QK8_0*sizeof(int));
typedef cuda::barrier<cuda::thread_scope::thread_scope_block> cuda_barrier;
const cuda::aligned_size_t<4*sizeof(int)> as(4*sizeof(int));
extern __shared__ float data_q8_0_i8[];
cuda_barrier * barrier = (cuda_barrier *) data_q8_0_i8;
float * buf_iw = (data_q8_0_i8 + 32);
int * vals = (int *) (buf_iw + WARP_SIZE);
const block_q8_0 * valsb = (const block_q8_0 *) vals;
if (threadIdx.x == 0) {
init(barrier, block_size);
}
if (threadIdx.x < WARP_SIZE) {
buf_iw[threadIdx.x] = 0.0f;
}
__syncthreads();
const int iy = blockDim.y*blockIdx.y + threadIdx.y;
const int * x = (const int *) vx;
char4 * y_qsc = (char4 *) y_qs_low;
#pragma unroll
for (int ix0 = 0; ix0 < nint; ix0 += 4*block_size) {
const int ix = ix0 + 4*threadIdx.x;
if (ix >= nint) {
break;
}
cuda::memcpy_async(&vals[ix], &x[iy*nint + ix], as, *barrier);
}
barrier->arrive_and_wait();
float amax = 0.0f;
#pragma unroll
for (int ix0 = 0; ix0 < kx/QK8_0; ix0 += block_size) {
const int ix = ix0 + threadIdx.x;
if (ix >= kx/QK8_0) {
break;
}
const float d = __half2float(valsb[ix].d);
amax = max(amax, fabsf(d));
}
amax = warp_reduce_max(amax);
if (threadIdx.x % WARP_SIZE == 0) {
buf_iw[threadIdx.x / WARP_SIZE] = amax;
}
__syncthreads();
amax = buf_iw[threadIdx.x % WARP_SIZE];
amax = warp_reduce_max(amax);
#pragma unroll
for (int ix0 = 0; ix0 < kx/sizeof(int); ix0 += block_size) {
const int ix = ix0 + threadIdx.x;
if (ix >= kx/sizeof(int)) {
break;
}
const block_q8_0 * bxi = valsb + ix/QI8_0;
const float scale = __half2float(bxi->d) / amax;
const int xi = get_int_from_int8(bxi->qs, ix % QI8_0);
const int8_t * xi8 = (const int8_t *) &xi;
int result32[4];
#pragma unroll
for (int l = 0; l < 4; ++l) {
result32[l] = roundf(xi8[l] * scale);
}
y_qsc[iy*kx/sizeof(int) + ix] = make_char4(result32[0], result32[1], result32[2], result32[3]);
}
if (threadIdx.x > 0) {
return;
}
y_d[iy] = amax;
}
#define I8_MAX_FRAG_SCALE 256.0f
template <bool frag_scales, int nrows_smem>
static __global__ void convert_float_to_i8(
const float * __restrict__ x, int * __restrict__ y_qs_low, int * __restrict__ y_bs, float * __restrict__ y_d, const int kx) {
extern __shared__ float data_convert_float_to_i8[];
float * buf_iw = data_convert_float_to_i8;
half * valsh = (half *) (buf_iw + WARP_SIZE);
const int iy0 = 8*(blockDim.y*blockIdx.y + threadIdx.y);
int8_t * qs_low = (int8_t *) y_qs_low;
float amax_row[8] = {0.0f};
#pragma unroll
for (int j = 0; j < 8; ++j) {
const int iy = iy0 + j;
float amax = 0.0f;
for (int ix0 = 0; ix0 < kx; ix0 += blockDim.x) {
const int ix = ix0 + threadIdx.x;
if (ix >= kx) {
break;
}
const float xi = x[iy*kx + ix];
amax = max(amax, fabsf(xi));
if (j < nrows_smem) {
valsh[j*kx + ix] = xi;
}
}
amax = warp_reduce_max(amax);
if (threadIdx.x % WARP_SIZE == 0) {
buf_iw[threadIdx.x / WARP_SIZE] = amax;
}
__syncthreads();
amax = buf_iw[threadIdx.x % WARP_SIZE];
amax = warp_reduce_max(amax);
amax_row[j] = amax;
if (threadIdx.x == 0) {
y_d[iy] = amax_row[j] / 127;
}
__syncthreads();
}
for (int ix0 = 0; ix0 < kx; ix0 += blockDim.x) {
const int ix = ix0 + threadIdx.x;
float rmax = 0.0f;
float valsi[8];
#pragma unroll
for (int j = 0; j < 8; ++j) {
const int iy = iy0 + j;
const float xi = ix >= kx ? 0.0f : (j < nrows_smem ? __half2float(valsh[j*kx + ix]) : x[iy*kx + ix]);
rmax = max(rmax, fabsf(xi) / amax_row[j]);
valsi[j] = xi;
}
#pragma unroll
for (int mask = 8; mask > 0; mask >>= 1) {
rmax = max(rmax, __shfl_xor_sync(0xFFFFFFFF, rmax, mask, 32));
}
const int bs = roundf(rmax * I8_MAX_FRAG_SCALE);
if (ix >= kx) {
break;
}
#pragma unroll
for (int j = 0; j < 8; ++j) {
const int iy = iy0 + j;
const float xi = valsi[j];
const int q = rmax == 0.0f ? 0 : roundf(xi * 127 / ((frag_scales ? rmax : 1.0f) * amax_row[j]));
qs_low[iy*kx + ix] = q;
}
if (frag_scales && ix % 16 == 0) {
y_bs[(iy0/8)*(kx/16) + ix/16] = bs;
}
}
}
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void k_get_rows(
const void * src0, const int32_t * src1, dst_t * dst,
@ -5116,6 +5310,248 @@ template <bool need_check> static __global__ void
#endif // __CUDA_ARCH__ >= CC_VOLTA
}
#define MMI8_PADDING 4
#define MMI8_TILE_STRIDE (2*WARP_SIZE + MMI8_PADDING)
#define MMI8_COPY_SIZE 2
#define MMI8_N_BARRIERS 8
static_assert(MMI8_N_BARRIERS < WARP_SIZE, "Max. 32 barrier support implemented.");
#define MMI8_X_AMPERE 64
#define MMI8_Y_AMPERE 144
#define MMI8_NWARPS_AMPERE 4
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 32, 8, 16, int8_t, nvcuda::wmma::row_major> frag_thin_a;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 32, 8, 16, int8_t, nvcuda::wmma::col_major> frag_thin_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 32, 8, 16, int> frag_thin_c;
typedef cuda::barrier<cuda::thread_scope::thread_scope_block> cuda_barrier;
template <int mmi8_x, int mmi8_y, int prefetch, bool need_check_x, bool need_check_y>
static __device__ __forceinline__ void load_tiles_i8(
const int * __restrict__ x_qs_low, const int * __restrict__ y_qs_low, const int * __restrict__ y_bs,
int * __restrict__ tile_x_qs, int * __restrict__ tile_y_qs,
const int ncols_x, const int nrows_x, const int nrows_y, const int ncols_y, const int k0,
cuda_barrier * barriers, const int ib_next) {
constexpr int nwarps = mmi8_y/32;
const cuda::aligned_size_t<MMI8_COPY_SIZE*sizeof(int)> as(MMI8_COPY_SIZE*sizeof(int));
const cuda::aligned_size_t<sizeof(int)> as4(sizeof(int));
if (k0 % (4*WARP_SIZE) == 0) {
const int ib_next_4 = (ib_next + 3) % MMI8_N_BARRIERS;
#pragma unroll
for (int j0 = 0; j0 < mmi8_x; j0 += 8*nwarps) {
if (j0 + 8*nwarps > mmi8_x && j0 + 8*threadIdx.y > mmi8_x-8) {
break;
}
const int j_bs = (blockIdx.y*(mmi8_x/8) + j0/8 + threadIdx.y);
const int k_bs = k0/4 + prefetch*WARP_SIZE + threadIdx.x;
const int j_tile = (ib_next_4/4)*mmi8_x + j0 + 8*threadIdx.y + threadIdx.x/4;
const int k_tile = 2*WARP_SIZE + threadIdx.x % 4;
cuda::memcpy_async(&tile_x_qs[j_tile*MMI8_TILE_STRIDE + k_tile],
&y_bs[j_bs*(nrows_y/16) + k_bs], as4, barriers[ib_next_4]);
}
}
#pragma unroll
for (int i0 = 0; i0 < mmi8_y; i0 += nwarps*MMI8_COPY_SIZE) {
const int x = MMI8_COPY_SIZE*(threadIdx.x % (WARP_SIZE/MMI8_COPY_SIZE));
const int y = MMI8_COPY_SIZE*threadIdx.y + threadIdx.x/(WARP_SIZE/MMI8_COPY_SIZE);
const int i_tile = i0 + y;
int i_qs = blockIdx.x*mmi8_y + i_tile;
if (need_check_x) {
i_qs = min(i_qs, nrows_x-1);
}
const int index_tile = i_tile * MMI8_TILE_STRIDE + (ib_next % 2)*WARP_SIZE + x;
const int index_qs = i_qs * (ncols_x/sizeof(int)) + k0 + prefetch*WARP_SIZE + x;
cuda::memcpy_async(&tile_x_qs[index_tile], &x_qs_low[index_qs], as, barriers[ib_next]);
}
#pragma unroll
for (int j0 = 0; j0 < mmi8_x; j0 += nwarps*MMI8_COPY_SIZE) {
const int x = MMI8_COPY_SIZE*(threadIdx.x % (WARP_SIZE/MMI8_COPY_SIZE));
const int y = MMI8_COPY_SIZE*threadIdx.y + threadIdx.x/(WARP_SIZE/MMI8_COPY_SIZE);
int j_tile = j0 + y;
if (j0 + nwarps*MMI8_COPY_SIZE > mmi8_x) {
j_tile = min(j_tile, mmi8_x-1);
}
int j_qs = blockIdx.y*mmi8_x + j_tile;
if (need_check_y) {
j_qs = min(j_qs, ncols_y-1);
}
const int index_tile = j_tile * MMI8_TILE_STRIDE + (ib_next % 2)*WARP_SIZE + x;
const int index_qs = j_qs * (nrows_y/sizeof(int)) + k0 + prefetch*WARP_SIZE + x;
cuda::memcpy_async(&tile_y_qs[index_tile], &y_qs_low[index_qs], as, barriers[ib_next]);
}
}
template <int mmi8_x, int mmi8_y>
static __device__ __forceinline__ void vec_dot_i8(
const int * __restrict__ tile_x_qs, const int * __restrict__ tile_y_qs, frag_thin_c * fc,
const int k0, const int ib_current) {
#pragma unroll
for (int k = 0; k < 32; k += 16/sizeof(int)) {
frag_thin_a fa;
const int ibs = ((k0/WARP_SIZE) % 4)*8 + k/4;
nvcuda::wmma::load_matrix_sync(
fa, (int8_t *) &tile_x_qs[threadIdx.y*(32*MMI8_TILE_STRIDE) + (ib_current % 2)*WARP_SIZE + k],
MMI8_TILE_STRIDE*sizeof(int));
#pragma unroll
for (int j = 0; j < mmi8_x; j += 8) {
frag_thin_b fb;
frag_thin_c fc_tmp;
const int bs = tile_x_qs[((ib_current/4)*mmi8_x + j + ibs/4)*MMI8_TILE_STRIDE + 2*WARP_SIZE + ibs % 4];
nvcuda::wmma::load_matrix_sync(
fb, (int8_t *) &tile_y_qs[j*MMI8_TILE_STRIDE + (ib_current % 2)*WARP_SIZE + k],
MMI8_TILE_STRIDE*sizeof(int));
nvcuda::wmma::fill_fragment(fc_tmp, 0);
nvcuda::wmma::mma_sync(fc_tmp, fa, fb, fc_tmp);
#pragma unroll
for (int l = 0; l < 32*8/WARP_SIZE; ++l) {
fc[j/8].x[l] += bs * fc_tmp.x[l];
}
}
}
}
// Set launch bounds based on available SRAM:
#define MMI8_LAUNCH_BOUNDS(kiB) __launch_bounds__(WARP_SIZE*mmi8_y/32, (kiB)*1024 / (512 + sizeof(int)*(mmi8_x + mmi8_y)*MMI8_TILE_STRIDE))
template <int mmi8_x, int mmi8_y, bool need_check_x, bool need_check_y>
#if __CUDA_ARCH__ >= CC_HOPPER
MMI8_LAUNCH_BOUNDS(228)
#elif __CUDA_ARCH__ >= CC_ADA_LOVELACE
MMI8_LAUNCH_BOUNDS(100)
#elif __CUDA_ARCH__ >= 870 // Jetson
MMI8_LAUNCH_BOUNDS(164)
#elif __CUDA_ARCH__ >= 860 // Ampere consumer
MMI8_LAUNCH_BOUNDS(100)
#elif __CUDA_ARCH__ >= CC_AMPERE // Ampere A100
MMI8_LAUNCH_BOUNDS(164)
#elif __CUDA_ARCH__ >= CC_TURING
MMI8_LAUNCH_BOUNDS(64)
#else // Volta, Jetson
MMI8_LAUNCH_BOUNDS(96)
#endif
static __global__ void mul_mat_i8(
const int * __restrict__ x_qs_low, const float * x_d, const int * __restrict__ y_qs_low, const int * __restrict__ y_bs,
const float * y_d, float * __restrict__ dst, const int ncols_x, const int nrows_x, const int ncols_y,
const int nrows_y, const int nrows_dst) {
// #if __CUDA_ARCH__ >= CC_VOLTA && !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
constexpr int nwarps = mmi8_y/32;
extern __shared__ char data_mmi8[];
cuda_barrier * barriers = (cuda_barrier *) data_mmi8;
int * tile_x_qs = (int *) (data_mmi8 + 512);
int * tile_y_qs = tile_x_qs + mmi8_y*MMI8_TILE_STRIDE;
if (threadIdx.x < MMI8_N_BARRIERS && threadIdx.y == 0) {
init(&barriers[threadIdx.x], nwarps*WARP_SIZE);
}
__syncthreads();
const int & ncols_dst = ncols_y;
frag_thin_c fc[mmi8_x/8];
{
constexpr int k0 = 0;
constexpr int ib_next = 0;
constexpr int prefetch = 0;
load_tiles_i8<mmi8_x, mmi8_y, prefetch, need_check_x, need_check_y>(
x_qs_low, y_qs_low, y_bs, tile_x_qs, tile_y_qs,
ncols_x, nrows_x, nrows_y, ncols_y, k0, barriers, ib_next);
}
#pragma unroll
for (int j = 0; j < mmi8_x; j += 8) {
nvcuda::wmma::fill_fragment(fc[j/8], 0);
}
for (int k0 = 0; k0 < ncols_x/sizeof(int) - WARP_SIZE; k0 += WARP_SIZE) {
const int ib_current = (k0/WARP_SIZE + 0) % MMI8_N_BARRIERS;
const int ib_next = (k0/WARP_SIZE + 1) % MMI8_N_BARRIERS;
constexpr int prefetch = 1;
load_tiles_i8<mmi8_x, mmi8_y, prefetch, need_check_x, need_check_y>(
x_qs_low, y_qs_low, y_bs, tile_x_qs, tile_y_qs,
ncols_x, nrows_x, nrows_y, ncols_y, k0, barriers, ib_next);
barriers[ib_current].arrive_and_wait();
vec_dot_i8<mmi8_x, mmi8_y>(tile_x_qs, tile_y_qs, fc, k0, ib_current);
__syncthreads();
}
{
const int k0 = ncols_x/sizeof(int) - WARP_SIZE;
const int ib_current = (k0/WARP_SIZE + 0) % MMI8_N_BARRIERS;
barriers[ib_current].arrive_and_wait();
vec_dot_i8<mmi8_x, mmi8_y>(tile_x_qs, tile_y_qs, fc, k0, ib_current);
__syncthreads();
}
int * tmp_fc = tile_x_qs + threadIdx.y*(32*8);
float * tmp_d_j = ((float *) tile_y_qs) + threadIdx.y*WARP_SIZE;
const int row_dst = blockIdx.x*mmi8_y + 32*threadIdx.y + threadIdx.x;
const float d_i = x_d[row_dst] / I8_MAX_FRAG_SCALE;
#pragma unroll
for (int j0 = 0; j0 < mmi8_x; j0 += 8) {
nvcuda::wmma::store_matrix_sync(tmp_fc, fc[j0/8], 32, nvcuda::wmma::mem_col_major);
if ((mmi8_y % WARP_SIZE != 0 && 32*threadIdx.y + threadIdx.x >= mmi8_y) || (need_check_x && row_dst >= nrows_dst)) {
continue;
}
if (j0 % WARP_SIZE == 0) {
const int col_dst = blockIdx.y*mmi8_x + j0 + threadIdx.x;
tmp_d_j[threadIdx.x] = y_d[col_dst];
}
#pragma unroll
for (int l = 0; l < 32*8; l += WARP_SIZE) {
const int col_dst = blockIdx.y*mmi8_x + j0 + l/32;
if (need_check_y && col_dst >= ncols_dst) {
continue;
}
const float d_j = tmp_d_j[(j0 + l/32) % WARP_SIZE];
dst[col_dst*nrows_dst + row_dst] = tmp_fc[l + threadIdx.x] * d_i*d_j;
}
}
// #else
// (void)x_qs_low;(void)x_d;(void)y_qs_low;(void)y_qs_high;(void)y_d;(void)dst;
// (void)ncols_x;(void)nrows_x;(void)ncols_y;(void)nrows_y;(void)nrows_dst;
// bad_arch();
// #endif // __CUDA_ARCH__ >= CC_VOLTA && !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
}
template <int qk, int qi, typename block_q_t, int vdr, vec_dot_q_cuda_t vec_dot_q_cuda>
static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst, const int ncols, const int nrows) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
@ -6273,6 +6709,99 @@ static void quantize_row_q8_1_cuda(const float * x, void * vy, const int kx, con
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx, kx_padded);
}
template <bool frag_scales>
static void convert_float_to_i8_cuda(
const float * x, int * y_qs_low, int * y_bs, float * y_d, const int kx, const int ky, cudaStream_t stream) {
GGML_ASSERT(ky % 8 == 0);
const dim3 num_blocks(1, ky/8, 1);
const dim3 block_size(1024, 1, 1);
int id;
CUDA_CHECK(cudaGetDevice(&id));
const int smempb = g_device_caps[id].smempb;
static bool smem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!smem_limit_raised[id]) {
CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8<frag_scales, 0>, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb));
CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8<frag_scales, 1>, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb));
CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8<frag_scales, 2>, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb));
CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8<frag_scales, 3>, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb));
CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8<frag_scales, 4>, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb));
CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8<frag_scales, 5>, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb));
CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8<frag_scales, 6>, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb));
CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8<frag_scales, 7>, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb));
CUDA_CHECK(cudaFuncSetAttribute(convert_float_to_i8<frag_scales, 8>, cudaFuncAttributeMaxDynamicSharedMemorySize, smempb));
}
const int nrows_smem = smempb / (WARP_SIZE*sizeof(int) + kx*sizeof(half));
switch (nrows_smem) {
case 0:
convert_float_to_i8<frag_scales, 0><<<num_blocks, block_size, WARP_SIZE*sizeof(int) + 0*kx*sizeof(half), stream>>>
(x, y_qs_low, y_bs, y_d, kx);
break;
case 1:
convert_float_to_i8<frag_scales, 1><<<num_blocks, block_size, WARP_SIZE*sizeof(int) + 1*kx*sizeof(half), stream>>>
(x, y_qs_low, y_bs, y_d, kx);
break;
case 2:
convert_float_to_i8<frag_scales, 2><<<num_blocks, block_size, WARP_SIZE*sizeof(int) + 2*kx*sizeof(half), stream>>>
(x, y_qs_low, y_bs, y_d, kx);
break;
case 3:
convert_float_to_i8<frag_scales, 3><<<num_blocks, block_size, WARP_SIZE*sizeof(int) + 3*kx*sizeof(half), stream>>>
(x, y_qs_low, y_bs, y_d, kx);
break;
case 4:
convert_float_to_i8<frag_scales, 4><<<num_blocks, block_size, WARP_SIZE*sizeof(int) + 4*kx*sizeof(half), stream>>>
(x, y_qs_low, y_bs, y_d, kx);
break;
case 5:
convert_float_to_i8<frag_scales, 5><<<num_blocks, block_size, WARP_SIZE*sizeof(int) + 5*kx*sizeof(half), stream>>>
(x, y_qs_low, y_bs, y_d, kx);
break;
case 6:
convert_float_to_i8<frag_scales, 6><<<num_blocks, block_size, WARP_SIZE*sizeof(int) + 6*kx*sizeof(half), stream>>>
(x, y_qs_low, y_bs, y_d, kx);
break;
case 7:
convert_float_to_i8<frag_scales, 7><<<num_blocks, block_size, WARP_SIZE*sizeof(int) + 7*kx*sizeof(half), stream>>>
(x, y_qs_low, y_bs, y_d, kx);
break;
default:
convert_float_to_i8<frag_scales, 8><<<num_blocks, block_size, WARP_SIZE*sizeof(int) + 8*kx*sizeof(half), stream>>>
(x, y_qs_low, y_bs, y_d, kx);
break;
}
}
static void convert_q8_0_to_i8_cuda(const void * x, int * y_qs_low, float * y_d, const int kx, const int ky, cudaStream_t stream) {
const dim3 num_blocks(1, ky, 1);
const size_t smem_vals = kx*ggml_type_size(GGML_TYPE_Q8_0)/ggml_blck_size(GGML_TYPE_Q8_0);
GGML_ASSERT(smem_vals % (4*sizeof(int)) == 0);
const size_t smem_barrier = 128; // actually only need 8 bytes but pad to 128 for alignment
const size_t smem = smem_vals + smem_barrier;
switch (kx) {
case 4096:
convert_q8_0_to_i8<128, 4096><<<num_blocks, 128, smem, stream>>>(x, y_qs_low, y_d, kx);
break;
case 5120:
convert_q8_0_to_i8<128, 5120><<<num_blocks, 128, smem, stream>>>(x, y_qs_low, y_d, kx);
break;
case 11008:
convert_q8_0_to_i8<512, 11008><<<num_blocks, 512, smem, stream>>>(x, y_qs_low, y_d, kx);
break;
case 13824:
convert_q8_0_to_i8<512, 13824><<<num_blocks, 512, smem, stream>>>(x, y_qs_low, y_d, kx);
break;
default:
fprintf(stderr, "%d\n", kx);
convert_q8_0_to_i8<256, 0><<<num_blocks, 256, smem, stream>>>(x, y_qs_low, y_d, kx);
break;
}
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * __restrict__ vx, dst_t * __restrict__ y, const int k, cudaStream_t stream) {
const int num_blocks = (k + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE);
@ -7052,6 +7581,157 @@ static void ggml_mul_mat_q5_K_q8_1_cuda(
}
}
static const uint32_t g_mmi8_configs[6] = {0x00068100, 0x00040100, 0x00080080, 0x00040080, 0x00040040, 0x00020040};
static float get_mmi8_config_score(const uint32_t config, const int nrows_x, const int ncols_y) {
const int64_t mmi8_x = (config >> 12) & 0x00000FFF;
const int64_t mmi8_y = (config >> 0) & 0x00000FFF;
const int64_t nwarps = mmi8_y/32;
const int64_t smem = 512 + sizeof(int)*(mmi8_x + mmi8_y)*MMI8_TILE_STRIDE;
int id;
CUDA_CHECK(cudaGetDevice(&id));
if (smem > g_device_caps[id].smempb) {
return 0.0f;
}
const int64_t blocks_per_sm = g_device_caps[id].smem / smem;
if (blocks_per_sm == 0) {
return 0.0f;
}
float score = 1.0f;
const int64_t grid_x = (nrows_x + mmi8_y - 1) / mmi8_y;
score *= nrows_x;
score /= grid_x*mmi8_y;
const int64_t grid_y = (ncols_y + mmi8_x - 1) / mmi8_x;
score *= ncols_y;
score /= grid_y*mmi8_x;
if (nrows_x % mmi8_y == 0) {
score *= 1.03;
}
if (ncols_y % mmi8_x == 0) {
score *= 1.03;
}
score *= mmi8_x*mmi8_y;
score /= mmi8_x*mmi8_y + 8196;
const int64_t nsm = g_device_caps[id].nsm;
const int64_t nblocks = grid_x*grid_y;
const int64_t nwaves = (nblocks + nsm*blocks_per_sm - 1) / (nsm*blocks_per_sm);
score *= nblocks;
score /= nsm*blocks_per_sm;
score /= nwaves;
if (mmi8_x > ncols_y) {
score -= mmi8_x - ncols_y;
}
return score;
}
#define MMI8_SMEM(mmi8_x, mmi8_y) (512 + sizeof(int)*(0x##mmi8_x + 0x##mmi8_y)*MMI8_TILE_STRIDE)
#define MMI8_SWITCH_CASE(mmi8_x, mmi8_y) \
case 0x00##mmi8_x##mmi8_y: \
mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, false, false> \
<<<block_nums, block_dims, MMI8_SMEM(mmi8_x, mmi8_y), stream>>> \
(x_qs_low, x_d, y_qs_low, y_bs, y_d, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); \
break; \
case 0x01##mmi8_x##mmi8_y: \
mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, true, false> \
<<<block_nums, block_dims, MMI8_SMEM(mmi8_x, mmi8_y), stream>>> \
(x_qs_low, x_d, y_qs_low, y_bs, y_d, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); \
break; \
case 0x02##mmi8_x##mmi8_y: \
mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, false, true> \
<<<block_nums, block_dims, MMI8_SMEM(mmi8_x, mmi8_y), stream>>> \
(x_qs_low, x_d, y_qs_low, y_bs, y_d, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); \
break; \
case 0x03##mmi8_x##mmi8_y: \
mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, true, true> \
<<<block_nums, block_dims, MMI8_SMEM(mmi8_x, mmi8_y), stream>>> \
(x_qs_low, x_d, y_qs_low, y_bs, y_d, dst, ncols_x, nrows_x, ncols_y, nrows_y, nrows_dst); \
break; \
#define MMI8_RAISE_SMEM_LIMIT(mmi8_x, mmi8_y) \
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, false, false>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, g_device_caps[id].smempb)); \
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, true, false>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, g_device_caps[id].smempb)); \
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, false, true>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, g_device_caps[id].smempb)); \
CUDA_CHECK(cudaFuncSetAttribute(mul_mat_i8<0x##mmi8_x, 0x##mmi8_y, true, true>, \
cudaFuncAttributeMaxDynamicSharedMemorySize, g_device_caps[id].smempb)); \
static void ggml_mul_mat_i8_cuda(
const int * x_qs_low, const float * x_d, const int * y_qs_low, const int * y_bs, const float * y_d, float * dst,
const int ncols_x, const int nrows_x, const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
int id;
CUDA_CHECK(cudaGetDevice(&id));
static bool smem_limit_raised[GGML_CUDA_MAX_DEVICES] = {false};
if (!smem_limit_raised[id]) {
MMI8_RAISE_SMEM_LIMIT(068, 100) // 104x256
MMI8_RAISE_SMEM_LIMIT(040, 100) // 64x256
MMI8_RAISE_SMEM_LIMIT(080, 080) // 128x128
MMI8_RAISE_SMEM_LIMIT(040, 080) // 64x128
MMI8_RAISE_SMEM_LIMIT(040, 040) // 64x64
MMI8_RAISE_SMEM_LIMIT(020, 040) // 32x64
smem_limit_raised[id] = true;
}
uint32_t best_config = 0;
float best_score = 0.0f;
for (uint64_t i = 0; i < sizeof(g_mmi8_configs)/sizeof(g_mmi8_configs[0]); ++i) {
const uint32_t config = g_mmi8_configs[i];
const float score = get_mmi8_config_score(config, nrows_x, ncols_y);
if (score > best_score) {
best_config = config;
best_score = score;
}
}
GGML_ASSERT(best_config != 0);
const int mmi8_x = (best_config >> 12) & 0x00000FFF;
const int mmi8_y = (best_config >> 0) & 0x00000FFF;
const int nwarps = mmi8_y/32;
GGML_ASSERT(mmi8_x <= mmi8_y); // Otherwise not enough space for fragment scales.
const int block_num_x = (nrows_x + mmi8_y - 1) / mmi8_y;
const int block_num_y = (ncols_y + mmi8_x - 1) / mmi8_x;
const dim3 block_nums(block_num_x, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, nwarps, 1);
if (nrows_x % mmi8_y != 0) {
best_config |= 0x01000000;
}
if (ncols_y % mmi8_x != 0) {
best_config |= 0x02000000;
}
switch (best_config) {
MMI8_SWITCH_CASE(068, 100) // 104x256
MMI8_SWITCH_CASE(040, 100) // 64x256
MMI8_SWITCH_CASE(080, 080) // 128x128
MMI8_SWITCH_CASE(040, 080) // 64x128
MMI8_SWITCH_CASE(040, 040) // 64x64
MMI8_SWITCH_CASE(020, 040) // 32x64
default:
GGML_ASSERT(false);
break;
}
}
static void ggml_mul_mat_q6_K_q8_1_cuda(
const void * vx, const void * vy, float * dst, const int ncols_x, const int nrows_x,
const int ncols_y, const int nrows_y, const int nrows_dst, cudaStream_t stream) {
@ -7290,7 +7970,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con
const dim3 block_nums(nrows_x, 1, 1);
const size_t shmem = (GGML_PAD(ncols_x, 2*WARP_SIZE) + WARP_SIZE)*sizeof(half);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
if (shmem <= g_device_caps[g_main_device].smpb) {
if (shmem <= g_device_caps[g_main_device].smempb) {
switch (ncols_x) {
case 32:
soft_max_f16<true, 32, 32, true><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
@ -7333,7 +8013,7 @@ static void soft_max_f32_cuda(const float * x, const float * y, float * dst, con
const dim3 block_nums(nrows_x, 1, 1);
const size_t shmem = (GGML_PAD(ncols_x, WARP_SIZE) + WARP_SIZE)*sizeof(float);
static_assert(CUDA_SOFT_MAX_BLOCK_SIZE == 1024, "These values need to be adjusted.");
if (shmem < g_device_caps[g_main_device].smpb) {
if (shmem < g_device_caps[g_main_device].smempb) {
switch (ncols_x) {
case 32:
soft_max_f32<true, 32, 32><<<block_nums, block_dims, shmem, stream>>>(x, y, dst, ncols_x, nrows_y, scale);
@ -7682,7 +8362,9 @@ GGML_CALL void ggml_init_cublas() {
#else
g_device_caps[id].cc = 100*prop.major + 10*prop.minor;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
g_device_caps[id].smpb = prop.sharedMemPerBlock;
g_device_caps[id].nsm = prop.multiProcessorCount;
g_device_caps[id].smem = prop.sharedMemPerMultiprocessor;
g_device_caps[id].smempb = prop.sharedMemPerBlockOptin;
}
for (int id = 0; id < g_device_count; ++id) {
g_default_tensor_split[id] /= total_vram;
@ -8160,6 +8842,62 @@ static void ggml_cuda_op_mul_mat_q(
(void) src1_ddf_i;
}
static void ggml_cuda_op_mul_mat_i8(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols,
const int64_t src1_padded_row_size, cudaStream_t stream) {
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
const int64_t ne0 = dst->ne[0];
const int64_t row_diff = row_high - row_low;
int id;
CUDA_CHECK(cudaGetDevice(&id));
// the main device has a larger memory buffer to hold the results from all GPUs
// nrows_dst == nrows of the matrix that the dequantize_mul_mat kernel writes into
const int64_t nrows_dst = dst->backend == GGML_BACKEND_GPU && id == g_main_device ? ne0 : row_diff;
cuda_pool_alloc<char> src0_ddi8(ne00*ne01 + ne01*sizeof(float));
int * src0_qs_low = (int *) (src0_ddi8.get() + 0);
float * src0_d = (float *) (src0_ddi8.get() + ne00*ne01);
cuda_pool_alloc<char> src1_ddi8((ne10 + sizeof(float))*src1_ncols + sizeof(int)*(ne10/16)*(src1_ncols/8));
int * src1_qs_low = (int *) (src1_ddi8.get() + 0);
int * src1_bs = (int *) (src1_ddi8.get() + ne10*src1_ncols);
float * src1_d = (float *) (src1_bs + (ne10/16) * (src1_ncols/8));
// efficient conversion to i8 only implemented for q8_0, convert to intermediary float as workaround
cuda_pool_alloc<float> src0_workaround;
to_fp32_cuda_t to_fp32 = ggml_get_to_fp32_cuda(src0->type);
switch (src0->type) {
case GGML_TYPE_Q8_0:
convert_q8_0_to_i8_cuda(src0_dd_i, src0_qs_low, src0_d, ne00, ne01, stream);
break;
default:
src0_workaround.alloc(ne00*ne01);
to_fp32(src0_dd_i, src0_workaround.get(), ne00*ne01, stream);
convert_float_to_i8_cuda<false>(src0_workaround.get(), src0_qs_low, nullptr, src0_d, ne00, ne01, stream);
break;
}
convert_float_to_i8_cuda<true>(src1_ddf_i, src1_qs_low, src1_bs, src1_d, ne10, ne11, stream);
ggml_mul_mat_i8_cuda(src0_qs_low, src0_d, src1_qs_low, src1_bs, src1_d, dst_dd_i,
ne00, row_diff, src1_ncols, ne10, nrows_dst, stream);
(void) src1;
(void) dst;
(void) src1_ddq_i;
(void) src1_padded_row_size;
}
static int64_t get_row_rounding(ggml_type type, const std::array<float, GGML_CUDA_MAX_DEVICES> & tensor_split) {
int64_t min_compute_capability = INT_MAX;
int64_t max_compute_capability = INT_MIN;
@ -9576,7 +10314,27 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
if (use_mul_mat_q) {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_q, true);
} else {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
const bool use_mmi8 = false;
#else
// const bool layer_0 = strncmp(src0->name+0, "blk.0.", 6) == 0;
// const bool layer_1 = strncmp(src0->name+0, "blk.1.", 6) == 0;
// const bool layer_2 = strncmp(src0->name+0, "blk.2.", 6) == 0;
// const bool attn_q = strncmp(src0->name+6, "attn_q", 6) == 0 || strncmp(src0->name+7, "attn_q", 6) == 0;
// const bool attn_k = strncmp(src0->name+6, "attn_k", 6) == 0 || strncmp(src0->name+7, "attn_k", 6) == 0;
// const bool attn_v = strncmp(src0->name+6, "attn_v", 6) == 0 || strncmp(src0->name+7, "attn_v", 6) == 0;
// const bool attn_output = strncmp(src0->name+6, "attn_output", 11) == 0 || strncmp(src0->name+7, "attn_output", 11) == 0;
// const bool ffn_up = strncmp(src0->name+6, "ffn_up", 6) == 0 || strncmp(src0->name+7, "ffn_up", 6) == 0;
// const bool ffn_gate = strncmp(src0->name+6, "ffn_gate", 8) == 0 || strncmp(src0->name+7, "ffn_gate", 8) == 0;
// const bool ffn_down = strncmp(src0->name+6, "ffn_down", 8) == 0 || strncmp(src0->name+7, "ffn_down", 8) == 0;
// const bool use_mmi8 = min_compute_capability >= CC_VOLTA && (attn_q || attn_k || attn_output || ffn_gate || ffn_down);
const bool use_mmi8 = min_compute_capability >= CC_VOLTA;
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)
if (use_mmi8) {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_i8, false);
} else {
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
}
}
}
} else {