apply suggestions
This commit is contained in:
parent
09db1a7cf3
commit
fded2e6a11
2 changed files with 83 additions and 34 deletions
95
ggml-cuda.cu
95
ggml-cuda.cu
|
@ -5989,38 +5989,55 @@ static __global__ void im2col_f32_f16(
|
||||||
|
|
||||||
#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256
|
#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256
|
||||||
|
|
||||||
template<int block_size>
|
template<int block_size, int k_seq_len>
|
||||||
static __global__ void flash_attn_f32(const float* q, const float* k,const float* v, float* dst, float kq_scale,
|
static __global__ void flash_attn_f32(
|
||||||
int d_head, int seq_len, int num_heads) {
|
const float* __restrict__ q,
|
||||||
|
const float* __restrict__ k,
|
||||||
|
const float* __restrict__ v,
|
||||||
|
float* __restrict__ kqv,
|
||||||
|
float kq_scale,
|
||||||
|
int head_dim, int seq_len, int num_heads) {
|
||||||
const int head = blockIdx.x / seq_len;
|
const int head = blockIdx.x / seq_len;
|
||||||
const int head_size = d_head * seq_len;
|
const int head_size = head_dim * seq_len;
|
||||||
const int s = blockIdx.x % seq_len;
|
const int s = blockIdx.x % seq_len;
|
||||||
const int tid = threadIdx.x;
|
|
||||||
|
|
||||||
extern __shared__ char work_data[];
|
extern __shared__ char shmem__[];
|
||||||
float* S = (float*)work_data; // theorical sequent length: 12848, due memory per block limit
|
float* S = (float*)shmem__;
|
||||||
float* warp_data = (float*)(work_data + seq_len * sizeof(float));
|
float* warp_data = (float*)(shmem__ + seq_len * sizeof(float));
|
||||||
|
|
||||||
// QK^T
|
// QK^T
|
||||||
for(int is = tid; is < seq_len; is += block_size) {
|
#pragma unroll
|
||||||
|
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
|
||||||
|
const int is = threadIdx.x + is0;
|
||||||
|
if(is >= seq_len) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int key_offset = is * head_dim + head * head_size;
|
||||||
|
const int query_offset = s * head_dim + head * head_size;
|
||||||
|
|
||||||
S[is] = 0.0f;
|
S[is] = 0.0f;
|
||||||
int key_offset = is * d_head + head * head_size;
|
for(int d = 0; d < head_dim; d++) {
|
||||||
int query_offset = s * d_head + head * head_size;
|
|
||||||
for(int d = 0; d < d_head; d++) {
|
|
||||||
S[is] += k[key_offset + d] * q[query_offset + d];
|
S[is] += k[key_offset + d] * q[query_offset + d];
|
||||||
}
|
}
|
||||||
S[is] *= kq_scale;
|
S[is] *= kq_scale;
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
|
|
||||||
float max_val = -INFINITY;
|
float max_val = -INFINITY;
|
||||||
// get the max
|
// get the max
|
||||||
for(int is = tid; is < seq_len; is += block_size) {
|
#pragma unroll
|
||||||
|
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
|
||||||
|
const int is = threadIdx.x + is0;
|
||||||
|
if(is >= seq_len) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
max_val = fmaxf(max_val , S[is]);
|
max_val = fmaxf(max_val , S[is]);
|
||||||
}
|
}
|
||||||
|
|
||||||
max_val = warp_reduce_max(max_val);
|
max_val = warp_reduce_max(max_val);
|
||||||
|
|
||||||
{ // get max from all threads
|
{ // get max from all threads
|
||||||
int warp_id = threadIdx.x / WARP_SIZE;
|
int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
int lane_id = threadIdx.x % WARP_SIZE;
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
@ -6034,14 +6051,20 @@ static __global__ void flash_attn_f32(const float* q, const float* k,const float
|
||||||
|
|
||||||
// softmax(QK^T)
|
// softmax(QK^T)
|
||||||
float sum = 0.0f;
|
float sum = 0.0f;
|
||||||
for(int is = tid; is < seq_len;is += block_size) {
|
#pragma unroll
|
||||||
const float val = expf(S[is] - max_val);
|
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
|
||||||
S[is] = val;
|
const int is = threadIdx.x + is0;
|
||||||
sum += val;
|
if(is >= seq_len) {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
S[is] = expf(S[is] - max_val);
|
||||||
|
sum += S[is];
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
sum = warp_reduce_sum(sum);
|
sum = warp_reduce_sum(sum);
|
||||||
{ // sum partials
|
{ // softmax sum partials
|
||||||
int warp_id = threadIdx.x / WARP_SIZE;
|
int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
int lane_id = threadIdx.x % WARP_SIZE;
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
if (lane_id == 0) {
|
if (lane_id == 0) {
|
||||||
|
@ -6053,19 +6076,31 @@ static __global__ void flash_attn_f32(const float* q, const float* k,const float
|
||||||
}
|
}
|
||||||
|
|
||||||
float inv_sum = 1.0f / sum;
|
float inv_sum = 1.0f / sum;
|
||||||
for(int is = tid; is < seq_len; is += block_size) {
|
#pragma unroll
|
||||||
S[is] *= inv_sum;
|
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
|
||||||
|
const int is = threadIdx.x + is0;
|
||||||
|
if(is >= seq_len) {
|
||||||
|
break;
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
S[is] *= inv_sum;
|
||||||
// softmax(QK^T)V
|
|
||||||
for (int d = tid; d < d_head; d += block_size) {
|
|
||||||
int dst_index = d + s * d_head + head * head_size;
|
|
||||||
int value_offset = d * seq_len + head * head_size;
|
|
||||||
dst[dst_index] = 0.0f;
|
|
||||||
for(int ic = 0; ic < seq_len; ic++) {
|
|
||||||
dst[dst_index] += v[value_offset + ic] * S[ic];
|
|
||||||
}
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// softmax(QK^T)V
|
||||||
|
for (int d = threadIdx.x; d < head_dim; d += block_size) {
|
||||||
|
const int dst_index = d + s * head_dim + head * head_size;
|
||||||
|
const int value_offset = d * seq_len + head * head_size;
|
||||||
|
|
||||||
|
float temp = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for(int ic = 0; ic < k_seq_len;ic++) {
|
||||||
|
if(ic >= seq_len) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
temp += v[value_offset + ic] * S[ic];
|
||||||
|
}
|
||||||
|
kqv[dst_index] = temp;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -7462,7 +7497,7 @@ static void im2col_f32_f16_cuda(const float* x, half* dst,
|
||||||
static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) {
|
static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) {
|
||||||
int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float);
|
int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float);
|
||||||
int num_blocks = num_heads * seq_len;
|
int num_blocks = num_heads * seq_len;
|
||||||
flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>(
|
flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE, 1024><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>(
|
||||||
q, k, v, dst, kq_scale, d_head, seq_len, num_heads);
|
q, k, v, dst, kq_scale, d_head, seq_len, num_heads);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,8 +23,9 @@ struct test_model {
|
||||||
struct ggml_tensor * k;
|
struct ggml_tensor * k;
|
||||||
struct ggml_tensor * v;
|
struct ggml_tensor * v;
|
||||||
ggml_backend_t backend = NULL;
|
ggml_backend_t backend = NULL;
|
||||||
ggml_backend_buffer_t buffer;
|
ggml_backend_buffer_t buffer = NULL;
|
||||||
struct ggml_context * ctx;
|
struct ggml_context * ctx = NULL;
|
||||||
|
bool naive_attn = false;
|
||||||
};
|
};
|
||||||
|
|
||||||
static std::vector<float> tensor_to_float(const ggml_tensor * t) {
|
static std::vector<float> tensor_to_float(const ggml_tensor * t) {
|
||||||
|
@ -216,8 +217,16 @@ struct ggml_cgraph * build_graph(const test_model& model, struct ggml_allocr * a
|
||||||
|
|
||||||
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
struct ggml_cgraph * gf = ggml_new_graph(ctx0);
|
||||||
|
|
||||||
|
if(!model.naive_attn) {
|
||||||
struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false);
|
struct ggml_tensor* result = ggml_flash_attn(ctx0, model.q, model.k, model.v, false);
|
||||||
ggml_build_forward_expand(gf, result);
|
ggml_build_forward_expand(gf, result);
|
||||||
|
} else {
|
||||||
|
struct ggml_tensor* kq = ggml_mul_mat(ctx0, model.k, model.q);
|
||||||
|
kq = ggml_scale_inplace(ctx0, kq, 1.0f / sqrtf((float)model.q->ne[0]));
|
||||||
|
kq = ggml_soft_max(ctx0, kq);
|
||||||
|
kq = ggml_mul_mat(ctx0, model.v, kq);
|
||||||
|
ggml_build_forward_expand(gf, kq);
|
||||||
|
}
|
||||||
|
|
||||||
// delete the temporally context used to build the graph
|
// delete the temporally context used to build the graph
|
||||||
ggml_free(ctx0);
|
ggml_free(ctx0);
|
||||||
|
@ -330,15 +339,18 @@ struct ggml_tensor* compute_graph(const test_model & model, ggml_backend_t backe
|
||||||
int main(int argc, char ** argv)
|
int main(int argc, char ** argv)
|
||||||
{
|
{
|
||||||
bool compare_backend = false;
|
bool compare_backend = false;
|
||||||
|
test_model model;
|
||||||
for (int i = 1; i < argc; i++) {
|
for (int i = 1; i < argc; i++) {
|
||||||
if (strcmp(argv[i], "comp") == 0) {
|
if (strcmp(argv[i], "comp") == 0) {
|
||||||
compare_backend = true;
|
compare_backend = true;
|
||||||
|
} else if (strcmp(argv[i], "naive") == 0) {
|
||||||
|
model.naive_attn = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_time_init();
|
ggml_time_init();
|
||||||
|
|
||||||
test_model model;
|
|
||||||
load_model(model, true);
|
load_model(model, true);
|
||||||
|
|
||||||
ggml_backend_buffer_t buf_compute; // for compute
|
ggml_backend_buffer_t buf_compute; // for compute
|
||||||
|
@ -359,9 +371,11 @@ int main(int argc, char ** argv)
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_t backend_cpu = ggml_backend_cpu_init();
|
ggml_backend_t backend_cpu = ggml_backend_cpu_init();
|
||||||
|
uint64_t compute_time_us__ = ggml_time_us();
|
||||||
struct ggml_tensor * result = compute_graph(model, backend_cpu, allocr, compare_backend);
|
struct ggml_tensor * result = compute_graph(model, backend_cpu, allocr, compare_backend);
|
||||||
if(!compare_backend) {
|
if(!compare_backend) {
|
||||||
|
ggml_backend_synchronize(model.backend);
|
||||||
|
printf("computing time: %.4f ms\n", (ggml_time_us() - compute_time_us__) / 1000.0);
|
||||||
float* data = new float[ggml_nelements(result)];
|
float* data = new float[ggml_nelements(result)];
|
||||||
|
|
||||||
ggml_backend_tensor_get(result, data, 0, ggml_nbytes(result));
|
ggml_backend_tensor_get(result, data, 0, ggml_nbytes(result));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue