cuda : switch to F16 scalars + tune warps for RTX 2060
This commit is contained in:
parent
2c04beeb81
commit
9a5c2a1681
2 changed files with 61 additions and 47 deletions
94
ggml-cuda.cu
94
ggml-cuda.cu
|
@ -6491,8 +6491,8 @@ static __global__ void flash_attn_ext_f16(
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
{
|
{
|
||||||
float S[Q];
|
half S[Q];
|
||||||
float M[Q];
|
half M[Q];
|
||||||
|
|
||||||
for(int i = 0; i < Q; i++) {
|
for(int i = 0; i < Q; i++) {
|
||||||
S[i] = 0.0f;
|
S[i] = 0.0f;
|
||||||
|
@ -6579,67 +6579,68 @@ static __global__ void flash_attn_ext_f16(
|
||||||
}
|
}
|
||||||
|
|
||||||
// used to detect blocks full of -INF
|
// used to detect blocks full of -INF
|
||||||
float smax = -INFINITY;
|
half smax = -INFINITY;
|
||||||
|
|
||||||
// online softmax
|
// online softmax
|
||||||
if (C == 32) {
|
if (C == 32) {
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
const int64_t p = lane_id;
|
const int64_t p = lane_id;
|
||||||
|
|
||||||
const float m = M[j];
|
const half m = M[j];
|
||||||
const float s = __half2float(ss[j*T + p]);
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
smax = warp_reduce_max(max(smax, s));
|
smax = warp_reduce_max(__hmax(smax, s));
|
||||||
M[j] = warp_reduce_max(max(M[j], s));
|
M[j] = warp_reduce_max(__hmax(M[j], s));
|
||||||
|
|
||||||
const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
|
const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]);
|
||||||
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
|
const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]);
|
||||||
|
|
||||||
S[j] = S[j]*ms + warp_reduce_sum(vs);
|
S[j] = S[j]*ms + warp_reduce_sum(vs);
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
if (p == j) {
|
if (p == j) {
|
||||||
ss[j*T + C + j] = __float2half(ms);
|
ss[j*T + C + j] = ms;
|
||||||
}
|
}
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
ss[j*T + p] = __float2half(vs);
|
ss[j*T + p] = vs;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
const float m = M[j];
|
const half m = M[j];
|
||||||
|
|
||||||
for (int64_t p = lane_id; p < C; p += NW) {
|
for (int64_t p = lane_id; p < C; p += NW) {
|
||||||
const float s = __half2float(ss[j*T + p]);
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
smax = warp_reduce_max(max(smax, s));
|
smax = warp_reduce_max(__hmax(smax, s));
|
||||||
M[j] = warp_reduce_max(max(M[j], s));
|
M[j] = warp_reduce_max(__hmax(M[j], s));
|
||||||
}
|
}
|
||||||
|
|
||||||
const float ms = m == -INFINITY ? 0.0f : expf(m - M[j]);
|
const half ms = __hisinf(m) ? 0.0f : expf(m - M[j]);
|
||||||
|
|
||||||
S[j] = S[j]*ms;
|
S[j] = S[j]*ms;
|
||||||
|
|
||||||
// create a QxQ diagonal matrix for rescaling the output
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
if (lane_id == j) {
|
if (lane_id == j) {
|
||||||
ss[j*T + C + j] = __float2half(ms);
|
ss[j*T + C + j] = ms;
|
||||||
}
|
}
|
||||||
|
|
||||||
for (int64_t p = lane_id; p < C; p += NW) {
|
for (int64_t p = lane_id; p < C; p += NW) {
|
||||||
const float s = __half2float(ss[j*T + p]);
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
const float vs = s == -INFINITY ? 0.0f : expf(s - M[j]);
|
const half vs = __hisinf(s) ? 0.0f : expf(s - M[j]);
|
||||||
|
|
||||||
S[j] = S[j] + warp_reduce_sum(vs);
|
S[j] = S[j] + warp_reduce_sum(vs);
|
||||||
|
|
||||||
// the P matrix from the paper (Q rows, C columns)
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
ss[j*T + p] = __float2half(vs);
|
ss[j*T + p] = vs;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// skip -INF blocks
|
// skip -INF blocks
|
||||||
if (smax == -INFINITY) {
|
if (__hisinf(smax)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6686,16 +6687,16 @@ static __global__ void flash_attn_ext_f16(
|
||||||
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
if (lane_id == 0) {
|
if (lane_id == 0) {
|
||||||
ss[j*T + 0] = __float2half(S[j]);
|
ss[j*T + 0] = S[j];
|
||||||
ss[j*T + 1] = __float2half(M[j]);
|
ss[j*T + 1] = M[j];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// reduce the warps sequentially
|
// reduce the warps sequentially
|
||||||
for (int64_t sg = 1; sg < num_warps; ++sg) {
|
for (int64_t sg = 1; sg < num_warps; ++sg) {
|
||||||
float S = 0.0f;
|
half S = 0.0f;
|
||||||
float M = -INFINITY;
|
half M = -INFINITY;
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
|
@ -6713,25 +6714,25 @@ static __global__ void flash_attn_ext_f16(
|
||||||
// the first simdgroup accumulates the results from the other simdgroups
|
// the first simdgroup accumulates the results from the other simdgroups
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
for (int64_t j = 0; j < Q; ++j) {
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
const float S0 = __half2float(ss[j*T + 0]);
|
const half S0 = ss[j*T + 0];
|
||||||
const float S1 = __half2float(ss[j*T + sg*SH + 0]);
|
const half S1 = ss[j*T + sg*SH + 0];
|
||||||
|
|
||||||
const float M0 = __half2float(ss[j*T + 1]);
|
const half M0 = ss[j*T + 1];
|
||||||
const float M1 = __half2float(ss[j*T + sg*SH + 1]);
|
const half M1 = ss[j*T + sg*SH + 1];
|
||||||
|
|
||||||
M = max(M0, M1);
|
M = __hmax(M0, M1);
|
||||||
|
|
||||||
const float ms0 = M0 == -INFINITY ? 0.0f : expf(M0 - M);
|
const half ms0 = __hisinf(M0) ? 0.0f : expf(M0 - M);
|
||||||
const float ms1 = M1 == -INFINITY ? 0.0f : expf(M1 - M);
|
const half ms1 = __hisinf(M1) ? 0.0f : expf(M1 - M);
|
||||||
|
|
||||||
S = S0*ms0 + S1*ms1;
|
S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
if (lane_id == 0) {
|
if (lane_id == 0) {
|
||||||
ss[j*T + 0] = __float2half(S);
|
ss[j*T + 0] = S;
|
||||||
ss[j*T + 1] = __float2half(M);
|
ss[j*T + 1] = M;
|
||||||
|
|
||||||
ss[j*T + C + j ] = __float2half(ms0);
|
ss[j*T + C + j ] = ms0;
|
||||||
ss[j*T + C + j + sg*SH] = __float2half(ms1);
|
ss[j*T + C + j + sg*SH] = ms1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -6774,10 +6775,10 @@ static __global__ void flash_attn_ext_f16(
|
||||||
// final rescale with 1/S and store to global memory
|
// final rescale with 1/S and store to global memory
|
||||||
if (warp_id == 0) {
|
if (warp_id == 0) {
|
||||||
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||||
const float S = __half2float(ss[j*T + 0]);
|
const half S = ss[j*T + 0];
|
||||||
|
|
||||||
for (int64_t i = lane_id; i < D; i += NW) {
|
for (int64_t i = lane_id; i < D; i += NW) {
|
||||||
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i]) / S;
|
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10930,12 +10931,15 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
float scale;
|
float scale;
|
||||||
memcpy(&scale, KQV->op_params, sizeof(float));
|
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||||
|
|
||||||
const int nqpb = 16; // queries per block
|
#define NQPB 16
|
||||||
const int ncpw = 32; // cache values per warp (does not work for other values)
|
#define NCPW 32
|
||||||
|
|
||||||
|
const int nqpb = NQPB; // queries per block
|
||||||
|
const int ncpw = NCPW; // cache values per warp (does not work for other values)
|
||||||
|
|
||||||
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
|
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
|
||||||
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
|
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
|
||||||
const int nwarps = Q->ne[1] <= nqpb ? MAX(4, MIN(K->ne[1]/ncpw, nwarps_max)) : 4;
|
const int nwarps = Q->ne[1] <= nqpb ? MAX(2, MIN(K->ne[1]/ncpw, nwarps_max)) : 2;
|
||||||
|
|
||||||
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
||||||
dim3 block_dim(32, nwarps, 1);
|
dim3 block_dim(32, nwarps, 1);
|
||||||
|
@ -10945,7 +10949,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
switch (Q->ne[0])
|
switch (Q->ne[0])
|
||||||
{
|
{
|
||||||
case 16:
|
case 16:
|
||||||
flash_attn_ext_f16<16, 16, 32>
|
flash_attn_ext_f16<16, NQPB, NCPW>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
@ -10962,7 +10966,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
case 64:
|
case 64:
|
||||||
flash_attn_ext_f16<64, 16, 32>
|
flash_attn_ext_f16<64, NQPB, NCPW>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
@ -10979,7 +10983,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
case 80:
|
case 80:
|
||||||
flash_attn_ext_f16<80, 16, 32>
|
flash_attn_ext_f16<80, NQPB, NCPW>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
@ -10996,7 +11000,7 @@ inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor *
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
case 128:
|
case 128:
|
||||||
flash_attn_ext_f16<128, 16, 32>
|
flash_attn_ext_f16<128, NQPB, NCPW>
|
||||||
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
(const char *) src0_extra->data_device[g_main_device], // Query
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
(const char *) src1_extra->data_device[g_main_device], // Key
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
|
|
@ -572,9 +572,19 @@ struct test_case {
|
||||||
// duplicate the op
|
// duplicate the op
|
||||||
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
|
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
|
||||||
int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
|
int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
|
||||||
|
#if 1
|
||||||
for (int i = 1; i < n_runs; i++) {
|
for (int i = 1; i < n_runs; i++) {
|
||||||
gf->nodes[gf->n_nodes++] = out;
|
gf->nodes[gf->n_nodes++] = out;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
n_runs = 1000;
|
||||||
|
int n_nodes = gf->n_nodes;
|
||||||
|
for (int i = 1; i < n_runs; i++) {
|
||||||
|
for (int j = 0; j < n_nodes; j++) {
|
||||||
|
gf->nodes[gf->n_nodes++] = gf->nodes[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// calculate memory
|
// calculate memory
|
||||||
size_t mem = n_runs * op_size(out);
|
size_t mem = n_runs * op_size(out);
|
||||||
|
@ -2199,8 +2209,8 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
test_cases.emplace_back(new test_pad());
|
test_cases.emplace_back(new test_pad());
|
||||||
test_cases.emplace_back(new test_leaky_relu());
|
test_cases.emplace_back(new test_leaky_relu());
|
||||||
|
|
||||||
#if 0
|
#if 1
|
||||||
for (int hs : { 64, 80, 96, 112, 128, 256, }) {
|
for (int hs : { 64, 80, 128, }) {
|
||||||
for (int nh : { 32, }) {
|
for (int nh : { 32, }) {
|
||||||
for (int kv : { 512, 1024, 2048, 4096, }) {
|
for (int kv : { 512, 1024, 2048, 4096, }) {
|
||||||
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {
|
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue