cuda : increase C to 128 for better performance
This commit is contained in:
parent
9a5c2a1681
commit
ac26f27028
4 changed files with 37 additions and 29 deletions
61
ggml-cuda.cu
61
ggml-cuda.cu
|
@ -6495,8 +6495,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
half M[Q];
|
half M[Q];
|
||||||
|
|
||||||
for(int i = 0; i < Q; i++) {
|
for(int i = 0; i < Q; i++) {
|
||||||
S[i] = 0.0f;
|
S[i] = __float2half(0.0f);
|
||||||
M[i] = -INFINITY;
|
M[i] = __float2half(-INFINITY);
|
||||||
}
|
}
|
||||||
|
|
||||||
// assume K and V are same shape
|
// assume K and V are same shape
|
||||||
|
@ -6579,7 +6579,7 @@ static __global__ void flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
|
|
||||||
// used to detect blocks full of -INF
|
// used to detect blocks full of -INF
|
||||||
half smax = -INFINITY;
|
half smax = __float2half(-INFINITY);
|
||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
if (C == 32) {
|
if (C == 32) {
|
||||||
|
@ -6592,8 +6592,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
smax = warp_reduce_max(__hmax(smax, s));
|
smax = warp_reduce_max(__hmax(smax, s));
|
||||||
M[j] = warp_reduce_max(__hmax(M[j], s));
|
M[j] = warp_reduce_max(__hmax(M[j], s));
|
||||||
|
|
||||||
const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]);
|
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
|
||||||
const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]);
|
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
|
||||||
|
|
||||||
S[j] = S[j]*ms + warp_reduce_sum(vs);
|
S[j] = S[j]*ms + warp_reduce_sum(vs);
|
||||||
|
|
||||||
|
@ -6612,32 +6612,37 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int64_t p = lane_id; p < C; p += NW) {
|
for (int64_t p = lane_id; p < C; p += NW) {
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
smax = warp_reduce_max(__hmax(smax, s));
|
smax = __hmax(smax, s);
|
||||||
M[j] = warp_reduce_max(__hmax(M[j], s));
|
M[j] = __hmax(M[j], s);
|
||||||
}
|
}
|
||||||
|
|
||||||
const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]);
|
smax = warp_reduce_max(smax);
|
||||||
|
M[j] = warp_reduce_max(M[j]);
|
||||||
|
|
||||||
S[j] = S[j]*ms;
|
const half ms = __hisinf(m) ? __float2half(0.0f) : hexp(m - M[j]);
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
if (lane_id == j) {
|
if (lane_id == j) {
|
||||||
ss[j*T + C + j] = ms;
|
ss[j*T + C + j] = ms;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// local sum
|
||||||
|
half ls = 0.0f;
|
||||||
|
|
||||||
for (int64_t p = lane_id; p < C; p += NW) {
|
for (int64_t p = lane_id; p < C; p += NW) {
|
||||||
const half s = ss[j*T + p];
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]);
|
const half vs = __hisinf(s) ? __float2half(0.0f) : hexp(s - M[j]);
|
||||||
|
|
||||||
S[j] = S[j] + warp_reduce_sum(vs);
|
ls += vs;
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
ss[j*T + p] = vs;
|
ss[j*T + p] = vs;
|
||||||
}
|
}
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
|
S[j] = S[j]*ms + warp_reduce_sum(ls);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// skip -INF blocks
|
// skip -INF blocks
|
||||||
if (__hisinf(smax)) {
|
if (__hisinf(smax)) {
|
||||||
|
@ -6669,15 +6674,19 @@ static __global__ void flash_attn_ext_f16(
|
||||||
for (int cc = 0; cc < C/16; ++cc) {
|
for (int cc = 0; cc < C/16; ++cc) {
|
||||||
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
|
half16x16_b mk[D16];
|
||||||
for (int64_t i = 0; i < D16; ++i) {
|
for (int64_t i = 0; i < D16; ++i) {
|
||||||
half16x16_b mk;
|
nvcuda::wmma::load_matrix_sync(mk[i], pv + i*16, nb21/sizeof(half));
|
||||||
nvcuda::wmma::load_matrix_sync(mk, pv + i*16, nb21/sizeof(half));
|
}
|
||||||
|
|
||||||
|
half16x16_a mv[Q16];
|
||||||
|
for (int64_t j = 0; j < Q16; ++j) {
|
||||||
|
nvcuda::wmma::load_matrix_sync(mv[j], ss + 16*j*T + 16*cc, T);
|
||||||
|
}
|
||||||
|
|
||||||
for (int64_t j = 0; j < Q16; ++j) {
|
for (int64_t j = 0; j < Q16; ++j) {
|
||||||
half16x16_a mv;
|
for (int64_t i = 0; i < D16; ++i) {
|
||||||
nvcuda::wmma::load_matrix_sync(mv, ss + 16*j*T + 16*cc, T);
|
nvcuda::wmma::mma_sync(lo[j][i], mv[j], mk[i], lo[j][i]);
|
||||||
|
|
||||||
nvcuda::wmma::mma_sync(lo[j][i], mv, mk, lo[j][i]);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6695,8 +6704,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
// reduce the warps sequentially
|
// reduce the warps sequentially
|
||||||
for (int64_t sg = 1; sg < num_warps; ++sg) {
|
for (int64_t sg = 1; sg < num_warps; ++sg) {
|
||||||
half S = 0.0f;
|
half S = __float2half(0.0f);
|
||||||
half M = -INFINITY;
|
half M = __float2half(-INFINITY);
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
@ -6722,8 +6731,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
|
|
||||||
M = __hmax(M0, M1);
|
M = __hmax(M0, M1);
|
||||||
|
|
||||||
const half ms0 = __hisinf(M0) ? 0.0f : expf(M0 - M);
|
const half ms0 = __hisinf(M0) ? __float2half(0.0f) : hexp(M0 - M);
|
||||||
const half ms1 = __hisinf(M1) ? 0.0f : expf(M1 - M);
|
const half ms1 = __hisinf(M1) ? __float2half(0.0f) : hexp(M1 - M);
|
||||||
|
|
||||||
S = S0*ms0 + S1*ms1;
|
S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
|
@ -6770,8 +6779,6 @@ static __global__ void flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// float2 * dst2 = (float2 *) dst;
|
|
||||||
|
|
||||||
// final rescale with 1/S and store to global memory
|
// final rescale with 1/S and store to global memory
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||||
|
@ -9637,7 +9644,7 @@ static void ggml_cuda_op_soft_max(
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t nrows_x = ggml_nrows(src0);
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
|
const int64_t nrows_y = src1 ? src0->ne[1] : 1; // note: using number of queries since mask can be padded!
|
||||||
|
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
memcpy(&scale, dst->op_params, sizeof(float));
|
memcpy(&scale, dst->op_params, sizeof(float));
|
||||||
|
@ -10932,7 +10939,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
memcpy(&scale, KQV->op_params, sizeof(float));
|
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||||
|
|
||||||
#define NQPB 16
|
#define NQPB 16
|
||||||
#define NCPW 32
|
#define NCPW 128
|
||||||
|
|
||||||
const int nqpb = NQPB; // queries per block
|
const int nqpb = NQPB; // queries per block
|
||||||
const int ncpw = NCPW; // cache values per warp (does not work for other values)
|
const int ncpw = NCPW; // cache values per warp (does not work for other values)
|
||||||
|
|
2
ggml.c
2
ggml.c
|
@ -5089,7 +5089,7 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
GGML_ASSERT(mask->ne[2] == 1);
|
GGML_ASSERT(mask->ne[2] == 1);
|
||||||
GGML_ASSERT(mask->ne[3] == 1);
|
GGML_ASSERT(mask->ne[3] == 1);
|
||||||
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
|
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
|
|
@ -6881,7 +6881,8 @@ static int llama_decode_internal(
|
||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
// after enough generations, the benefit from this heuristic disappears
|
// after enough generations, the benefit from this heuristic disappears
|
||||||
// if we start defragmenting the cache, the benefit from this will be more important
|
// if we start defragmenting the cache, the benefit from this will be more important
|
||||||
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
|
// note: we pad the n_kv because certain GPU kernels require it (e.g. ggml_flash_attn_ext)
|
||||||
|
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(128, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128)));
|
||||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||||
|
|
||||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||||
|
|
|
@ -2210,7 +2210,7 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
test_cases.emplace_back(new test_leaky_relu());
|
test_cases.emplace_back(new test_leaky_relu());
|
||||||
|
|
||||||
#if 1
|
#if 1
|
||||||
for (int hs : { 64, 80, 128, }) {
|
for (int hs : { 128, 64, 80, }) {
|
||||||
for (int nh : { 32, }) {
|
for (int nh : { 32, }) {
|
||||||
for (int kv : { 512, 1024, 2048, 4096, }) {
|
for (int kv : { 512, 1024, 2048, 4096, }) {
|
||||||
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {
|
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue