Merge remote-tracking branch 'origin/master' into sl/custom-tensor-offload

This commit is contained in:
slaren 2025-02-09 00:36:06 +01:00
commit 2e54433a9d
122 changed files with 11294 additions and 4576 deletions

View file

@ -274,22 +274,25 @@ endif()
# Generate version info based on git commit.
find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH)
execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE GGML_BUILD_NUMBER
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(NOT DEFINED GGML_BUILD_NUMBER)
find_program(GIT_EXE NAMES git git.exe REQUIRED NO_CMAKE_FIND_ROOT_PATH)
execute_process(COMMAND ${GIT_EXE} rev-list --count HEAD
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE GGML_BUILD_NUMBER
OUTPUT_STRIP_TRAILING_WHITESPACE
)
if(GGML_BUILD_NUMBER EQUAL 1)
message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.")
if(GGML_BUILD_NUMBER EQUAL 1)
message(WARNING "GGML build version fixed at 1 likely due to a shallow clone.")
endif()
execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE GGML_BUILD_COMMIT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
endif()
execute_process(COMMAND ${GIT_EXE} rev-parse --short HEAD
WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}
OUTPUT_VARIABLE GGML_BUILD_COMMIT
OUTPUT_STRIP_TRAILING_WHITESPACE
)
# Capture variables prefixed with GGML_.

View file

@ -1775,7 +1775,7 @@ extern "C" {
struct ggml_tensor * a,
int k);
#define GGML_KQ_MASK_PAD 32
#define GGML_KQ_MASK_PAD 64
// q: [n_embd, n_batch, n_head, 1]
// k: [n_embd, n_kv, n_head_kv, 1]

View file

@ -989,19 +989,7 @@ ggml_backend_buffer_t ggml_backend_alloc_ctx_tensors_from_buft(struct ggml_conte
this_size = GGML_PAD(ggml_backend_buft_get_alloc_size(buft, t), alignment);
}
if (this_size > max_size) {
GGML_LOG_ERROR("%s: tensor %s is too large to fit in a %s buffer (tensor size: %zu, max buffer size: %zu)\n",
__func__, t->name,
ggml_backend_buft_name(buft),
this_size, max_size);
for (size_t i = 0; i < n_buffers; i++) {
ggml_backend_buffer_free(buffers[i]);
}
free(buffers);
return NULL;
}
if ((cur_buf_size + this_size) > max_size) {
if (cur_buf_size > 0 && (cur_buf_size + this_size) > max_size) {
// allocate tensors in the current buffer
if (!alloc_tensor_range(ctx, first, t, buft, cur_buf_size, &buffers, &n_buffers)) {
return NULL;

View file

@ -360,21 +360,15 @@ inline static int32x4_t ggml_vdotq_s32(int32x4_t acc, int8x16_t a, int8x16_t b)
#endif
#if defined(__loongarch_asx)
typedef union {
int32_t i;
float f;
} ft_union;
/* float type data load instructions */
static __m128 __lsx_vreplfr2vr_s(float val) {
ft_union fi_tmpval = {.f = val};
return (__m128)__lsx_vreplgr2vr_w(fi_tmpval.i);
static __m128 __lsx_vreplfr2vr_s(const float val) {
v4f32 res = {val, val, val, val};
return (__m128)res;
}
static __m256 __lasx_xvreplfr2vr_s(float val) {
ft_union fi_tmpval = {.f = val};
return (__m256)__lasx_xvreplgr2vr_w(fi_tmpval.i);
static __m256 __lasx_xvreplfr2vr_s(const float val) {
v8f32 res = {val, val, val, val, val, val, val, val};
return (__m256)res;
}
#endif

View file

@ -297,6 +297,90 @@ static const uint64_t table_b2b_0[1 << 8] = { B8(00, 10) }; // ( b) << 4
static const uint64_t table_b2b_1[1 << 8] = { B8(10, 00) }; // (!b) << 4
#endif
#if defined(__loongarch_sx)
static __m128i lsx_packs_w(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_w(a, 15);
tmp1 = __lsx_vsat_w(b, 15);
return __lsx_vpickev_h(tmp1, tmp);
}
static __m128i lsx_packs_h(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_h(a, 7);
tmp1 = __lsx_vsat_h(b, 7);
return __lsx_vpickev_b(tmp1, tmp);
}
static __m128i lsx_packus_h(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_hu(a, 7);
tmp1 = __lsx_vsat_hu(b, 7);
return __lsx_vpickev_b(tmp1, tmp);
}
static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
__m128i tmp1, tmp2;
tmp1 = __lsx_vmulwev_h_b(a, b);
tmp2 = __lsx_vmulwod_h_b(a, b);
return __lsx_vsadd_h(tmp1, tmp2);
}
static __m128i lsx_madd_h(__m128i a, __m128i b) {
__m128i tmp1, tmp2;
tmp1 = __lsx_vmulwev_w_h(a, b);
tmp2 = __lsx_vmulwod_w_h(a, b);
return __lsx_vadd_w(tmp1, tmp2);
}
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
v4i32 __ret = {d, c, b, a};
return (__m128i)__ret;
}
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
__m128i mask_f, zero, tmp0, tmp2, mask;
int f = 0x8f;
mask_f = __lsx_vreplgr2vr_b(f);
zero = __lsx_vldi(0);
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
return __lsx_vshuf_b(a, zero, tmp2);
}
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
__m128i tmp1 = __lsx_vpickev_h(b, a);
__m128i tmp2 = __lsx_vpickod_h(b, a);
return __lsx_vadd_h(tmp1, tmp2);
}
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
__m128i tmp1 = __lsx_vpickev_w(b, a);
__m128i tmp2 = __lsx_vpickod_w(b, a);
return __lsx_vadd_w(tmp1, tmp2);
}
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
return __lsx_vfadd_s(tmp1, tmp2);
}
static inline float hsum_float_4x4(const __m128 a, const __m128 b, const __m128 c, const __m128 d) {
__m128 res_0 =lsx_hadd_s(a, b);
__m128 res_1 =lsx_hadd_s(c, d);
__m128 res =lsx_hadd_s(res_0, res_1);
res =lsx_hadd_s(res, res);
res =lsx_hadd_s(res, res);
return ((v4f32)res)[0];
}
#endif
#if defined(__loongarch_asx)
#ifdef __clang__
@ -395,11 +479,6 @@ static __m256i lasx_set_w(int e7, int e6, int e5, int e4, int e3, int e2, int e1
return (__m256i)__ret;
}
static __m128i lsx_set_w(int32_t a, int32_t b, int32_t c, int32_t d) {
v4i32 __ret = {d, c, b, a};
return (__m128i)__ret;
}
static __m256i lasx_set_d(int64_t a, int64_t b, int64_t c, int64_t d) {
v4i64 __ret = {d, c, b, a};
return (__m256i)__ret;
@ -409,18 +488,6 @@ static __m256i lasx_insertf128( __m128i x, __m128i y) {
return lasx_set_q(x, y);
}
static __m128i lsx_shuffle_b(__m128i a, __m128i b) {
__m128i mask_f, zero, tmp0, tmp2, mask;
int f = 0x8f;
mask_f = __lsx_vreplgr2vr_b(f);
zero = __lsx_vldi(0);
tmp0 = __lsx_vand_v(b, mask_f); // get mask with low 4 bit and sign bits
tmp0 = __lsx_vori_b(tmp0, 0x10); // make each mask or with 0x10 prepare for positive
mask = __lsx_vsle_b(zero, tmp0); // if mask >= 0, set mask
tmp2 = __lsx_vand_v(tmp0, mask); // maskout the in2 < ones
return __lsx_vshuf_b(a, zero, tmp2);
}
static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
__m256i mask_f, zero, tmp0, tmp2, mask;
int f = 0x8f;
@ -434,30 +501,15 @@ static __m256i lasx_shuffle_b(__m256i a, __m256i b) {
}
static __m256i lasx_extu8_16(__m128i a) {
__m128i zero = __lsx_vldi(0);
__m128i vlo = __lsx_vilvl_b(zero, a);
__m128i vhi = __lsx_vilvh_b(zero, a);
return lasx_set_q(vhi, vlo);
return __lasx_vext2xv_hu_bu(____m256i(a));
}
static __m256i lasx_ext8_16(__m128i a) {
__m128i sign = __lsx_vslti_b(a, 0);
__m128i vlo = __lsx_vilvl_b(sign, a);
__m128i vhi = __lsx_vilvh_b(sign, a);
return lasx_set_q(vhi, vlo);
return __lasx_vext2xv_h_b(____m256i(a));
}
static __m256i lasx_ext16_32(__m128i a) {
__m256i tmp1;
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 0), 0);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 1), 1);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 2), 2);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 3), 3);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 4), 4);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 5), 5);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 6), 6);
tmp1 = __lasx_xvinsgr2vr_w(tmp1, __lsx_vpickve2gr_h(a, 7), 7);
return tmp1;
return __lasx_vext2xv_w_h(____m256i(a));
}
static __m128i lasx_extracti128( __m256i a, int pos) {
@ -482,25 +534,6 @@ static __m128 lasx_extractf128( __m256 a, int pos) {
return ret;
}
static __m128i lsx_hadd_h(__m128i a, __m128i b) {
__m128i tmp1 = __lsx_vpickev_h(b, a);
__m128i tmp2 = __lsx_vpickod_h(b, a);
return __lsx_vadd_h(tmp1, tmp2);
}
static __m128i lsx_hadd_w(__m128i a, __m128i b) {
__m128i tmp1 = __lsx_vpickev_w(b, a);
__m128i tmp2 = __lsx_vpickod_w(b, a);
return __lsx_vadd_w(tmp1, tmp2);
}
static __m128 lsx_hadd_s(__m128 a, __m128 b) {
__m128 tmp1 = (__m128)__lsx_vpickev_w((__m128i)b, (__m128i)a);
__m128 tmp2 = (__m128)__lsx_vpickod_w((__m128i)b, (__m128i)a);
return __lsx_vfadd_s(tmp1, tmp2);
}
static __m256i lasx_maddubs_h(__m256i a, __m256i b) {
__m256i tmp1, tmp2;
tmp1 = __lasx_xvmulwev_h_b(a, b);
@ -529,42 +562,6 @@ static __m256i lasx_packs_h(__m256i a, __m256i b) {
return __lasx_xvpickev_b(tmp1, tmp);
}
static __m128i lsx_packs_w(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_w(a, 15);
tmp1 = __lsx_vsat_w(b, 15);
return __lsx_vpickev_h(tmp1, tmp);
}
static __m128i lsx_packs_h(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_h(a, 7);
tmp1 = __lsx_vsat_h(b, 7);
return __lsx_vpickev_b(tmp1, tmp);
}
static __m128i lsx_packus_h(__m128i a, __m128i b) {
__m128i tmp, tmp1;
tmp = __lsx_vsat_hu(a, 7);
tmp1 = __lsx_vsat_hu(b, 7);
return __lsx_vpickev_b(tmp1, tmp);
}
static __m128i lsx_maddubs_h(__m128i a, __m128i b) {
__m128i tmp1, tmp2;
tmp1 = __lsx_vmulwev_h_b(a, b);
tmp2 = __lsx_vmulwod_h_b(a, b);
return __lsx_vsadd_h(tmp1, tmp2);
}
static __m128i lsx_madd_h(__m128i a, __m128i b) {
__m128i tmp1, tmp2;
tmp1 = __lsx_vmulwev_w_h(a, b);
tmp2 = __lsx_vmulwod_w_h(a, b);
return __lsx_vadd_w(tmp1, tmp2);
}
// multiply int8_t, add results pairwise twice
static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
// Get absolute values of x vectors
@ -580,12 +577,10 @@ static inline __m128i mul_sum_i8_pairs(const __m128i x, const __m128i y) {
// horizontally add 8 floats
static inline float hsum_float_8(const __m256 x) {
__m128 res = lasx_extractf128(x, 1);
ft_union tmp;
res = __lsx_vfadd_s(res, lasx_extractf128(x, 0));
res = __lsx_vfadd_s(res, (__m128)__lsx_vpickod_d((__m128i)res, (__m128i)res));
res = __lsx_vfadd_s(res, (__m128)__lsx_vinsgr2vr_w(__lsx_vldi(0), __lsx_vpickve2gr_w(res, 1), 0));
tmp.i = __lsx_vpickve2gr_w(res, 0);
return tmp.f;
return ((v4f32)res)[0];
}
// horizontally add 8 int32_t
@ -927,7 +922,6 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
#elif defined(__loongarch_asx)
for (int i = 0; i < nb; i++) {
ft_union fi;
__m256 v0 = (__m256)__lasx_xvld( x , 0);
__m256 v1 = (__m256)__lasx_xvld( x , 32);
__m256 v2 = (__m256)__lasx_xvld( x , 64);
@ -945,8 +939,7 @@ void quantize_row_q8_0(const float * restrict x, void * restrict vy, int64_t k)
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
__m128 tmp = max4;
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vinsgr2vr_w(tmp, __lsx_vpickve2gr_w( max4, 1 ), 0 ));
fi.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
const float max_scalar = fi.f;
const float max_scalar = ((v4f32)max4)[0];
// Quantize these floats
const float d = max_scalar / 127.f;
@ -1251,7 +1244,6 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
#elif defined(__loongarch_asx)
for (int i = 0; i < nb; i++) {
ft_union ft;
__m256 v0 = (__m256)__lasx_xvld( x , 0 );
__m256 v1 = (__m256)__lasx_xvld( x , 32 );
__m256 v2 = (__m256)__lasx_xvld( x , 64 );
@ -1269,8 +1261,7 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vpickod_d((__m128i) max4, (__m128i)max4 ) );
__m128 tmp = max4;
max4 = __lsx_vfmax_s( max4, (__m128)__lsx_vextrins_w((__m128i)tmp, (__m128i)max4, 0x10 ));
ft.i = __lsx_vpickve2gr_w( (__m128i)max4, 0 );
const float max_scalar = ft.f;
const float max_scalar = ((v4f32)max4)[0];
// Quantize these floats
const float d = max_scalar / 127.f;
@ -2232,21 +2223,22 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
}
sumf = hsum_float_8(acc);
#elif defined(__loongarch_sx)
// set constants
const __m128i low_mask = __lsx_vreplgr2vr_b(0xF);
const __m128i off = __lsx_vreplgr2vr_b(8);
// Initialize accumulator with zeros
__m128 acc_0 = __lsx_vldi(0);
__m128 acc_1 = __lsx_vldi(0);
__m128 acc_2 = __lsx_vldi(0);
__m128 acc_3 = __lsx_vldi(0);
__m128 acc_0 = (__m128)__lsx_vldi(0);
__m128 acc_1 = (__m128)__lsx_vldi(0);
__m128 acc_2 = (__m128)__lsx_vldi(0);
__m128 acc_3 = (__m128)__lsx_vldi(0);
for (; ib + 1 < nb; ib += 2) {
// Compute combined scale for the block 0 and 1
const __m128 d_0_1 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
const __m128 d_0_1 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib].d) * GGML_FP16_TO_FP32(y[ib].d) );
const __m128i tmp_0_1 = __lsx_vld((const __m128i *)x[ib].qs, 0);
@ -2264,7 +2256,7 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, size_t bs, const void * r
//_mm_prefetch(&y[ib] + 2 * sizeof(block_q8_0), _MM_HINT_T0);
// Compute combined scale for the block 2 and 3
const __m128 d_2_3 = __lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
const __m128 d_2_3 = (__m128)__lsx_vreplgr2vr_w( GGML_FP16_TO_FP32(x[ib + 1].d) * GGML_FP16_TO_FP32(y[ib + 1].d) );
const __m128i tmp_2_3 = __lsx_vld((const __m128i *)x[ib + 1].qs, 0);
@ -6141,9 +6133,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * restrict s, size_t bs, const void * r
acc_m = __lsx_vfadd_s(acc_m, (__m128)tmp1);
ft_union fi;
fi.i = __lsx_vpickve2gr_w(acc_m, 0);
*s = hsum_float_8(acc) + fi.f ;
*s = hsum_float_8(acc) + ((v4f32)acc_m)[0];
#else
const uint8_t * scales = (const uint8_t*)&utmp[0];

View file

@ -1078,29 +1078,23 @@ do { \
#define GGML_F16_STEP 32
#define GGML_F16_EPR 8
// F16 arithmetic is not supported by AVX, so we use F32 instead
// F16 arithmetic is not supported by LASX, so we use F32 instead
#define GGML_F32Cx8 __m256
#define GGML_F32Cx8_ZERO (__m256)__lasx_xvldi(0)
#define GGML_F32Cx8_SET1(x) (__m256)__lasx_xvreplgr2vr_w((x))
static inline __m256 __lasx_f32cx8_load(const ggml_fp16_t * x) {
float tmp[8];
for (int i = 0; i < 8; i++) {
tmp[i] = GGML_FP16_TO_FP32(x[i]);
}
return (__m256)__lasx_xvld(tmp, 0);
__m256i a;
memcpy(&a, x, sizeof(ggml_fp16_t) * 8);
a = __lasx_xvpermi_d(a, 0 | (1 << 4));
return __lasx_xvfcvtl_s_h(a);
}
static inline void __lasx_f32cx8_store(ggml_fp16_t * x, __m256 y) {
float arr[8];
__lasx_xvst(y, arr, 0);
for (int i = 0; i < 8; i++) {
x[i] = GGML_FP32_TO_FP16(arr[i]);
}
__m256i a = __lasx_xvfcvt_h_s(y, y);
a = __lasx_xvpermi_d(a, 0 | (2 << 2));
memcpy(x, &a, sizeof(ggml_fp16_t) * 8);
}
#define GGML_F32Cx8_LOAD(x) __lasx_f32cx8_load(x)
#define GGML_F32Cx8_STORE(x, y) __lasx_f32cx8_store(x, y)
@ -13862,9 +13856,13 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
tp->ec = GGML_STATUS_ABORTED;
}
ggml_barrier(state->threadpool);
if (node_n + 1 < cgraph->n_nodes) {
ggml_barrier(state->threadpool);
}
}
ggml_barrier(state->threadpool);
return 0;
}

View file

@ -28,7 +28,7 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h")
file(GLOB GGML_SOURCES_CUDA "*.cu")
file(GLOB SRCS "template-instances/fattn-wmma*.cu")
file(GLOB SRCS "template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})

View file

@ -61,6 +61,13 @@
#define GGML_CUDA_CC_RDNA2 (GGML_CUDA_CC_OFFSET_AMD + 0x1030) // RX 6000, minimum for dp4a
#define GGML_CUDA_CC_RDNA3 (GGML_CUDA_CC_OFFSET_AMD + 0x1100) // RX 7000, minimum for WMMA
#define GGML_CUDA_CC_IS_RDNA(cc) (cc >= GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_IS_RDNA1(cc) (cc >= GGML_CUDA_CC_RDNA1 && cc < GGML_CUDA_CC_RDNA2)
#define GGML_CUDA_CC_IS_RDNA2(cc) (cc >= GGML_CUDA_CC_RDNA2 && cc < GGML_CUDA_CC_RDNA3)
#define GGML_CUDA_CC_IS_RDNA3(cc) (cc >= GGML_CUDA_CC_RDNA3)
#define GGML_CUDA_CC_IS_GCN(cc) (cc > GGML_CUDA_CC_OFFSET_AMD && cc < GGML_CUDA_CC_CDNA)
#define GGML_CUDA_CC_IS_CDNA(cc) (cc >= GGML_CUDA_CC_CDNA && cc < GGML_CUDA_CC_RDNA1)
#define GGML_CUDA_CC_QY1 210
#define GGML_CUDA_CC_QY2 220
@ -148,7 +155,7 @@ typedef float2 dfloat2;
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#define INT8_MMA_AVAILABLE
#define NEW_MMA_AVAILABLE
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= GGML_CUDA_CC_TURING
#if !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ <= GGML_CUDA_CC_QY1)
@ -159,14 +166,24 @@ static constexpr bool fast_fp16_available(const int cc) {
return cc >= GGML_CUDA_CC_PASCAL && cc != 610;
}
// Any FP16 tensor cores are available.
static constexpr bool fp16_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_VOLTA;
}
static constexpr bool int8_mma_available(const int cc) {
// Volta technically had FP16 tensor cores but they work very differently compared to Turing and later.
static constexpr bool new_mma_available(const int cc) {
return cc < GGML_CUDA_CC_OFFSET_AMD && cc >= GGML_CUDA_CC_TURING;
}
static constexpr __device__ int ggml_cuda_get_physical_warp_size() {
#if defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
return __AMDGCN_WAVEFRONT_SIZE;
#else
return 32;
#endif // defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__)
}
[[noreturn]]
static __device__ void no_device_code(
const char * file_name, const int line, const char * function_name, const int arch, const char * arch_list) {

View file

@ -516,6 +516,114 @@ constexpr __device__ dequantize_1_f32_t get_dequantize_1_f32(ggml_type type_V) {
nullptr;
}
// The HIP compiler for some reason complains that it can't unroll a loop because of the jt*ncols + j >= ne01 conditional.
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template<int D, int ncols, int KQ_stride> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_stream_k_fixup(
float * __restrict__ dst, const float2 * __restrict__ dst_fixup, const int ne01, const int ne02, const int ne11) {
const float * dst_fixup_data = ((const float *) dst_fixup) + gridDim.x*(2*2*ncols);
const int iter_k = ne11 / KQ_stride;
const int iter_j = (ne01 + (ncols - 1)) / ncols;
const int bidx0 = blockIdx.x;
const int kbc0 = (bidx0 + 0)*iter_k*iter_j*ne02 / gridDim.x;
const int kbc0_stop = (bidx0 + 1)*iter_k*iter_j*ne02 / gridDim.x;
const bool did_not_have_any_data = kbc0 == kbc0_stop;
const bool wrote_beginning_of_tile = kbc0 % iter_k == 0;
const bool did_not_write_last = kbc0/iter_k == kbc0_stop/iter_k && kbc0_stop % iter_k != 0;
if (did_not_have_any_data || wrote_beginning_of_tile || did_not_write_last) {
return;
}
const int channel = kbc0 / (iter_k*iter_j);
const int jt = (kbc0 - channel*iter_k*iter_j) / iter_k;
dst += jt*ncols*ne02*D + channel*D;
// Load the partial result that needs a fixup:
float dst_val[ncols] = {0.0f};
float max_val[ncols] = {0.0f};
float rowsum[ncols] = {0.0f};
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
break;
}
dst_val[j] = dst[j*ne02*D + threadIdx.x];
const float2 tmp = dst_fixup[bidx0*ncols + j];
max_val[j] = tmp.x;
rowsum[j] = tmp.y;
}
// Iterate over previous blocks and compute the combined results.
// All CUDA blocks that get here must have a previous block that needs a fixup.
int bidx = bidx0 - 1;
int kbc_stop = kbc0;
while(true) {
const int kbc = bidx*iter_k*iter_j*ne02 / gridDim.x;
if (kbc == kbc_stop) { // Did not have any data.
bidx--;
kbc_stop = kbc;
continue;
}
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
break;
}
const float dst_add = dst_fixup_data[bidx*ncols*D + j*D + threadIdx.x];
const float2 tmp = dst_fixup[(gridDim.x + bidx)*ncols + j];
// Scale the current and new value accumulators depending on the max. values.
const float max_val_new = fmaxf(max_val[j], tmp.x);
const float diff_val = max_val[j] - max_val_new;
const float diff_add = tmp.x - max_val_new;
const float scale_val = diff_val >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_val) : 0.0f;
const float scale_add = diff_add >= SOFTMAX_FTZ_THRESHOLD ? expf(diff_add) : 0.0f;
dst_val[j] = scale_val*dst_val[j] + scale_add*dst_add;
rowsum[j] = scale_val*rowsum[j] + scale_add*tmp.y;
max_val[j] = max_val_new;
}
// If this block started in a previous tile we are done and don't need to combine additional partial results.
if (kbc % iter_k == 0 || kbc/iter_k < kbc0/iter_k) {
break;
}
bidx--;
kbc_stop = kbc;
}
// Write back final result:
#pragma unroll
for (int j = 0; j < ncols; ++j) {
if (jt*ncols + j >= ne01) {
return;
}
dst[j*ne02*D + threadIdx.x] = dst_val[j] / rowsum[j];
}
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif // __clang__
template<int D, int parallel_blocks> // D == head size
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(D, 1)
@ -581,10 +689,11 @@ static void on_no_fattn_vec_case(const int D) {
}
}
template <int D, int parallel_blocks>
// parallel_blocks == 0 is stream-k decomposition
template <int D, int cols_per_block, int parallel_blocks, int KQ_stride>
void launch_fattn(
ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel,
const int nwarps, const int cols_per_block, const bool need_f16_K, const bool need_f16_V
const int nwarps, const size_t nbytes_shared, const bool need_f16_K, const bool need_f16_V
) {
const ggml_tensor * Q = dst->src[0];
const ggml_tensor * K = dst->src[1];
@ -603,20 +712,23 @@ void launch_fattn(
GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding.");
GGML_ASSERT(Q->ne[3] == 1);
ggml_cuda_pool & pool = ctx.pool();
cudaStream_t main_stream = ctx.stream();
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
ggml_cuda_pool_alloc<half> K_f16(pool);
ggml_cuda_pool_alloc<half> V_f16(pool);
ggml_cuda_pool_alloc<float> dst_tmp(pool);
ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool);
char * K_data = (char *) K->data;
const char * K_data = (const char *) K->data;
size_t nb11 = K->nb[1];
size_t nb12 = K->nb[2];
size_t nb13 = K->nb[3];
char * V_data = (char *) V->data;
const char * V_data = (const char *) V->data;
size_t nb21 = V->nb[1];
size_t nb22 = V->nb[2];
size_t nb23 = V->nb[3];
@ -649,39 +761,60 @@ void launch_fattn(
nb23 = nb23*bs*sizeof(half)/ts;
}
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
const int ntiles_x = ((Q->ne[1] + cols_per_block - 1) / cols_per_block);
const int ntiles_total = ntiles_x*Q->ne[2]*Q->ne[3];
const dim3 block_dim(WARP_SIZE, nwarps, 1);
const dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]);
const int shmem = 0;
dim3 blocks_num;
if (parallel_blocks == 0) {
// For short contexts it can be faster to have the SMs work on whole tiles because this lets us skip the fixup.
const int tiles_nwaves = (ntiles_total - nsm - 1) / nsm;
const bool tiles_inefficient = 3*nsm < 2*tiles_nwaves*ntiles_total;
const bool short_context = K->ne[1] < 4096;
const int nblocks_stream_k = 2*nsm;
blocks_num.x = short_context && !tiles_inefficient ? ntiles_total : nblocks_stream_k;
blocks_num.y = 1;
blocks_num.z = 1;
dst_tmp_meta.alloc(blocks_num.x*cols_per_block * (2*2 + D) * sizeof(float));
} else {
blocks_num.x = parallel_blocks*ntiles_x;
blocks_num.y = Q->ne[2];
blocks_num.z = Q->ne[3];
if (parallel_blocks > 1) {
dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV));
dst_tmp_meta.alloc(parallel_blocks*ggml_nrows(KQV));
}
}
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) KQV->op_params + 2, sizeof(float));
memcpy(&scale, (const float *) KQV->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (logit_softcap != 0.0f) {
scale /= logit_softcap;
}
const uint32_t n_head = Q->ne[2];
const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head));
const uint32_t n_head_log2 = 1u << uint32_t(floorf(log2f(float(n_head))));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
fattn_kernel<<<blocks_num, block_dim, shmem, main_stream>>>(
fattn_kernel<<<blocks_num, block_dim, nbytes_shared, main_stream>>>(
(const char *) Q->data,
K_data,
V_data,
mask ? ((const char *) mask->data) : nullptr,
(parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr,
(parallel_blocks) > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr,
scale, max_bias, m0, m1, n_head_log2, logit_softcap,
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
@ -693,16 +826,22 @@ void launch_fattn(
);
CUDA_CHECK(cudaGetLastError());
if ((parallel_blocks) == 1) {
return;
if constexpr (parallel_blocks == 0) {
if (blocks_num.x % ntiles_total != 0) { // Fixup is only needed if the SMs work on fractional tiles.
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine = blocks_num;
flash_attn_stream_k_fixup<D, cols_per_block, KQ_stride>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
((float *) KQV->data, dst_tmp_meta.ptr, Q->ne[1], Q->ne[2], K->ne[1]);
}
} else if constexpr (parallel_blocks > 1) {
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, 0, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
}
const dim3 block_dim_combine(D, 1, 1);
const dim3 blocks_num_combine(Q->ne[1], blocks_num.y, blocks_num.z);
const int shmem_combine = 0;
flash_attn_combine_results<D, parallel_blocks>
<<<blocks_num_combine, block_dim_combine, shmem_combine, main_stream>>>
(dst_tmp.ptr, dst_tmp_meta.ptr, (float *) KQV->data);
CUDA_CHECK(cudaGetLastError());
}

View file

@ -0,0 +1,637 @@
#include "common.cuh"
#include "mma.cuh"
#include "fattn-common.cuh"
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap, bool needs_fixup, bool is_fixup>
static __device__ __forceinline__ void flash_attn_ext_f16_process_tile(
const float2 * const __restrict__ Q_f2,
const half2 * const __restrict__ K_h2,
const half2 * const __restrict__ V_h2,
const half * const __restrict__ maskh,
float2 * const __restrict__ dstk,
float2 * const __restrict__ dstk_fixup,
const float scale,
const float slope,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3,
const int jt,
const int kb0_start,
const int kb0_stop) {
#ifdef NEW_MMA_AVAILABLE
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
typedef mma_A_I16K8<half2> mma_A;
typedef mma_B_J8K8<half2> mma_B;
typedef mma_C_I16J8<float> mma_C_KQ;
typedef mma_C_I16J8<half2> mma_C_VKQ;
static_assert(nwarps*mma_B::J % ncols == 0, "bad nwarps");
constexpr int np = nwarps*mma_B::J / ncols; // Number of parallel CUDA warps per Q column.
static_assert(D % nwarps == 0, "bad D");
static_assert(KQ_stride % nwarps == 0, "bad KQ_stride");
constexpr int D2_padded = D/2 + 4; // Size of D in half2, padded to avoid shared memory bank conflicts.
extern __shared__ half2 tile_KV[]; // Temporary shared buffer for loading K/V data with KQ_stride*D logical elements.
const int stride_Q = nb01 / sizeof(float2);
const int stride_KV = nb11 / sizeof(half2);
const int stride_mask = nb31 / sizeof(half);
mma_B Q_B[D/(2*mma_B::K)];
mma_C_VKQ VKQ_C[D/mma_C_VKQ::I];
float2 KQ_rowsum = {0.0f, 0.0f};
float2 KQ_max = {-FLT_MAX/2.0f, -FLT_MAX/2.0f};
float2 KQ_max_scale = {0.0f, 0.0f};
// Temporarily load Q data into tile_KV, will be loaded into registers afterwards.
// The loading is done with decreasing granularity for D for better memory bandwidth.
const half2 scale_h2 = make_half2(scale, scale);
#pragma unroll
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_j = WARP_SIZE / stride_k;
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
break;
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps*stride_j) {
const int j = j0 + threadIdx.y*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
if (jt*ncols + j < ne01) {
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
const float2 tmp = Q_f2[(jt*ncols + j)*stride_Q + k];
tile_KV[j*D2_padded + k] = scale_h2 * make_half2(tmp.x, tmp.y);
}
} else {
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
tile_KV[j*D2_padded + k] = make_half2(0.0f, 0.0f);
}
}
}
}
__syncthreads();
{
const int j0 = (threadIdx.y / np) * mma_B::J;
#pragma unroll
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
Q_B[k0/mma_B::K].load_ldmatrix(tile_KV + j0*D2_padded + k0, D2_padded);
}
}
__syncthreads();
// Iterate over ne11 == previous tokens:
for (int kb0 = kb0_start; kb0 < kb0_stop; ++kb0) {
const int k_VKQ_0 = kb0*KQ_stride;
mma_C_KQ KQ_C[KQ_stride/(np*mma_C_KQ::I)];
// Load K data into tile with decreasing granularity for D for better memory bandwidth:
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
#pragma unroll
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_i = WARP_SIZE / stride_k;
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < KQ_stride; i_KQ_0 += nwarps*stride_i) {
const int i_KQ = i_KQ_0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
#pragma unroll
for (int k_KQ_0 = k0_start; k_KQ_0 < k0_stop; k_KQ_0 += stride_k) {
const int k_KQ = k_KQ_0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
tile_KV[i_KQ*D2_padded + k_KQ] = K_h2[(k_VKQ_0 + i_KQ)*stride_KV + k_KQ];
}
}
}
__syncthreads();
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_00 = 0; i_KQ_00 < KQ_stride; i_KQ_00 += np*mma_A::I) {
const int i_KQ_0 = i_KQ_00 + (threadIdx.y % np)*mma_A::I;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += mma_A::K) {
mma_A K_A;
K_A.load_ldmatrix(tile_KV + i_KQ_0*D2_padded + k_KQ_0, D2_padded);
KQ_C[i_KQ_00/(np*mma_A::I)].mma(K_A, Q_B[k_KQ_0/mma_A::K]);
}
}
__syncthreads();
if (use_logit_softcap) {
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int i = 0; i < KQ_stride/(np*mma_C_KQ::I); ++i) {
#pragma unroll
for (int l = 0; l < mma_C_KQ::ne; ++l) {
KQ_C[i].x[l] = logit_softcap*tanhf(KQ_C[i].x[l]);
}
}
}
if (maskh) {
static_assert(KQ_stride % (np *mma_C_KQ::I) == 0, "bad loop size");
static_assert(ncols % (nwarps/np*mma_C_KQ::J) == 0, "bad loop size");
#pragma unroll
for (int i00 = 0; i00 < KQ_stride; i00 += np*mma_C_KQ::I) {
const int i0 = i00 + (threadIdx.y % np)*mma_C_KQ::I;
#pragma unroll
for (int l = 0; l < mma_C_KQ::ne; ++l) {
const int i = i0 + mma_C_KQ::get_i(l);
const int j = (threadIdx.y / np)*mma_C_KQ::J + mma_C_KQ::get_j(l);
KQ_C[i00/(np*mma_C_KQ::I)].x[l] += slope*__half2float(maskh[j*stride_mask + k_VKQ_0 + i]);
}
}
}
// Calculate softmax for each KQ column using the current max. value.
// The divisor is stored in KQ_rowsum and will be applied at the end.
float2 KQ_max_new = KQ_max;
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
#pragma unroll
for (int l0 = 0; l0 < mma_C_KQ::ne; l0 += 2) {
KQ_max_new.x = fmaxf(KQ_max_new.x, KQ_C[k].x[l0 + 0]);
KQ_max_new.y = fmaxf(KQ_max_new.y, KQ_C[k].x[l0 + 1]);
}
}
// Values per KQ column are spread across 8 threads, does not need full warp reduce:
#pragma unroll
for (int offset = 16; offset > 2; offset >>= 1) {
KQ_max_new.x = fmaxf(KQ_max_new.x, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.x, offset, WARP_SIZE));
KQ_max_new.y = fmaxf(KQ_max_new.y, __shfl_xor_sync(0xFFFFFFFF, KQ_max_new.y, offset, WARP_SIZE));
}
{
const float2 diff = make_float2(KQ_max.x - KQ_max_new.x, KQ_max.y - KQ_max_new.y);
KQ_max_scale = make_float2(expf(diff.x), expf(diff.y));
if (diff.x <= SOFTMAX_FTZ_THRESHOLD) {
KQ_max_scale.x = 0.0f;
}
if (diff.y <= SOFTMAX_FTZ_THRESHOLD) {
KQ_max_scale.y = 0.0f;
}
KQ_max = KQ_max_new;
}
float2 KQ_rowsum_add = make_float2(0.0f, 0.0f);
static_assert(KQ_stride % (np*mma_C_KQ::I) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*mma_C_KQ::I); ++k) {
#pragma unroll
for (int l = 0; l < mma_C_KQ::ne; ++l) {
const float KQ_max_l = l % 2 == 0 ? KQ_max.x : KQ_max.y;
const float diff = KQ_C[k].x[l] - KQ_max_l;
KQ_C[k].x[l] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_C[k].x[l] = 0.0f;
}
if (l % 2 == 0) {
KQ_rowsum_add.x += KQ_C[k].x[l];
} else {
KQ_rowsum_add.y += KQ_C[k].x[l];
}
}
}
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum.x = KQ_max_scale.x*KQ_rowsum.x + KQ_rowsum_add.x;
KQ_rowsum.y = KQ_max_scale.y*KQ_rowsum.y + KQ_rowsum_add.y;
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale.x, KQ_max_scale.y);
#pragma unroll
for (int i = 0; i < D/mma_C_VKQ::I; ++i) {
#pragma unroll
for (int l = 0; l < mma_C_VKQ::ne; ++l) {
VKQ_C[i].x[l] *= KQ_max_scale_h2;
}
}
// Convert KQ C tiles into B tiles for VKQ calculation:
mma_B B[KQ_stride/(np*2*mma_B::K)];
static_assert(KQ_stride % (np*2*mma_B::K) == 0, "bad loop size");
#pragma unroll
for (int k = 0; k < KQ_stride/(np*2*mma_B::K); ++k) {
B[k] = KQ_C[k].to_mma_B();
}
// Load V data into tile with decreasing granularity for D for better memory bandwidth:
static_assert(KQ_stride % (4*nwarps) == 0, "out of bounds");
#pragma unroll
for (int stride_i : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
const int i0_start = stride_i == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_i);
const int i0_stop = D/2 - (D/2) % (1*stride_i);
const int stride_k = WARP_SIZE / stride_i;
#pragma unroll
for (int k_V_0 = 0; k_V_0 < KQ_stride; k_V_0 += nwarps*stride_k) {
const int k_V = k_V_0 + threadIdx.y*stride_k + (stride_i == WARP_SIZE ? 0 : threadIdx.x / stride_i);
#pragma unroll
for (int i_V_0 = i0_start; i_V_0 < i0_stop; i_V_0 += stride_i) {
const int i_V = i_V_0 + (stride_i == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_i);
tile_KV[k_V*D2_padded + i_V] = V_h2[(k_VKQ_0 + k_V)*stride_KV + i_V];
}
}
}
__syncthreads();
// Calculate VKQ tile:
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += mma_C_VKQ::I) {
static_assert((KQ_stride/2) % (np*mma_A::K) == 0, "bad loop size");
#pragma unroll
for (int k00 = 0; k00 < KQ_stride/2; k00 += np*mma_A::K) {
const int k0 = k00 + (threadIdx.y % np)*mma_A::K;
mma_A A;
A.load_ldmatrix_trans(tile_KV + 2*k0*D2_padded + i_VKQ_0/2, D2_padded);
VKQ_C[i_VKQ_0/mma_C_VKQ::I].mma(A, B[k00/(np*mma_A::K)]);
}
}
__syncthreads();
}
// Finally, sum up partial KQ rowsums.
// The partial sums are spread across 8 threads each, does not need full reduce.
#pragma unroll
for (int offset = 16; offset > 2; offset >>= 1) {
KQ_rowsum.x += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.x, offset, WARP_SIZE);
KQ_rowsum.y += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum.y, offset, WARP_SIZE);
}
// Write VKQ accumulators to shared memory in column-major format.
// It's faster to do small writes to shared memory, then large write to VRAM than to do small writes to VRAM.
// Also for np > 1 the combination is done via these values in shared memory.
const int j_cwd = threadIdx.y*mma_B::J + mma_B::get_j(-1); // j combine write data
#pragma unroll
for (int k0 = 0; k0 < D/2; k0 += mma_B::K) {
const mma_B B = VKQ_C[k0/mma_B::K].to_mma_B(); // Conversion of C to B matrix puts it in column-major format.
#pragma unroll
for (int l = 0; l < mma_B::ne; ++l) {
const int k = k0 + mma_B::get_k(l);
tile_KV[j_cwd*D2_padded + k] = B.x[l];
}
}
const int j_cwmo = (threadIdx.x % (2*mma_C_VKQ::J)) / mma_C_VKQ::J; // j combine write meta offset
const int j_cwm = threadIdx.y*(2*mma_C_VKQ::J) + 2*mma_C_VKQ::get_j(-1) + j_cwmo; // j combine write meta
const float2 KQ_cmr = make_float2(((const float *) &KQ_max)[j_cwmo], ((const float *) &KQ_rowsum)[j_cwmo]); // KQ combine max rowsum
if (((!needs_fixup && !is_fixup) || np > 1) && threadIdx.x < 2*mma_C_VKQ::J) {
// Use the 16 bytes of padding in each row to store the meta data: KQ max, KQ rowsum, KQ max scale.
((float2 *) tile_KV)[j_cwm*(D2_padded/2) + D/4] = KQ_cmr;
}
__syncthreads();
static_assert(np == 1 || np == 2 || np == 4, "bad np");
if (np == 1) {
// No combination is needed, the meta data can be directly written from registers to VRAM.
if (needs_fixup && threadIdx.x < mma_B::J) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
dstk_fixup_meta[j_cwm] = KQ_cmr;
}
if (is_fixup && threadIdx.x < mma_B::J) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[j_cwm] = KQ_cmr;
}
} else if (threadIdx.y % np == 0) {
// Combine the meta data for parallel warps via shared memory.
// Warps with threadIdx.y % np != 0 must NOT return early.
// All threads must return simultaneously to avoid race conditions with work on the next tile.
float * meta_j = (float *) tile_KV + (threadIdx.y*mma_B::J + threadIdx.x)*D2_padded + D/2;
float KQ_cm = -FLT_MAX/2; // KQ combine max per parallel warp.
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
KQ_cm = meta_j[0];
}
float KQ_cmn = KQ_cm; // KQ combine max new, max between all parallel warps.
#pragma unroll
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE));
}
const float KQ_cms = expf(KQ_cm - KQ_cmn); // KQ combine max scale per warp.
float KQ_crs = 0.0f; // KQ combine rowsum, scaled sum of all parallel warps.
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
KQ_crs = KQ_cms*meta_j[1];
}
#pragma unroll
for (int offset = np*mma_B::J/2; offset >= mma_B::J; offset >>= 1) {
KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE);
}
// Write back combined meta data:
if (np*mma_B::J == WARP_SIZE || threadIdx.x < np*mma_B::J) {
meta_j[0] = KQ_cmn; // Combined max. KQ values.
meta_j[1] = KQ_crs; // Combined KQ rowsums.
meta_j[2] = KQ_cms; // KQ max scales per parallel warp.
}
if (needs_fixup && threadIdx.x < mma_B::J) {
float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols;
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
if (is_fixup && threadIdx.x < mma_B::J) {
float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols;
dstk_fixup_meta[(threadIdx.y/np)*mma_B::J + threadIdx.x] = make_float2(KQ_cmn, KQ_crs);
}
}
if (np > 1) {
__syncthreads();
}
if (np == 1 || threadIdx.y % np == 0) {
// The first 2*2*gridDim.x*ncols floats in dstk_fixup are for storing max. values and row sums.
// The values after that are for the partial results of the individual blocks.
float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(D/2));
#pragma unroll
for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) {
const int k0_start = stride_k == WARP_SIZE ? 0 : D/2 - (D/2) % (2*stride_k);
const int k0_stop = D/2 - (D/2) % (1*stride_k);
const int stride_j = WARP_SIZE / stride_k;
if (nwarps*stride_j > ncols && threadIdx.y*stride_j >= ncols) {
break;
}
#pragma unroll
for (int j0_dst = 0; j0_dst < ncols; j0_dst += (nwarps/np)*stride_j) {
const int j_dst = j0_dst + (threadIdx.y/np)*stride_j + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k);
const int j_tile_KV = (j_dst/mma_B::J)*(np*mma_B::J) + j_dst % mma_B::J;
if (!is_fixup && jt*ncols + j_dst >= ne01) {
continue;
}
const float * meta_j = (const float *) tile_KV + j_tile_KV*D2_padded + D/2;
#pragma unroll
for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) {
const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k);
float2 dstk_val = make_float2(0.0f, 0.0f);
#pragma unroll
for (int ip = 0; ip < np; ++ip) {
const float KQ_crs = np == 1 ? 1.0f : meta_j[ip*mma_B::J*D2_padded + 2];
const float2 dstk_val_add = __half22float2(tile_KV[(j_tile_KV + ip*mma_B::J)*D2_padded + k]);
dstk_val.x += dstk_val_add.x*KQ_crs;
dstk_val.y += dstk_val_add.y*KQ_crs;
}
if (!needs_fixup && !is_fixup) {
const float KQ_rowsum_j = meta_j[1];
dstk_val.x /= KQ_rowsum_j;
dstk_val.y /= KQ_rowsum_j;
}
if (is_fixup) {
dstk_fixup_data[j_dst*(D/2) + k] = dstk_val;
} else {
dstk[(jt*ncols + j_dst)*ne02*(D/2) + k] = dstk_val;
}
}
}
}
}
if (np > 1) {
__syncthreads();
}
#else
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
template<int D, int ncols, int nwarps, int KQ_stride, bool use_logit_softcap>
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 2)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}
static_assert(FATTN_KQ_STRIDE % KQ_stride == 0, "bad KQ_stride");
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const int iter_k = ne11 / KQ_stride;
const int iter_j = (ne01 + (ncols - 1)) / ncols;
// kbc == k block continuous, current index in continuous ijk space.
int kbc = (blockIdx.x + 0)*iter_k*iter_j*ne02 / gridDim.x;
const int kbc_stop = (blockIdx.x + 1)*iter_k*iter_j*ne02 / gridDim.x;
// If the seams of 2 CUDA blocks fall within an output tile their results need to be combined.
// For this we need to track both the block that starts the tile (needs_fixup) and the block that finishes the tile (is_fixup).
// In the most general case >2 seams can fall into the same tile.
// kb0 == k start index when in the output tile.
int kb0_start = kbc % iter_k;
int kb0_stop = min(iter_k, kb0_start + kbc_stop - kbc);
while (kbc < kbc_stop && kb0_stop == iter_k) {
const int channel = kbc / (iter_k*iter_j);
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(D/2);
const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer.
if (kb0_start == 0) {
constexpr bool needs_fixup = false; // CUDA block is working on an entire tile.
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
jt, kb0_start, kb0_stop);
} else {
constexpr bool needs_fixup = true; // CUDA block is working on the beginning of a tile.
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
jt, kb0_start, kb0_stop);
}
kbc += iter_k;
kbc -= kbc % iter_k;
kb0_start = 0;
kb0_stop = min(iter_k, kbc_stop - kbc);
}
if (kbc >= kbc_stop) {
return;
}
const int channel = kbc / (iter_k*iter_j);
const int jt = (kbc - channel*iter_k*iter_j) / iter_k; // j index of current tile.
const float2 * Q_f2 = (const float2 *) (Q + nb02* channel);
const half2 * K_h2 = (const half2 *) (K + nb12*(channel / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb12*(channel / gqa_ratio)); // K and V have same shape
const half * maskh = mask ? (const half *) mask + (nb31/sizeof(half))*jt*ncols : nullptr;
float2 * dstk = ((float2 *) dst) + channel*(D/2);
const float slope = get_alibi_slope(max_bias, channel, n_head_log2, m0, m1);
constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks.
constexpr bool needs_fixup = false;
flash_attn_ext_f16_process_tile<D, ncols, nwarps, KQ_stride, use_logit_softcap, needs_fixup, is_fixup>
(Q_f2, K_h2, V_h2, maskh, dstk, dst_meta, scale, slope, logit_softcap,
ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne31, nb31, nb01, nb02, nb03, nb11, nb12, nb13, nb21, nb22, nb23, ne0, ne1, ne2, ne3,
jt, kb0_start, kb0_stop);
}
template <int D, int cols_per_block>
void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
typedef mma_A_I16K8<half2> mma_A;
typedef mma_B_J8K8<half2> mma_B;
static_assert(D % mma_B::K == 0, "bad D");
static_assert(cols_per_block % mma_B::J == 0, "bad cols_per_block");
const ggml_tensor * KQV = dst;
constexpr int KQ_stride = D <= 128 ? 64 : 32;
constexpr int nwarps = (KQ_stride == 32 && cols_per_block <= 16) ?
cols_per_block/mma_B::J * KQ_stride/mma_A::I : (cols_per_block <= 8 ? 4 : 8);
constexpr size_t nbytes_shared = std::max(KQ_stride, nwarps*mma_B::J) * (D + 8) * sizeof(half);
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<D, cols_per_block, nwarps, KQ_stride, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, 0, KQ_stride>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
}
#define DECL_FATTN_MMA_F16_CASE(D, cols_per_block) \
template void ggml_cuda_flash_attn_ext_mma_f16_case \
<D, cols_per_block>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
extern DECL_FATTN_MMA_F16_CASE( 64, 8);
extern DECL_FATTN_MMA_F16_CASE( 80, 8);
extern DECL_FATTN_MMA_F16_CASE( 96, 8);
extern DECL_FATTN_MMA_F16_CASE(112, 8);
extern DECL_FATTN_MMA_F16_CASE(128, 8);
extern DECL_FATTN_MMA_F16_CASE(256, 8);
extern DECL_FATTN_MMA_F16_CASE( 64, 16);
extern DECL_FATTN_MMA_F16_CASE( 80, 16);
extern DECL_FATTN_MMA_F16_CASE( 96, 16);
extern DECL_FATTN_MMA_F16_CASE(112, 16);
extern DECL_FATTN_MMA_F16_CASE(128, 16);
extern DECL_FATTN_MMA_F16_CASE(256, 16);
extern DECL_FATTN_MMA_F16_CASE( 64, 32);
extern DECL_FATTN_MMA_F16_CASE( 80, 32);
extern DECL_FATTN_MMA_F16_CASE( 96, 32);
extern DECL_FATTN_MMA_F16_CASE(112, 32);
extern DECL_FATTN_MMA_F16_CASE(128, 32);
extern DECL_FATTN_MMA_F16_CASE(256, 32);
extern DECL_FATTN_MMA_F16_CASE( 64, 64);
extern DECL_FATTN_MMA_F16_CASE( 80, 64);
extern DECL_FATTN_MMA_F16_CASE( 96, 64);
extern DECL_FATTN_MMA_F16_CASE(112, 64);
extern DECL_FATTN_MMA_F16_CASE(128, 64);
extern DECL_FATTN_MMA_F16_CASE(256, 64);

View file

@ -45,7 +45,17 @@ static __global__ void flash_attn_tile_ext_f16(
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
#ifdef FP16_MMA_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
@ -288,16 +298,18 @@ void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");

View file

@ -48,7 +48,12 @@ static __global__ void flash_attn_tile_ext_f32(
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
#ifdef FP16_MMA_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
@ -287,16 +292,18 @@ void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor *
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, parallel_blocks, use_logit_softcap>;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, true, true);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");

View file

@ -42,6 +42,12 @@ static __global__ void flash_attn_vec_ext_f16(
const int ne2,
const int ne3) {
#ifdef FP16_AVAILABLE
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
@ -303,7 +309,8 @@ void ggml_cuda_flash_attn_ext_vec_f16_case_impl(ggml_backend_cuda_context & ctx,
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
}
template <int D, ggml_type type_K, ggml_type type_V>

View file

@ -41,6 +41,11 @@ static __global__ void flash_attn_vec_ext_f32(
const int ne1,
const int ne2,
const int ne3) {
#ifndef FLASH_ATTN_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
@ -284,7 +289,8 @@ void ggml_cuda_flash_attn_ext_vec_f32_case_impl(ggml_backend_cuda_context & ctx,
fattn_kernel_t fattn_kernel = flash_attn_vec_ext_f32<D, cols_per_block, parallel_blocks, type_K, type_V, use_logit_softcap>;
constexpr bool need_f16_K = D != 128;
constexpr bool need_f16_V = D != 128 && D != 64;
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, need_f16_K, need_f16_V);
constexpr size_t nbytes_shared = 0;
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, nbytes_shared, need_f16_K, need_f16_V);
}
template <int D, ggml_type type_K, ggml_type type_V>

View file

@ -0,0 +1,648 @@
// Old and deprecated WMMA FlashAttention implementation.
// It is still needed for Volta since the memory layout of NVIDIA tensor cores changed with Turing.
// Long-term the WMMA code should be replaced with a dedicated Volta implementation.
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-wmma-f16.cuh"
#ifdef FP16_MMA_AVAILABLE
#include <mma.h>
#endif // FP16_MMA_AVAILABLE
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
#if __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
constexpr int D_padded = D + 8;
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
frag_b Q_b[D/16][ncols/frag_n];
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
constexpr int mem_KQ = ncols*kqs_padded*kqar;
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
float * KQ_f = (float *) KQ;
half2 * KQ2 = (half2 *) KQ;
float KQ_rowsum_f[ncols/nwarps] = {0.0f};
float KQ_max_f[ncols/nwarps];
float KQ_max_scale_f[ncols/nwarps] = {0.0f};
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
KQ_max_f[j] = -FLT_MAX/2.0f;
}
half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
half2 KQ_max_h2[ncols/nwarps];
half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
}
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
half2 * VKQ2 = (half2 *) VKQ;
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
break;
}
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
}
}
// Convert Q to half and apply scale, temporarily store in KQ:
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
break;
}
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
}
}
__syncthreads();
// Load Q into tensor core fragments/registers since it will be used frequently:
#pragma unroll
for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
}
}
__syncthreads();
// Iterate over ne11 == previous tokens:
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
frag_c_KQ KQ_c[ncols/frag_n];
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
}
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
}
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
}
}
__syncthreads();
// Calculate softmax for each KQ column using the current max. value.
// The divisor is stored in KQ_rowsum and will be applied at the end.
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (std::is_same<KQ_acc_t, float>::value) {
float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
if (use_logit_softcap) {
KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
}
}
float KQ_max_new = KQ_max_f[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
}
KQ_max_new = warp_reduce_max(KQ_max_new);
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
KQ_max_scale_f[j0/nwarps] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_max_scale_f[j0/nwarps] = 0.0f;
}
KQ_max_f[j0/nwarps] = KQ_max_new;
float KQ_rowsum_add = 0.0f;
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
}
KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
}
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
} else {
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
if (use_logit_softcap) {
// There is no dedicated tangens hyperbolicus function for half2.
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
}
}
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
}
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
KQ_max_h2[j0/nwarps] = KQ_max_new;
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
}
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
}
}
__syncthreads();
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
nvcuda::wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded);
}
}
frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
}
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
}
}
}
__syncthreads();
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, nvcuda::wmma::mem_col_major);
}
}
__syncthreads();
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
half2 VKQ_scale;
if (std::is_same<KQ_acc_t, float>::value) {
VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
} else {
VKQ_scale = KQ_max_scale_h2[j0/nwarps];
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
break;
}
half2 VKQ_add = make_half2(0.0f, 0.0f);
#pragma unroll
for (int l = 0; l < VKQ_ratio; ++l) {
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
}
VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
}
}
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j_VKQ = j0 + threadIdx.y;
if (ic0 + j_VKQ >= ne01) {
return;
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
float KQ_rowsum_j;
if (std::is_same<KQ_acc_t, float>::value) {
KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
} else {
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
}
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
break;
}
float dst_val = VKQ[j_VKQ*D_padded + i];
if (parallel_blocks == 1) {
dst_val /= KQ_rowsum_j;
}
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
}
if (parallel_blocks == 1 || threadIdx.x != 0) {
continue;
}
float2 dst_meta_val;
if (std::is_same<KQ_acc_t, float>::value) {
dst_meta_val.x = KQ_max_f[j0/nwarps];
} else {
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
}
dst_meta_val.y = KQ_rowsum_j;
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
}
#else
NO_DEVICE_CODE;
#endif // __CUDA_ARCH__ == GGML_CUDA_CC_VOLTA
}
constexpr int get_max_power_of_2(int x) {
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
}
static_assert(get_max_power_of_2(1) == 1, "Test failed.");
static_assert(get_max_power_of_2(2) == 2, "Test failed.");
static_assert(get_max_power_of_2(4) == 4, "Test failed.");
static_assert(get_max_power_of_2(6) == 2, "Test failed.");
// Number of VKQ rows calculated in parallel:
constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
}
static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
template <int D, int cols_per_block, typename KQ_acc_t>
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
constexpr int nwarps = 4;
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (4*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 4;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
return;
}
if (2*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 2;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
return;
}
constexpr int parallel_blocks = 1;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, cols_per_block, parallel_blocks, -1>(ctx, dst, fattn_kernel, nwarps, 0, true, true);
}
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
if (prec != GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
constexpr int cols_per_block = 16;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
} else {
constexpr int cols_per_block = 32;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
// case 256:
// ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
// break;
default:
GGML_ABORT("fatal error");
break;
}
}
return;
}
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
constexpr int cols_per_block = 8;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 16;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
return;
}
constexpr int cols_per_block = 32;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
}

View file

@ -1,543 +1,3 @@
#include "common.cuh"
#include "fattn-common.cuh"
#ifdef FP16_MMA_AVAILABLE
#include <mma.h>
#endif // FP16_MMA_AVAILABLE
// D == head size, VKQ_stride == num VKQ rows calculated in parallel:
template<int D, int ncols, int nwarps, int VKQ_stride, int parallel_blocks, typename KQ_acc_t, bool use_logit_softcap>
#if !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
__launch_bounds__(nwarps*WARP_SIZE, 1)
#endif // !(defined(GGML_USE_HIP) && defined(__HIP_PLATFORM_AMD__))
static __global__ void flash_attn_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int ne00,
const int ne01,
const int ne02,
const int ne03,
const int ne10,
const int ne11,
const int ne12,
const int ne13,
const int ne31,
const int nb31,
const int nb01,
const int nb02,
const int nb03,
const int nb11,
const int nb12,
const int nb13,
const int nb21,
const int nb22,
const int nb23,
const int ne0,
const int ne1,
const int ne2,
const int ne3) {
#ifdef FP16_MMA_AVAILABLE
// Skip unused kernel variants for faster compilation:
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = ncols*(blockIdx.x / parallel_blocks); // Index of the first Q/QKV column to work on.
const int ip = blockIdx.x % parallel_blocks; // Index in group of blocks running for the same column in parallel.
static_assert(D <= FATTN_KQ_STRIDE, "D must be <= FATTN_KQ_STRIDE.");
static_assert(ncols == 8 || ncols % 16 == 0, "ncols must be 8 or a multiple of 16.");
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::row_major> frag_a_K;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_a_V;
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, frag_m, frag_n, 16, half, nvcuda::wmma::col_major> frag_b;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
static_assert(VKQ_ratio <= nwarps, "VKQ_ratio must be <= nwarps.");
// Pad internal representation of KQ, KQV to reduce shared memory bank conflicts:
constexpr int D_padded = D + 8;
constexpr int kqs_padded = FATTN_KQ_STRIDE + 8;
constexpr int kqar = sizeof(KQ_acc_t)/sizeof(half);
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float * Q_f = (const float *) (Q + nb02* blockIdx.y + nb01*ic0);
const half * K_h = (const half *) (K + nb12*(blockIdx.y / gqa_ratio));
const half * V_h = (const half *) (V + nb12*(blockIdx.y / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) mask + (nb31/sizeof(half))* ic0;
const half2 * mask2 = (const half2 *) mask + (nb31/sizeof(half))*(ic0/2);
const int stride_Q = nb01 / sizeof(float);
const int stride_KV = nb11 / sizeof(half);
const float slopef = get_alibi_slope(max_bias, blockIdx.y, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
const half2 slope2 = make_half2(slopef, slopef);
const half2 logit_softcap_2 = make_half2(logit_softcap, logit_softcap);
frag_b Q_b[D/16][ncols/frag_n];
// A single buffer for temporarily holding tiles of KQ and VKQ parts:
constexpr int mem_KQ = ncols*kqs_padded*kqar;
constexpr int mem_VKQ_parts = VKQ_ratio*ncols*D_padded;
__shared__ half KQ[mem_KQ >= mem_VKQ_parts ? mem_KQ : mem_VKQ_parts];
float * KQ_f = (float *) KQ;
half2 * KQ2 = (half2 *) KQ;
float KQ_rowsum_f[ncols/nwarps] = {0.0f};
float KQ_max_f[ncols/nwarps];
float KQ_max_scale_f[ncols/nwarps] = {0.0f};
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
KQ_max_f[j] = -FLT_MAX/2.0f;
}
half2 KQ_rowsum_h2[ncols/nwarps] = {{0.0f, 0.0f}};
half2 KQ_max_h2[ncols/nwarps];
half2 KQ_max_scale_h2[ncols/nwarps] = {{0.0f, 0.0f}};
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
KQ_max_h2[j] = make_half2(-HALF_MAX_HALF, -HALF_MAX_HALF);
}
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
half2 * VKQ2 = (half2 *) VKQ;
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
break;
}
VKQ2[j*(D_padded/2) + i] = make_half2(0.0f, 0.0f);
}
}
// Convert Q to half and apply scale, temporarily store in KQ:
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
break;
}
KQ[j*D_padded + i] = ic0 + j < ne01 ? Q_f[j*stride_Q + i] * scale : 0.0f;
}
}
__syncthreads();
// Load Q into tensor core fragments/registers since it will be used frequently:
#pragma unroll
for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
}
}
__syncthreads();
// Iterate over ne11 == previous tokens:
for (int k_VKQ_0 = ip*FATTN_KQ_STRIDE; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*FATTN_KQ_STRIDE) {
// Calculate tile of KQ:
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE; i_KQ_0 += KQ_stride_tc) {
frag_c_KQ KQ_c[ncols/frag_n];
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(KQ_c[j], 0.0f);
}
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
nvcuda::wmma::load_matrix_sync(K_a, K_h + (k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
}
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync((KQ_acc_t *) KQ + j0*kqs_padded + i_KQ_0 + frag_m*threadIdx.y, KQ_c[j0/frag_n], kqs_padded, nvcuda::wmma::mem_col_major);
}
}
__syncthreads();
// Calculate softmax for each KQ column using the current max. value.
// The divisor is stored in KQ_rowsum and will be applied at the end.
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (std::is_same<KQ_acc_t, float>::value) {
float KQ_f_tmp[FATTN_KQ_STRIDE / WARP_SIZE];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] = KQ_f[j*kqs_padded + k];
if (use_logit_softcap) {
KQ_f_tmp[k0/WARP_SIZE] = logit_softcap*tanhf(KQ_f_tmp[k0/WARP_SIZE]);
}
}
float KQ_max_new = KQ_max_f[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f;
KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]);
}
KQ_max_new = warp_reduce_max(KQ_max_new);
const float diff = KQ_max_f[j0/nwarps] - KQ_max_new;
KQ_max_scale_f[j0/nwarps] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_max_scale_f[j0/nwarps] = 0.0f;
}
KQ_max_f[j0/nwarps] = KQ_max_new;
float KQ_rowsum_add = 0.0f;
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
const float diff = KQ_f_tmp[k0/WARP_SIZE] - KQ_max_f[j0/nwarps];
KQ_f_tmp[k0/WARP_SIZE] = expf(diff);
if (diff <= SOFTMAX_FTZ_THRESHOLD) {
KQ_f_tmp[k0/WARP_SIZE] = 0.0f;
}
KQ_rowsum_add += KQ_f_tmp[k0/WARP_SIZE];
KQ[j*(kqar*kqs_padded) + k] = KQ_f_tmp[k0/WARP_SIZE];
}
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_f[j0/nwarps] = KQ_max_scale_f[j0/nwarps]*KQ_rowsum_f[j0/nwarps] + KQ_rowsum_add;
} else {
half2 KQ2_tmp[FATTN_KQ_STRIDE/(2*WARP_SIZE)];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] = KQ2[j*(kqs_padded/2) + k];
if (use_logit_softcap) {
// There is no dedicated tangens hyperbolicus function for half2.
KQ2_tmp[k0/WARP_SIZE] = h2exp(KQ2_tmp[k0/WARP_SIZE]*make_half2(2.0f, 2.0f));
KQ2_tmp[k0/WARP_SIZE] = (KQ2_tmp[k0/WARP_SIZE] - make_half2(1.0f, 1.0f))
/(KQ2_tmp[k0/WARP_SIZE] + make_half2(1.0f, 1.0f));
KQ2_tmp[k0/WARP_SIZE] *= logit_softcap_2;
}
}
half2 KQ_max_new = KQ_max_h2[j0/nwarps];
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f);
KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]);
}
KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new))));
const half2 diff = KQ_max_h2[j0/nwarps] - KQ_max_new;
KQ_max_scale_h2[j0/nwarps] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint32_t *) &KQ_max_scale_h2[j0/nwarps]) &= ftz_mask;
KQ_max_h2[j0/nwarps] = KQ_max_new;
half2 KQ_rowsum_add = make_half2(0.0f, 0.0f);
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) {
const int k = k0 + threadIdx.x;
const half2 diff = KQ2_tmp[k0/WARP_SIZE] - KQ_max_h2[j0/nwarps];
KQ2_tmp[k0/WARP_SIZE] = h2exp(diff);
const uint32_t ftz_mask = __hgt2_mask(diff, make_half2(SOFTMAX_FTZ_THRESHOLD, SOFTMAX_FTZ_THRESHOLD));
*((uint32_t *) &KQ2_tmp[k0/WARP_SIZE]) &= ftz_mask;
KQ_rowsum_add += KQ2_tmp[k0/WARP_SIZE];
KQ2[j*(kqs_padded/2) + k] = KQ2_tmp[k0/WARP_SIZE];
}
KQ_rowsum_add = warp_reduce_sum(KQ_rowsum_add);
// Scale previous KQ_rowsum to account for a potential increase in KQ_max:
KQ_rowsum_h2[j0/nwarps] = KQ_max_scale_h2[j0/nwarps]*KQ_rowsum_h2[j0/nwarps] + KQ_rowsum_add;
}
}
__syncthreads();
frag_b KQ_b[FATTN_KQ_STRIDE/(VKQ_ratio*16)][ncols/frag_n];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
nvcuda::wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded);
}
}
frag_c_VKQ VKQ_c[D/VKQ_stride][ncols/frag_n];
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D; i_VKQ_0 += VKQ_stride) {
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::fill_fragment(VKQ_c[i_VKQ_0/VKQ_stride][j], 0.0f);
}
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += VKQ_ratio*16) {
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
nvcuda::wmma::load_matrix_sync(v_a, V_h + (k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
nvcuda::wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
}
}
}
__syncthreads();
const int offset_k = (threadIdx.y % VKQ_ratio) * (ncols*D_padded);
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += VKQ_stride) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
nvcuda::wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, nvcuda::wmma::mem_col_major);
}
}
__syncthreads();
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
half2 VKQ_scale;
if (std::is_same<KQ_acc_t, float>::value) {
VKQ_scale = make_half2(KQ_max_scale_f[j0/nwarps], KQ_max_scale_f[j0/nwarps]);
} else {
VKQ_scale = KQ_max_scale_h2[j0/nwarps];
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D/2 && i >= D/2) {
break;
}
half2 VKQ_add = make_half2(0.0f, 0.0f);
#pragma unroll
for (int l = 0; l < VKQ_ratio; ++l) {
VKQ_add += KQ2[l*(ncols*D_padded/2) + j*(D_padded/2) + i];
}
VKQ2[j*(D_padded/2) + i] = VKQ_scale*VKQ2[j*(D_padded/2) + i] + VKQ_add;
}
}
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j_VKQ = j0 + threadIdx.y;
if (ic0 + j_VKQ >= ne01) {
return;
}
const int j_dst = (ic0 + j_VKQ)*parallel_blocks + ip;
float KQ_rowsum_j;
if (std::is_same<KQ_acc_t, float>::value) {
KQ_rowsum_j = KQ_rowsum_f[j0/nwarps];
} else {
KQ_rowsum_j = __low2float(KQ_rowsum_h2[j0/nwarps]) + __high2float(KQ_rowsum_h2[j0/nwarps]);
}
#pragma unroll
for (int i0 = 0; i0 < D; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE > D && i >= D) {
break;
}
float dst_val = VKQ[j_VKQ*D_padded + i];
if (parallel_blocks == 1) {
dst_val /= KQ_rowsum_j;
}
dst[j_dst*gridDim.y*D + blockIdx.y*D + i] = dst_val;
}
if (parallel_blocks == 1 || threadIdx.x != 0) {
continue;
}
float2 dst_meta_val;
if (std::is_same<KQ_acc_t, float>::value) {
dst_meta_val.x = KQ_max_f[j0/nwarps];
} else {
dst_meta_val.x = __low2float(KQ_max_h2[j0/nwarps]);
}
dst_meta_val.y = KQ_rowsum_j;
dst_meta[(ic0 + j_VKQ)*gridDim.y*parallel_blocks + blockIdx.y*parallel_blocks + ip] = dst_meta_val;
}
#else
NO_DEVICE_CODE;
#endif // FP16_MMA_AVAILABLE
}
constexpr int get_max_power_of_2(int x) {
return x % 2 == 0 ? 2*get_max_power_of_2(x/2) : 1;
}
static_assert(get_max_power_of_2(1) == 1, "Test failed.");
static_assert(get_max_power_of_2(2) == 2, "Test failed.");
static_assert(get_max_power_of_2(4) == 4, "Test failed.");
static_assert(get_max_power_of_2(6) == 2, "Test failed.");
// Number of VKQ rows calculated in parallel:
constexpr int get_VKQ_stride(int D, int nwarps, int frag_m) {
return (get_max_power_of_2(D/frag_m) < nwarps ? get_max_power_of_2(D/frag_m) : nwarps)*frag_m;
}
static_assert(get_VKQ_stride(128, 1, 32) == 32, "Test failed.");
static_assert(get_VKQ_stride(128, 2, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride(128, 4, 32) == 128, "Test failed.");
static_assert(get_VKQ_stride( 64, 1, 32) == 32, "Test failed.");
static_assert(get_VKQ_stride( 64, 2, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride( 64, 4, 32) == 64, "Test failed.");
static_assert(get_VKQ_stride( 80, 1, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 2, 16) == 16, "Test failed.");
static_assert(get_VKQ_stride( 80, 4, 16) == 16, "Test failed.");
template <int D, int cols_per_block, typename KQ_acc_t>
void ggml_cuda_flash_attn_ext_wmma_f16_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
constexpr int nwarps = 4;
constexpr int frag_m = cols_per_block == 8 && D % 32 == 0 ? 32 : 16;
const int blocks_num_pb1 = ((Q->ne[1] + cols_per_block - 1) / cols_per_block)*Q->ne[2]*Q->ne[3];
const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm;
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (4*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 4;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
return;
}
if (2*blocks_num_pb1 < 2*nsm) {
constexpr int parallel_blocks = 2;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
return;
}
constexpr int parallel_blocks = 1;
fattn_kernel_t fattn_kernel;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
} else {
constexpr bool use_logit_softcap = true;
fattn_kernel = flash_attn_ext_f16<
D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t, use_logit_softcap>;
}
launch_fattn<D, parallel_blocks>(ctx, dst, fattn_kernel, nwarps, cols_per_block, true, true);
}
#define DECL_FATTN_WMMA_F16_CASE(D, cols_per_block, KQ_acc_t) \
template void ggml_cuda_flash_attn_ext_wmma_f16_case \
<D, cols_per_block, KQ_acc_t>(ggml_backend_cuda_context & ctx, ggml_tensor * dst) \
extern DECL_FATTN_WMMA_F16_CASE( 64, 16, float);
extern DECL_FATTN_WMMA_F16_CASE( 80, 16, float);
extern DECL_FATTN_WMMA_F16_CASE( 96, 16, float);
extern DECL_FATTN_WMMA_F16_CASE(112, 16, float);
extern DECL_FATTN_WMMA_F16_CASE(128, 16, float);
extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
extern DECL_FATTN_WMMA_F16_CASE( 64, 32, float);
extern DECL_FATTN_WMMA_F16_CASE( 80, 32, float);
extern DECL_FATTN_WMMA_F16_CASE( 96, 32, float);
extern DECL_FATTN_WMMA_F16_CASE(112, 32, float);
extern DECL_FATTN_WMMA_F16_CASE(128, 32, float);
// extern DECL_FATTN_WMMA_F16_CASE(256, 16, float);
extern DECL_FATTN_WMMA_F16_CASE( 64, 8, half);
extern DECL_FATTN_WMMA_F16_CASE( 96, 8, half);
extern DECL_FATTN_WMMA_F16_CASE(128, 8, half);
extern DECL_FATTN_WMMA_F16_CASE(256, 8, half);
extern DECL_FATTN_WMMA_F16_CASE( 64, 16, half);
extern DECL_FATTN_WMMA_F16_CASE( 80, 16, half);
extern DECL_FATTN_WMMA_F16_CASE( 96, 16, half);
extern DECL_FATTN_WMMA_F16_CASE(112, 16, half);
extern DECL_FATTN_WMMA_F16_CASE(128, 16, half);
extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
extern DECL_FATTN_WMMA_F16_CASE( 64, 32, half);
extern DECL_FATTN_WMMA_F16_CASE( 80, 32, half);
extern DECL_FATTN_WMMA_F16_CASE( 96, 32, half);
extern DECL_FATTN_WMMA_F16_CASE(112, 32, half);
extern DECL_FATTN_WMMA_F16_CASE(128, 32, half);
extern DECL_FATTN_WMMA_F16_CASE(256, 16, half);
void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View file

@ -1,5 +1,6 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-mma-f16.cuh"
#include "fattn-tile-f16.cuh"
#include "fattn-tile-f32.cuh"
#include "fattn-vec-f16.cuh"
@ -7,144 +8,56 @@
#include "fattn-wmma-f16.cuh"
#include "fattn.cuh"
#include <cstdint>
template <int cols_per_block>
static void ggml_cuda_flash_attn_ext_mma_f16_switch_hs(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
static void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const enum ggml_prec prec = ggml_flash_attn_ext_get_prec(KQV);
if (prec != GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 32 || Q->ne[0] > 128) {
constexpr int cols_per_block = 16;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, float>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
} else {
constexpr int cols_per_block = 32;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, float>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, float>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, float>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, float>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
break;
// case 256:
// ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, float>(ctx, dst);
// break;
default:
GGML_ABORT("fatal error");
break;
}
}
return;
}
if (Q->ne[1] <= 8 && Q->ne[0] % WARP_SIZE == 0) {
constexpr int cols_per_block = 8;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
return;
}
if (Q->ne[1] <= 32) {
constexpr int cols_per_block = 16;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
return;
}
constexpr int cols_per_block = 32;
switch (Q->ne[0]) {
case 64:
ggml_cuda_flash_attn_ext_wmma_f16_case< 64, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case< 64, cols_per_block>(ctx, dst);
break;
case 80:
ggml_cuda_flash_attn_ext_wmma_f16_case< 80, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case< 80, cols_per_block>(ctx, dst);
break;
case 96:
ggml_cuda_flash_attn_ext_wmma_f16_case< 96, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case< 96, cols_per_block>(ctx, dst);
break;
case 112:
ggml_cuda_flash_attn_ext_wmma_f16_case<112, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case<112, cols_per_block>(ctx, dst);
break;
case 128:
ggml_cuda_flash_attn_ext_wmma_f16_case<128, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case<128, cols_per_block>(ctx, dst);
break;
case 256:
ggml_cuda_flash_attn_ext_wmma_f16_case<256, cols_per_block, half>(ctx, dst);
ggml_cuda_flash_attn_ext_mma_f16_case<256, cols_per_block>(ctx, dst);
break;
default:
GGML_ABORT("fatal error");
break;
}
}
static void ggml_cuda_flash_attn_ext_mma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<8>(ctx, dst);
return;
}
if (Q->ne[1] <= 16) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<16>(ctx, dst);
return;
}
if (Q->ne[1] <= 32) {
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<32>(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16_switch_hs<64>(ctx, dst);
}
#define FATTN_VEC_F16_CASE(D, type_K, type_V) \
if (Q->ne[0] == (D) && K->type == (type_K) && V->type == (type_V)) { \
ggml_cuda_flash_attn_ext_vec_f16_case<D, type_K, type_V>(ctx, dst); \
@ -323,10 +236,18 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
if (!fp16_mma_available(cc)) {
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
if (prec == GGML_PREC_DEFAULT) {
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_vec_f16(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
}
} else {
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
if (Q->ne[1] <= 8) {
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);
} else {
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
}
}
return;
}
@ -341,5 +262,11 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
}
}
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
// The MMA implementation needs Turing or newer, use the old WMMA code for Volta:
if (cc == GGML_CUDA_CC_VOLTA) {
ggml_cuda_flash_attn_ext_wmma_f16(ctx, dst);
return;
}
ggml_cuda_flash_attn_ext_mma_f16(ctx, dst);
}

View file

@ -38,6 +38,7 @@
#include "ggml-cuda/upscale.cuh"
#include "ggml-cuda/wkv6.cuh"
#include "ggml-cuda/gla.cuh"
#include "ggml.h"
#include <algorithm>
#include <array>
@ -1205,7 +1206,7 @@ static void ggml_cuda_op_mul_mat_cublas(
CUBLAS_CHECK(cublasSetStream(ctx.cublas_handle(id), stream));
if (compute_capability == GGML_CUDA_CC_CDNA) {
if (GGML_CUDA_CC_IS_CDNA(compute_capability)) {
const float alpha = 1.0f;
const float beta = 0.0f;
CUBLAS_CHECK(
@ -1365,8 +1366,6 @@ static void ggml_cuda_op_mul_mat(
const int64_t ne13 = src1->ne[3];
const int64_t nrows1 = ggml_nrows(src1);
GGML_ASSERT(ne03 == ne13);
const int64_t ne0 = dst->ne[0];
const int64_t ne1 = dst->ne[1];
@ -1380,9 +1379,11 @@ static void ggml_cuda_op_mul_mat(
GGML_ASSERT(src1->type == GGML_TYPE_F32 || (src1->ne[2] == 1 && src1->ne[3] == 1));
GGML_ASSERT(ne12 >= ne02 && ne12 % ne02 == 0);
GGML_ASSERT(ne12 % ne02 == 0);
GGML_ASSERT(ne13 % ne03 == 0);
const int64_t i02_divisor = ne12 / ne02;
const int64_t i03_divisor = ne13 / ne03;
const size_t src0_ts = ggml_type_size(src0->type);
const size_t src0_bs = ggml_blck_size(src0->type);
@ -1398,6 +1399,7 @@ static void ggml_cuda_op_mul_mat(
GGML_ASSERT(!(split && ne02 > 1));
GGML_ASSERT(!(split && ne03 > 1));
GGML_ASSERT(!(split && ne02 < ne12));
GGML_ASSERT(!(split && ne03 < ne13));
ggml_tensor_extra_gpu * src0_extra = split ? (ggml_tensor_extra_gpu *) src0->extra : nullptr;
@ -1561,7 +1563,8 @@ static void ggml_cuda_op_mul_mat(
}
// for split tensors the data begins at i0 == i0_offset_low
char * src0_dd_i = dev[id].src0_dd + (i0/i02_divisor) * (ne01*ne00*src0_ts)/src0_bs;
const size_t nbytes_src0_matrix = ne01*ne00*src0_ts / src0_bs;
char * src0_dd_i = dev[id].src0_dd + ((i03/i03_divisor)*ne02 + (i02/i02_divisor)) * nbytes_src0_matrix;
float * src1_ddf_i = dev[id].src1_ddf + (i0*ne11 + src1_col_0) * ne10;
char * src1_ddq_i = dev[id].src1_ddq + src1_ddq_i_offset;
float * dst_dd_i = dev[id].dst_dd + (i0*ne1 + src1_col_0) * (dst_on_device ? ne0 : row_diff);
@ -1605,8 +1608,9 @@ static void ggml_cuda_op_mul_mat(
CUDA_CHECK(cudaGetLastError());
}
if (src1_col_0 == 0 && !src0_is_contiguous && i02 % i02_divisor == 0) {
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(src0_dd_i, src0, i03, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
if (src1_col_0 == 0 && !src0_is_contiguous && i03 % i03_divisor == 0 && i02 % i02_divisor == 0) {
CUDA_CHECK(ggml_cuda_cpy_tensor_2d(
src0_dd_i, src0, i03/i03_divisor, i02/i02_divisor, dev[id].row_low, dev[id].row_high, stream));
}
// do the computation
@ -1750,7 +1754,7 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co
beta = &beta_f32;
}
if (ggml_cuda_info().devices[ctx.device].cc == GGML_CUDA_CC_CDNA) {
if (GGML_CUDA_CC_IS_CDNA(ggml_cuda_info().devices[ctx.device].cc)) {
cu_compute_type = CUBLAS_COMPUTE_32F;
alpha = &alpha_f32;
beta = &beta_f32;
@ -1881,7 +1885,7 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor
//printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name);
//printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name);
if (!split && use_mul_mat_vec && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
if (!split && use_mul_mat_vec && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) {
// the custom F16 vector kernel can be used over batched cuBLAS GEMM
// but this is only faster for GPUs without tensor cores or with a thin src0 matrix (particularly KQV in attention)
ggml_cuda_mul_mat_vec(ctx, src0, src1, dst);
@ -2215,12 +2219,7 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
ggml_cuda_op_rms_norm_back(ctx, dst);
break;
case GGML_OP_MUL_MAT:
if (dst->src[0]->ne[3] != dst->src[1]->ne[3]) {
GGML_LOG_ERROR("%s: cannot compute %s: src0->ne[3] = %" PRId64 ", src1->ne[3] = %" PRId64 " - fallback to CPU\n", __func__, dst->name, dst->src[0]->ne[3], dst->src[1]->ne[3]);
return false;
} else {
ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
}
ggml_cuda_mul_mat(ctx, dst->src[0], dst->src[1], dst);
break;
case GGML_OP_MUL_MAT_ID:
ggml_cuda_mul_mat_id(ctx, dst);
@ -2997,9 +2996,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (b->type == GGML_TYPE_F16 && a->type != GGML_TYPE_F16) {
return false;
}
if (op->op == GGML_OP_MUL_MAT && a->ne[3] != b->ne[3]) {
return false;
}
#ifdef GGML_USE_MUSA
if (b->type == GGML_TYPE_F16 && b->ne[2]*b->ne[3] > 1 &&
!ggml_is_transposed(a) && !ggml_is_transposed(b)) {
@ -3139,6 +3135,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
break;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
return true;
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->ne[0] % WARP_SIZE == 0;
break;
@ -3181,7 +3178,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
return true;
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_ARANGE:

View file

@ -1,11 +1,67 @@
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
// The documentation for the PTX instructions can be found under:
// https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#matrix-multiply-accumulate-operation-using-mma-instruction
//
// Like with nvcuda::wmma there are three types of matrix tiles: A, B, and C with A @ B = C.
// A is a row-major matrix with shape I x K.
// B is a column-major matrix with shape K x J.
// C is a column-major matrix with shape I x J.
// Note that along their lowest dimension I, J, and K are measured in physical 32 bit elements instead of logical elements.
// The functions get_i, get_j, and get_k can be used to get the physical 32 bit index of the lth element of a thread within a tile.
// All matrix tiles have ne physical 32 bit elements per warp.
//
// As described in the documentation, all pointers for load_ldmatrix must be to shared memory and aligned to 16 bytes.
#include "common.cuh"
struct mma_int_A_I16K4 {
#if CUDART_VERSION >= 11080
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
int ret = 0;
#ifdef NEW_MMA_AVAILABLE
asm("movmatrix.sync.aligned.m8n8.trans.b16 %0, %1;"
: "+r"(ret) : "r"(x));
#else
NO_DEVICE_CODE;
#endif // defined(NEW_MMA_AVAILABLE)
return ret;
}
#else
static __device__ __forceinline__ int ggml_cuda_movmatrix(const int x) {
// Imagine transposing row-major matrix to column-major matrix.
const int src_i_low = 2 * (threadIdx.x % 4);
const int src_i_high = src_i_low + 1;
const int src_j = threadIdx.x / 4;
const int src_laneid_low = src_i_low * 4 + src_j / 2;
const int src_laneid_high = src_i_high * 4 + src_j / 2;
const int shift_low = ((src_j + 0) % 2) * 16;
const int shift_high = ((src_j + 1) % 2) * 16;
const int ret_low = (__shfl_sync(0xFFFFFFFF, x, src_laneid_low, WARP_SIZE) >> shift_low) & 0x0000FFFF;
const int ret_high = (__shfl_sync(0xFFFFFFFF, x, src_laneid_high, WARP_SIZE) << shift_high) & 0xFFFF0000;
return ret_low | ret_high;
}
#endif // CUDART_VERSION >= 11080
template <typename T>
struct mma_A_I16K4 {
static_assert(sizeof(T) == 4, "bad type size");
static constexpr int I = 16;
static constexpr int K = 4;
static constexpr int ne = 2;
int x[ne] = {0};
T x[ne];
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l%2) * (I/2) + threadIdx.x / K;
@ -21,27 +77,35 @@ struct mma_int_A_I16K4 {
return ret;
}
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE)
const int * xs = xs0 + (threadIdx.x%I)*stride;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(x[0]), "+r"(x[1])
: "l"(xs));
#else
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_i(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) x;
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(xi[0]), "+r"(xi[1])
: "l"(xs));
#else
load_generic(xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
};
struct mma_int_A_I16K8 {
template <typename T>
struct mma_A_I16K8 {
static_assert(sizeof(T) == 4, "bad type size");
static constexpr int I = 16;
static constexpr int K = 8;
static constexpr int ne = 4;
int x[ne] = {0};
T x[ne];
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l%2) * (I/2) + threadIdx.x / (K/2);
@ -57,31 +121,62 @@ struct mma_int_A_I16K8 {
return ret;
}
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE)
const int * xs = xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
: "l"(xs));
#else
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_i(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
__device__ __forceinline__ void load_low(const int * __restrict__ xs0, const int & stride) {
((mma_int_A_I16K4 *) x)[0].load(xs0, stride);
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int * ) x;
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x4.b16 {%0, %1, %2, %3}, [%4];"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
: "l"(xs));
#else
GGML_UNUSED(xs0);
GGML_UNUSED(stride);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ void load_ldmatrix_trans(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int * ) x;
const int * xs = (const int *) xs0 + (threadIdx.x%I)*stride + (threadIdx.x/I)*(K/2);
asm("ldmatrix.sync.aligned.m8n8.x4.trans.b16 {%0, %1, %2, %3}, [%4];"
: "+r"(xi[0]), "+r"(xi[2]), "+r"(xi[1]), "+r"(xi[3])
: "l"(xs));
#else
GGML_UNUSED(xs0);
GGML_UNUSED(stride);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ void transpose() {
int * xi = (int *) x;
xi[0] = ggml_cuda_movmatrix(xi[0]);
const int tmp = ggml_cuda_movmatrix(xi[1]);
xi[1] = ggml_cuda_movmatrix(xi[2]);
xi[2] = tmp;
xi[3] = ggml_cuda_movmatrix(xi[3]);
}
};
struct mma_int_B_J8K4 {
template <typename T>
struct mma_B_J8K4 {
static_assert(sizeof(T) == 4, "bad type size");
static constexpr int J = 8;
static constexpr int K = 4;
static constexpr int ne = 1;
int x[ne] = {0};
T x[ne];
static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x / K;
@ -97,27 +192,34 @@ struct mma_int_B_J8K4 {
return ret;
}
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
const int * xs = xs0 + (threadIdx.x%J)*stride;
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
: "+r"(x[0])
: "l"(xs));
#else
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) x;
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride;
asm("ldmatrix.sync.aligned.m8n8.x1.b16 {%0}, [%1];"
: "+r"(xi[0]) : "l"(xs));
#else
load_generic(xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
};
struct mma_int_B_J8K8 {
template <typename T>
struct mma_B_J8K8 {
static_assert(sizeof(T) == 4, "bad type size");
static constexpr int J = 8;
static constexpr int K = 8;
static constexpr int ne = 2;
int x[ne] = {0};
T x[ne];
static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x / (K/2);
@ -133,22 +235,31 @@ struct mma_int_B_J8K8 {
return ret;
}
__device__ __forceinline__ void load(const int * __restrict__ xs0, const int & stride) {
#if defined(INT8_MMA_AVAILABLE) && false // Loading as 4 byte values is faster
const int * xs = xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(x[0]), "+r"(x[1])
: "l"(xs));
#else
__device__ __forceinline__ void load_generic(const T * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_k(l)];
}
#endif // defined(INT8_MMA_AVAILABLE)
}
__device__ __forceinline__ void load_ldmatrix(const T * __restrict__ xs0, const int & stride) {
#ifdef NEW_MMA_AVAILABLE
int * xi = (int *) x;
const int * xs = (const int *) xs0 + (threadIdx.x%J)*stride + ((threadIdx.x/J)*(K/2)) % K;
asm("ldmatrix.sync.aligned.m8n8.x2.b16 {%0, %1}, [%2];"
: "+r"(xi[0]), "+r"(xi[1])
: "l"(xs));
#else
load_generic(xs0, stride);
#endif // NEW_MMA_AVAILABLE
}
};
struct mma_int_C_I16J8 {
template <typename T>
struct mma_C_I16J8 {};
template <>
struct mma_C_I16J8<int> {
static constexpr int I = 16;
static constexpr int J = 8;
static constexpr int ne = 4;
@ -169,8 +280,8 @@ struct mma_int_C_I16J8 {
return ret;
}
__device__ __forceinline__ void mma_K4(const mma_int_A_I16K4 & mma_A, const mma_int_B_J8K4 & mma_B) {
#ifdef INT8_MMA_AVAILABLE
__device__ __forceinline__ void mma(const mma_A_I16K4<int> & mma_A, const mma_B_J8K4<int> & mma_B) {
#ifdef NEW_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
@ -188,11 +299,11 @@ struct mma_int_C_I16J8 {
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ void mma_K8(const mma_int_A_I16K8 & mma_A, const mma_int_B_J8K8 & mma_B) {
#ifdef INT8_MMA_AVAILABLE
__device__ __forceinline__ void mma(const mma_A_I16K8<int> & mma_A, const mma_B_J8K8<int> & mma_B) {
#ifdef NEW_MMA_AVAILABLE
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(x[0]), "+r"(x[1]), "+r"(x[2]), "+r"(x[3])
@ -216,6 +327,132 @@ struct mma_int_C_I16J8 {
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
NO_DEVICE_CODE;
#endif // INT8_MMA_AVAILABLE
#endif // NEW_MMA_AVAILABLE
}
};
template <>
struct mma_C_I16J8<half2> {
static constexpr int I = 16;
static constexpr int J = 4;
static constexpr int ne = 2;
half2 x[ne] = {{0.0f, 0.0f}, {0.0f, 0.0f}};
static __device__ __forceinline__ int get_i(const int l) {
const int ret = l * (I/2) + threadIdx.x / J;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_j(const int /* l */) {
const int ret = threadIdx.x % J;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
#ifdef NEW_MMA_AVAILABLE
int * Axi = (int *) mma_A.x;
int * Bxi = (int *) mma_B.x;
int * xi = (int *) x;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3, %4, %5}, {%6, %7}, {%0, %1};"
: "+r"(xi[0]), "+r"(xi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(xi[0]), "+r"(xi[1])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 {%0, %1}, {%2, %3}, {%4}, {%0, %1};"
: "+r"(xi[0]), "+r"(xi[1])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
mma_B_J8K8<half2> mma_B;
int * xi = (int *) x;
int * Bxi = (int *) mma_B.x;
Bxi[0] = ggml_cuda_movmatrix(xi[0]);
Bxi[1] = ggml_cuda_movmatrix(xi[1]);
return mma_B;
}
};
template <>
struct mma_C_I16J8<float> {
static constexpr int I = 16;
static constexpr int J = 8;
static constexpr int ne = 4;
float x[ne] = {0.0f, 0.0f, 0.0f, 0.0f};
static __device__ __forceinline__ int get_i(const int l) {
const int ret = (l/2) * (I/2) + threadIdx.x / (J/2);
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < I);
return ret;
}
static __device__ __forceinline__ int get_j(const int l) {
const int ret = 2 * (threadIdx.x % (J/2)) + l%2;
GGML_CUDA_ASSUME(ret >= 0);
GGML_CUDA_ASSUME(ret < J);
return ret;
}
__device__ __forceinline__ void mma(const mma_A_I16K8<half2> & mma_A, const mma_B_J8K8<half2> & mma_B) {
#ifdef NEW_MMA_AVAILABLE
int * Axi = (int *) mma_A.x;
int * Bxi = (int *) mma_B.x;
int * xi = (int *) x;
#if __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%0, %1, %2, %3};"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[0]), "r"(Bxi[1]));
#else
// On Turing m16n8k16 mma is not available, use 2x m8n8k8 mma instead:
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
: "r"(Axi[0]), "r"(Axi[1]), "r"(Bxi[0]));
asm("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 {%0, %1, %2, %3}, {%4, %5}, {%6}, {%0, %1, %2, %3};"
: "+r"(xi[0]), "+r"(xi[1]), "+r"(xi[2]), "+r"(xi[3])
: "r"(Axi[2]), "r"(Axi[3]), "r"(Bxi[1]));
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE
#else
GGML_UNUSED(mma_A);
GGML_UNUSED(mma_B);
NO_DEVICE_CODE;
#endif // NEW_MMA_AVAILABLE
}
__device__ __forceinline__ mma_B_J8K8<half2> to_mma_B() {
mma_B_J8K8<half2> mma_B;
mma_B.x[0] = make_half2(x[0], x[1]);
mma_B.x[1] = make_half2(x[2], x[3]);
int * Bxi = (int *) mma_B.x;
Bxi[0] = ggml_cuda_movmatrix(Bxi[0]);
Bxi[1] = ggml_cuda_movmatrix(Bxi[1]);
return mma_B;
}
__device__ __forceinline__ void load_generic(const float * __restrict__ xs0, const int & stride) {
#pragma unroll
for (int l = 0; l < ne; ++l) {
x[l] = xs0[get_j(l)*stride + get_i(l)];
}
}
};

View file

@ -132,7 +132,7 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return false;
}
if (int8_mma_available(cc)) {
if (new_mma_available(cc)) {
return true;
}
@ -148,5 +148,5 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11) {
return cc < GGML_CUDA_CC_VOLTA || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}
return (cc < GGML_CUDA_CC_RDNA3 && cc != GGML_CUDA_CC_CDNA && cc != GGML_CUDA_CC_VEGA20) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
return (!GGML_CUDA_CC_IS_RDNA3(cc) && !GGML_CUDA_CC_IS_CDNA(cc) && !GGML_CUDA_CC_IS_GCN(cc)) || ne11 < MMQ_DP4A_MAX_BATCH_SIZE;
}

File diff suppressed because it is too large Load diff

View file

@ -1,25 +1,29 @@
#include "ggml.h"
#include "common.cuh"
#include "mmv.cuh"
template <typename T, typename type_acc, int block_size>
static __global__ void mul_mat_vec(
const T * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) {
const int64_t row = blockIdx.x;
const int64_t channel = blockIdx.z;
const int tid = threadIdx.x;
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst) {
const int64_t row = blockIdx.x;
const int64_t channel = blockIdx.y;
const int64_t sample = blockIdx.z;
const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
x += (channel/channel_ratio)*stride_channel_x + row*stride_row;
y += channel *stride_channel_y;
dst += channel *stride_channel_dst;
x += (sample/sample_ratio)*stride_sample_x + (channel/channel_ratio)*stride_channel_x + row*stride_row;
y += sample *stride_sample_y + channel *stride_channel_y;
dst += sample *stride_sample_dst + channel *stride_channel_dst;
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[];
float * buf_iw = (float *) data_mmv;
if (block_size > WARP_SIZE) {
if (tid < WARP_SIZE) {
if (block_size > warp_size) {
if (tid < warp_size) {
buf_iw[tid] = 0.0f;
}
__syncthreads();
@ -67,16 +71,16 @@ static __global__ void mul_mat_vec(
static_assert(std::is_same<T, void>::value, "unsupported type");
}
sumf = warp_reduce_sum(sumf);
sumf = warp_reduce_sum<warp_size>(sumf);
if (block_size > WARP_SIZE) {
buf_iw[tid/WARP_SIZE] = sumf;
if (block_size > warp_size) {
buf_iw[tid/warp_size] = sumf;
__syncthreads();
if (tid >= WARP_SIZE) {
if (tid >= warp_size) {
return;
}
sumf = buf_iw[tid];
sumf = warp_reduce_sum(sumf);
sumf = warp_reduce_sum<warp_size>(sumf);
}
if (tid != 0) {
@ -90,16 +94,28 @@ template <typename T, typename type_acc>
static void launch_mul_mat_vec_cuda(
const T * x, const float * y, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(nchannels_y % nchannels_x == 0);
GGML_ASSERT(nsamples_y % nsamples_x == 0);
const int64_t channel_ratio = nchannels_y / nchannels_x;
const int64_t sample_ratio = nsamples_y / nsamples_x;
int device;
int warp_size;
int64_t block_size_best = WARP_SIZE;
int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE);
for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) {
CUDA_CHECK(cudaGetDevice(&device));
warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t block_size_best = warp_size;
int64_t niter_best = (ncols + 2*warp_size - 1) / (2*warp_size);
int64_t max_block_size = 256;
if(ggml_cuda_info().devices[device].cc > GGML_CUDA_CC_OFFSET_AMD && ggml_cuda_info().devices[device].cc < GGML_CUDA_CC_RDNA1) {
max_block_size = 128;
}
for (int64_t block_size = 2*warp_size; block_size <= max_block_size; block_size += warp_size) {
const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size);
if (niter < niter_best) {
niter_best = niter;
@ -107,41 +123,49 @@ static void launch_mul_mat_vec_cuda(
}
}
const int smem = WARP_SIZE*sizeof(float);
const dim3 block_nums(nrows, 1, nchannels_y);
const int smem = warp_size*sizeof(float);
const dim3 block_nums(nrows, nchannels_y, nsamples_y);
const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) {
case 32: {
mul_mat_vec<T, type_acc, 32><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 64: {
mul_mat_vec<T, type_acc, 64><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 96: {
mul_mat_vec<T, type_acc, 96><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 128: {
mul_mat_vec<T, type_acc, 128><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 160: {
mul_mat_vec<T, type_acc, 160><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 192: {
mul_mat_vec<T, type_acc, 192><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 224: {
mul_mat_vec<T, type_acc, 224><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 256: {
mul_mat_vec<T, type_acc, 256><<<block_nums, block_dims, smem, stream>>>
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst);
(x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {
GGML_ABORT("fatal error");
@ -153,16 +177,19 @@ template<typename T>
static void mul_mat_vec_cuda(
const T * x, const float * y, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_y, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) {
switch (prec) {
case GGML_PREC_DEFAULT: {
launch_mul_mat_vec_cuda<T, half>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
launch_mul_mat_vec_cuda<T, half>
(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case GGML_PREC_F32: {
launch_mul_mat_vec_cuda<T, float>(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y,
stride_channel_x, stride_channel_y, stride_channel_dst, stream);
launch_mul_mat_vec_cuda<T, float>
(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
}
}
@ -171,10 +198,19 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
GGML_TENSOR_BINARY_OP_LOCALS;
GGML_ASSERT(src1->ne[1] == 1);
const size_t ts_src0 = ggml_type_size(src0->type);
const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type);
GGML_ASSERT(ne11 == 1);
GGML_ASSERT(ne12 == ne2);
GGML_ASSERT(ne13 == ne3);
GGML_ASSERT(nb00 == ts_src0);
GGML_ASSERT(nb10 == ts_src1);
GGML_ASSERT(nb0 == ts_dst);
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32;
@ -182,29 +218,22 @@ void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor *
const float * src1_d = (const float *) src1->data;
float * dst_d = (float *) dst->data;
const int64_t ne02 = src0->ne[2];
const int64_t ne12 = src1->ne[2];
GGML_ASSERT(dst->ne[2] == ne12);
GGML_ASSERT(src0->ne[3] == 1);
GGML_ASSERT(src1->ne[3] == 1);
GGML_ASSERT( dst->ne[3] == 1);
const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type);
const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type);
const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type);
const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type);
const int64_t s01 = src0->nb[1] / ts_src0;
const int64_t s02 = src0->nb[2] / ts_src0;
const int64_t s12 = src1->nb[2] / ts_src1;
const int64_t s2 = dst->nb[2] / ts_dst;
const int64_t s03 = src0->nb[3] / ts_src0;
const int64_t s13 = src1->nb[3] / ts_src1;
const int64_t s3 = dst->nb[3] / ts_dst;
switch (src0->type) {
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12,
channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream());
mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, s01, ne02, ne12, s02, s12, s2, ne03, ne13, s03, s13, s3, prec, ctx.stream());
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@ -233,20 +262,27 @@ void ggml_cuda_op_mul_mat_vec(
const int64_t stride_row = ne00;
const int64_t nchannels_x = 1;
const int64_t nchannels_y = 1;
const int64_t channel_stride_x = 0;
const int64_t channel_stride_y = 0;
const int64_t channel_stride_dst = 0;
const int64_t stride_channel_x = 0;
const int64_t stride_channel_y = 0;
const int64_t stride_channel_dst = 0;
const int64_t nsamples_x = 1;
const int64_t nsamples_y = 1;
const int64_t stride_sample_x = 0;
const int64_t stride_sample_y = 0;
const int64_t stride_sample_dst = 0;
switch (src0->type) {
case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
mul_mat_vec_cuda(src0_d, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row,
nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream);
nchannels_x, nchannels_y, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_y, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream);
} break;
default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));

View file

@ -1,12 +1,20 @@
#include "norm.cuh"
#include <cstdint>
template <int block_size>
static __global__ void norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
static __global__ void norm_f32(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
x += int64_t(row)*ncols;
dst += int64_t(row)*ncols;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float2 mean_var = make_float2(0.0f, 0.0f);
@ -97,12 +105,19 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
}
template <int block_size>
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols, const float eps) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
static __global__ void rms_norm_f32(
const float * x, float * dst, const int ncols, const int64_t stride_row, const int64_t stride_channel,
const int64_t stride_sample, const float eps) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
x += int64_t(row)*ncols;
dst += int64_t(row)*ncols;
const int row = blockIdx.x;
const int channel = blockIdx.y;
const int sample = blockIdx.z;
const int tid = threadIdx.x;
x += sample*stride_sample + channel*stride_channel + row*stride_row;
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
float tmp = 0.0f; // partial sum for thread in warp
@ -186,13 +201,16 @@ static __global__ void rms_norm_back_f32(
}
}
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
static void norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
@ -207,13 +225,16 @@ static void group_norm_f32_cuda(
}
}
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
static void rms_norm_f32_cuda(
const float * x, float * dst, const int ncols, const int nrows, const int nchannels, const int nsamples,
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
rms_norm_f32<WARP_SIZE><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
rms_norm_f32<1024><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
}
}
@ -229,23 +250,26 @@ static void rms_norm_back_f32_cuda(const float * grad, const float * xf, float *
void ggml_cuda_op_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
GGML_TENSOR_UNARY_OP_LOCALS;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
}
void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -254,8 +278,6 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
@ -271,23 +293,26 @@ void ggml_cuda_op_group_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst)
void ggml_cuda_op_rms_norm(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *)src0->data;
float * dst_d = (float *)dst->data;
const float * src0_d = (const float *) src0->data;
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(ggml_is_contiguous(src0));
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
const int64_t ne00 = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
GGML_TENSOR_UNARY_OP_LOCALS;
float eps;
memcpy(&eps, dst->op_params, sizeof(float));
GGML_ASSERT(eps >= 0.0f);
rms_norm_f32_cuda(src0_d, dst_d, ne00, nrows, eps, stream);
const size_t ts0 = ggml_type_size(src0->type);
GGML_ASSERT(nb00 == ts0);
const int64_t s01 = nb01 / ts0;
const int64_t s02 = nb02 / ts0;
const int64_t s03 = nb03 / ts0;
rms_norm_f32_cuda(src0_d, dst_d, ne00, ne01, ne02, ne03, s01, s02, s03, eps, stream);
}
void ggml_cuda_op_rms_norm_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View file

@ -18,7 +18,7 @@ __device__ float __forceinline__ t2f32<half>(half val) {
#ifdef __clang__
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wpass-failed"
#endif
#endif // __clang__
template <bool use_shared, int ncols_template, int block_size_template, typename T>
static __global__ void soft_max_f32(
const float * x, const T * mask, float * dst, const int ncols_par, const int nrows_y,
@ -126,7 +126,7 @@ static __global__ void soft_max_f32(
}
#ifdef __clang__
#pragma clang diagnostic pop
#endif
#endif // __clang__
static __global__ void soft_max_back_f32(
const float * grad, const float * dstf, float * dst, const int ncols, const float scale) {

View file

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 16);
DECL_FATTN_MMA_F16_CASE(80, 16);
DECL_FATTN_MMA_F16_CASE(96, 16);
DECL_FATTN_MMA_F16_CASE(112, 16);
DECL_FATTN_MMA_F16_CASE(128, 16);
DECL_FATTN_MMA_F16_CASE(256, 16);

View file

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 32);
DECL_FATTN_MMA_F16_CASE(80, 32);
DECL_FATTN_MMA_F16_CASE(96, 32);
DECL_FATTN_MMA_F16_CASE(112, 32);
DECL_FATTN_MMA_F16_CASE(128, 32);
DECL_FATTN_MMA_F16_CASE(256, 32);

View file

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 64);
DECL_FATTN_MMA_F16_CASE(80, 64);
DECL_FATTN_MMA_F16_CASE(96, 64);
DECL_FATTN_MMA_F16_CASE(112, 64);
DECL_FATTN_MMA_F16_CASE(128, 64);
DECL_FATTN_MMA_F16_CASE(256, 64);

View file

@ -0,0 +1,10 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-mma-f16.cuh"
DECL_FATTN_MMA_F16_CASE(64, 8);
DECL_FATTN_MMA_F16_CASE(80, 8);
DECL_FATTN_MMA_F16_CASE(96, 8);
DECL_FATTN_MMA_F16_CASE(112, 8);
DECL_FATTN_MMA_F16_CASE(128, 8);
DECL_FATTN_MMA_F16_CASE(256, 8);

View file

@ -1,10 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh"
DECL_FATTN_WMMA_F16_CASE(64, 16, float);
DECL_FATTN_WMMA_F16_CASE(80, 16, float);
DECL_FATTN_WMMA_F16_CASE(96, 16, float);
DECL_FATTN_WMMA_F16_CASE(112, 16, float);
DECL_FATTN_WMMA_F16_CASE(128, 16, float);
DECL_FATTN_WMMA_F16_CASE(256, 16, float);

View file

@ -1,9 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh"
DECL_FATTN_WMMA_F16_CASE(64, 32, float);
DECL_FATTN_WMMA_F16_CASE(80, 32, float);
DECL_FATTN_WMMA_F16_CASE(96, 32, float);
DECL_FATTN_WMMA_F16_CASE(112, 32, float);
DECL_FATTN_WMMA_F16_CASE(128, 32, float);

View file

@ -1,10 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh"
DECL_FATTN_WMMA_F16_CASE(64, 16, half);
DECL_FATTN_WMMA_F16_CASE(80, 16, half);
DECL_FATTN_WMMA_F16_CASE(96, 16, half);
DECL_FATTN_WMMA_F16_CASE(112, 16, half);
DECL_FATTN_WMMA_F16_CASE(128, 16, half);
DECL_FATTN_WMMA_F16_CASE(256, 16, half);

View file

@ -1,10 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh"
DECL_FATTN_WMMA_F16_CASE(64, 32, half);
DECL_FATTN_WMMA_F16_CASE(80, 32, half);
DECL_FATTN_WMMA_F16_CASE(96, 32, half);
DECL_FATTN_WMMA_F16_CASE(112, 32, half);
DECL_FATTN_WMMA_F16_CASE(128, 32, half);
DECL_FATTN_WMMA_F16_CASE(256, 32, half);

View file

@ -1,8 +0,0 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh"
DECL_FATTN_WMMA_F16_CASE(64, 8, half);
DECL_FATTN_WMMA_F16_CASE(96, 8, half);
DECL_FATTN_WMMA_F16_CASE(128, 8, half);
DECL_FATTN_WMMA_F16_CASE(256, 8, half);

View file

@ -12,13 +12,13 @@ SOURCE_FATTN_VEC = """// This file has been autogenerated by generate_cu_files.p
DECL_FATTN_VEC_F{vkq_size}_CASE({head_size}, {type_k}, {type_v});
"""
SOURCE_FATTN_WMMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
SOURCE_FATTN_MMA_START = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../fattn-wmma-f16.cuh"
#include "../fattn-mma-f16.cuh"
"""
SOURCE_FATTN_WMMA_CASE = "DECL_FATTN_WMMA_F16_CASE({head_size}, {cols_per_block}, {kq_acc_t});\n"
SOURCE_FATTN_MMA_CASE = "DECL_FATTN_MMA_F16_CASE({head_size}, {cols_per_block});\n"
TYPES_MMQ = [
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
@ -57,20 +57,12 @@ for vkq_size in [16, 32]:
with open(f"fattn-vec-f{vkq_size}-instance-hs{head_size}-{get_short_name(type_k)}-{get_short_name(type_v)}.cu", "w") as f:
f.write(SOURCE_FATTN_VEC.format(vkq_size=vkq_size, head_size=head_size, type_k=type_k, type_v=type_v))
for kq_acc_t in ["half", "float"]:
for cols_per_block in [8, 16, 32]:
if kq_acc_t == "float" and cols_per_block == 8:
continue
for cols_per_block in [8, 16, 32, 64]:
with open(f"fattn-mma-f16-instance-cpb{cols_per_block}.cu", "w") as f:
f.write(SOURCE_FATTN_MMA_START)
with open(f"fattn-wmma-f16-instance-kq{kq_acc_t}-cpb{cols_per_block}.cu", "w") as f:
f.write(SOURCE_FATTN_WMMA_START)
for head_size in [64, 80, 96, 112, 128, 256]:
if cols_per_block == 8 and head_size % 32 != 0: # wmma fragment is 8x32
continue
if kq_acc_t == "float" and cols_per_block == 32 and head_size == 256: # register spilling, bad performance
continue
f.write(SOURCE_FATTN_WMMA_CASE.format(kq_acc_t=kq_acc_t, cols_per_block=cols_per_block, head_size=head_size))
for head_size in [64, 80, 96, 112, 128, 256]:
f.write(SOURCE_FATTN_MMA_CASE.format(cols_per_block=cols_per_block, head_size=head_size))
for type in TYPES_MMQ:
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:

View file

@ -1,5 +1,6 @@
#pragma once
#define HIP_ENABLE_WARP_SYNC_BUILTINS 1
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
@ -8,6 +9,7 @@
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif // __HIP_PLATFORM_AMD__
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
@ -25,6 +27,7 @@
#define CU_MEM_LOCATION_TYPE_DEVICE hipMemLocationTypeDevice
#define CU_MEM_ACCESS_FLAGS_PROT_READWRITE hipMemAccessFlagsProtReadWrite
#define CU_CHECK(fn) {hipError_t err = fn; if(err != hipSuccess) { GGML_ABORT("HipVMM Failure: %s\n", hipGetErrorString(err)); }}
#define __shfl_sync(mask, var, laneMask, width) __shfl(var, laneMask, width)
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
#define cublasCreate hipblasCreate

View file

@ -46,11 +46,14 @@ endif()
message(STATUS "HIP and hipBLAS found")
# Workaround old compilers
set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} --gpu-max-threads-per-block=1024")
file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh")
list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h")
file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_ROCM ${SRCS})

View file

@ -19,8 +19,15 @@
// max number of MTLCommandBuffer used to submit a graph for processing
#define GGML_METAL_MAX_COMMAND_BUFFERS 8
#ifndef TARGET_OS_VISION
#define TARGET_OS_VISION 0
#endif
// create residency sets only on macOS >= 15.0
#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000
#if TARGET_OS_OSX && __MAC_OS_X_VERSION_MAX_ALLOWED >= 150000 || \
TARGET_OS_IOS && __IPHONE_OS_VERSION_MAX_ALLOWED >= 180000 || \
TARGET_OS_TV && __TV_OS_VERSION_MAX_ALLOWED >= 180000 || \
TARGET_OS_VISION && __VISION_OS_VERSION_MAX_ALLOWED >= 200000
#define GGML_METAL_HAS_RESIDENCY_SETS 1
#endif
@ -1071,7 +1078,7 @@ static bool ggml_backend_metal_buffer_rset_init(
}
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
if (@available(macOS 15.0, *)) {
if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
MTLResidencySetDescriptor * desc = [[MTLResidencySetDescriptor alloc] init];
desc.label = @"ggml_backend_metal";
desc.initialCapacity = ctx->n_buffers;
@ -1106,7 +1113,7 @@ static bool ggml_backend_metal_buffer_rset_init(
// rset free
static void ggml_backend_metal_buffer_rset_free(struct ggml_backend_metal_buffer_context * ctx) {
#if defined(GGML_METAL_HAS_RESIDENCY_SETS)
if (@available(macOS 15.0, *)) {
if (@available(macOS 15.0, iOS 18.0, tvOS 18.0, visionOS 2.0, *)) {
if (ctx->rset) {
[ctx->rset endResidency];
[ctx->rset removeAllAllocations];
@ -1201,12 +1208,13 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_SUM_ROWS:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction;
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_RMS_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0);
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_ARGMAX:
case GGML_OP_NORM:
return true;
case GGML_OP_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_ROPE:
{
const int mode = ((const int32_t *) op->op_params)[2];

View file

@ -29,7 +29,7 @@ if (MUSAToolkit_FOUND)
list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
file(GLOB GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-wmma*.cu")
file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})
file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_MUSA ${SRCS})

View file

@ -1045,7 +1045,28 @@ bool rpc_server::copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_co
ggml_free(ctx);
return false;
}
GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n", __func__, (void*)src->buffer, (void*)dst->buffer);
uint64_t src_size = (uint64_t) ggml_nbytes(src);
uint64_t dst_data = (uint64_t) dst->data;
uint64_t dst_base = (uint64_t) ggml_backend_buffer_get_base(dst->buffer);
uint64_t dst_buf_sz = (uint64_t) ggml_backend_buffer_get_size(dst->buffer);
if (dst_data + src_size > dst_base + dst_buf_sz) {
GGML_PRINT_DEBUG("[%s] out-of-bounds write in rpc_server::copy_tensor:\n"
" write range : [0x%" PRIx64 ", 0x%" PRIx64 "]\n"
" buffer base: [0x%" PRIx64 ", 0x%" PRIx64 "]\n",
__func__,
dst_data,
dst_data + src_size,
dst_base,
dst_base + dst_buf_sz);
ggml_free(ctx);
return false;
}
GGML_PRINT_DEBUG("[%s] src->buffer: %p, dst->buffer: %p\n",
__func__, (void*) src->buffer, (void*) dst->buffer);
response.result = ggml_backend_buffer_copy_tensor(src, dst);
ggml_free(ctx);
return true;

View file

@ -103,11 +103,10 @@ void print_device_detail(int id, sycl::device &device, std::string device_type)
name = std::regex_replace(name, std::regex("\\(TM\\)"), "");
auto global_mem_size = prop.get_global_mem_size()/1000000;
std::string xmx = gpu_has_xmx(device) ? "yes" : "no";
GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|%14s|\n", id, device_type.c_str(),
GGML_LOG_INFO("|%2d|%19s|%39s|%7s|%7d|%8d|%5d|%6luM|%21s|\n", id, device_type.c_str(),
name.c_str(), version.c_str(), prop.get_max_compute_units(),
prop.get_max_work_group_size(), prop.get_max_sub_group_size(),
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str(), xmx.c_str());
global_mem_size, device.get_info<sycl::info::device::driver_version>().c_str());
}
void ggml_backend_sycl_print_sycl_devices() {
@ -118,16 +117,16 @@ void ggml_backend_sycl_print_sycl_devices() {
GGML_LOG_INFO(
"| | | | "
" |Max | |Max |Global | | XMX |\n");
" |Max | |Max |Global | |\n");
GGML_LOG_INFO(
"| | | | "
" |compute|Max work|sub |mem | | or |\n");
" |compute|Max work|sub |mem | |\n");
GGML_LOG_INFO(
"|ID| Device Type| "
"Name|Version|units |group |group|size | Driver version| Tensor Cores |\n");
"Name|Version|units |group |group|size | Driver version|\n");
GGML_LOG_INFO(
"|--|-------------------|---------------------------------------|------"
"-|-------|--------|-----|-------|---------------------|--------------|\n");
"-|-------|--------|-----|-------|---------------------|\n");
for (int id = 0; id < device_count; ++id) {
sycl::device device = dpct::dev_mgr::instance().get_device(id);
@ -4537,14 +4536,17 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_NORM:
case GGML_OP_ADD:
case GGML_OP_ADD1:
case GGML_OP_LOG:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
return true;
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_SCALE:
case GGML_OP_SQR:
case GGML_OP_SQRT:
@ -4576,7 +4578,6 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
case GGML_OP_ACC:
case GGML_OP_GROUP_NORM:
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_LEAKY_RELU:

View file

@ -156,6 +156,7 @@ struct vk_device_struct {
vk::PhysicalDeviceProperties properties;
std::string name;
uint64_t max_memory_allocation_size;
uint64_t suballocation_block_size;
bool fp16;
bool pipeline_robustness;
vk::Device device;
@ -1621,6 +1622,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
//CREATE_FA(GGML_TYPE_IQ2_S, iq2_s)
//CREATE_FA(GGML_TYPE_IQ3_XXS, iq3_xxs)
//CREATE_FA(GGML_TYPE_IQ3_S, iq3_s)
//CREATE_FA(GGML_TYPE_IQ4_XS, iq4_xs)
CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl)
#undef CREATE_FA
@ -1654,6 +1656,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f16, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3)
CREATE_MM2(pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4)
@ -1672,6 +1675,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f16, , mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4)
#undef CREATE_MM
#undef CREATE_MM2
@ -1725,6 +1729,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
} else {
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f16acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
@ -1743,6 +1748,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
}
@ -1769,6 +1775,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
} else {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f16acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
@ -1787,6 +1794,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
}
@ -1836,6 +1844,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f16acc, matmul_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f16acc, matmul_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f16acc, matmul_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f16acc, matmul_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f16acc, matmul_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
@ -1860,6 +1869,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f16acc, matmul_id_iq2_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f16acc, matmul_id_iq3_xxs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f16acc, matmul_id_iq3_s_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f16acc, matmul_id_iq4_xs_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f16acc, matmul_id_iq4_nl_f32, _f16acc, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
#undef CREATE_MM2
@ -1901,6 +1911,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
CREATE_MM(pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, );
// If there's not enough shared memory for row_ids and the result tile, don't create these pipelines.
@ -1925,6 +1936,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
CREATE_MM(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id);
}
#undef CREATE_MM
@ -1961,6 +1973,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f32_f32_len, mul_mat_vec_iq2_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f32_f32_len, mul_mat_vec_iq3_xxs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f32_f32_len, mul_mat_vec_iq3_s_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f32_f32_len, mul_mat_vec_iq4_xs_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f32_f32_len, mul_mat_vec_iq4_nl_f32_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32_"+std::to_string(i+1), mul_mat_vec_f32_f16_f32_len, mul_mat_vec_f32_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {device->subgroup_size, 2, i+1}, 1);
@ -1980,6 +1993,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq2_s_f16_f32_len, mul_mat_vec_iq2_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_xxs_f16_f32_len, mul_mat_vec_iq3_xxs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq3_s_f16_f32_len, mul_mat_vec_iq3_s_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_xs_f16_f32_len, mul_mat_vec_iq4_xs_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq, i+1}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32_"+std::to_string(i+1), mul_mat_vec_iq4_nl_f16_f32_len, mul_mat_vec_iq4_nl_f16_f32_data, "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq, i+1}, 1, true);
}
@ -2000,6 +2014,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {subgroup_size_16, 2*rm_stdq}, 1, true);
// dequant shaders
@ -2019,6 +2034,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
// get_rows
@ -2034,6 +2050,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1);
@ -2048,6 +2065,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1);
@ -2269,6 +2287,7 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->physical_device.getProperties2(&props2);
device->properties = props2.properties;
device->vendor_id = device->properties.vendorID;
const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE");
@ -2280,7 +2299,20 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->max_memory_allocation_size = props3.maxMemoryAllocationSize;
}
device->vendor_id = device->properties.vendorID;
const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv("GGML_VK_SUBALLOCATION_BLOCK_SIZE");
if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) {
device->suballocation_block_size = std::stoul(GGML_VK_SUBALLOCATION_BLOCK_SIZE);
#if defined(_WIN32)
} else if (device->vendor_id == VK_VENDOR_ID_NVIDIA) {
// Limit batching of allocations to 1GB by default to avoid fragmentation issues
device->suballocation_block_size = 1024*1024*1024;
#endif
} else {
device->suballocation_block_size = device->max_memory_allocation_size;
}
device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size);
device->subgroup_size = subgroup_props.subgroupSize;
device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu;
if (sm_builtins) {
@ -2748,8 +2780,9 @@ static void ggml_vk_print_gpu_info(size_t idx) {
std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none";
std::string device_name = props2.properties.deviceName.data();
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size, matrix_cores.c_str());
GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | warp size: %zu | shared memory: %d | matrix cores: %s\n",
idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, subgroup_size,
props2.properties.limits.maxComputeSharedMemorySize, matrix_cores.c_str());
if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) {
GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n");
@ -2980,6 +3013,7 @@ static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
break;
default:
@ -3033,6 +3067,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_conte
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
break;
default:
@ -3069,6 +3104,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
break;
default:
@ -3117,6 +3153,7 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
break;
default:
@ -3148,6 +3185,7 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
break;
default:
@ -7561,7 +7599,7 @@ static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type
static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) {
ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context;
return ctx->device->max_memory_allocation_size;
return ctx->device->suballocation_block_size;
}
static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) {
@ -8022,6 +8060,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
break;
default:
@ -8095,6 +8134,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
//case GGML_TYPE_IQ2_S:
//case GGML_TYPE_IQ3_XXS:
//case GGML_TYPE_IQ3_S:
//case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
break;
default:
@ -8117,6 +8157,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_TYPE_IQ2_S:
case GGML_TYPE_IQ3_XXS:
case GGML_TYPE_IQ3_S:
case GGML_TYPE_IQ4_XS:
case GGML_TYPE_IQ4_NL:
return true;
default:
@ -8182,9 +8223,11 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_VIEW:
case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
return true;
case GGML_OP_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_RMS_NORM:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_ADD:
case GGML_OP_ACC:
case GGML_OP_MUL:

View file

@ -12,7 +12,7 @@ layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in;
#endif
void main() {
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
if (gl_LocalInvocationIndex.x != 0) {
return;

View file

@ -217,7 +217,7 @@ void quantize(uint dst_idx, uint src_idx)
#endif
void main() {
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
if (gl_LocalInvocationIndex.x != 0) {
return;

View file

@ -304,6 +304,42 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
}
#endif
#if defined(DATA_A_IQ4_XS)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint iq = 16 * ib32 + (iqs % 16);
const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
const uint qshift = (iqs & 16) >> 2;
u8vec2 qs = u8vec2(data_a[a_offset + ib].qs[iq], data_a[a_offset + ib].qs[iq + 1]);
qs = (qs >> qshift) & uint8_t(0xF);
const float dl = float(int(sl | (sh << 4)) - 32);
return dl * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
}
vec4 dequantize4(uint ib, uint iqs, uint a_offset) {
const uint ib32 = iqs / 32;
const uint iq = 16 * ib32 + (iqs % 16);
const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3;
const uint qshift = (iqs & 16) >> 2;
u8vec4 qs = u8vec4(
data_a[a_offset + ib].qs[iq + 0],
data_a[a_offset + ib].qs[iq + 1],
data_a[a_offset + ib].qs[iq + 2],
data_a[a_offset + ib].qs[iq + 3]
);
qs = (qs >> qshift) & uint8_t(0xF);
const float dl = float(int(sl | (sh << 4)) - 32);
return dl * vec4(
kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y],
kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]);
}
#endif
#if defined(DATA_A_IQ4_NL)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
const uint vui = uint(data_a[a_offset + ib].qs[iqs]);
@ -321,7 +357,7 @@ vec2 get_dm(uint ib, uint a_offset) {
}
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
vec2 get_dm(uint ib, uint a_offset) {
return vec2(float(data_a[a_offset + ib].d), 0);
}

View file

@ -323,15 +323,16 @@ float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCo
const uint8_t qs = bl.block.qs[iqs];
const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3]));
const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t(signscale >> 28));
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28));
uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7);
sign |= bitCount(sign) << 7;
const uint8_t g = unpack8(iq2xxs_grid[qs][(idx & 4) >> 2])[idx & 3];
uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2];
g2 >>= (idx & 2) * 8;
const vec2 g = vec2(unpack8(g2));
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
return ret;
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
return float16_t(ret[idx & 1]);
}
#endif
@ -350,14 +351,16 @@ float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoor
const uint iqs = (idx & 0xF8) >> 3; // 0..63
const uint16_t qs = bl.block.qs[iqs];
const float16_t dscale = bl.block.d * 0.25hf * (0.5hf + float16_t((bl.block.scales[is] >> sshift) & 0xF));
const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF));
uint sign = uint(qs >> 9);
sign |= bitCount(sign) << 7;
const uint8_t g = unpack8(iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2])[idx & 3];
uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2];
g2 >>= (idx & 2) * 8;
const vec2 g = vec2(unpack8(g2));
float16_t ret = dscale * float16_t(g) * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
return ret;
vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf);
return float16_t(ret[idx & 1]);
}
#endif
@ -369,24 +372,23 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2
float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
uint idx = coordInBlock[1];
uint lsb = idx & 1;
idx /= 2;
const uint ib8 = (idx % 128) / 4; // 0..31
const uint ib32 = ib8 / 4; // 0..7
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
const uint ib8 = (idx & 0xF8) >> 3; // 0..31
const uint qhshift = 2 * (ib8 % 4);
const uint scale = (bl.block.scales[ib32] >> (2 * (ib8 & 2))) & 0xf;
const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf;
const uint qs = bl.block.qs[ib8];
const uint qh = bl.block.qh[ib32];
const uint qhshift = 2 * (ib8 % 4);
const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (2 * (idx % 4));
const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6);
const float d = float(bl.block.d);
const float db = d * 0.25 * (0.5 + scale);
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint16_t grid = unpack16(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 2) >> 1])[idx & 1];
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid));
return float16_t(v[lsb]);
const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign));
uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2];
g2 >>= (idx & 2) * 8;
const vec2 v = db * vec2(sign01) * vec2(unpack8(g2));
return float16_t(v[idx & 1]);
}
#endif
@ -401,28 +403,25 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3
float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl);
uint idx = coordInBlock[1];
uint lsb = idx & 1;
idx /= 2;
const uint iqs = (idx % 128) / 2; // 0..63
const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values
const uint iqs = (idx & 0xFC) >> 2; // 0..63
const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values
const float d = float(bl.block.d);
const uint qs = bl.block.qs[iqs];
const uint signs = pack32(u8vec4(
bl.block.qs[is+0],
bl.block.qs[is+1],
bl.block.qs[is+2],
bl.block.qs[is+3]
const uint signs = pack32(u16vec2(
bl16.block.qs[is/2+0],
bl16.block.qs[is/2+1]
));
const float db = d * 0.5 * (0.5 + (signs >> 28));
const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7);
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (2 * (idx % 4));
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(int8_t(sign << 1), int8_t(sign))));
const uint grid = iq3xxs_grid[qs] >> (16 * (idx & 1));
const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6);
const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
return float16_t(v[lsb]);
return float16_t(v[idx & 1]);
}
#endif
@ -434,26 +433,45 @@ layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3
float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
uint idx = coordInBlock[1];
uint lsb = idx & 1;
idx /= 2;
const uint iqs = (idx % 128) / 2; // 0..63
const uint iqh = iqs / 8;
const uint iqs = (idx & 0xFC) >> 2; // 0..63
const uint iqh = (idx & 0xE0) >> 5;
const float d = float(bl.block.d);
const uint qs = bl.block.qs[iqs];
const uint qh = bl.block.qh[iqh];
const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (2 * (idx % 4)));
const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6));
const uint scale = bl.block.scales[iqs / 16];
const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign)));
const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign)));
const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf));
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3);
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
return float16_t(v[lsb]);
return float16_t(v[idx & 1]);
}
#endif
#if defined(DATA_A_IQ4_XS)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS {
block_iq4_xs block;
};
float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
const float16_t d = bl.block.d;
const uint idx = coordInBlock[1];
const uint ib32 = (idx & 0xE0) >> 5; // 0..7
const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3;
const uint qshift = (idx & 16) >> 2;
const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF;
float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]);
return ret;
}
#endif
#if defined(DATA_A_IQ4_NL)
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL {
@ -504,6 +522,8 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
#define dequantFuncA dequantFuncIQ3_XXS
#elif defined(DATA_A_IQ3_S)
#define dequantFuncA dequantFuncIQ3_S
#elif defined(DATA_A_IQ4_XS)
#define dequantFuncA dequantFuncIQ4_XS
#elif defined(DATA_A_IQ4_NL)
#define dequantFuncA dequantFuncIQ4_NL
#endif

View file

@ -0,0 +1,34 @@
#version 450
#include "dequant_head.comp"
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {block_iq4_xs data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_b[];};
void main() {
// Each thread handles 1 subblock (1 scale and 32 quantized values)
const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8;
init_iq_shmem(gl_WorkGroupSize);
if (ib >= p.nel / 256) {
return;
}
const uint ib32 = gl_LocalInvocationID.x % 8;
const float d = float(data_a[ib].d);
// Scales are 6 bits
const uint scale = ((data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF)
| (((data_a[ib].scales_h >> (2 * ib32)) & 3) << 4);
const float dl = d * (int(scale) - 32);
const uint b_idx = 256 * ib + 32 * ib32;
const uint q_idx = 16 * ib32;
[[unroll]] for (uint l = 0; l < 16; ++l) {
data_b[b_idx + l + 0] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]);
data_b[b_idx + l + 16] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]);
}
}

View file

@ -104,7 +104,7 @@ ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE ele
#endif
void main() {
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
#endif

View file

@ -12,7 +12,7 @@ void main() {
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
#endif

View file

@ -133,7 +133,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
#endif

View file

@ -95,7 +95,7 @@ shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS];
#endif
void main() {
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
#endif
@ -547,6 +547,25 @@ void main() {
const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> (16 * (idx % 2));
const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_IQ4_XS)
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
const uint ib = idx / 128; // 2 values per idx
const uint ib32 = (idx % 128) / 16; // 0..7
const uint iq = 16 * ib32 + 2 * (idx % 8);
const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3;
const uint qshift = (idx & 8) >> 1;
u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]);
qs = (qs >> qshift) & uint8_t(0xF);
const float d = float(data_a[ib].d);
const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]);
buf_a[buf_idx ] = FLOAT_TYPE(v.x);
buf_a[buf_idx + 1] = FLOAT_TYPE(v.y);
#elif defined(DATA_A_IQ4_NL)

View file

@ -106,7 +106,7 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
#endif
void main() {
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_NL)
#if defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
init_iq_shmem(gl_WorkGroupSize);
#endif

View file

@ -1026,6 +1026,23 @@ void init_iq_shmem(uvec3 wgsize)
#define A_TYPE_PACKED16 block_iq3_s_packed16
#endif
#define QUANT_K_IQ4_XS 256
#define QUANT_R_IQ4_XS 1
struct block_iq4_xs
{
float16_t d;
uint16_t scales_h;
uint8_t scales_l[QUANT_K_IQ4_XS/64];
uint8_t qs[QUANT_K_IQ4_XS/2];
};
#if defined(DATA_A_IQ4_XS)
#define QUANT_K QUANT_K_IQ4_XS
#define QUANT_R QUANT_R_IQ4_XS
#define A_TYPE block_iq4_xs
#endif
#define QUANT_K_IQ4_NL 32
#define QUANT_R_IQ4_NL 2
@ -1042,7 +1059,13 @@ struct block_iq4_nl_packed16
};
#if defined(DATA_A_IQ4_NL)
#define QUANT_K QUANT_K_IQ4_NL
#define QUANT_R QUANT_R_IQ4_NL
#define A_TYPE block_iq4_nl
#define A_TYPE_PACKED16 block_iq4_nl_packed16
#endif
#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS)
const int8_t kvalues_iq4nl_const[16] = {
int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10),
int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113)
@ -1058,11 +1081,6 @@ void init_iq_shmem(uvec3 wgsize)
}
barrier();
}
#define QUANT_K QUANT_K_IQ4_NL
#define QUANT_R QUANT_R_IQ4_NL
#define A_TYPE block_iq4_nl
#define A_TYPE_PACKED16 block_iq4_nl_packed16
#endif
#endif // !defined(GGML_TYPES_COMP)

View file

@ -60,6 +60,7 @@ const std::vector<std::string> type_names = {
"iq2_s",
"iq3_xxs",
"iq3_s",
"iq4_xs",
"iq4_nl"
};