diff --git a/examples/batched-bench/batched-bench.cpp b/examples/batched-bench/batched-bench.cpp index 19674dfd3..78ba9d037 100644 --- a/examples/batched-bench/batched-bench.cpp +++ b/examples/batched-bench/batched-bench.cpp @@ -100,7 +100,7 @@ int main(int argc, char ** argv) { ctx_params.seed = 1234; ctx_params.n_ctx = n_kv_max; - ctx_params.n_batch = 512; + ctx_params.n_batch = 2048; ctx_params.n_threads = params.n_threads; ctx_params.n_threads_batch = params.n_threads_batch == -1 ? params.n_threads : params.n_threads_batch; diff --git a/ggml-cuda.cu b/ggml-cuda.cu index be8e33a56..31bfc43c1 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -14,6 +14,7 @@ #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/dmmv.cuh" +#include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmq.cuh" @@ -2382,6 +2383,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_ARGSORT: ggml_cuda_op_argsort(ctx, dst); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_cuda_flash_attn_ext(ctx, dst); + break; default: return false; } @@ -2656,6 +2660,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons case GGML_OP_ARANGE: case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; default: return false; diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu new file mode 100644 index 000000000..bcf27fd79 --- /dev/null +++ b/ggml-cuda/fattn.cu @@ -0,0 +1,593 @@ +#include "fattn.cuh" + +#include + +static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + a = __hadd2(a, __shfl_xor_sync(0xffffffff, a, mask, 32)); + } + return a; +#else + GGML_UNUSED(a); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL +} + +static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { +#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +#pragma unroll + for (int mask = 16; mask > 0; mask >>= 1) { + x = __hmax2(x, __shfl_xor_sync(0xffffffff, x, mask, 32)); + } + return x; +#else + GGML_UNUSED(x); + NO_DEVICE_CODE; +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX +} + +#if __CUDA_ARCH__ >= CC_VOLTA +typedef nvcuda::wmma::fragment half16x16_a; +typedef nvcuda::wmma::fragment half16x16_b; +typedef nvcuda::wmma::fragment half16x16_bT; +typedef nvcuda::wmma::fragment half16x16_acc; +#endif + +// based on metal version +template // D head size, Q queries per block, C cache items per block +static __global__ void flash_attn_ext_f16( + const char * __restrict__ q, + const char * __restrict__ k, + const char * __restrict__ v, + const char * __restrict__ mask, + float * __restrict__ dst, + float scale, + int ne00, + int ne01, + int ne02, + int ne03, + int ne10, + int ne11, + int ne12, + int ne13, + int ne31, + int nb31, + int nb01, + int nb02, + int nb03, + int nb11, + int nb12, + int nb13, + int ne0, + int ne1, + int ne2, + int ne3) { +#if __CUDA_ARCH__ >= CC_VOLTA + const int warp_id = threadIdx.y; + const int lane_id = threadIdx.x; + + const int num_warps = blockDim.y; // number of warps + const int iq3 = blockIdx.z; + const int iq2 = blockIdx.y; + const int iq1 = blockIdx.x * Q; + + const int D16 = D/16; + const int Q16 = Q/16; + const int C16 = C/16; + + const int NW = WARP_SIZE; + const int SH = (C + Q); // shared memory per simdgroup in (half) + + const int T = D + num_warps*SH; // shared memory size per query in (half) + const int T2 = T/2; // shared memory size per query in (half2) + const int C2 = C/2; + const int D2 = D/2; + + extern __shared__ half __flash_attn_f16_shmem[]; + // pq + half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data + half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2 + half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix + half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2 + + half16x16_acc zr; + half16x16_acc lo[Q16][D16]; + + // load heads from Q to shared memory +#pragma unroll + for (int j0 = 0; j0 < Q; j0 += num_warps) { + const int j = j0 + warp_id; + if (j >= Q) { + break; + } + + const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + +#pragma unroll + for (int i0 = 0; i0 < D2; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D2) { + break; + } + + if (iq1 + j < ne01) { + sq2[j*T2 + i] = __float22half2_rn(q2[i]); + } else { + sq2[j*T2 + i] = make_half2(0.0, 0.0); + } + } + } + + nvcuda::wmma::fill_fragment(zr, 0.0); + + // zero out lo + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { + nvcuda::wmma::fill_fragment(lo[j][i], 0.0); + } + } + + // zero out shared memory SH + for (int j = 0; j < Q; ++j) { + for (int i0 = 0; i0 < SH; i0 += NW) { + const int i = i0 + lane_id; + if (i >= SH) { + break; + } + + ss[j*T + i] = 0.0; + } + } + + __syncthreads(); + + { + half S = __float2half(0.0f); + half M[Q]; + + for (int i = 0; i < Q; ++i) { + M[i] = CUDART_MIN_DENORM_FP16; + } + + // assume K and V are same shape + const int ne22 = ne12; + const int ne23 = ne13; + + const int nb21 = nb11; + const int nb22 = nb12; + const int nb23 = nb13; + + // broadcast + const int rk2 = ne02/ne12; + const int rk3 = ne03/ne13; + + const int rv2 = ne02/ne22; + const int rv3 = ne03/ne23; + + // k indices + const int ik2 = iq2 / rk2; + const int ik3 = iq3 / rk3; + + // v indices + const int iv2 = iq2 / rv2; + const int iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + half16x16_a mq[Q16][D16]; + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T); + } + } + + // pointer to the mask + const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr; + + // prepare diagonal scale matrix + half16x16_b mscale; + for (int i = 0; i < 16; ++i) { + ss[i*T + i] = __float2half(scale); + } + nvcuda::wmma::load_matrix_sync(mscale, ss, T); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) { + const int ic = ic0 + warp_id*C; + if (ic >= ne11) { + break; + } + + // Q*K^T + { +#pragma unroll + for (int cc = 0; cc < C16; ++cc) { + half16x16_acc mqk[Q16]; + for (int j = 0; j < Q16; ++j) { + nvcuda::wmma::fill_fragment(mqk[j], 0); + } + + const half * pk = (const half *) ((const char *) k + ((ic + 16*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (int i = 0; i < D16; ++i) { + half16x16_bT mk; // transposed key + nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half)); + + for (int j = 0; j < Q16; ++j) { + nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]); + } + } + + // mqk = mqk*scale + mask + for (int j = 0; j < Q16; ++j) { + half16x16_a mqka; + half16x16_acc mm; + + if (mp) { + nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major); + } + + // convert accumulator to matrix_a + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T); + + nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr); + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major); + } + } + } + + // used to detect blocks full of -INF + half2 smax = make_half2(-INFINITY, -INFINITY); + + // online softmax + for (int j = 0; j < Q; ++j) { + const half m = M[j]; + + for (int p0 = 0; p0 < C2; p0 += NW) { + const int p = p0 + lane_id; + + const half2 s = ss2[j*T2 + p]; + + smax = __hmax2(smax, s); + M[j] = __hmax(M[j], __hmax(s.x, s.y)); + } + + M[j] = warp_reduce_max(M[j]); + + // local sum + half2 ls = make_half2(0.0f, 0.0f); + half2 M2 = make_half2(M[j], M[j]); + + for (int p0 = 0; p0 < C2; p0 += NW) { + const int p = p0 + lane_id; + + const half2 s = ss2[j*T2 + p]; + + const half2 vs = h2exp(s - M2); + + ls += vs; + + // the P matrix from the paper (Q rows, C columns) + ss2[j*T2 + p] = vs; + } + + ls = warp_reduce_sum(ls); + + const half ms = hexp(m - M[j]); + + // create a QxQ diagonal matrix for rescaling the output + if (lane_id == j) { + ss[j*T + C + j] = ms; + + S = S*ms + ls.x + ls.y; + } + } + + smax = warp_reduce_max(smax); + + // skip -INF blocks + if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) { + continue; + } + + // O = diag(ms)*O + for (int j = 0; j < Q16; ++j) { + half16x16_a mm; + half16x16_b lob; + + nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T); + + for (int i = 0; i < D16; ++i) { + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T); + + nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr); + } + } + + // restore zeros + for (int j = 0; j < Q16; ++j) { + nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major); + } + + // O = O + (Q*K^T)*V + { + for (int cc = 0; cc < C16; ++cc) { + const half * pv = (const half *) ((const char *) v + ((ic + 16*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + half16x16_b mv[D16]; + for (int i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(mv[i], pv + i*16, nb21/sizeof(half)); + } + + half16x16_a ms[Q16]; + for (int j = 0; j < Q16; ++j) { + nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T); + } + + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { + nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[i], lo[j][i]); + } + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + if (lane_id < Q) { + ss[lane_id*T + 0] = S; + ss[lane_id*T + 1] = M[lane_id]; + } + } + + // reduce the warps sequentially + for (int sg = 1; sg < num_warps; ++sg) { + __syncthreads(); + + // each simdgroup stores its output to shared memory, reusing sq + if (warp_id == sg) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + } + } + } + + __syncthreads(); + + // the first simdgroup accumulates the results from the other simdgroups + if (warp_id == 0) { + for (int j = lane_id; j < Q; j += NW) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; + + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; + + const half M = __hmax(M0, M1); + + const half ms0 = hexp(M0 - M); + const half ms1 = hexp(M1 - M); + + const half S = S0*ms0 + S1*ms1; + + ss[j*T + 0] = S; + ss[j*T + 1] = M; + + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (int j = 0; j < Q16; ++j) { + half16x16_a ms0; + half16x16_a ms1; + half16x16_b t; + half16x16_acc t2; + + nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T); + nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T); + + for (int i = 0; i < D16; ++i) { + nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T); + nvcuda::wmma::mma_sync(t2, ms1, t, zr); + + // convert accumulator to matrix_b + nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T); + + nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2); + } + } + } + } + + // store result to shared memory (reuse sq) + if (warp_id == 0) { + for (int j = 0; j < Q16; ++j) { + for (int i = 0; i < D16; ++i) { + nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major); + } + } + } + + // final rescale with 1/S and store to global memory + if (warp_id == 0) { + for (int j = 0; j < Q && iq1 + j < ne01; ++j) { + const half S = ss[j*T + 0]; + + for (int i0 = 0; i0 < D; i0 += NW) { + const int i = i0 + lane_id; + if (i >= D) { + break; + } + + dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S); + } + } + } +#else + NO_DEVICE_CODE; +#endif +} + +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + + const ggml_tensor * mask = dst->src[3]; + + ggml_tensor * KQV = dst; + + GGML_ASSERT(Q->type == GGML_TYPE_F32); + GGML_ASSERT(K->type == GGML_TYPE_F16); + GGML_ASSERT(V->type == GGML_TYPE_F16); + GGML_ASSERT(KQV->type == GGML_TYPE_F32); + + GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16); + GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && + "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); + + ggml_cuda_set_device(ctx.device); + + const cudaStream_t main_stream = ctx.stream(); + + float scale; + memcpy(&scale, KQV->op_params, sizeof(float)); + +#define NQPB 16 +#define NCPW 128 + + const int nqpb = NQPB; // queries per block + const int ncpw = NCPW; // cache values per warp (does not work for other values) + + GGML_ASSERT(NQPB <= 32); + + const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much? + // TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why + const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1; + + dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]); + dim3 block_dim(32, nwarps, 1); + + const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2); + + // increase shared memory limit to 96KB + //const size_t shmem_max = 96*1024; + //cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max); + + switch (Q->ne[0]) { + case 64: + flash_attn_ext_f16<64, NQPB, NCPW> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? (const char *) mask->data : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 80: + flash_attn_ext_f16<80, NQPB, NCPW> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? (const char *) mask->data : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 96: + flash_attn_ext_f16<96, NQPB, NCPW> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? (const char *) mask->data : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 112: + flash_attn_ext_f16<112, NQPB, NCPW> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? (const char *) mask->data : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 128: + flash_attn_ext_f16<128, NQPB, NCPW> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? (const char *) mask->data : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + case 256: + flash_attn_ext_f16<256, NQPB, NCPW> + <<>> ( + (const char *) Q->data, // Query + (const char *) K->data, // Key + (const char *) V->data, // Value + mask ? (const char *) mask->data : nullptr, // Mask + (float *) KQV->data, // dst + scale, + Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], + K->ne[0], K->ne[1], K->ne[2], K->ne[3], + mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0, + Q->nb[1], Q->nb[2], Q->nb[3], + K->nb[1], K->nb[2], K->nb[3], + KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3] + ); + break; + default: + break; + } + + CUDA_CHECK(cudaGetLastError()); +} diff --git a/ggml-cuda/fattn.cuh b/ggml-cuda/fattn.cuh new file mode 100644 index 000000000..ad3ca7a8d --- /dev/null +++ b/ggml-cuda/fattn.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml-metal.m b/ggml-metal.m index a08abbc29..a578fb06a 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -173,6 +173,12 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, + GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, GGML_METAL_KERNEL_TYPE_CPY_F32_F16, GGML_METAL_KERNEL_TYPE_CPY_F32_F32, GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, @@ -449,6 +455,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { id metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \ kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \ [metal_function release]; \ + GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \ + (int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \ + (int) kernel->pipeline.threadExecutionWidth); \ if (error) { \ GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \ [metal_library release]; \ @@ -604,6 +613,12 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) { GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true); @@ -734,6 +749,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: return true; case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: @@ -1277,6 +1293,8 @@ static enum ggml_status ggml_metal_graph_compute( } break; case GGML_OP_SOFT_MAX: { + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); + int nth = 32; // SIMD width id pipeline = nil; @@ -2465,6 +2483,106 @@ static enum ggml_status ggml_metal_graph_compute( [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; } break; + case GGML_OP_FLASH_ATTN_EXT: + { + GGML_ASSERT(ne00 % 4 == 0); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + + struct ggml_tensor * src3 = gf->nodes[i]->src[3]; + + GGML_ASSERT(ggml_are_same_shape(src1, src2)); + GGML_ASSERT(src3); + + size_t offs_src3 = 0; + + id id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil; + + GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16); + GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) && + "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + + const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30); + const int64_t ne31 = src3 ? src3->ne[1] : 0; + const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32); + const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33); + + const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30); + const uint64_t nb31 = src3 ? src3->nb[1] : 0; + const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32); + const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33); + + const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t); + + float scale; + memcpy(&scale, dst->op_params, sizeof(float)); + + id pipeline = nil; + + switch (ne00) { + case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break; + case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break; + case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break; + case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break; + case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break; + case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break; + default: + { + GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00); + GGML_METAL_LOG_ERROR("add template specialization for this size\n"); + GGML_ASSERT(false && "add template specialization for this size"); + } + } + + // TODO: extend if necessary + [encoder setComputePipelineState:pipeline]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; + [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5]; + [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6]; + [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7]; + [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8]; + [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9]; + [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10]; + [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11]; + [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12]; + [encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13]; + [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14]; + [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15]; + [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16]; + [encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18]; + [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19]; + [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20]; + [encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21]; + [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22]; + [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23]; + [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24]; + [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25]; + [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26]; + [encoder setBytes:&scale length:sizeof( float) atIndex:27]; + + const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! + const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! + + GGML_ASSERT(nqptg <= 32); + GGML_ASSERT(nqptg % 8 == 0); + GGML_ASSERT(ncpsg % 32 == 0); + + // simdgroups per threadgroup (a.k.a. warps) + // for small batches use more simdgroups (needs more tests, to confirm if it's worth it) + const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4; + + const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2); + + //printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength); + GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength); + [encoder setThreadgroupMemoryLength:smem atIndex:0]; + + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)]; + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: @@ -2668,10 +2786,13 @@ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backe UNUSED(buft); } -static void ggml_backend_metal_log_allocated_size(id device) { +static void ggml_backend_metal_log_allocated_size(id device, size_t size_aligned) { +#ifndef GGML_METAL_NDEBUG #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) if (@available(macOS 10.12, iOS 16.0, *)) { - GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)", + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)", + __func__, + size_aligned / 1024.0 / 1024.0, device.currentAllocatedSize / 1024.0 / 1024.0, device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); @@ -2681,10 +2802,15 @@ static void ggml_backend_metal_log_allocated_size(id device) { GGML_METAL_LOG_INFO("\n"); } } else { - GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0); + GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n", + __func__, + size_aligned / 1024.0 / 1024.0, + device.currentAllocatedSize / 1024.0 / 1024.0); } +#endif #endif UNUSED(device); + UNUSED(size_aligned); } GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { @@ -2718,8 +2844,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff return NULL; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); - ggml_backend_metal_log_allocated_size(device); + ggml_backend_metal_log_allocated_size(device, size_aligned); return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size); } @@ -2806,7 +2931,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0); + ggml_backend_metal_log_allocated_size(device, size_aligned); ++ctx->n_buffers; } else { @@ -2829,7 +2954,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, return false; } - GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i); + ggml_backend_metal_log_allocated_size(device, size_step_aligned); + if (i + size_step < size) { GGML_METAL_LOG_INFO("\n"); } @@ -2838,8 +2964,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, } } - ggml_backend_metal_log_allocated_size(device); - return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size); } diff --git a/ggml-metal.metal b/ggml-metal.metal index 744b2a8b4..dc6ad3141 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -318,10 +318,10 @@ kernel void kernel_sum_rows( } kernel void kernel_soft_max( - device const float * src0, - device const float * src1, - device const float * src2, - device float * dst, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -340,10 +340,10 @@ kernel void kernel_soft_max( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; - device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr; - device const float * ppos = src2 != src0 ? src2 : nullptr; - device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr; + device const half * ppos = src2 != src0 ? (device const half *) src2 : nullptr; + device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); float slope = 0.0f; @@ -422,10 +422,10 @@ kernel void kernel_soft_max( } kernel void kernel_soft_max_4( - device const float * src0, - device const float * src1, - device const float * src2, - device float * dst, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, constant int64_t & ne00, constant int64_t & ne01, constant int64_t & ne02, @@ -444,10 +444,10 @@ kernel void kernel_soft_max_4( const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01; const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01); - device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); - device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr; - device const float4 * ppos = src2 != src0 ? (device const float4 *)(src2) : nullptr; - device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00); + device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; + device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr; + device const half4 * ppos = src2 != src0 ? (device const half4 *) src2 : nullptr; + device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4; float slope = 0.0f; @@ -464,7 +464,7 @@ kernel void kernel_soft_max_4( float4 lmax4 = -INFINITY; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)); + lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))); } const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3])); @@ -490,7 +490,7 @@ kernel void kernel_soft_max_4( // parallel sum float4 lsum4 = 0.0f; for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f)) - max_val); + const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4)((pmask ? pmask[i00] : 0.0f) + (ppos ? slope*ppos[i00] : 0.0f))) - max_val); lsum4 += exp_psrc4; pdst4[i00] = exp_psrc4; } @@ -2031,6 +2031,411 @@ kernel void kernel_leaky_relu_f32( dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope; } +typedef void (flash_attn_ext_f16_t)( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]); + +// ref: https://arxiv.org/pdf/2307.08691.pdf +template // head size, queries per threadgroup, cache items per threadgroup +kernel void kernel_flash_attn_ext_f16( + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01, + constant int64_t & ne02, + constant int64_t & ne03, + constant uint64_t & nb00, + constant uint64_t & nb01, + constant uint64_t & nb02, + constant uint64_t & nb03, + constant int64_t & ne10, + constant int64_t & ne11, + constant int64_t & ne12, + constant int64_t & ne13, + constant uint64_t & nb10, + constant uint64_t & nb11, + constant uint64_t & nb12, + constant uint64_t & nb13, + constant int64_t & ne31, + constant uint64_t & nb31, + constant int64_t & ne0, + constant int64_t & ne1, + constant int64_t & ne2, + constant int64_t & ne3, + constant float & scale, + threadgroup half * shared [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + const uint nsg = ntg.y; // number of simdgroups + + const int64_t iq3 = tgpig[2]; + const int64_t iq2 = tgpig[1]; + const int64_t iq1 = tgpig[0]*Q; + + const int64_t D4 = D/4; + const int64_t D8 = D/8; + const int64_t Q8 = Q/8; + const int64_t NW = N_SIMDWIDTH; + const int64_t SH = (C + Q); // shared memory per simdgroup in (half) + + const int64_t T = D + nsg*SH; // shared memory size per query in (half) + const int64_t T4 = T/4; // shared memory size per query in (half4) + + threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data + threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4 + threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix + + // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) + simdgroup_half8x8 lo[Q8][D8]; + + // load heads from Q to shared memory + for (int64_t j = sgitg; j < Q; j += nsg) { + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + + for (int64_t i = tiisg; i < D4; i += NW) { + if (iq1 + j < ne01) { + sq4[j*T4 + i] = (half4) q4[i]; + } else { + sq4[j*T4 + i] = 0.0h; + } + } + } + + // zero out lo + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + lo[j][i] = make_filled_simdgroup_matrix(0.0h); + } + } + + // zero out shared memory SH + for (int64_t j = 0; j < Q; ++j) { + for (int64_t i = tiisg; i < SH; i += NW) { + ss[j*T + i] = 0.0h; + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + { + half S[Q] = { [0 ... Q-1] = 0.0h }; + half M[Q] = { [0 ... Q-1] = -INFINITY }; + + // assume K and V are same shape + const int64_t ne22 = ne12; + const int64_t ne23 = ne13; + + const uint64_t nb21 = nb11; + const uint64_t nb22 = nb12; + const uint64_t nb23 = nb13; + + // broadcast + const int64_t rk2 = ne02/ne12; + const int64_t rk3 = ne03/ne13; + + const int64_t rv2 = ne02/ne22; + const int64_t rv3 = ne03/ne23; + + // k indices + const int64_t ik2 = iq2 / rk2; + const int64_t ik3 = iq3 / rk3; + + // v indices + const int64_t iv2 = iq2 / rv2; + const int64_t iv3 = iq3 / rv3; + + // load the queries from shared memory into local memory + simdgroup_half8x8 mq[Q8][D8]; + + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T); + } + } + + // pointer to the mask + device const half * mp = (device const half *) (mask + iq1*nb31); + + // prepare diagonal scale matrix + simdgroup_half8x8 mscale(scale); + + // loop over the KV cache + // each simdgroup handles blocks of Q rows and C columns + for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) { + // Q*K^T + { + for (int cc = 0; cc < C/8; ++cc) { + simdgroup_half8x8 mqk[Q8]; + for (int64_t j = 0; j < Q8; ++j) { + mqk[j] = make_filled_simdgroup_matrix(0.h); + } + + device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13)); + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose + + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]); + } + } + + // mqk = mqk*scale + mask + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_half8x8 mm; + simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false); + simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm); + + simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false); + } + } + } + + // used to detect blocks full of -INF + half smax = -INFINITY; + + // online softmax + if (C == 32) { + half ms[Q]; + + for (int64_t j = 0; j < Q; ++j) { + const int64_t p = tiisg; + + const half m = M[j]; + const half s = ss[j*T + p]; + + smax = simd_max(max(smax, s)); + M[j] = simd_max(max(M[j], s)); + + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + + S[j] = S[j]*ms[j] + simd_sum(vs); + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; + } + } else { + half ms[Q]; + + for (int64_t j = 0; j < Q; ++j) { + const half m = M[j]; + + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + smax = max(smax, s); + M[j] = max(M[j], s); + } + + smax = simd_max(smax); + M[j] = simd_max(M[j]); + + ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]); + + // local sum + half ls = 0.0h; + + for (int64_t p = tiisg; p < C; p += NW) { + const half s = ss[j*T + p]; + + const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]); + + ls += vs; + + // the P matrix from the paper (Q rows, C columns) + ss[j*T + p] = vs; + } + + S[j] = S[j]*ms[j] + simd_sum(ls); + } + + // create a QxQ diagonal matrix for rescaling the output + if (tiisg < Q) { + ss[tiisg*T + C + tiisg] = ms[tiisg]; + } + } + + // skip -INF blocks + if (smax == -INFINITY) { + continue; + } + + // O = diag(ms)*O + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_half8x8 mm; + simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false); + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_multiply(lo[j][i], mm, lo[j][i]); + } + } + + // O = O + (Q*K^T)*V + { + for (int cc = 0; cc < C/8; ++cc) { + device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23)); + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_half8x8 mk; + simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false); + + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_half8x8 mv; + simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false); + + simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]); + } + } + } + } + } + + // these are needed for reducing the results from the simdgroups (reuse the ss buffer) + for (int64_t j = 0; j < Q; ++j) { + if (tiisg == 0) { + ss[j*T + 0] = S[j]; + ss[j*T + 1] = M[j]; + } + } + } + + // reduce the warps sequentially + for (int64_t sg = 1; sg < nsg; ++sg) { + half S = { 0.0h }; + half M = { -INFINITY }; + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // each simdgroup stores its output to shared memory, reusing sq + if (sgitg == sg) { + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + } + } + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + // the first simdgroup accumulates the results from the other simdgroups + if (sgitg == 0) { + for (int64_t j = 0; j < Q; ++j) { + const half S0 = ss[j*T + 0]; + const half S1 = ss[j*T + sg*SH + 0]; + + const half M0 = ss[j*T + 1]; + const half M1 = ss[j*T + sg*SH + 1]; + + M = max(M0, M1); + + const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M); + const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M); + + S = S0*ms0 + S1*ms1; + + if (tiisg == 0) { + ss[j*T + 0] = S; + ss[j*T + 1] = M; + + ss[j*T + C + j ] = ms0; + ss[j*T + C + j + sg*SH] = ms1; + } + } + + // O_0 = diag(ms0)*O_0 + diag(ms1)*O_1 + for (int64_t j = 0; j < Q8; ++j) { + simdgroup_half8x8 t; + simdgroup_half8x8 ms0; + simdgroup_half8x8 ms1; + + simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false); + simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false); + + for (int64_t i = 0; i < D8; ++i) { + simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false); + simdgroup_multiply(t, ms1, t); + + simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t); + } + } + } + } + + // store result to shared memory (reuse sq) + if (sgitg == 0) { + for (int64_t j = 0; j < Q8; ++j) { + for (int64_t i = 0; i < D8; ++i) { + simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false); + } + } + } + + device float4 * dst4 = (device float4 *) dst; + + // final rescale with 1/S and store to global memory + if (sgitg == 0) { + for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) { + const half S = ss[j*T + 0]; + + for (int64_t i = tiisg; i < D4; i += NW) { + dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S; + } + } + } +} + +template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>; +template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>; + kernel void kernel_cpy_f16_f16( device const half * src0, device half * dst, diff --git a/ggml.c b/ggml.c index eb469d0f7..da67e7d75 100644 --- a/ggml.c +++ b/ggml.c @@ -962,7 +962,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F16x8_ZERO #define GGML_F16_VEC_SET1 GGML_F16x8_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F16x8_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F16x8_STORE((__fp16 *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F16x8_FMA #define GGML_F16_VEC_ADD GGML_F16x8_ADD #define GGML_F16_VEC_MUL GGML_F16x8_MUL @@ -988,7 +988,7 @@ inline static float vaddvq_f32(float32x4_t v) { #define GGML_F16_VEC_ZERO GGML_F32Cx4_ZERO #define GGML_F16_VEC_SET1 GGML_F32Cx4_SET1 #define GGML_F16_VEC_LOAD(p, i) GGML_F32Cx4_LOAD(p) - #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE(p, r[i]) + #define GGML_F16_VEC_STORE(p, r, i) GGML_F32Cx4_STORE((__fp16 *)(p), r[i]) #define GGML_F16_VEC_FMA GGML_F32Cx4_FMA #define GGML_F16_VEC_ADD GGML_F32Cx4_ADD #define GGML_F16_VEC_MUL GGML_F32Cx4_MUL @@ -1155,7 +1155,7 @@ do { \ #if defined(__F16C__) // the _mm256_cvt intrinsics require F16C -#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x))) +#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x))) #define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0)) #else static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) { @@ -1673,6 +1673,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float #endif } +inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ax[GGML_F16_ARR]; + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v); + } +#endif +} + // xs and vs are byte strides of x and v inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) { @@ -1757,6 +1788,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { #endif } +inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { +#if defined(GGML_SIMD) + const int np = (n & ~(GGML_F16_STEP - 1)); + + GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); + + GGML_F16_VEC ay[GGML_F16_ARR]; + + for (int i = 0; i < np; i += GGML_F16_STEP) { + for (int j = 0; j < GGML_F16_ARR; j++) { + ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); + ay[j] = GGML_F16_VEC_MUL(ay[j], vx); + + GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); + } + } + + // leftovers + for (int i = np; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#else + // scalar + for (int i = 0; i < n; ++i) { + y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v); + } +#endif +} + inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, 0, x, 0, x, 0, 1); *s = sqrtf(*s); } inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; } inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); } @@ -2011,6 +2071,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "LEAKY_RELU", "FLASH_ATTN", + "FLASH_ATTN_EXT", "FLASH_FF", "FLASH_ATTN_BACK", "SSM_CONV", @@ -2037,7 +2098,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { "CROSS_ENTROPY_LOSS_BACK", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "none", @@ -2101,6 +2162,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "leaky_relu(x)", "flash_attn(x)", + "flash_attn_ext(x)", "flash_ff(x)", "flash_attn_back(x)", "ssm_conv(x)", @@ -2127,7 +2189,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { "cross_entropy_loss_back(x,y)", }; -static_assert(GGML_OP_COUNT == 76, "GGML_OP_COUNT != 76"); +static_assert(GGML_OP_COUNT == 77, "GGML_OP_COUNT != 77"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); @@ -4572,6 +4634,8 @@ struct ggml_tensor * ggml_mul_mat( void ggml_mul_mat_set_prec( struct ggml_tensor * a, enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_MUL_MAT); + const int32_t prec_i32 = (int32_t) prec; ggml_set_op_params_i32(a, 0, prec_i32); @@ -5408,14 +5472,15 @@ static struct ggml_tensor * ggml_soft_max_impl( GGML_ASSERT(ggml_is_contiguous(a)); if (mask) { + GGML_ASSERT(mask->type == GGML_TYPE_F16); GGML_ASSERT(ggml_is_contiguous(mask)); GGML_ASSERT(ggml_is_matrix(mask)); - GGML_ASSERT(ggml_can_repeat_rows(mask, a)); + GGML_ASSERT(mask->ne[1] >= a->ne[1]); } if (pos) { GGML_ASSERT(ggml_is_vector(pos)); - GGML_ASSERT(pos->type == GGML_TYPE_F32); + GGML_ASSERT(pos->type == GGML_TYPE_F16); GGML_ASSERT(pos->ne[0] == a->ne[0]); } @@ -6227,6 +6292,59 @@ struct ggml_tensor * ggml_flash_attn( return result; } +// ggml_flash_attn_ext + +struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale) { + GGML_ASSERT(ggml_can_mul_mat(k, q)); + // TODO: check if vT can be multiplied by (k*qT) + if (mask) { + GGML_ASSERT(ggml_is_contiguous(mask)); + GGML_ASSERT(mask->ne[2] == 1); + GGML_ASSERT(mask->ne[3] == 1); + GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) && + "the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big"); + //GGML_ASSERT(ggml_can_repeat_rows(mask, qk)); + } + + bool is_node = false; + + if (q->grad || k->grad || v->grad) { + is_node = true; + } + + // permute(0, 2, 1, 3) + int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; + struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne); + + float params[] = { scale }; + ggml_set_op_params(result, params, sizeof(params)); + + result->op = GGML_OP_FLASH_ATTN_EXT; + result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->src[0] = q; + result->src[1] = k; + result->src[2] = v; + result->src[3] = mask; + + return result; +} + +void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec) { + GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT); + + const int32_t prec_i32 = (int32_t) prec; + + ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos +} + // ggml_flash_ff struct ggml_tensor * ggml_flash_ff( @@ -12235,7 +12353,7 @@ static void ggml_compute_forward_soft_max_f32( GGML_TENSOR_UNARY_OP_LOCALS - const int64_t ne11 = src1 ? src1->ne[1] : 1; + //const int64_t ne11 = src1 ? src1->ne[1] : 1; // TODO: is this supposed to be ceil instead of floor? // https://huggingface.co/mosaicml/mpt-7b/blob/main/attention.py#L370 @@ -12265,12 +12383,14 @@ static void ggml_compute_forward_soft_max_f32( float * dp = (float *)((char *) dst->data + i1*dst->nb[1]); // broadcast the mask across rows - float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL; + ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data) + (i1%ne01)*ne00 : NULL; ggml_vec_cpy_f32 (nc, wp, sp); ggml_vec_scale_f32(nc, wp, scale); if (mp) { - ggml_vec_acc_f32(nc, wp, mp); + for (int i = 0; i < nc; ++i) { + wp[i] += GGML_FP16_TO_FP32(mp[i]); + } } // ALiBi bias @@ -14549,6 +14669,197 @@ static void ggml_compute_forward_flash_attn( } } +// ggml_compute_forward_flash_attn_ext + +static void ggml_compute_forward_flash_attn_ext_f16( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + int64_t t0 = ggml_perf_time_us(); + UNUSED(t0); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t D = neq0; + const int64_t N = neq1; + + GGML_ASSERT(ne0 == D); + GGML_ASSERT(ne2 == N); + + GGML_ASSERT(nbq0 == sizeof(float)); + GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t)); + + GGML_ASSERT(neq0 == D); + GGML_ASSERT(nek0 == D); + GGML_ASSERT(nev0 == D); + + GGML_ASSERT(neq1 == N); + GGML_ASSERT(nev0 == D); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + // broadcast factors + const int64_t rk2 = neq2/nek2; + const int64_t rk3 = neq3/nek3; + + const int64_t rv2 = neq2/nev2; + const int64_t rv3 = neq3/nev3; + + if (params->type == GGML_TASK_TYPE_INIT) { + return; + } + + if (params->type == GGML_TASK_TYPE_FINALIZE) { + return; + } + + // parallelize by q rows using ggml_vec_dot_f32 + + // total rows in q + const int nr = neq1*neq2*neq3; + + // rows per thread + const int dr = (nr + nth - 1)/nth; + + // row range for this thread + const int ir0 = dr*ith; + const int ir1 = MIN(ir0 + dr, nr); + + float scale = 1.0f; + memcpy(&scale, (float *) dst->op_params + 0, sizeof(float)); + + // loop over n_batch and n_head + for (int ir = ir0; ir < ir1; ++ir) { + // q indices + const int iq3 = ir/(neq2*neq1); + const int iq2 = (ir - iq3*neq2*neq1)/neq1; + const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1); + + float S = 0.0f; + float M = -INFINITY; + + float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32); + ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory + ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D); + + memset(V16, 0, D*sizeof(ggml_fp16_t)); + + const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL; + + // k indices + const int ik3 = iq3 / rk3; + const int ik2 = iq2 / rk2; + + // v indices + const int iv3 = iq3 / rv3; + const int iv2 = iq2 / rv2; + + // online softmax / attention + // loop over n_kv and n_head_kv + // ref: https://arxiv.org/pdf/2112.05682.pdf + for (int64_t ic = 0; ic < nek1; ++ic) { + const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f; + if (mv == -INFINITY) { + continue; + } + + float s; + + // convert Q to F16 in V32 + { + const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3)); + + for (int64_t d = 0; d < D; ++d) { + Q16[d] = GGML_FP32_TO_FP16(pq[d]); + } + } + + ggml_vec_dot_f16(D, + &s, 0, + (ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)), 0, + Q16, 0, 1); + + s = s*scale + mv; + + const float Mold = M; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + M = s; + ms = expf(Mold - M); + + // V = V*expf(Mold - M) + ggml_vec_scale_f16(D, V16, ms); + } else { + vs = expf(s - M); + } + + const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3)); + + // V += v*expf(s - M) + ggml_vec_mad_f16(D, V16, v16, vs); + + S = S*ms + vs; + } + + // V /= S + for (int64_t d = 0; d < D; ++d) { + V32[d] = GGML_FP16_TO_FP32(V16[d])/S; + } + + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // original + //memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); + + // permute(0, 2, 1, 3) + memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1); + } +} + +static void ggml_compute_forward_flash_attn_ext( + const struct ggml_compute_params * params, + const struct ggml_tensor * q, + const struct ggml_tensor * k, + const struct ggml_tensor * v, + const struct ggml_tensor * mask, + struct ggml_tensor * dst) { + switch (dst->op_params[1]) { + case GGML_PREC_DEFAULT: + { + ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst); + } break; + default: + { + // TODO: implement F32 precision + GGML_ASSERT(false); + } break; + } +} + // ggml_compute_forward_flash_ff static void ggml_compute_forward_flash_ff_f16( @@ -16370,6 +16681,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm const bool masked = t != 0; ggml_compute_forward_flash_attn(params, masked, tensor); } break; + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor); + } break; case GGML_OP_FLASH_FF: { ggml_compute_forward_flash_ff(params, tensor); @@ -17382,6 +17697,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor GGML_ASSERT(false); // TODO: not implemented } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { struct ggml_tensor * flash_grad = NULL; if (src0->grad || src1->grad || tensor->src[2]->grad) { @@ -18154,6 +18470,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads, int n_cur_ n_tasks = n_threads; } break; case GGML_OP_FLASH_ATTN: + case GGML_OP_FLASH_ATTN_EXT: { n_tasks = n_threads; } break; @@ -18557,6 +18874,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2 } } break; + case GGML_OP_FLASH_ATTN_EXT: + { + const int64_t ne00 = node->src[0]->ne[0]; // D + + cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size + } break; case GGML_OP_FLASH_FF: { if (node->src[1]->type == GGML_TYPE_F32) { diff --git a/ggml.h b/ggml.h index 5d4a4ceb6..fc32f9fa8 100644 --- a/ggml.h +++ b/ggml.h @@ -475,6 +475,7 @@ extern "C" { GGML_OP_LEAKY_RELU, GGML_OP_FLASH_ATTN, + GGML_OP_FLASH_ATTN_EXT, GGML_OP_FLASH_FF, GGML_OP_FLASH_ATTN_BACK, GGML_OP_SSM_CONV, @@ -1723,6 +1724,25 @@ extern "C" { struct ggml_tensor * v, bool masked); +#define GGML_KQ_MASK_PAD 32 + + // q: [n_embd, n_batch, n_head, 1] + // k: [n_embd, n_kv, n_head_kv, 1] + // v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !! + // mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! + // res: [n_embd, n_head, n_batch, 1] !! permuted !! + GGML_API struct ggml_tensor * ggml_flash_attn_ext( + struct ggml_context * ctx, + struct ggml_tensor * q, + struct ggml_tensor * k, + struct ggml_tensor * v, + struct ggml_tensor * mask, + float scale); + + GGML_API void ggml_flash_attn_ext_set_prec( + struct ggml_tensor * a, + enum ggml_prec prec); + GGML_API struct ggml_tensor * ggml_flash_attn_back( struct ggml_context * ctx, struct ggml_tensor * q, diff --git a/llama.cpp b/llama.cpp index 892d46fbc..e1d12bbd5 100644 --- a/llama.cpp +++ b/llama.cpp @@ -107,6 +107,7 @@ #define LLAMA_MAX_NODES 8192 #define LLAMA_MAX_EXPERTS 8 +#define LLAMA_FLASH_ATTN // // logging @@ -5522,23 +5523,34 @@ static void llm_build_kv_store( GGML_ASSERT(kv.size == n_ctx); - // compute the transposed [n_tokens, n_embd] V matrix - struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); - //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed - cb(v_cur_t, "v_cur_t", il); - struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa, (ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head); cb(k_cache_view, "k_cache_view", il); + // important: storing RoPE-ed version of K in the KV cache! + ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); + +#if defined(LLAMA_FLASH_ATTN) + // NOTE: the V cache is not transposed when using FLASH attention !! + struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa, + (ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head); + cb(v_cache_view, "v_cache_view", il); + + ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view)); + + GGML_UNUSED(n_ctx); +#else + // compute the transposed [n_tokens, n_embd] V matrix + //struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens)); + struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed + cb(v_cur_t, "v_cur_t", il); + struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa, ( n_ctx)*ggml_element_size(kv.v_l[il]), (kv_head)*ggml_element_size(kv.v_l[il])); - cb(v_cache_view, "v_cache_view", il); - // important: storing RoPE-ed version of K in the KV cache! - ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view)); ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view)); +#endif } static struct ggml_tensor * llm_build_norm( @@ -5699,6 +5711,33 @@ static struct ggml_tensor * llm_build_kqv( 0); cb(k, "k", il); + struct ggml_tensor * cur; + +#if defined(LLAMA_FLASH_ATTN) + GGML_UNUSED(model); + GGML_UNUSED(n_ctx); + + GGML_ASSERT(kq_pos == nullptr && "ALiBi is not yet supported with Flash Attention"); + + // split cached v into n_head heads (not transposed) + struct ggml_tensor * v = + ggml_view_3d(ctx, kv.v_l[il], + n_embd_head_v, n_kv, n_head_kv, + ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa), + ggml_row_size(kv.v_l[il]->type, n_embd_head_k), + 0); + cb(v, "v", il); + + cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale); + ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT); + //printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]); + //printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]); + //printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]); + //printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]); + //printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]); + + cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens); +#else struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q); cb(kq, "kq", il); @@ -5762,8 +5801,9 @@ static struct ggml_tensor * llm_build_kqv( struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3); cb(kqv_merged, "kqv_merged", il); - struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); + cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens); cb(cur, "kqv_merged_cont", il); +#endif ggml_build_forward_expand(graph, cur); @@ -6051,20 +6091,20 @@ struct llm_build_context { struct ggml_tensor * build_inp_KQ_mask(bool causal = true) { if (causal) { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, n_tokens); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } else { - lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, n_tokens); + lctx.inp_KQ_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_tokens, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD)); } cb(lctx.inp_KQ_mask, "KQ_mask", -1); ggml_set_input(lctx.inp_KQ_mask); - return lctx.inp_KQ_mask; + return ggml_cast(ctx0, lctx.inp_KQ_mask, GGML_TYPE_F16); } struct ggml_tensor * build_inp_KQ_pos() { lctx.inp_KQ_pos = ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, n_kv); cb(lctx.inp_KQ_pos, "KQ_pos", -1); ggml_set_input(lctx.inp_KQ_pos); - return lctx.inp_KQ_pos; + return ggml_cast(ctx0, lctx.inp_KQ_pos, GGML_TYPE_F16); } struct ggml_tensor * build_inp_mean() { @@ -9932,7 +9972,7 @@ static int llama_decode_internal( // a heuristic, to avoid attending the full cache if it is not yet utilized // after enough generations, the benefit from this heuristic disappears // if we start defragmenting the cache, the benefit from this will be more important - kv_self.n = std::min(kv_self.size, std::max(32u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32))); + kv_self.n = std::min(kv_self.size, std::max(128u, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128))); //kv_self.n = llama_kv_cache_cell_max(kv_self); } } @@ -13847,6 +13887,10 @@ struct llama_context * llama_new_context_with_model( const auto & hparams = model->hparams; auto & cparams = ctx->cparams; + // the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask + // this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext) + cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch); + cparams.n_seq_max = std::max(1u, params.n_seq_max); cparams.n_threads = params.n_threads; cparams.n_threads_batch = params.n_threads_batch; diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5dfea5662..89afe3b95 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -571,9 +571,19 @@ struct test_case { // duplicate the op size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1; +#if 0 for (int i = 1; i < n_runs; i++) { gf->nodes[gf->n_nodes++] = out; } +#else + int n_nodes = gf->n_nodes; + n_runs = 1000; + for (int i = 1; i < n_runs; i++) { + for (int j = 0; j < n_nodes; j++) { + gf->nodes[gf->n_nodes++] = gf->nodes[j]; + } + } +#endif // calculate memory size_t mem = n_runs * op_size(out); @@ -1103,11 +1113,11 @@ struct test_soft_max : public test_case { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_tensor * mask = nullptr; if (this->mask) { - mask = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); + mask = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]); } ggml_tensor * pos = nullptr; if (max_bias > 0.0f) { - pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, ne[0]); + pos = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, ne[0]); } ggml_tensor * out = ggml_soft_max_ext(ctx, a, mask, pos, scale, max_bias); return out; @@ -1477,6 +1487,76 @@ struct test_leaky_relu : public test_case { } }; +// GGML_OP_FLASH_ATTN_EXT +struct test_flash_attn_ext : public test_case { + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nb; // batch size + + std::string vars() override { + return VARS_TO_STR4(hs, nh, kv, nb); + } + + double max_nmse_err() override { + return 5e-4; + } + + test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1); + ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs)); + return out; + } +}; + +// Attention +struct test_attn : public test_case { + const int64_t hs; // head size + const int64_t nh; // num heads + const int64_t kv; // kv size + const int64_t nb; // batch size + + std::string op_desc(ggml_tensor * t) override { + return "ATTN"; + + GGML_UNUSED(t); + } + + std::string vars() override { + return VARS_TO_STR4(hs, nh, kv, nb); + } + + double max_nmse_err() override { + return 5e-4; + } + + test_attn(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8) + : hs(hs), nh(nh), kv(kv), nb(nb) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1); + ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1); + ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed + ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, 1); + + struct ggml_tensor * cur; + + cur = ggml_mul_mat (ctx, k, q); + cur = ggml_soft_max_ext(ctx, cur, mask, nullptr, 1.0f/sqrtf(hs), 0.0f); + cur = ggml_mul_mat (ctx, v, cur); + cur = ggml_permute (ctx, cur, 0, 2, 1, 3); + cur = ggml_cont_2d (ctx, cur, hs*nh, nb); + + return cur; + } +}; + // Mixtral MOE struct test_moe : public test_case { const int n_experts; @@ -1749,7 +1829,7 @@ struct test_llama : public test_llm { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); @@ -1871,7 +1951,7 @@ struct test_falcon : public test_llm { struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens); // KQ_mask (mask for 1 head, it will be broadcasted to all heads) - struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1); + struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1); ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400); @@ -2180,6 +2260,30 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op test_cases.emplace_back(new test_timestep_embedding()); test_cases.emplace_back(new test_leaky_relu()); +#if 1 + for (int hs : { 128, 256, 64, 80, }) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, 2048, 4096, }) { + for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) { + test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + } + } + } + } +#else + for (int hs : { 128, }) { + for (int nh : { 32, }) { + for (int kv : { 512, 1024, }) { + for (int nb : { 1, 2, 4, 8, 512 }) { + test_cases.emplace_back(new test_attn (hs, nh, kv, nb)); + test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb)); + } + } + } + } +#endif + // these tests are disabled to save execution time, but they can be handy for debugging #if 0 #if !defined(__SANITIZE_THREAD__)