Compare commits
60 commits
master
...
gg/flash-a
Author | SHA1 | Date | |
---|---|---|---|
|
49a483e0f2 | ||
|
1846e92a90 | ||
|
ef68fac2a8 | ||
|
cfd9732b2e | ||
|
e04ff39181 | ||
|
5b263dd83a | ||
|
3b1c4e7673 | ||
|
a7b471569b | ||
|
b958151e3f | ||
|
c51f27c0db | ||
|
92472ea22c | ||
|
1f8a592482 | ||
|
7c34655b36 | ||
|
b150abe83e | ||
|
b68a112204 | ||
|
12eaa22628 | ||
|
db1f3c482e | ||
|
c6769b9422 | ||
|
cda5a60a41 | ||
|
56e45a239e | ||
|
41d136b602 | ||
|
5a19a9f6d0 | ||
|
2e46013749 | ||
|
910b15bb40 | ||
|
8ad92dc1ec | ||
|
2ddc9bbef1 | ||
|
3d03bcb7af | ||
|
78df5527e4 | ||
|
d073e4f933 | ||
|
5fcb9c1c5a | ||
|
c6c1132e5e | ||
|
abeaf0d90e | ||
|
4794821a31 | ||
|
1db22d7032 | ||
|
134c81c78d | ||
|
0ad44baf33 | ||
|
8612864108 | ||
|
3a428a1097 | ||
|
ecc466a460 | ||
|
77f6976a87 | ||
|
b3dd7d975f | ||
|
6fea843b24 | ||
|
f9ca5dcbe8 | ||
|
40ea8cd1ac | ||
|
432ad04ffa | ||
|
d917746ddb | ||
|
1446a12b29 | ||
|
17720fad66 | ||
|
77d08f3272 | ||
|
a4b6341c7b | ||
|
f31955f5d1 | ||
|
8cde449b8b | ||
|
b97325800a | ||
|
52ae085750 | ||
|
528da7515e | ||
|
1173f49c3b | ||
|
a9681febd6 | ||
|
c3cdfffa88 | ||
|
fa7ebcca99 | ||
|
a1c004ef2e |
8 changed files with 2006 additions and 200 deletions
|
@ -104,7 +104,7 @@ int main(int argc, char ** argv) {
|
||||||
|
|
||||||
ctx_params.seed = 1234;
|
ctx_params.seed = 1234;
|
||||||
ctx_params.n_ctx = n_kv_max;
|
ctx_params.n_ctx = n_kv_max;
|
||||||
ctx_params.n_batch = 512;
|
ctx_params.n_batch = 2048;
|
||||||
ctx_params.mul_mat_q = mmq;
|
ctx_params.mul_mat_q = mmq;
|
||||||
|
|
||||||
ctx_params.n_threads = params.n_threads;
|
ctx_params.n_threads = params.n_threads;
|
||||||
|
|
807
ggml-cuda.cu
807
ggml-cuda.cu
|
@ -108,6 +108,7 @@
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cublas_v2.h>
|
#include <cublas_v2.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
#include <mma.h>
|
||||||
|
|
||||||
#if CUDART_VERSION < 11020
|
#if CUDART_VERSION < 11020
|
||||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||||
|
@ -655,6 +656,19 @@ static __device__ __forceinline__ half2 warp_reduce_sum(half2 a) {
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half warp_reduce_sum(half x) {
|
||||||
|
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x = __hadd(__shfl_xor_sync(0xffffffff, x, mask, 32), x);
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
#else
|
||||||
|
(void) x;
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
static __device__ __forceinline__ float warp_reduce_max(float x) {
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int mask = 16; mask > 0; mask >>= 1) {
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
@ -676,6 +690,18 @@ static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ __forceinline__ half warp_reduce_max(half x) {
|
||||||
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
|
#pragma unroll
|
||||||
|
for (int mask = 16; mask > 0; mask >>= 1) {
|
||||||
|
x = __hmax(x, __shfl_xor_sync(0xffffffff, x, mask, 32));
|
||||||
|
}
|
||||||
|
return x;
|
||||||
|
#else
|
||||||
|
(void) x;
|
||||||
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
static __device__ __forceinline__ float op_repeat(const float a, const float b) {
|
||||||
return b;
|
return b;
|
||||||
GGML_UNUSED(a);
|
GGML_UNUSED(a);
|
||||||
|
@ -989,6 +1015,7 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
|
||||||
if (lane_id == 0) {
|
if (lane_id == 0) {
|
||||||
s_sum[warp_id] = tmp;
|
s_sum[warp_id] = tmp;
|
||||||
}
|
}
|
||||||
|
|
||||||
__syncthreads();
|
__syncthreads();
|
||||||
tmp = s_sum[lane_id];
|
tmp = s_sum[lane_id];
|
||||||
tmp = warp_reduce_sum(tmp);
|
tmp = warp_reduce_sum(tmp);
|
||||||
|
@ -5917,7 +5944,7 @@ static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
|
template <bool vals_smem, int ncols_template, int block_size_template, bool need_check>
|
||||||
static __global__ void soft_max_f16(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
static __global__ void soft_max_f16(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_PASCAL && CUDART_VERSION >= CUDART_HMAX
|
||||||
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
const int ncols_data = ncols_template == 0 ? ncols_par : ncols_template;
|
||||||
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
|
const int ncols_smem = GGML_PAD(ncols_data, 2*WARP_SIZE)/2;
|
||||||
|
@ -5952,12 +5979,12 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
|
||||||
if (need_check && col_data + 0 >= ncols_data) {
|
if (need_check && col_data + 0 >= ncols_data) {
|
||||||
val.x = -INFINITY;
|
val.x = -INFINITY;
|
||||||
} else {
|
} else {
|
||||||
val.x = x[ix + 0]*scale + (y ? y[iy + 0] : 0.0f);
|
val.x = x[ix + 0]*scale + (y ? __half2float(y[iy + 0]) : 0.0f);
|
||||||
}
|
}
|
||||||
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
if (need_check && col_data + WARP_SIZE >= ncols_data) {
|
||||||
val.y = -INFINITY;
|
val.y = -INFINITY;
|
||||||
} else {
|
} else {
|
||||||
val.y = x[ix + WARP_SIZE]*scale + (y ? y[iy + WARP_SIZE] : 0.0f);
|
val.y = x[ix + WARP_SIZE]*scale + (y ? __half2float(y[iy + WARP_SIZE]) : 0.0f);
|
||||||
}
|
}
|
||||||
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
|
if (!need_check || col_smem < (vals_smem ? ncols_smem : ncols_data)) {
|
||||||
vals[col_smem] = val;
|
vals[col_smem] = val;
|
||||||
|
@ -6047,7 +6074,7 @@ static __global__ void soft_max_f16(const float * x, const float * y, float * ds
|
||||||
}
|
}
|
||||||
|
|
||||||
template <bool vals_smem, int ncols_template, int block_size_template>
|
template <bool vals_smem, int ncols_template, int block_size_template>
|
||||||
static __global__ void soft_max_f32(const float * x, const float * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
static __global__ void soft_max_f32(const float * x, const half * y, float * dst, const int ncols_par, const int nrows_y, const float scale) {
|
||||||
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
const int ncols = ncols_template == 0 ? ncols_par : ncols_template;
|
||||||
|
|
||||||
const int tid = threadIdx.x;
|
const int tid = threadIdx.x;
|
||||||
|
@ -6077,7 +6104,7 @@ static __global__ void soft_max_f32(const float * x, const float * y, float * ds
|
||||||
const int ix = rowx*ncols + col;
|
const int ix = rowx*ncols + col;
|
||||||
const int iy = rowy*ncols + col;
|
const int iy = rowy*ncols + col;
|
||||||
|
|
||||||
const float val = x[ix]*scale + (y ? y[iy] : 0.0f);
|
const float val = x[ix]*scale + (y ? __half2float(y[iy]) : 0.0f);
|
||||||
vals[col] = val;
|
vals[col] = val;
|
||||||
max_val = max(max_val, val);
|
max_val = max(max_val, val);
|
||||||
}
|
}
|
||||||
|
@ -6249,6 +6276,539 @@ static __global__ void pool2d_nchw_kernel(
|
||||||
o_ptr[cur_oh * ow + cur_ow] = res;
|
o_ptr[cur_oh * ow + cur_ow] = res;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#define CUDA_FLASH_ATTENTION_BLOCK_SIZE 256
|
||||||
|
|
||||||
|
template<int block_size, int k_seq_len, int k_head_dim>
|
||||||
|
static __global__ void flash_attn_f32(
|
||||||
|
const float* __restrict__ q,
|
||||||
|
const float* __restrict__ k,
|
||||||
|
const float* __restrict__ v,
|
||||||
|
float* __restrict__ kqv,
|
||||||
|
float kq_scale,
|
||||||
|
int head_dim, int seq_len, int num_heads) {
|
||||||
|
const int head = blockIdx.x / seq_len;
|
||||||
|
const int head_size = head_dim * seq_len;
|
||||||
|
const int s = blockIdx.x % seq_len;
|
||||||
|
|
||||||
|
extern __shared__ char flash_attn_shmem_f32[];
|
||||||
|
float* S = (float*)flash_attn_shmem_f32;
|
||||||
|
float* warp_data = (float*)(flash_attn_shmem_f32 + seq_len * sizeof(float));
|
||||||
|
|
||||||
|
// QK^T
|
||||||
|
#pragma unroll
|
||||||
|
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
|
||||||
|
const int is = threadIdx.x + is0;
|
||||||
|
if(is >= seq_len) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const int key_offset = is * head_dim + head * head_size;
|
||||||
|
const int query_offset = s * head_dim + head * head_size;
|
||||||
|
|
||||||
|
float tmp = 0.0f;
|
||||||
|
for(int d = 0; d < head_dim; d++) {
|
||||||
|
tmp += k[key_offset + d] * q[query_offset + d];
|
||||||
|
}
|
||||||
|
S[is] = tmp * kq_scale;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
float max_val = -INFINITY;
|
||||||
|
// get the max
|
||||||
|
#pragma unroll
|
||||||
|
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
|
||||||
|
const int is = threadIdx.x + is0;
|
||||||
|
if(is >= seq_len) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
max_val = fmaxf(max_val , S[is]);
|
||||||
|
}
|
||||||
|
|
||||||
|
max_val = warp_reduce_max(max_val);
|
||||||
|
|
||||||
|
{ // get max from all threads
|
||||||
|
int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
if (lane_id == 0) {
|
||||||
|
warp_data[warp_id] = max_val;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
max_val = warp_data[lane_id];
|
||||||
|
max_val = warp_reduce_max(max_val);
|
||||||
|
}
|
||||||
|
|
||||||
|
// softmax(QK^T)
|
||||||
|
float sum = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
|
||||||
|
const int is = threadIdx.x + is0;
|
||||||
|
if(is >= seq_len) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
float tmp = expf(S[is] - max_val);
|
||||||
|
sum += tmp;
|
||||||
|
S[is] = tmp;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
sum = warp_reduce_sum(sum);
|
||||||
|
{ // softmax sum partials
|
||||||
|
int warp_id = threadIdx.x / WARP_SIZE;
|
||||||
|
int lane_id = threadIdx.x % WARP_SIZE;
|
||||||
|
if (lane_id == 0) {
|
||||||
|
warp_data[warp_id] = sum;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
sum = warp_data[lane_id];
|
||||||
|
sum = warp_reduce_sum(sum);
|
||||||
|
}
|
||||||
|
|
||||||
|
float inv_sum = 1.0f / sum;
|
||||||
|
#pragma unroll
|
||||||
|
for(int is0 = 0; is0 < k_seq_len; is0 += block_size) {
|
||||||
|
const int is = threadIdx.x + is0;
|
||||||
|
if(is >= seq_len) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
S[is] *= inv_sum;
|
||||||
|
}
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// softmax(QK^T)V
|
||||||
|
#pragma unroll
|
||||||
|
for (int d0 = threadIdx.x; d0 < k_head_dim; d0 += block_size) {
|
||||||
|
const int d = threadIdx.x + d0;
|
||||||
|
if(d >= head_dim) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
const int dst_index = d + s * head_dim + head * head_size;
|
||||||
|
const int value_offset = d * seq_len + head * head_size;
|
||||||
|
|
||||||
|
float temp = 0.0f;
|
||||||
|
#pragma unroll
|
||||||
|
for(int ic = 0; ic < k_seq_len;ic++) {
|
||||||
|
if(ic >= seq_len) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
temp += v[value_offset + ic] * S[ic];
|
||||||
|
}
|
||||||
|
kqv[dst_index] = temp;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_a, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_a;
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::row_major> half16x16_b;
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::matrix_b, 16, 16, 16, half, nvcuda::wmma::col_major> half16x16_bT;
|
||||||
|
typedef nvcuda::wmma::fragment<nvcuda::wmma::accumulator, 16, 16, 16, half> half16x16_acc;
|
||||||
|
#endif
|
||||||
|
|
||||||
|
// based on metal version
|
||||||
|
template<int D, int Q, int C> // D head size, Q queries per block, C cache items per block
|
||||||
|
static __global__ void flash_attn_ext_f16(
|
||||||
|
const char* __restrict__ q,
|
||||||
|
const char* __restrict__ k,
|
||||||
|
const char* __restrict__ v,
|
||||||
|
const char* __restrict__ mask,
|
||||||
|
float* __restrict__ dst,
|
||||||
|
float scale,
|
||||||
|
int ne00,
|
||||||
|
int ne01,
|
||||||
|
int ne02,
|
||||||
|
int ne03,
|
||||||
|
int ne10,
|
||||||
|
int ne11,
|
||||||
|
int ne12,
|
||||||
|
int ne13,
|
||||||
|
int ne31,
|
||||||
|
int nb31,
|
||||||
|
int nb01,
|
||||||
|
int nb02,
|
||||||
|
int nb03,
|
||||||
|
int nb11,
|
||||||
|
int nb12,
|
||||||
|
int nb13,
|
||||||
|
int ne0,
|
||||||
|
int ne1,
|
||||||
|
int ne2,
|
||||||
|
int ne3) {
|
||||||
|
#if __CUDA_ARCH__ >= CC_VOLTA
|
||||||
|
const int warp_id = threadIdx.y;
|
||||||
|
const int lane_id = threadIdx.x;
|
||||||
|
|
||||||
|
const int num_warps = blockDim.y; // number of warps
|
||||||
|
const int iq3 = blockIdx.z;
|
||||||
|
const int iq2 = blockIdx.y;
|
||||||
|
const int iq1 = blockIdx.x * Q;
|
||||||
|
|
||||||
|
const int D16 = D/16;
|
||||||
|
const int Q16 = Q/16;
|
||||||
|
const int C16 = C/16;
|
||||||
|
|
||||||
|
const int NW = WARP_SIZE;
|
||||||
|
const int SH = (C + Q); // shared memory per simdgroup in (half)
|
||||||
|
|
||||||
|
const int T = D + num_warps*SH; // shared memory size per query in (half)
|
||||||
|
const int T2 = T/2; // shared memory size per query in (half2)
|
||||||
|
const int C2 = C/2;
|
||||||
|
const int D2 = D/2;
|
||||||
|
|
||||||
|
extern __shared__ half __flash_attn_f16_shmem[];
|
||||||
|
// pq
|
||||||
|
half * sq = (half *) (__flash_attn_f16_shmem + 0*D); // holds the query data
|
||||||
|
half2 * sq2 = (half2 *) (__flash_attn_f16_shmem + 0*D); // same as above but in half2
|
||||||
|
half * ss = (half *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||||
|
half2 * ss2 = (half2 *) (__flash_attn_f16_shmem + warp_id*SH + 1*D); // same as above but in half2
|
||||||
|
|
||||||
|
half16x16_acc zr;
|
||||||
|
half16x16_acc lo[Q16][D16];
|
||||||
|
|
||||||
|
// load heads from Q to shared memory
|
||||||
|
#pragma unroll
|
||||||
|
for (int j0 = 0; j0 < Q; j0 += num_warps) {
|
||||||
|
const int j = j0 + warp_id;
|
||||||
|
if (j >= Q) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
const float2 * q2 = (const float2 *) (q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||||
|
|
||||||
|
#pragma unroll
|
||||||
|
for (int i0 = 0; i0 < D2; i0 += NW) {
|
||||||
|
const int i = i0 + lane_id;
|
||||||
|
if (i >= D2) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (iq1 + j < ne01) {
|
||||||
|
sq2[j*T2 + i] = __float22half2_rn(q2[i]);
|
||||||
|
} else {
|
||||||
|
sq2[j*T2 + i] = make_half2(0.0, 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nvcuda::wmma::fill_fragment(zr, 0.0);
|
||||||
|
|
||||||
|
// zero out lo
|
||||||
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
for (int i = 0; i < D16; ++i) {
|
||||||
|
nvcuda::wmma::fill_fragment(lo[j][i], 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out shared memory SH
|
||||||
|
for (int j = 0; j < Q; ++j) {
|
||||||
|
for (int i0 = 0; i0 < SH; i0 += NW) {
|
||||||
|
const int i = i0 + lane_id;
|
||||||
|
if (i >= SH) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
ss[j*T + i] = 0.0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
{
|
||||||
|
half S = __float2half(0.0f);
|
||||||
|
half M[Q];
|
||||||
|
|
||||||
|
for (int i = 0; i < Q; ++i) {
|
||||||
|
M[i] = CUDART_MIN_DENORM_FP16;
|
||||||
|
}
|
||||||
|
|
||||||
|
// assume K and V are same shape
|
||||||
|
const int ne22 = ne12;
|
||||||
|
const int ne23 = ne13;
|
||||||
|
|
||||||
|
const int nb21 = nb11;
|
||||||
|
const int nb22 = nb12;
|
||||||
|
const int nb23 = nb13;
|
||||||
|
|
||||||
|
// broadcast
|
||||||
|
const int rk2 = ne02/ne12;
|
||||||
|
const int rk3 = ne03/ne13;
|
||||||
|
|
||||||
|
const int rv2 = ne02/ne22;
|
||||||
|
const int rv3 = ne03/ne23;
|
||||||
|
|
||||||
|
// k indices
|
||||||
|
const int ik2 = iq2 / rk2;
|
||||||
|
const int ik3 = iq3 / rk3;
|
||||||
|
|
||||||
|
// v indices
|
||||||
|
const int iv2 = iq2 / rv2;
|
||||||
|
const int iv3 = iq3 / rv3;
|
||||||
|
|
||||||
|
// load the queries from shared memory into local memory
|
||||||
|
half16x16_a mq[Q16][D16];
|
||||||
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
for (int i = 0; i < D16; ++i) {
|
||||||
|
nvcuda::wmma::load_matrix_sync(mq[j][i], sq + 16*j*T + i*16, T);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pointer to the mask
|
||||||
|
const half * mp = mask ? (const half *) (mask + iq1*nb31) : nullptr;
|
||||||
|
|
||||||
|
// prepare diagonal scale matrix
|
||||||
|
half16x16_b mscale;
|
||||||
|
for (int i = 0; i < 16; ++i) {
|
||||||
|
ss[i*T + i] = __float2half(scale);
|
||||||
|
}
|
||||||
|
nvcuda::wmma::load_matrix_sync(mscale, ss, T);
|
||||||
|
|
||||||
|
// loop over the KV cache
|
||||||
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
|
for (int ic0 = 0; ic0 < ne11; ic0 += C*num_warps) {
|
||||||
|
const int ic = ic0 + warp_id*16;
|
||||||
|
if (ic >= ne11) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Q*K^T
|
||||||
|
{
|
||||||
|
#pragma unroll
|
||||||
|
for (int cc = 0; cc < C16; ++cc) {
|
||||||
|
half16x16_acc mqk[Q16];
|
||||||
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
nvcuda::wmma::fill_fragment(mqk[j], 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
const half * pk = (const half *) ((const char *) k + ((ic + 16*num_warps*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
|
for (int i = 0; i < D16; ++i) {
|
||||||
|
half16x16_bT mk; // transposed key
|
||||||
|
nvcuda::wmma::load_matrix_sync(mk, pk + i*16, nb11/sizeof(half));
|
||||||
|
|
||||||
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
nvcuda::wmma::mma_sync(mqk[j], mq[j][i], mk, mqk[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mqk = mqk*scale + mask
|
||||||
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
half16x16_a mqka;
|
||||||
|
half16x16_acc mm;
|
||||||
|
|
||||||
|
if (mp) {
|
||||||
|
nvcuda::wmma::load_matrix_sync(mm, mp + 16*j*(nb31/sizeof(half)) + ic + 16*num_warps*cc, nb31/sizeof(half), nvcuda::wmma::mem_row_major);
|
||||||
|
}
|
||||||
|
|
||||||
|
// convert accumulator to matrix_a
|
||||||
|
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
||||||
|
nvcuda::wmma::load_matrix_sync (mqka, ss + 16*j*T + 16*cc, T);
|
||||||
|
|
||||||
|
nvcuda::wmma::mma_sync(mqk[j], mqka, mscale, mp ? mm : zr);
|
||||||
|
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + 16*cc, mqk[j], T, nvcuda::wmma::mem_row_major);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// used to detect blocks full of -INF
|
||||||
|
half2 smax = make_half2(-INFINITY, -INFINITY);
|
||||||
|
|
||||||
|
// online softmax
|
||||||
|
for (int j = 0; j < Q; ++j) {
|
||||||
|
const half m = M[j];
|
||||||
|
|
||||||
|
for (int p0 = 0; p0 < C2; p0 += NW) {
|
||||||
|
const int p = p0 + lane_id;
|
||||||
|
|
||||||
|
const half2 s = ss2[j*T2 + p];
|
||||||
|
|
||||||
|
smax = __hmax2(smax, s);
|
||||||
|
M[j] = __hmax(M[j], __hmax(s.x, s.y));
|
||||||
|
}
|
||||||
|
|
||||||
|
M[j] = warp_reduce_max(M[j]);
|
||||||
|
|
||||||
|
// local sum
|
||||||
|
half2 ls = make_half2(0.0f, 0.0f);
|
||||||
|
half2 M2 = make_half2(M[j], M[j]);
|
||||||
|
|
||||||
|
for (int p0 = 0; p0 < C2; p0 += NW) {
|
||||||
|
const int p = p0 + lane_id;
|
||||||
|
|
||||||
|
const half2 s = ss2[j*T2 + p];
|
||||||
|
|
||||||
|
const half2 vs = h2exp(s - M2);
|
||||||
|
|
||||||
|
ls += vs;
|
||||||
|
|
||||||
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
|
ss2[j*T2 + p] = vs;
|
||||||
|
}
|
||||||
|
|
||||||
|
ls = warp_reduce_sum(ls);
|
||||||
|
|
||||||
|
const half ms = hexp(m - M[j]);
|
||||||
|
|
||||||
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
|
if (lane_id == j) {
|
||||||
|
ss[j*T + C + j] = ms;
|
||||||
|
|
||||||
|
S = S*ms + ls.x + ls.y;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
smax = warp_reduce_max(smax);
|
||||||
|
|
||||||
|
// skip -INF blocks
|
||||||
|
if (__hisinf(smax.x) == -1 && __hisinf(smax.y) == -1) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// O = diag(ms)*O
|
||||||
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
half16x16_a mm;
|
||||||
|
half16x16_b lob;
|
||||||
|
|
||||||
|
nvcuda::wmma::load_matrix_sync(mm, ss + 16*j*T + C + 16*j, T);
|
||||||
|
|
||||||
|
for (int i = 0; i < D16; ++i) {
|
||||||
|
// convert accumulator to matrix_b
|
||||||
|
nvcuda::wmma::store_matrix_sync( ss + 16*j*T + C + 16*j, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||||
|
nvcuda::wmma::load_matrix_sync (lob, ss + 16*j*T + C + 16*j, T);
|
||||||
|
|
||||||
|
nvcuda::wmma::mma_sync(lo[j][i], mm, lob, zr);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// restore zeros
|
||||||
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
nvcuda::wmma::store_matrix_sync(ss + 16*j*T + C + 16*j, zr, T, nvcuda::wmma::mem_row_major);
|
||||||
|
}
|
||||||
|
|
||||||
|
// O = O + (Q*K^T)*V
|
||||||
|
{
|
||||||
|
half16x16_b mv[C16][D16];
|
||||||
|
for (int i = 0; i < D16; ++i) {
|
||||||
|
for (int cc = 0; cc < C16; ++cc) {
|
||||||
|
const half * pv = (const half *) ((const char *) v + ((ic + 16*num_warps*cc)*256 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
|
nvcuda::wmma::load_matrix_sync(mv[cc][i], pv + i*16, 256/sizeof(half));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int cc = 0; cc < C16; ++cc) {
|
||||||
|
half16x16_a ms[Q16];
|
||||||
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
nvcuda::wmma::load_matrix_sync(ms[j], ss + 16*j*T + 16*cc, T);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
#pragma unroll
|
||||||
|
for (int i = 0; i < D16; ++i) {
|
||||||
|
nvcuda::wmma::mma_sync(lo[j][i], ms[j], mv[cc][i], lo[j][i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||||
|
if (lane_id < Q) {
|
||||||
|
ss[lane_id*T + 0] = S;
|
||||||
|
ss[lane_id*T + 1] = M[lane_id];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce the warps sequentially
|
||||||
|
for (int sg = 1; sg < num_warps; ++sg) {
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// each simdgroup stores its output to shared memory, reusing sq
|
||||||
|
if (warp_id == sg) {
|
||||||
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
for (int i = 0; i < D16; ++i) {
|
||||||
|
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
__syncthreads();
|
||||||
|
|
||||||
|
// the first simdgroup accumulates the results from the other simdgroups
|
||||||
|
if (warp_id == 0) {
|
||||||
|
for (int j = lane_id; j < Q; j += NW) {
|
||||||
|
const half S0 = ss[j*T + 0];
|
||||||
|
const half S1 = ss[j*T + sg*SH + 0];
|
||||||
|
|
||||||
|
const half M0 = ss[j*T + 1];
|
||||||
|
const half M1 = ss[j*T + sg*SH + 1];
|
||||||
|
|
||||||
|
const half M = __hmax(M0, M1);
|
||||||
|
|
||||||
|
const half ms0 = hexp(M0 - M);
|
||||||
|
const half ms1 = hexp(M1 - M);
|
||||||
|
|
||||||
|
const half S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
|
ss[j*T + 0] = S;
|
||||||
|
ss[j*T + 1] = M;
|
||||||
|
|
||||||
|
ss[j*T + C + j ] = ms0;
|
||||||
|
ss[j*T + C + j + sg*SH] = ms1;
|
||||||
|
}
|
||||||
|
|
||||||
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||||
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
half16x16_a ms0;
|
||||||
|
half16x16_a ms1;
|
||||||
|
half16x16_b t;
|
||||||
|
half16x16_acc t2;
|
||||||
|
|
||||||
|
nvcuda::wmma::load_matrix_sync(ms0, ss + 16*j*T + C + 16*j, T);
|
||||||
|
nvcuda::wmma::load_matrix_sync(ms1, ss + 16*j*T + C + 16*j + sg*SH, T);
|
||||||
|
|
||||||
|
for (int i = 0; i < D16; ++i) {
|
||||||
|
nvcuda::wmma::load_matrix_sync(t, sq + 16*j*T + i*16, T);
|
||||||
|
nvcuda::wmma::mma_sync(t2, ms1, t, zr);
|
||||||
|
|
||||||
|
// convert accumulator to matrix_b
|
||||||
|
nvcuda::wmma::store_matrix_sync( sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||||
|
nvcuda::wmma::load_matrix_sync (t, sq + 16*j*T + i*16, T);
|
||||||
|
|
||||||
|
nvcuda::wmma::mma_sync(lo[j][i], ms0, t, t2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// store result to shared memory (reuse sq)
|
||||||
|
if (warp_id == 0) {
|
||||||
|
for (int j = 0; j < Q16; ++j) {
|
||||||
|
for (int i = 0; i < D16; ++i) {
|
||||||
|
nvcuda::wmma::store_matrix_sync(sq + 16*j*T + i*16, lo[j][i], T, nvcuda::wmma::mem_row_major);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// final rescale with 1/S and store to global memory
|
||||||
|
if (warp_id == 0) {
|
||||||
|
for (int j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||||
|
const half S = ss[j*T + 0];
|
||||||
|
|
||||||
|
for (int i0 = 0; i0 < D; i0 += NW) {
|
||||||
|
const int i = i0 + lane_id;
|
||||||
|
if (i >= D) {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
dst[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D + i] = __half2float(sq[j*T + i] / S);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
NO_DEVICE_CODE;
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
template<int qk, int qr, dequantize_kernel_t dq>
|
template<int qk, int qr, dequantize_kernel_t dq>
|
||||||
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
static void get_rows_cuda(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst,
|
||||||
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
const void * src0_dd, const int32_t * src1_dd, float * dst_dd, cudaStream_t stream) {
|
||||||
|
@ -7585,7 +8145,7 @@ static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols
|
||||||
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
diag_mask_inf_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols_x, rows_per_channel, n_past);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void soft_max_f16_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
static void soft_max_f16_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
while (nth < ncols_x/2 && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
const dim3 block_dims(nth, 1, 1);
|
const dim3 block_dims(nth, 1, 1);
|
||||||
|
@ -7628,7 +8188,7 @@ static void soft_max_f16_cuda(const float * x, const float * y, float * dst, con
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void soft_max_f32_cuda(const float * x, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
static void soft_max_f32_cuda(const float * x, const half * y, float * dst, const int ncols_x, const int nrows_x, const int nrows_y, const float scale, cudaStream_t stream) {
|
||||||
int nth = WARP_SIZE;
|
int nth = WARP_SIZE;
|
||||||
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
while (nth < ncols_x && nth < CUDA_SOFT_MAX_BLOCK_SIZE) nth *= 2;
|
||||||
const dim3 block_dims(nth, 1, 1);
|
const dim3 block_dims(nth, 1, 1);
|
||||||
|
@ -7682,6 +8242,13 @@ static void im2col_cuda(const float* x, T* dst,
|
||||||
im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
im2col_kernel<<<block_nums, CUDA_IM2COL_BLOCK_SIZE, 0, stream>>>(x, dst, batch_offset, offset_delta, IC, IW, IH, OH, OW, KW, KH, parallel_elements, (IC * KH * KW), s0, s1, p0, p1, d0, d1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void flash_attn_f32_cuda(const float* q, const float* k,const float* v, float* dst, float kq_scale, const int d_head, const int seq_len, const int num_heads, cudaStream_t stream) {
|
||||||
|
int sram_memory_size = seq_len*sizeof(float) + WARP_SIZE * sizeof(float);
|
||||||
|
int num_blocks = num_heads * seq_len;
|
||||||
|
flash_attn_f32<CUDA_FLASH_ATTENTION_BLOCK_SIZE, 1024, 64><<<num_blocks, CUDA_FLASH_ATTENTION_BLOCK_SIZE, sram_memory_size, stream>>>(
|
||||||
|
q, k, v, dst, kq_scale, d_head, seq_len, num_heads);
|
||||||
|
}
|
||||||
|
|
||||||
// buffer pool for cuda
|
// buffer pool for cuda
|
||||||
#define MAX_CUDA_BUFFERS 256
|
#define MAX_CUDA_BUFFERS 256
|
||||||
|
|
||||||
|
@ -9060,11 +9627,11 @@ static void ggml_cuda_op_soft_max(
|
||||||
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
GGML_ASSERT( dst->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32); // src1 contains mask and it is optional
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16); // src1 contains mask and it is optional
|
||||||
|
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t nrows_x = ggml_nrows(src0);
|
const int64_t nrows_x = ggml_nrows(src0);
|
||||||
const int64_t nrows_y = src1 ? ggml_nrows(src1) : 1;
|
const int64_t nrows_y = src1 ? src0->ne[1] : 1; // note: using number of queries since mask can be padded!
|
||||||
|
|
||||||
float scale = 1.0f;
|
float scale = 1.0f;
|
||||||
memcpy(&scale, dst->op_params, sizeof(float));
|
memcpy(&scale, dst->op_params, sizeof(float));
|
||||||
|
@ -9080,9 +9647,9 @@ static void ggml_cuda_op_soft_max(
|
||||||
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && CUDART_VERSION >= CUDART_HMAX
|
||||||
|
|
||||||
if (use_f16_soft_max) {
|
if (use_f16_soft_max) {
|
||||||
soft_max_f16_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
soft_max_f16_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||||
} else {
|
} else {
|
||||||
soft_max_f32_cuda(src0_dd, src1 ? src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
soft_max_f32_cuda(src0_dd, src1 ? (const half *) src1_dd : nullptr, dst_dd, ne00, nrows_x, nrows_y, scale, main_stream);
|
||||||
}
|
}
|
||||||
|
|
||||||
(void) dst;
|
(void) dst;
|
||||||
|
@ -10284,6 +10851,211 @@ static void ggml_cuda_mul_mat_id(const ggml_tensor * src0, const ggml_tensor * s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline void ggml_cuda_flash_attn(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, ggml_tensor * KQV) {
|
||||||
|
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(K->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(V->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
|
||||||
|
GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
|
||||||
|
GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
|
||||||
|
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
|
||||||
|
|
||||||
|
ggml_cuda_set_device(g_main_device);
|
||||||
|
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||||
|
|
||||||
|
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra;
|
||||||
|
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra;
|
||||||
|
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra;
|
||||||
|
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra;
|
||||||
|
|
||||||
|
const int64_t d_head = Q->ne[0];
|
||||||
|
const int64_t sequence_length = Q->ne[1];
|
||||||
|
const int64_t num_heads = Q->ne[2];
|
||||||
|
|
||||||
|
GGML_ASSERT(Q->ne[0] == d_head);
|
||||||
|
GGML_ASSERT(K->ne[0] == d_head);
|
||||||
|
GGML_ASSERT(V->ne[1] == d_head);
|
||||||
|
|
||||||
|
GGML_ASSERT(Q->ne[1] == sequence_length);
|
||||||
|
GGML_ASSERT(K->ne[1] == sequence_length);
|
||||||
|
GGML_ASSERT(V->ne[0] == sequence_length);
|
||||||
|
|
||||||
|
GGML_ASSERT(Q->ne[2] == num_heads);
|
||||||
|
GGML_ASSERT(K->ne[2] == num_heads);
|
||||||
|
GGML_ASSERT(V->ne[2] == num_heads);
|
||||||
|
|
||||||
|
float KQ_scale = 1.0f / sqrtf((float)d_head);
|
||||||
|
|
||||||
|
flash_attn_f32_cuda(
|
||||||
|
(float *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(float *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(float *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
KQ_scale, d_head, sequence_length, num_heads, main_stream);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
inline void ggml_cuda_flash_attn_ext(const ggml_tensor * Q, const ggml_tensor * K, const ggml_tensor * V, const ggml_tensor * mask, ggml_tensor * KQV) {
|
||||||
|
GGML_ASSERT(Q->type == GGML_TYPE_F32);
|
||||||
|
GGML_ASSERT(K->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(V->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(KQV->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
GGML_ASSERT(Q->backend == GGML_BACKEND_GPU);
|
||||||
|
GGML_ASSERT(K->backend == GGML_BACKEND_GPU);
|
||||||
|
GGML_ASSERT(V->backend == GGML_BACKEND_GPU);
|
||||||
|
GGML_ASSERT(KQV->backend == GGML_BACKEND_GPU);
|
||||||
|
|
||||||
|
GGML_ASSERT(!mask || mask->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(!mask || mask->backend == GGML_BACKEND_GPU);
|
||||||
|
GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) &&
|
||||||
|
"the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big");
|
||||||
|
|
||||||
|
ggml_cuda_set_device(g_main_device);
|
||||||
|
const cudaStream_t main_stream = g_cudaStreams[g_main_device][0];
|
||||||
|
|
||||||
|
ggml_tensor_extra_gpu * src0_extra = (ggml_tensor_extra_gpu *) Q->extra;
|
||||||
|
ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu *) K->extra;
|
||||||
|
ggml_tensor_extra_gpu * src2_extra = (ggml_tensor_extra_gpu *) V->extra;
|
||||||
|
ggml_tensor_extra_gpu * src3_extra = mask ? (ggml_tensor_extra_gpu *) mask->extra : nullptr;
|
||||||
|
ggml_tensor_extra_gpu * dst_extra = (ggml_tensor_extra_gpu *) KQV->extra;
|
||||||
|
|
||||||
|
float scale;
|
||||||
|
memcpy(&scale, KQV->op_params, sizeof(float));
|
||||||
|
|
||||||
|
#define NQPB 16
|
||||||
|
#define NCPW 128
|
||||||
|
|
||||||
|
const int nqpb = NQPB; // queries per block
|
||||||
|
const int ncpw = NCPW; // cache values per warp (does not work for other values)
|
||||||
|
|
||||||
|
GGML_ASSERT(NQPB <= 32);
|
||||||
|
|
||||||
|
const int nwarps_max = 8; // TODO: we don't want to launch too much warps. how much is too much?
|
||||||
|
// TODO: produces wrong results for nwarps > 8 (RTX 2060) - not sure why
|
||||||
|
const int nwarps = Q->ne[1] <= nqpb ? std::max(2, std::min((int) K->ne[1]/ncpw, nwarps_max)) : 1;
|
||||||
|
|
||||||
|
dim3 blocks_num((Q->ne[1] + nqpb - 1) / nqpb, Q->ne[2], Q->ne[3]);
|
||||||
|
dim3 block_dim(32, nwarps, 1);
|
||||||
|
|
||||||
|
const size_t shmem = nqpb*(Q->ne[0] + nwarps*(ncpw + nqpb))*(sizeof(float)/2);
|
||||||
|
|
||||||
|
// increase shared memory limit to 96KB
|
||||||
|
//const size_t shmem_max = 96*1024;
|
||||||
|
//cudaFuncSetAttribute(flash_attn_ext_f16<128, NQPB, NCPW>, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_max);
|
||||||
|
|
||||||
|
switch (Q->ne[0]) {
|
||||||
|
case 64:
|
||||||
|
flash_attn_ext_f16<64, NQPB, NCPW>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 80:
|
||||||
|
flash_attn_ext_f16<80, NQPB, NCPW>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 96:
|
||||||
|
flash_attn_ext_f16<96, NQPB, NCPW>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 112:
|
||||||
|
flash_attn_ext_f16<112, NQPB, NCPW>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 128:
|
||||||
|
flash_attn_ext_f16<128, NQPB, NCPW>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
case 256:
|
||||||
|
flash_attn_ext_f16<256, NQPB, NCPW>
|
||||||
|
<<<blocks_num, block_dim, shmem, main_stream>>> (
|
||||||
|
(const char *) src0_extra->data_device[g_main_device], // Query
|
||||||
|
(const char *) src1_extra->data_device[g_main_device], // Key
|
||||||
|
(const char *) src2_extra->data_device[g_main_device], // Value
|
||||||
|
mask ? ((const char *) src3_extra->data_device[g_main_device]) : nullptr, // Mask
|
||||||
|
(float *) dst_extra->data_device[g_main_device], // dst
|
||||||
|
scale,
|
||||||
|
Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3],
|
||||||
|
K->ne[0], K->ne[1], K->ne[2], K->ne[3],
|
||||||
|
mask ? mask->ne[1] : 0, mask ? mask->nb[1] : 0,
|
||||||
|
Q->nb[1], Q->nb[2], Q->nb[3],
|
||||||
|
K->nb[1], K->nb[2], K->nb[3],
|
||||||
|
KQV->ne[0], KQV->ne[1], KQV->ne[2], KQV->ne[3]
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
}
|
||||||
|
|
||||||
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
static void ggml_cuda_scale(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||||
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
|
ggml_cuda_op_flatten(src0, src1, dst, ggml_cuda_op_scale);
|
||||||
}
|
}
|
||||||
|
@ -10573,6 +11345,10 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
func = ggml_cuda_argsort;
|
func = ggml_cuda_argsort;
|
||||||
break;
|
break;
|
||||||
|
case GGML_OP_FLASH_ATTN:
|
||||||
|
break;
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
@ -10587,7 +11363,13 @@ GGML_CALL bool ggml_cuda_compute_forward(struct ggml_compute_params * params, st
|
||||||
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
func(tensor->src[0], tensor->src[1], tensor);
|
if(tensor->op == GGML_OP_FLASH_ATTN) {
|
||||||
|
ggml_cuda_flash_attn(tensor->src[0], tensor->src[1], tensor->src[2], tensor);
|
||||||
|
} else if(tensor->op == GGML_OP_FLASH_ATTN_EXT) {
|
||||||
|
ggml_cuda_flash_attn_ext(tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
||||||
|
} else {
|
||||||
|
func(tensor->src[0], tensor->src[1], tensor);
|
||||||
|
}
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -11403,6 +12185,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
|
||||||
case GGML_OP_UPSCALE:
|
case GGML_OP_UPSCALE:
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
return true;
|
return true;
|
||||||
default:
|
default:
|
||||||
return false;
|
return false;
|
||||||
|
|
395
ggml-metal.m
395
ggml-metal.m
|
@ -141,6 +141,12 @@ enum ggml_metal_kernel_type {
|
||||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC,
|
||||||
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC,
|
||||||
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128,
|
||||||
|
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F16,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_F32,
|
||||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||||
|
@ -390,6 +396,9 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
id<MTLFunction> metal_function = [metal_library newFunctionWithName:@"kernel_"#name]; \
|
||||||
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
kernel->pipeline = [ctx->device newComputePipelineStateWithFunction:metal_function error:&error]; \
|
||||||
[metal_function release]; \
|
[metal_function release]; \
|
||||||
|
GGML_METAL_LOG_INFO("%s: loaded %-32s %16p | th_max = %4d | th_width = %4d\n", __func__, "kernel_"#name, (void *) kernel->pipeline, \
|
||||||
|
(int) kernel->pipeline.maxTotalThreadsPerThreadgroup, \
|
||||||
|
(int) kernel->pipeline.threadExecutionWidth); \
|
||||||
if (error) { \
|
if (error) { \
|
||||||
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
GGML_METAL_LOG_ERROR("%s: error: load pipeline error: %s\n", __func__, [[error description] UTF8String]); \
|
||||||
[metal_library release]; \
|
[metal_library release]; \
|
||||||
|
@ -401,130 +410,136 @@ static struct ggml_metal_context * ggml_metal_init(int n_cb) {
|
||||||
|
|
||||||
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
// simd_sum and simd_max requires MTLGPUFamilyApple7
|
||||||
|
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD, add, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ADD_ROW, add_row, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL, mul, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_ROW, mul_row, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV, div, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIV_ROW, div_row, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE, scale, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SCALE_4, scale_4, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_TANH, tanh, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RELU, relu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU, gelu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GELU_QUICK, gelu_quick, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SILU, silu, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX, soft_max, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SOFT_MAX_4, soft_max_4, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF, diag_mask_inf, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF_8, diag_mask_inf_8, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F32, get_rows_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_F16, get_rows_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_0, get_rows_q4_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_1, get_rows_q4_1, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_0, get_rows_q5_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_1, get_rows_q5_1, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q8_0, get_rows_q8_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q2_K, get_rows_q2_K, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q3_K, get_rows_q3_K, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q4_K, get_rows_q4_K, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q5_K, get_rows_q5_K, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_Q6_K, get_rows_q6_K, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XXS, get_rows_iq2_xxs, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ2_XS, get_rows_iq2_xs, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_IQ3_XXS, get_rows_iq3_xxs, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GET_ROWS_I32, get_rows_i32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_RMS_NORM, rms_norm, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_GROUP_NORM, group_norm, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_NORM, norm, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F32_F32, mul_mv_f32_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F16, mul_mv_f16_f16, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32, mul_mv_f16_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_1ROW, mul_mv_f16_f32_1row, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_F16_F32_L4, mul_mv_f16_f32_l4, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_0_F32, mul_mv_q4_0_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_1_F32, mul_mv_q4_1_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_0_F32, mul_mv_q5_0_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_1_F32, mul_mv_q5_1_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q8_0_F32, mul_mv_q8_0_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q2_K_F32, mul_mv_q2_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q3_K_F32, mul_mv_q3_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q4_K_F32, mul_mv_q4_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q5_K_F32, mul_mv_q5_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_Q6_K_F32, mul_mv_q6_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XXS_F32, mul_mv_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ2_XS_F32, mul_mv_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_IQ3_XXS_F32, mul_mv_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F32_F32, mul_mv_id_f32_f32, ctx->support_simdgroup_reduction);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F16, mul_mv_id_f16_f16, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32, mul_mv_id_f16_f32, ctx->support_simdgroup_reduction);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_1ROW, mul_mv_id_f16_f32_1row, ctx->support_simdgroup_reduction);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_F16_F32_L4, mul_mv_id_f16_f32_l4, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_0_F32, mul_mv_id_q4_0_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_1_F32, mul_mv_id_q4_1_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_0_F32, mul_mv_id_q5_0_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_1_F32, mul_mv_id_q5_1_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q8_0_F32, mul_mv_id_q8_0_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q2_K_F32, mul_mv_id_q2_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q3_K_F32, mul_mv_id_q3_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q4_K_F32, mul_mv_id_q4_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q5_K_F32, mul_mv_id_q5_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_Q6_K_F32, mul_mv_id_q6_K_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XXS_F32, mul_mv_id_iq2_xxs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ2_XS_F32, mul_mv_id_iq2_xs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MV_ID_IQ3_XXS_F32, mul_mv_id_iq3_xxs_f32, ctx->support_simdgroup_reduction);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F32_F32, mul_mm_f32_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_F16_F32, mul_mm_f16_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_0_F32, mul_mm_q4_0_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_1_F32, mul_mm_q4_1_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_0_F32, mul_mm_q5_0_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_1_F32, mul_mm_q5_1_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q8_0_F32, mul_mm_q8_0_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q2_K_F32, mul_mm_q2_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q3_K_F32, mul_mm_q3_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q4_K_F32, mul_mm_q4_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q5_K_F32, mul_mm_q5_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_Q6_K_F32, mul_mm_q6_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XXS_F32, mul_mm_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ2_XS_F32, mul_mm_iq2_xs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ3_XXS_F32, mul_mm_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F32, mul_mm_id_f32_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F32, mul_mm_id_f16_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_0_F32, mul_mm_id_q4_0_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_1_F32, mul_mm_id_q4_1_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_0_F32, mul_mm_id_q5_0_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_1_F32, mul_mm_id_q5_1_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q8_0_F32, mul_mm_id_q8_0_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q2_K_F32, mul_mm_id_q2_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q3_K_F32, mul_mm_id_q3_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q4_K_F32, mul_mm_id_q4_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q5_K_F32, mul_mm_id_q5_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_Q6_K_F32, mul_mm_id_q6_K_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XXS_F32, mul_mm_id_iq2_xxs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ2_XS_F32, mul_mm_id_iq2_xs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_IQ3_XXS_F32, mul_mm_id_iq3_xxs_f32, ctx->support_simdgroup_mm);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F32, rope_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_F16, rope_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ALIBI_F32, alibi_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_UPSCALE_F32, upscale_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_PAD_F32, pad_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_ASC, argsort_f32_i32_asc, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ARGSORT_F32_I32_DESC, argsort_f32_i32_desc, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32, leaky_relu_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64, flash_attn_ext_f16_h64, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80, flash_attn_ext_f16_h80, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96, flash_attn_ext_f16_h96, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112, flash_attn_ext_f16_h112, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128, flash_attn_ext_f16_h128, true);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256, flash_attn_ext_f16_h256, true);
|
||||||
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F16, cpy_f32_f16, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_F32, cpy_f32_f32, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
||||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_0, cpy_f32_q5_0, true);
|
||||||
|
//GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q5_1, cpy_f32_q5_1, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F32, cpy_f16_f32, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CONCAT, concat, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SQR, sqr, true);
|
||||||
|
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_SUM_ROWS, sum_rows, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
[metal_library release];
|
[metal_library release];
|
||||||
|
@ -640,6 +655,7 @@ static bool ggml_metal_supports_op(const struct ggml_metal_context * ctx, const
|
||||||
case GGML_OP_PAD:
|
case GGML_OP_PAD:
|
||||||
case GGML_OP_ARGSORT:
|
case GGML_OP_ARGSORT:
|
||||||
case GGML_OP_LEAKY_RELU:
|
case GGML_OP_LEAKY_RELU:
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
return true;
|
return true;
|
||||||
case GGML_OP_MUL_MAT:
|
case GGML_OP_MUL_MAT:
|
||||||
case GGML_OP_MUL_MAT_ID:
|
case GGML_OP_MUL_MAT_ID:
|
||||||
|
@ -1171,6 +1187,8 @@ static bool ggml_metal_graph_compute(
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_SOFT_MAX:
|
case GGML_OP_SOFT_MAX:
|
||||||
{
|
{
|
||||||
|
GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F16);
|
||||||
|
|
||||||
int nth = 32; // SIMD width
|
int nth = 32; // SIMD width
|
||||||
|
|
||||||
id<MTLComputePipelineState> pipeline = nil;
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
@ -2178,6 +2196,111 @@ static bool ggml_metal_graph_compute(
|
||||||
|
|
||||||
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
{
|
||||||
|
GGML_ASSERT(ne00 % 4 == 0);
|
||||||
|
GGML_ASSERT(src0->type == GGML_TYPE_F32);
|
||||||
|
|
||||||
|
struct ggml_tensor * src2 = gf->nodes[i]->src[2];
|
||||||
|
struct ggml_tensor * src3 = gf->nodes[i]->src[3];
|
||||||
|
|
||||||
|
GGML_ASSERT(ggml_are_same_shape(src1, src2));
|
||||||
|
GGML_ASSERT(src3);
|
||||||
|
|
||||||
|
size_t offs_src2 = 0;
|
||||||
|
size_t offs_src3 = 0;
|
||||||
|
|
||||||
|
GGML_ASSERT(src2);
|
||||||
|
id<MTLBuffer> id_src2 = ggml_metal_get_buffer(src2, &offs_src2);
|
||||||
|
|
||||||
|
id<MTLBuffer> id_src3 = src3 ? ggml_metal_get_buffer(src3, &offs_src3) : nil;
|
||||||
|
|
||||||
|
GGML_ASSERT(!src3 || src3->type == GGML_TYPE_F16);
|
||||||
|
GGML_ASSERT(!src3 || src3->ne[1] >= GGML_PAD(src0->ne[1], 8) &&
|
||||||
|
"the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big");
|
||||||
|
|
||||||
|
const int64_t ne30 = src3 ? src3->ne[0] : 0; GGML_UNUSED(ne30);
|
||||||
|
const int64_t ne31 = src3 ? src3->ne[1] : 0;
|
||||||
|
const int64_t ne32 = src3 ? src3->ne[2] : 0; GGML_UNUSED(ne32);
|
||||||
|
const int64_t ne33 = src3 ? src3->ne[3] : 0; GGML_UNUSED(ne33);
|
||||||
|
|
||||||
|
const uint64_t nb30 = src3 ? src3->nb[0] : 0; GGML_UNUSED(nb30);
|
||||||
|
const uint64_t nb31 = src3 ? src3->nb[1] : 0;
|
||||||
|
const uint64_t nb32 = src3 ? src3->nb[2] : 0; GGML_UNUSED(nb32);
|
||||||
|
const uint64_t nb33 = src3 ? src3->nb[3] : 0; GGML_UNUSED(nb33);
|
||||||
|
|
||||||
|
const enum ggml_type src2t = src2 ? src2->type : GGML_TYPE_COUNT; GGML_UNUSED(src2t);
|
||||||
|
|
||||||
|
float scale;
|
||||||
|
memcpy(&scale, dst->op_params, sizeof(float));
|
||||||
|
|
||||||
|
id<MTLComputePipelineState> pipeline = nil;
|
||||||
|
|
||||||
|
switch (ne00) {
|
||||||
|
case 64: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H64 ].pipeline; break;
|
||||||
|
case 80: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H80 ].pipeline; break;
|
||||||
|
case 96: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H96 ].pipeline; break;
|
||||||
|
case 112: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H112].pipeline; break;
|
||||||
|
case 128: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H128].pipeline; break;
|
||||||
|
case 256: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_F16_H256].pipeline; break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
GGML_METAL_LOG_ERROR("unsupported size: %lld\n", ne00);
|
||||||
|
GGML_METAL_LOG_ERROR("add template specialization for this size\n");
|
||||||
|
GGML_ASSERT(false && "add template specialization for this size");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: extend if necessary
|
||||||
|
[encoder setComputePipelineState:pipeline];
|
||||||
|
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
|
||||||
|
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||||
|
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||||
|
[encoder setBuffer:id_src3 offset:offs_src3 atIndex:3];
|
||||||
|
[encoder setBuffer:id_dst offset:offs_dst atIndex:4];
|
||||||
|
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:5];
|
||||||
|
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:6];
|
||||||
|
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:7];
|
||||||
|
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:8];
|
||||||
|
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:9];
|
||||||
|
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:10];
|
||||||
|
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:11];
|
||||||
|
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:12];
|
||||||
|
[encoder setBytes:&ne10 length:sizeof( int64_t) atIndex:13];
|
||||||
|
[encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:14];
|
||||||
|
[encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:15];
|
||||||
|
[encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:16];
|
||||||
|
[encoder setBytes:&nb10 length:sizeof(uint64_t) atIndex:17];
|
||||||
|
[encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:18];
|
||||||
|
[encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:19];
|
||||||
|
[encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:20];
|
||||||
|
[encoder setBytes:&ne31 length:sizeof( int64_t) atIndex:21];
|
||||||
|
[encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:22];
|
||||||
|
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:23];
|
||||||
|
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:24];
|
||||||
|
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:25];
|
||||||
|
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:26];
|
||||||
|
[encoder setBytes:&scale length:sizeof( float) atIndex:27];
|
||||||
|
|
||||||
|
const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !!
|
||||||
|
const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !!
|
||||||
|
|
||||||
|
GGML_ASSERT(nqptg <= 32);
|
||||||
|
GGML_ASSERT(nqptg % 8 == 0);
|
||||||
|
GGML_ASSERT(ncpsg % 32 == 0);
|
||||||
|
|
||||||
|
// simdgroups per threadgroup (a.k.a. warps)
|
||||||
|
// for small batches use more simdgroups (needs more tests, to confirm if it's worth it)
|
||||||
|
const int64_t nsg = ne01 <= nqptg ? MAX(4, MIN(ne11/ncpsg, (int64_t) pipeline.maxTotalThreadsPerThreadgroup/32)) : 4;
|
||||||
|
|
||||||
|
const size_t smem = nqptg*(ne00 + nsg*(ncpsg + nqptg))*(sizeof(float)/2);
|
||||||
|
|
||||||
|
//printf("smem: %zu, max: %zu\n", smem, ctx->device.maxThreadgroupMemoryLength);
|
||||||
|
GGML_ASSERT(smem <= ctx->device.maxThreadgroupMemoryLength);
|
||||||
|
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||||
|
|
||||||
|
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nqptg - 1)/nqptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(32, nsg, 1)];
|
||||||
|
} break;
|
||||||
case GGML_OP_DUP:
|
case GGML_OP_DUP:
|
||||||
case GGML_OP_CPY:
|
case GGML_OP_CPY:
|
||||||
case GGML_OP_CONT:
|
case GGML_OP_CONT:
|
||||||
|
@ -2379,10 +2502,13 @@ GGML_CALL static const char * ggml_backend_metal_buffer_type_get_name(ggml_backe
|
||||||
UNUSED(buft);
|
UNUSED(buft);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
|
static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device, size_t size_aligned) {
|
||||||
|
#ifndef GGML_METAL_NDEBUG
|
||||||
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
|
||||||
if (@available(macOS 10.12, iOS 16.0, *)) {
|
if (@available(macOS 10.12, iOS 16.0, *)) {
|
||||||
GGML_METAL_LOG_INFO(", (%8.2f / %8.2f)",
|
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f / %8.2f)",
|
||||||
|
__func__,
|
||||||
|
size_aligned / 1024.0 / 1024.0,
|
||||||
device.currentAllocatedSize / 1024.0 / 1024.0,
|
device.currentAllocatedSize / 1024.0 / 1024.0,
|
||||||
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
|
||||||
|
|
||||||
|
@ -2392,10 +2518,15 @@ static void ggml_backend_metal_log_allocated_size(id<MTLDevice> device) {
|
||||||
GGML_METAL_LOG_INFO("\n");
|
GGML_METAL_LOG_INFO("\n");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
GGML_METAL_LOG_INFO(", (%8.2f)\n", device.currentAllocatedSize / 1024.0 / 1024.0);
|
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, (%8.2f)\n",
|
||||||
|
__func__,
|
||||||
|
size_aligned / 1024.0 / 1024.0,
|
||||||
|
device.currentAllocatedSize / 1024.0 / 1024.0);
|
||||||
}
|
}
|
||||||
|
#endif
|
||||||
#endif
|
#endif
|
||||||
UNUSED(device);
|
UNUSED(device);
|
||||||
|
UNUSED(size_aligned);
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) {
|
||||||
|
@ -2429,8 +2560,7 @@ GGML_CALL static ggml_backend_buffer_t ggml_backend_metal_buffer_type_alloc_buff
|
||||||
return NULL;
|
return NULL;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
|
ggml_backend_metal_log_allocated_size(device, size_aligned);
|
||||||
ggml_backend_metal_log_allocated_size(device);
|
|
||||||
|
|
||||||
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
return ggml_backend_buffer_init(buft, ggml_backend_metal_buffer_i, ctx, size);
|
||||||
}
|
}
|
||||||
|
@ -2517,7 +2647,7 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB", __func__, size_aligned / 1024.0 / 1024.0);
|
ggml_backend_metal_log_allocated_size(device, size_aligned);
|
||||||
|
|
||||||
++ctx->n_buffers;
|
++ctx->n_buffers;
|
||||||
} else {
|
} else {
|
||||||
|
@ -2540,7 +2670,8 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
GGML_METAL_LOG_INFO("%s: allocated buffer, size = %8.2f MiB, offs = %12ld", __func__, size_step_aligned / 1024.0 / 1024.0, i);
|
ggml_backend_metal_log_allocated_size(device, size_step_aligned);
|
||||||
|
|
||||||
if (i + size_step < size) {
|
if (i + size_step < size) {
|
||||||
GGML_METAL_LOG_INFO("\n");
|
GGML_METAL_LOG_INFO("\n");
|
||||||
}
|
}
|
||||||
|
@ -2549,8 +2680,6 @@ GGML_CALL ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ggml_backend_metal_log_allocated_size(device);
|
|
||||||
|
|
||||||
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
return ggml_backend_buffer_init(ggml_backend_metal_buffer_type(), ggml_backend_metal_buffer_i, ctx, size);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
435
ggml-metal.metal
435
ggml-metal.metal
|
@ -349,9 +349,9 @@ kernel void kernel_sum_rows(
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_soft_max(
|
kernel void kernel_soft_max(
|
||||||
device const float * src0,
|
device const char * src0,
|
||||||
device const float * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device char * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
|
@ -366,9 +366,9 @@ kernel void kernel_soft_max(
|
||||||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||||
|
|
||||||
device const float * psrc0 = src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
device const float * pmask = src1 != src0 ? src1 + i01*ne00 : nullptr;
|
device const half * pmask = src1 != src0 ? (device const half *) src1 + i01*ne00 : nullptr;
|
||||||
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
|
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float lmax = -INFINITY;
|
float lmax = -INFINITY;
|
||||||
|
@ -435,14 +435,14 @@ kernel void kernel_soft_max(
|
||||||
}
|
}
|
||||||
|
|
||||||
kernel void kernel_soft_max_4(
|
kernel void kernel_soft_max_4(
|
||||||
device const float * src0,
|
device const char * src0,
|
||||||
device const float * src1,
|
device const char * src1,
|
||||||
device float * dst,
|
device char * dst,
|
||||||
constant int64_t & ne00,
|
constant int64_t & ne00,
|
||||||
constant int64_t & ne01,
|
constant int64_t & ne01,
|
||||||
constant int64_t & ne02,
|
constant int64_t & ne02,
|
||||||
constant float & scale,
|
constant float & scale,
|
||||||
threadgroup float * buf [[threadgroup(0)]],
|
threadgroup float * buf [[threadgroup(0)]],
|
||||||
uint tgpig[[threadgroup_position_in_grid]],
|
uint tgpig[[threadgroup_position_in_grid]],
|
||||||
uint tpitg[[thread_position_in_threadgroup]],
|
uint tpitg[[thread_position_in_threadgroup]],
|
||||||
uint sgitg[[simdgroup_index_in_threadgroup]],
|
uint sgitg[[simdgroup_index_in_threadgroup]],
|
||||||
|
@ -452,15 +452,15 @@ kernel void kernel_soft_max_4(
|
||||||
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
|
||||||
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
|
||||||
|
|
||||||
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device const float4 * psrc4 = (device const float4 *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||||
device const float4 * pmask = src1 != src0 ? (device const float4 *)(src1 + i01*ne00) : nullptr;
|
device const half4 * pmask = src1 != src0 ? (device const half4 *) src1 + i01*ne00/4 : nullptr;
|
||||||
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
|
device float4 * pdst4 = (device float4 *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00)/4;
|
||||||
|
|
||||||
// parallel max
|
// parallel max
|
||||||
float4 lmax4 = -INFINITY;
|
float4 lmax4 = -INFINITY;
|
||||||
|
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
lmax4 = fmax(lmax4, psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f));
|
lmax4 = fmax(lmax4, psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f));
|
||||||
}
|
}
|
||||||
|
|
||||||
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
const float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
|
||||||
|
@ -486,7 +486,7 @@ kernel void kernel_soft_max_4(
|
||||||
// parallel sum
|
// parallel sum
|
||||||
float4 lsum4 = 0.0f;
|
float4 lsum4 = 0.0f;
|
||||||
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) {
|
||||||
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (pmask ? pmask[i00] : 0.0f)) - max_val);
|
const float4 exp_psrc4 = exp((psrc4[i00]*scale + (float4) (pmask ? pmask[i00] : 0.0f)) - max_val);
|
||||||
lsum4 += exp_psrc4;
|
lsum4 += exp_psrc4;
|
||||||
pdst4[i00] = exp_psrc4;
|
pdst4[i00] = exp_psrc4;
|
||||||
}
|
}
|
||||||
|
@ -1984,6 +1984,411 @@ kernel void kernel_leaky_relu_f32(
|
||||||
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
dst[tpig] = src0[tpig] > 0.0f ? src0[tpig] : src0[tpig] * slope;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
typedef void (flash_attn_ext_f16_t)(
|
||||||
|
device const char * q,
|
||||||
|
device const char * k,
|
||||||
|
device const char * v,
|
||||||
|
device const char * mask,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb03,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & ne13,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant uint64_t & nb13,
|
||||||
|
constant int64_t & ne31,
|
||||||
|
constant uint64_t & nb31,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant int64_t & ne3,
|
||||||
|
constant float & scale,
|
||||||
|
threadgroup half * shared,
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]]);
|
||||||
|
|
||||||
|
// ref: https://arxiv.org/pdf/2307.08691.pdf
|
||||||
|
template<int64_t D, int64_t Q, int64_t C> // head size, queries per threadgroup, cache items per threadgroup
|
||||||
|
kernel void kernel_flash_attn_ext_f16(
|
||||||
|
device const char * q,
|
||||||
|
device const char * k,
|
||||||
|
device const char * v,
|
||||||
|
device const char * mask,
|
||||||
|
device float * dst,
|
||||||
|
constant int64_t & ne00,
|
||||||
|
constant int64_t & ne01,
|
||||||
|
constant int64_t & ne02,
|
||||||
|
constant int64_t & ne03,
|
||||||
|
constant uint64_t & nb00,
|
||||||
|
constant uint64_t & nb01,
|
||||||
|
constant uint64_t & nb02,
|
||||||
|
constant uint64_t & nb03,
|
||||||
|
constant int64_t & ne10,
|
||||||
|
constant int64_t & ne11,
|
||||||
|
constant int64_t & ne12,
|
||||||
|
constant int64_t & ne13,
|
||||||
|
constant uint64_t & nb10,
|
||||||
|
constant uint64_t & nb11,
|
||||||
|
constant uint64_t & nb12,
|
||||||
|
constant uint64_t & nb13,
|
||||||
|
constant int64_t & ne31,
|
||||||
|
constant uint64_t & nb31,
|
||||||
|
constant int64_t & ne0,
|
||||||
|
constant int64_t & ne1,
|
||||||
|
constant int64_t & ne2,
|
||||||
|
constant int64_t & ne3,
|
||||||
|
constant float & scale,
|
||||||
|
threadgroup half * shared [[threadgroup(0)]],
|
||||||
|
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||||
|
uint3 tpitg[[thread_position_in_threadgroup]],
|
||||||
|
uint3 ntg[[threads_per_threadgroup]],
|
||||||
|
uint tiisg[[thread_index_in_simdgroup]],
|
||||||
|
uint sgitg[[simdgroup_index_in_threadgroup]]) {
|
||||||
|
const uint nsg = ntg.y; // number of simdgroups
|
||||||
|
|
||||||
|
const int64_t iq3 = tgpig[2];
|
||||||
|
const int64_t iq2 = tgpig[1];
|
||||||
|
const int64_t iq1 = tgpig[0]*Q;
|
||||||
|
|
||||||
|
const int64_t D4 = D/4;
|
||||||
|
const int64_t D8 = D/8;
|
||||||
|
const int64_t Q8 = Q/8;
|
||||||
|
const int64_t NW = N_SIMDWIDTH;
|
||||||
|
const int64_t SH = (C + Q); // shared memory per simdgroup in (half)
|
||||||
|
|
||||||
|
const int64_t T = D + nsg*SH; // shared memory size per query in (half)
|
||||||
|
const int64_t T4 = T/4; // shared memory size per query in (half4)
|
||||||
|
|
||||||
|
threadgroup half * sq = (threadgroup half *) (shared + 0*D); // holds the query data
|
||||||
|
threadgroup half4 * sq4 = (threadgroup half4 *) (shared + 0*D); // same as above but in half4
|
||||||
|
threadgroup half * ss = (threadgroup half *) (shared + sgitg*SH + 1*D); // scratch buffer for attention and diagonal matrix
|
||||||
|
|
||||||
|
// store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
|
||||||
|
simdgroup_half8x8 lo[Q8][D8];
|
||||||
|
|
||||||
|
// load heads from Q to shared memory
|
||||||
|
for (int64_t j = sgitg; j < Q; j += nsg) {
|
||||||
|
device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03));
|
||||||
|
|
||||||
|
for (int64_t i = tiisg; i < D4; i += NW) {
|
||||||
|
if (iq1 + j < ne01) {
|
||||||
|
sq4[j*T4 + i] = (half4) q4[i];
|
||||||
|
} else {
|
||||||
|
sq4[j*T4 + i] = 0.0h;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out lo
|
||||||
|
for (int64_t j = 0; j < Q8; ++j) {
|
||||||
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
|
lo[j][i] = make_filled_simdgroup_matrix<half, 8>(0.0h);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// zero out shared memory SH
|
||||||
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
|
for (int64_t i = tiisg; i < SH; i += NW) {
|
||||||
|
ss[j*T + i] = 0.0h;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
{
|
||||||
|
half S[Q] = { [0 ... Q-1] = 0.0h };
|
||||||
|
half M[Q] = { [0 ... Q-1] = -INFINITY };
|
||||||
|
|
||||||
|
// assume K and V are same shape
|
||||||
|
const int64_t ne22 = ne12;
|
||||||
|
const int64_t ne23 = ne13;
|
||||||
|
|
||||||
|
const uint64_t nb21 = nb11;
|
||||||
|
const uint64_t nb22 = nb12;
|
||||||
|
const uint64_t nb23 = nb13;
|
||||||
|
|
||||||
|
// broadcast
|
||||||
|
const int64_t rk2 = ne02/ne12;
|
||||||
|
const int64_t rk3 = ne03/ne13;
|
||||||
|
|
||||||
|
const int64_t rv2 = ne02/ne22;
|
||||||
|
const int64_t rv3 = ne03/ne23;
|
||||||
|
|
||||||
|
// k indices
|
||||||
|
const int64_t ik2 = iq2 / rk2;
|
||||||
|
const int64_t ik3 = iq3 / rk3;
|
||||||
|
|
||||||
|
// v indices
|
||||||
|
const int64_t iv2 = iq2 / rv2;
|
||||||
|
const int64_t iv3 = iq3 / rv3;
|
||||||
|
|
||||||
|
// load the queries from shared memory into local memory
|
||||||
|
simdgroup_half8x8 mq[Q8][D8];
|
||||||
|
|
||||||
|
for (int64_t j = 0; j < Q8; ++j) {
|
||||||
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_load(mq[j][i], sq + 8*j*T + i*8, T);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// pointer to the mask
|
||||||
|
device const half * mp = (device const half *) (mask + iq1*nb31);
|
||||||
|
|
||||||
|
// prepare diagonal scale matrix
|
||||||
|
simdgroup_half8x8 mscale(scale);
|
||||||
|
|
||||||
|
// loop over the KV cache
|
||||||
|
// each simdgroup handles blocks of Q rows and C columns
|
||||||
|
for (int64_t ic = C*sgitg; ic < ne11; ic += C*nsg) {
|
||||||
|
// Q*K^T
|
||||||
|
{
|
||||||
|
for (int cc = 0; cc < C/8; ++cc) {
|
||||||
|
simdgroup_half8x8 mqk[Q8];
|
||||||
|
for (int64_t j = 0; j < Q8; ++j) {
|
||||||
|
mqk[j] = make_filled_simdgroup_matrix<half, 8>(0.h);
|
||||||
|
}
|
||||||
|
|
||||||
|
device const half * pk = (device const half *) ((device const char *) k + ((ic + 8*cc)*nb11 + ik2*nb12 + ik3*nb13));
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_half8x8 mk;
|
||||||
|
simdgroup_load(mk, pk + i*8, nb11/sizeof(half), 0, true); // transpose
|
||||||
|
|
||||||
|
for (int64_t j = 0; j < Q8; ++j) {
|
||||||
|
simdgroup_multiply_accumulate(mqk[j], mq[j][i], mk, mqk[j]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// mqk = mqk*scale + mask
|
||||||
|
for (int64_t j = 0; j < Q8; ++j) {
|
||||||
|
simdgroup_half8x8 mm;
|
||||||
|
simdgroup_load(mm, mp + 8*j*(nb31/sizeof(half)) + ic + 8*cc, nb31/sizeof(half), 0, false);
|
||||||
|
simdgroup_multiply_accumulate(mqk[j], mqk[j], mscale, mm);
|
||||||
|
|
||||||
|
simdgroup_store(mqk[j], ss + 8*j*T + 8*cc, T, 0, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// used to detect blocks full of -INF
|
||||||
|
half smax = -INFINITY;
|
||||||
|
|
||||||
|
// online softmax
|
||||||
|
if (C == 32) {
|
||||||
|
half ms[Q];
|
||||||
|
|
||||||
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
|
const int64_t p = tiisg;
|
||||||
|
|
||||||
|
const half m = M[j];
|
||||||
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
|
smax = simd_max(max(smax, s));
|
||||||
|
M[j] = simd_max(max(M[j], s));
|
||||||
|
|
||||||
|
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||||
|
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||||
|
|
||||||
|
S[j] = S[j]*ms[j] + simd_sum(vs);
|
||||||
|
|
||||||
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
|
ss[j*T + p] = vs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
|
if (tiisg < Q) {
|
||||||
|
ss[tiisg*T + C + tiisg] = ms[tiisg];
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
half ms[Q];
|
||||||
|
|
||||||
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
|
const half m = M[j];
|
||||||
|
|
||||||
|
for (int64_t p = tiisg; p < C; p += NW) {
|
||||||
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
|
smax = max(smax, s);
|
||||||
|
M[j] = max(M[j], s);
|
||||||
|
}
|
||||||
|
|
||||||
|
smax = simd_max(smax);
|
||||||
|
M[j] = simd_max(M[j]);
|
||||||
|
|
||||||
|
ms[j] = m == -INFINITY ? 0.0h : exp(m - M[j]);
|
||||||
|
|
||||||
|
// local sum
|
||||||
|
half ls = 0.0h;
|
||||||
|
|
||||||
|
for (int64_t p = tiisg; p < C; p += NW) {
|
||||||
|
const half s = ss[j*T + p];
|
||||||
|
|
||||||
|
const half vs = s == -INFINITY ? 0.0h : exp(s - M[j]);
|
||||||
|
|
||||||
|
ls += vs;
|
||||||
|
|
||||||
|
// the P matrix from the paper (Q rows, C columns)
|
||||||
|
ss[j*T + p] = vs;
|
||||||
|
}
|
||||||
|
|
||||||
|
S[j] = S[j]*ms[j] + simd_sum(ls);
|
||||||
|
}
|
||||||
|
|
||||||
|
// create a QxQ diagonal matrix for rescaling the output
|
||||||
|
if (tiisg < Q) {
|
||||||
|
ss[tiisg*T + C + tiisg] = ms[tiisg];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// skip -INF blocks
|
||||||
|
if (smax == -INFINITY) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// O = diag(ms)*O
|
||||||
|
for (int64_t j = 0; j < Q8; ++j) {
|
||||||
|
simdgroup_half8x8 mm;
|
||||||
|
simdgroup_load(mm, ss + 8*j*T + C + 8*j, T, 0, false);
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_multiply(lo[j][i], mm, lo[j][i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// O = O + (Q*K^T)*V
|
||||||
|
{
|
||||||
|
for (int cc = 0; cc < C/8; ++cc) {
|
||||||
|
device const half * pv = (device const half *) ((device const char *) v + ((ic + 8*cc)*nb21 + iv2*nb22 + iv3*nb23));
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_half8x8 mk;
|
||||||
|
simdgroup_load(mk, pv + i*8, nb21/sizeof(half), 0, false);
|
||||||
|
|
||||||
|
for (int64_t j = 0; j < Q8; ++j) {
|
||||||
|
simdgroup_half8x8 mv;
|
||||||
|
simdgroup_load(mv, ss + 8*j*T + 8*cc, T, 0, false);
|
||||||
|
|
||||||
|
simdgroup_multiply_accumulate(lo[j][i], mv, mk, lo[j][i]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// these are needed for reducing the results from the simdgroups (reuse the ss buffer)
|
||||||
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
|
if (tiisg == 0) {
|
||||||
|
ss[j*T + 0] = S[j];
|
||||||
|
ss[j*T + 1] = M[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// reduce the warps sequentially
|
||||||
|
for (int64_t sg = 1; sg < nsg; ++sg) {
|
||||||
|
half S = { 0.0h };
|
||||||
|
half M = { -INFINITY };
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// each simdgroup stores its output to shared memory, reusing sq
|
||||||
|
if (sgitg == sg) {
|
||||||
|
for (int64_t j = 0; j < Q8; ++j) {
|
||||||
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||||
|
|
||||||
|
// the first simdgroup accumulates the results from the other simdgroups
|
||||||
|
if (sgitg == 0) {
|
||||||
|
for (int64_t j = 0; j < Q; ++j) {
|
||||||
|
const half S0 = ss[j*T + 0];
|
||||||
|
const half S1 = ss[j*T + sg*SH + 0];
|
||||||
|
|
||||||
|
const half M0 = ss[j*T + 1];
|
||||||
|
const half M1 = ss[j*T + sg*SH + 1];
|
||||||
|
|
||||||
|
M = max(M0, M1);
|
||||||
|
|
||||||
|
const half ms0 = M0 == -INFINITY ? 0.0h : exp(M0 - M);
|
||||||
|
const half ms1 = M1 == -INFINITY ? 0.0h : exp(M1 - M);
|
||||||
|
|
||||||
|
S = S0*ms0 + S1*ms1;
|
||||||
|
|
||||||
|
if (tiisg == 0) {
|
||||||
|
ss[j*T + 0] = S;
|
||||||
|
ss[j*T + 1] = M;
|
||||||
|
|
||||||
|
ss[j*T + C + j ] = ms0;
|
||||||
|
ss[j*T + C + j + sg*SH] = ms1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// O_0 = diag(ms0)*O_0 + diag(ms1)*O_1
|
||||||
|
for (int64_t j = 0; j < Q8; ++j) {
|
||||||
|
simdgroup_half8x8 t;
|
||||||
|
simdgroup_half8x8 ms0;
|
||||||
|
simdgroup_half8x8 ms1;
|
||||||
|
|
||||||
|
simdgroup_load(ms0, ss + 8*j*T + C + 8*j, T, 0, false);
|
||||||
|
simdgroup_load(ms1, ss + 8*j*T + C + 8*j + sg*SH, T, 0, false);
|
||||||
|
|
||||||
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_load (t, sq + 8*j*T + i*8, T, 0, false);
|
||||||
|
simdgroup_multiply(t, ms1, t);
|
||||||
|
|
||||||
|
simdgroup_multiply_accumulate(lo[j][i], ms0, lo[j][i], t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// store result to shared memory (reuse sq)
|
||||||
|
if (sgitg == 0) {
|
||||||
|
for (int64_t j = 0; j < Q8; ++j) {
|
||||||
|
for (int64_t i = 0; i < D8; ++i) {
|
||||||
|
simdgroup_store(lo[j][i], sq + 8*j*T + i*8, T, 0, false);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
device float4 * dst4 = (device float4 *) dst;
|
||||||
|
|
||||||
|
// final rescale with 1/S and store to global memory
|
||||||
|
if (sgitg == 0) {
|
||||||
|
for (int64_t j = 0; j < Q && iq1 + j < ne01; ++j) {
|
||||||
|
const half S = ss[j*T + 0];
|
||||||
|
|
||||||
|
for (int64_t i = tiisg; i < D4; i += NW) {
|
||||||
|
dst4[(iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) sq4[j*T4 + i]/S;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h64" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<64, 8, 32>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h80" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<80, 8, 32>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h96" )]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<96, 8, 32>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h112")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<112, 8, 32>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h128")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<128, 8, 32>;
|
||||||
|
template [[host_name("kernel_flash_attn_ext_f16_h256")]] kernel flash_attn_ext_f16_t kernel_flash_attn_ext_f16<256, 8, 32>;
|
||||||
|
|
||||||
kernel void kernel_cpy_f16_f16(
|
kernel void kernel_cpy_f16_f16(
|
||||||
device const half * src0,
|
device const half * src0,
|
||||||
device half * dst,
|
device half * dst,
|
||||||
|
|
335
ggml.c
335
ggml.c
|
@ -865,7 +865,7 @@ do { \
|
||||||
|
|
||||||
#if defined(__F16C__)
|
#if defined(__F16C__)
|
||||||
// the _mm256_cvt intrinsics require F16C
|
// the _mm256_cvt intrinsics require F16C
|
||||||
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((__m128i *)(x)))
|
#define GGML_F32Cx8_LOAD(x) _mm256_cvtph_ps(_mm_loadu_si128((const __m128i *)(x)))
|
||||||
#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
|
#define GGML_F32Cx8_STORE(x, y) _mm_storeu_si128((__m128i *)(x), _mm256_cvtps_ph(y, 0))
|
||||||
#else
|
#else
|
||||||
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
|
static inline __m256 __avx_f32cx8_load(ggml_fp16_t *x) {
|
||||||
|
@ -1371,6 +1371,37 @@ inline static void ggml_vec_mad_f32(const int n, float * restrict y, const float
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * restrict y, const ggml_fp16_t * restrict x, const float v) {
|
||||||
|
#if defined(GGML_SIMD)
|
||||||
|
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||||
|
|
||||||
|
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||||
|
|
||||||
|
GGML_F16_VEC ax[GGML_F16_ARR];
|
||||||
|
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||||
|
|
||||||
|
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||||
|
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||||
|
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
|
||||||
|
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||||
|
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
|
||||||
|
|
||||||
|
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// leftovers
|
||||||
|
for (int i = np; i < n; ++i) {
|
||||||
|
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// scalar
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i]) + GGML_FP16_TO_FP32(x[i])*v);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
// xs and vs are byte strides of x and v
|
// xs and vs are byte strides of x and v
|
||||||
inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
|
inline static void ggml_vec_mad_f32_unroll(const int n, const int xs, const int vs, float * restrict y, const float * restrict xv, const float * restrict vv) {
|
||||||
|
|
||||||
|
@ -1455,6 +1486,35 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
|
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
|
||||||
|
#if defined(GGML_SIMD)
|
||||||
|
const int np = (n & ~(GGML_F16_STEP - 1));
|
||||||
|
|
||||||
|
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
|
||||||
|
|
||||||
|
GGML_F16_VEC ay[GGML_F16_ARR];
|
||||||
|
|
||||||
|
for (int i = 0; i < np; i += GGML_F16_STEP) {
|
||||||
|
for (int j = 0; j < GGML_F16_ARR; j++) {
|
||||||
|
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
|
||||||
|
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
|
||||||
|
|
||||||
|
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// leftovers
|
||||||
|
for (int i = np; i < n; ++i) {
|
||||||
|
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
// scalar
|
||||||
|
for (int i = 0; i < n; ++i) {
|
||||||
|
y[i] = GGML_FP32_TO_FP16(GGML_FP16_TO_FP32(y[i])*v);
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
}
|
||||||
|
|
||||||
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); }
|
inline static void ggml_vec_norm_f32 (const int n, float * s, const float * x) { ggml_vec_dot_f32(n, s, x, x); *s = sqrtf(*s); }
|
||||||
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
|
inline static void ggml_vec_sqr_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = x[i]*x[i]; }
|
||||||
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
|
inline static void ggml_vec_sqrt_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = sqrtf(x[i]); }
|
||||||
|
@ -1701,6 +1761,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"LEAKY_RELU",
|
"LEAKY_RELU",
|
||||||
|
|
||||||
"FLASH_ATTN",
|
"FLASH_ATTN",
|
||||||
|
"FLASH_ATTN_EXT",
|
||||||
"FLASH_FF",
|
"FLASH_FF",
|
||||||
"FLASH_ATTN_BACK",
|
"FLASH_ATTN_BACK",
|
||||||
"WIN_PART",
|
"WIN_PART",
|
||||||
|
@ -1725,7 +1786,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
|
||||||
"CROSS_ENTROPY_LOSS_BACK",
|
"CROSS_ENTROPY_LOSS_BACK",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
|
static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"none",
|
"none",
|
||||||
|
@ -1787,6 +1848,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"leaky_relu(x)",
|
"leaky_relu(x)",
|
||||||
|
|
||||||
"flash_attn(x)",
|
"flash_attn(x)",
|
||||||
|
"flash_attn_ext(x)",
|
||||||
"flash_ff(x)",
|
"flash_ff(x)",
|
||||||
"flash_attn_back(x)",
|
"flash_attn_back(x)",
|
||||||
"win_part(x)",
|
"win_part(x)",
|
||||||
|
@ -1811,7 +1873,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
||||||
"cross_entropy_loss_back(x,y)",
|
"cross_entropy_loss_back(x,y)",
|
||||||
};
|
};
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 72, "GGML_OP_COUNT != 72");
|
static_assert(GGML_OP_COUNT == 73, "GGML_OP_COUNT != 73");
|
||||||
|
|
||||||
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
|
||||||
|
|
||||||
|
@ -4188,6 +4250,8 @@ struct ggml_tensor * ggml_mul_mat(
|
||||||
void ggml_mul_mat_set_prec(
|
void ggml_mul_mat_set_prec(
|
||||||
struct ggml_tensor * a,
|
struct ggml_tensor * a,
|
||||||
enum ggml_prec prec) {
|
enum ggml_prec prec) {
|
||||||
|
GGML_ASSERT(a->op == GGML_OP_MUL_MAT);
|
||||||
|
|
||||||
const int32_t prec_i32 = (int32_t) prec;
|
const int32_t prec_i32 = (int32_t) prec;
|
||||||
|
|
||||||
ggml_set_op_params_i32(a, 0, prec_i32);
|
ggml_set_op_params_i32(a, 0, prec_i32);
|
||||||
|
@ -5021,10 +5085,11 @@ static struct ggml_tensor * ggml_soft_max_impl(
|
||||||
bool inplace) {
|
bool inplace) {
|
||||||
GGML_ASSERT(ggml_is_contiguous(a));
|
GGML_ASSERT(ggml_is_contiguous(a));
|
||||||
if (mask) {
|
if (mask) {
|
||||||
|
GGML_ASSERT(mask->type == GGML_TYPE_F16);
|
||||||
GGML_ASSERT(ggml_is_contiguous(mask));
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
GGML_ASSERT(mask->ne[2] == 1);
|
GGML_ASSERT(mask->ne[2] == 1);
|
||||||
GGML_ASSERT(mask->ne[3] == 1);
|
GGML_ASSERT(mask->ne[3] == 1);
|
||||||
GGML_ASSERT(ggml_can_repeat_rows(mask, a));
|
GGML_ASSERT(mask->ne[1] >= a->ne[1]);
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_node = false;
|
bool is_node = false;
|
||||||
|
@ -5775,6 +5840,59 @@ struct ggml_tensor * ggml_flash_attn(
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_flash_attn_ext
|
||||||
|
|
||||||
|
struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * q,
|
||||||
|
struct ggml_tensor * k,
|
||||||
|
struct ggml_tensor * v,
|
||||||
|
struct ggml_tensor * mask,
|
||||||
|
float scale) {
|
||||||
|
GGML_ASSERT(ggml_can_mul_mat(k, q));
|
||||||
|
// TODO: check if vT can be multiplied by (k*qT)
|
||||||
|
if (mask) {
|
||||||
|
GGML_ASSERT(ggml_is_contiguous(mask));
|
||||||
|
GGML_ASSERT(mask->ne[2] == 1);
|
||||||
|
GGML_ASSERT(mask->ne[3] == 1);
|
||||||
|
GGML_ASSERT(mask->ne[1] >= GGML_PAD(q->ne[1], GGML_KQ_MASK_PAD) &&
|
||||||
|
"the Flash-Attention kernel requires the mask to be padded to GGML_KQ_MASK_PAD and at least n_queries big");
|
||||||
|
//GGML_ASSERT(ggml_can_repeat_rows(mask, qk));
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_node = false;
|
||||||
|
|
||||||
|
if (q->grad || k->grad || v->grad) {
|
||||||
|
is_node = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// permute(0, 2, 1, 3)
|
||||||
|
int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] };
|
||||||
|
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, GGML_MAX_DIMS, ne);
|
||||||
|
|
||||||
|
float params[] = { scale };
|
||||||
|
ggml_set_op_params(result, params, sizeof(params));
|
||||||
|
|
||||||
|
result->op = GGML_OP_FLASH_ATTN_EXT;
|
||||||
|
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
|
||||||
|
result->src[0] = q;
|
||||||
|
result->src[1] = k;
|
||||||
|
result->src[2] = v;
|
||||||
|
result->src[3] = mask;
|
||||||
|
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
void ggml_flash_attn_ext_set_prec(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_prec prec) {
|
||||||
|
GGML_ASSERT(a->op == GGML_OP_FLASH_ATTN_EXT);
|
||||||
|
|
||||||
|
const int32_t prec_i32 = (int32_t) prec;
|
||||||
|
|
||||||
|
ggml_set_op_params_i32(a, 1, prec_i32); // scale is on first pos
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_flash_ff
|
// ggml_flash_ff
|
||||||
|
|
||||||
struct ggml_tensor * ggml_flash_ff(
|
struct ggml_tensor * ggml_flash_ff(
|
||||||
|
@ -11437,12 +11555,14 @@ static void ggml_compute_forward_soft_max_f32(
|
||||||
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
float * dp = (float *)((char *) dst->data + i1*dst->nb[1]);
|
||||||
|
|
||||||
// broadcast the mask across rows
|
// broadcast the mask across rows
|
||||||
float * mp = src1 ? (float *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
|
ggml_fp16_t * mp = src1 ? (ggml_fp16_t *)((char *) src1->data + (i1%ne11)*src1->nb[1]) : NULL;
|
||||||
|
|
||||||
ggml_vec_cpy_f32 (nc, wp, sp);
|
ggml_vec_cpy_f32 (nc, wp, sp);
|
||||||
ggml_vec_scale_f32(nc, wp, scale);
|
ggml_vec_scale_f32(nc, wp, scale);
|
||||||
if (mp) {
|
if (mp) {
|
||||||
ggml_vec_acc_f32(nc, wp, mp);
|
for (int i = 0; i < nc; ++i) {
|
||||||
|
wp[i] += GGML_FP16_TO_FP32(mp[i]);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#ifndef NDEBUG
|
#ifndef NDEBUG
|
||||||
|
@ -13552,6 +13672,197 @@ static void ggml_compute_forward_flash_attn(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ggml_compute_forward_flash_attn_ext
|
||||||
|
|
||||||
|
static void ggml_compute_forward_flash_attn_ext_f16(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
const struct ggml_tensor * q,
|
||||||
|
const struct ggml_tensor * k,
|
||||||
|
const struct ggml_tensor * v,
|
||||||
|
const struct ggml_tensor * mask,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
int64_t t0 = ggml_perf_time_us();
|
||||||
|
UNUSED(t0);
|
||||||
|
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, neq, q, ne)
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nbq, q, nb)
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, nek, k, ne)
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nbk, k, nb)
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, nev, v, ne)
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nbv, v, nb)
|
||||||
|
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
|
||||||
|
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
|
||||||
|
|
||||||
|
const int ith = params->ith;
|
||||||
|
const int nth = params->nth;
|
||||||
|
|
||||||
|
const int64_t D = neq0;
|
||||||
|
const int64_t N = neq1;
|
||||||
|
|
||||||
|
GGML_ASSERT(ne0 == D);
|
||||||
|
GGML_ASSERT(ne2 == N);
|
||||||
|
|
||||||
|
GGML_ASSERT(nbq0 == sizeof(float));
|
||||||
|
GGML_ASSERT(nbk0 == sizeof(ggml_fp16_t));
|
||||||
|
GGML_ASSERT(nbv0 == sizeof(ggml_fp16_t));
|
||||||
|
|
||||||
|
GGML_ASSERT(neq0 == D);
|
||||||
|
GGML_ASSERT(nek0 == D);
|
||||||
|
GGML_ASSERT(nev0 == D);
|
||||||
|
|
||||||
|
GGML_ASSERT(neq1 == N);
|
||||||
|
GGML_ASSERT(nev0 == D);
|
||||||
|
|
||||||
|
// dst cannot be transposed or permuted
|
||||||
|
GGML_ASSERT(nb0 == sizeof(float));
|
||||||
|
GGML_ASSERT(nb0 <= nb1);
|
||||||
|
GGML_ASSERT(nb1 <= nb2);
|
||||||
|
GGML_ASSERT(nb2 <= nb3);
|
||||||
|
|
||||||
|
// broadcast factors
|
||||||
|
const int64_t rk2 = neq2/nek2;
|
||||||
|
const int64_t rk3 = neq3/nek3;
|
||||||
|
|
||||||
|
const int64_t rv2 = neq2/nev2;
|
||||||
|
const int64_t rv3 = neq3/nev3;
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_INIT) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params->type == GGML_TASK_FINALIZE) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// parallelize by q rows using ggml_vec_dot_f32
|
||||||
|
|
||||||
|
// total rows in q
|
||||||
|
const int nr = neq1*neq2*neq3;
|
||||||
|
|
||||||
|
// rows per thread
|
||||||
|
const int dr = (nr + nth - 1)/nth;
|
||||||
|
|
||||||
|
// row range for this thread
|
||||||
|
const int ir0 = dr*ith;
|
||||||
|
const int ir1 = MIN(ir0 + dr, nr);
|
||||||
|
|
||||||
|
float scale = 1.0f;
|
||||||
|
memcpy(&scale, (float *) dst->op_params + 0, sizeof(float));
|
||||||
|
|
||||||
|
// loop over n_batch and n_head
|
||||||
|
for (int ir = ir0; ir < ir1; ++ir) {
|
||||||
|
// q indices
|
||||||
|
const int iq3 = ir/(neq2*neq1);
|
||||||
|
const int iq2 = (ir - iq3*neq2*neq1)/neq1;
|
||||||
|
const int iq1 = (ir - iq3*neq2*neq1 - iq2*neq1);
|
||||||
|
|
||||||
|
float S = 0.0f;
|
||||||
|
float M = -INFINITY;
|
||||||
|
|
||||||
|
float * V32 = (float *) params->wdata + ith*(2*D + CACHE_LINE_SIZE_F32);
|
||||||
|
ggml_fp16_t * Q16 = (ggml_fp16_t *) (V32); // reuse memory
|
||||||
|
ggml_fp16_t * V16 = (ggml_fp16_t *) (V32 + D);
|
||||||
|
|
||||||
|
memset(V16, 0, D*sizeof(ggml_fp16_t));
|
||||||
|
|
||||||
|
const ggml_fp16_t * mp = mask ? (ggml_fp16_t *)((char *) mask->data + iq1*mask->nb[1]) : NULL;
|
||||||
|
|
||||||
|
// k indices
|
||||||
|
const int ik3 = iq3 / rk3;
|
||||||
|
const int ik2 = iq2 / rk2;
|
||||||
|
|
||||||
|
// v indices
|
||||||
|
const int iv3 = iq3 / rv3;
|
||||||
|
const int iv2 = iq2 / rv2;
|
||||||
|
|
||||||
|
// online softmax / attention
|
||||||
|
// loop over n_kv and n_head_kv
|
||||||
|
// ref: https://arxiv.org/pdf/2112.05682.pdf
|
||||||
|
for (int64_t ic = 0; ic < nek1; ++ic) {
|
||||||
|
const float mv = mp ? GGML_FP16_TO_FP32(mp[ic]) : 0.0f;
|
||||||
|
if (mv == -INFINITY) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
float s;
|
||||||
|
|
||||||
|
// convert Q to F16 in V32
|
||||||
|
{
|
||||||
|
const float * pq = (const float *) ((char *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3));
|
||||||
|
|
||||||
|
for (int64_t d = 0; d < D; ++d) {
|
||||||
|
Q16[d] = GGML_FP32_TO_FP16(pq[d]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ggml_vec_dot_f16(D,
|
||||||
|
&s,
|
||||||
|
(ggml_fp16_t *) ((char *) k->data + ( ic*nbk1 + ik2*nbk2 + ik3*nbk3)),
|
||||||
|
Q16);
|
||||||
|
|
||||||
|
s = s*scale + mv;
|
||||||
|
|
||||||
|
const float Mold = M;
|
||||||
|
|
||||||
|
float ms = 1.0f;
|
||||||
|
float vs = 1.0f;
|
||||||
|
|
||||||
|
if (s > M) {
|
||||||
|
M = s;
|
||||||
|
ms = expf(Mold - M);
|
||||||
|
|
||||||
|
// V = V*expf(Mold - M)
|
||||||
|
ggml_vec_scale_f16(D, V16, ms);
|
||||||
|
} else {
|
||||||
|
vs = expf(s - M);
|
||||||
|
}
|
||||||
|
|
||||||
|
const ggml_fp16_t * v16 = (const ggml_fp16_t *) ((char *) v->data + (ic*nbv1 + iv2*nbv2 + iv3*nbv3));
|
||||||
|
|
||||||
|
// V += v*expf(s - M)
|
||||||
|
ggml_vec_mad_f16(D, V16, v16, vs);
|
||||||
|
|
||||||
|
S = S*ms + vs;
|
||||||
|
}
|
||||||
|
|
||||||
|
// V /= S
|
||||||
|
for (int64_t d = 0; d < D; ++d) {
|
||||||
|
V32[d] = GGML_FP16_TO_FP32(V16[d])/S;
|
||||||
|
}
|
||||||
|
|
||||||
|
// dst indices
|
||||||
|
const int i1 = iq1;
|
||||||
|
const int i2 = iq2;
|
||||||
|
const int i3 = iq3;
|
||||||
|
|
||||||
|
// original
|
||||||
|
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float));
|
||||||
|
|
||||||
|
// permute(0, 2, 1, 3)
|
||||||
|
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, V32, nb1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void ggml_compute_forward_flash_attn_ext(
|
||||||
|
const struct ggml_compute_params * params,
|
||||||
|
const struct ggml_tensor * q,
|
||||||
|
const struct ggml_tensor * k,
|
||||||
|
const struct ggml_tensor * v,
|
||||||
|
const struct ggml_tensor * mask,
|
||||||
|
struct ggml_tensor * dst) {
|
||||||
|
switch (dst->op_params[1]) {
|
||||||
|
case GGML_PREC_DEFAULT:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_flash_attn_ext_f16(params, q, k, v, mask, dst);
|
||||||
|
} break;
|
||||||
|
default:
|
||||||
|
{
|
||||||
|
// TODO: implement F32 precision
|
||||||
|
GGML_ASSERT(false);
|
||||||
|
} break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ggml_compute_forward_flash_ff
|
// ggml_compute_forward_flash_ff
|
||||||
|
|
||||||
static void ggml_compute_forward_flash_ff_f16(
|
static void ggml_compute_forward_flash_ff_f16(
|
||||||
|
@ -15086,6 +15397,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
|
||||||
const bool masked = t != 0;
|
const bool masked = t != 0;
|
||||||
ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
|
ggml_compute_forward_flash_attn(params, tensor->src[0], tensor->src[1], tensor->src[2], masked, tensor);
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
{
|
||||||
|
ggml_compute_forward_flash_attn_ext(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor);
|
||||||
|
} break;
|
||||||
case GGML_OP_FLASH_FF:
|
case GGML_OP_FLASH_FF:
|
||||||
{
|
{
|
||||||
ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
|
ggml_compute_forward_flash_ff(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor->src[3], tensor->src[4], tensor);
|
||||||
|
@ -16082,6 +16397,7 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
|
||||||
GGML_ASSERT(false); // TODO: not implemented
|
GGML_ASSERT(false); // TODO: not implemented
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN:
|
case GGML_OP_FLASH_ATTN:
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
struct ggml_tensor * flash_grad = NULL;
|
struct ggml_tensor * flash_grad = NULL;
|
||||||
if (src0->grad || src1->grad || tensor->src[2]->grad) {
|
if (src0->grad || src1->grad || tensor->src[2]->grad) {
|
||||||
|
@ -16810,6 +17126,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
case GGML_OP_FLASH_ATTN:
|
case GGML_OP_FLASH_ATTN:
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
{
|
{
|
||||||
n_tasks = n_threads;
|
n_tasks = n_threads;
|
||||||
} break;
|
} break;
|
||||||
|
@ -17204,6 +17521,12 @@ struct ggml_cplan ggml_graph_plan(const struct ggml_cgraph * cgraph, int n_threa
|
||||||
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
cur += sizeof(float)*ne11*n_tasks; // this is overestimated by x2
|
||||||
}
|
}
|
||||||
} break;
|
} break;
|
||||||
|
case GGML_OP_FLASH_ATTN_EXT:
|
||||||
|
{
|
||||||
|
const int64_t ne00 = node->src[0]->ne[0]; // D
|
||||||
|
|
||||||
|
cur = 2*sizeof(float)*ne00*n_tasks; // 2x head size
|
||||||
|
} break;
|
||||||
case GGML_OP_FLASH_FF:
|
case GGML_OP_FLASH_FF:
|
||||||
{
|
{
|
||||||
if (node->src[1]->type == GGML_TYPE_F32) {
|
if (node->src[1]->type == GGML_TYPE_F32) {
|
||||||
|
|
20
ggml.h
20
ggml.h
|
@ -454,6 +454,7 @@ extern "C" {
|
||||||
GGML_OP_LEAKY_RELU,
|
GGML_OP_LEAKY_RELU,
|
||||||
|
|
||||||
GGML_OP_FLASH_ATTN,
|
GGML_OP_FLASH_ATTN,
|
||||||
|
GGML_OP_FLASH_ATTN_EXT,
|
||||||
GGML_OP_FLASH_FF,
|
GGML_OP_FLASH_FF,
|
||||||
GGML_OP_FLASH_ATTN_BACK,
|
GGML_OP_FLASH_ATTN_BACK,
|
||||||
GGML_OP_WIN_PART,
|
GGML_OP_WIN_PART,
|
||||||
|
@ -1645,6 +1646,25 @@ extern "C" {
|
||||||
struct ggml_tensor * v,
|
struct ggml_tensor * v,
|
||||||
bool masked);
|
bool masked);
|
||||||
|
|
||||||
|
#define GGML_KQ_MASK_PAD 32
|
||||||
|
|
||||||
|
// q: [n_embd, n_batch, n_head, 1]
|
||||||
|
// k: [n_embd, n_kv, n_head_kv, 1]
|
||||||
|
// v: [n_embd, n_kv, n_head_kv, 1] !! not transposed !!
|
||||||
|
// mask: [n_kv, n_batch_pad, 1, 1] !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !!
|
||||||
|
// res: [n_embd, n_head, n_batch, 1] !! permuted !!
|
||||||
|
GGML_API struct ggml_tensor * ggml_flash_attn_ext(
|
||||||
|
struct ggml_context * ctx,
|
||||||
|
struct ggml_tensor * q,
|
||||||
|
struct ggml_tensor * k,
|
||||||
|
struct ggml_tensor * v,
|
||||||
|
struct ggml_tensor * mask,
|
||||||
|
float scale);
|
||||||
|
|
||||||
|
GGML_API void ggml_flash_attn_ext_set_prec(
|
||||||
|
struct ggml_tensor * a,
|
||||||
|
enum ggml_prec prec);
|
||||||
|
|
||||||
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
GGML_API struct ggml_tensor * ggml_flash_attn_back(
|
||||||
struct ggml_context * ctx,
|
struct ggml_context * ctx,
|
||||||
struct ggml_tensor * q,
|
struct ggml_tensor * q,
|
||||||
|
|
102
llama.cpp
102
llama.cpp
|
@ -102,6 +102,8 @@
|
||||||
#define LLAMA_MAX_NODES 8192
|
#define LLAMA_MAX_NODES 8192
|
||||||
#define LLAMA_MAX_EXPERTS 8
|
#define LLAMA_MAX_EXPERTS 8
|
||||||
|
|
||||||
|
#define LLAMA_FLASH_ATTN
|
||||||
|
|
||||||
//
|
//
|
||||||
// logging
|
// logging
|
||||||
//
|
//
|
||||||
|
@ -4361,23 +4363,34 @@ static void llm_build_kv_store(
|
||||||
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
const int64_t n_embd_k_gqa = hparams.n_embd_k_gqa();
|
||||||
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
const int64_t n_embd_v_gqa = hparams.n_embd_v_gqa();
|
||||||
|
|
||||||
// compute the transposed [n_tokens, n_embd] V matrix
|
|
||||||
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
|
|
||||||
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
|
|
||||||
cb(v_cur_t, "v_cur_t", il);
|
|
||||||
|
|
||||||
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
|
struct ggml_tensor * k_cache_view = ggml_view_1d(ctx, kv.k_l[il], n_tokens*n_embd_k_gqa,
|
||||||
(ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
|
(ggml_row_size(kv.k_l[il]->type, n_embd_k_gqa))*kv_head);
|
||||||
cb(k_cache_view, "k_cache_view", il);
|
cb(k_cache_view, "k_cache_view", il);
|
||||||
|
|
||||||
|
// important: storing RoPE-ed version of K in the KV cache!
|
||||||
|
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
|
||||||
|
|
||||||
|
#if defined(LLAMA_FLASH_ATTN)
|
||||||
|
// NOTE: the V cache is not transposed when using FLASH attention !!
|
||||||
|
struct ggml_tensor * v_cache_view = ggml_view_1d(ctx, kv.v_l[il], n_tokens*n_embd_v_gqa,
|
||||||
|
(ggml_row_size(kv.v_l[il]->type, n_embd_v_gqa))*kv_head);
|
||||||
|
cb(v_cache_view, "v_cache_view", il);
|
||||||
|
|
||||||
|
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur, v_cache_view));
|
||||||
|
|
||||||
|
GGML_UNUSED(n_ctx);
|
||||||
|
#else
|
||||||
|
// compute the transposed [n_tokens, n_embd] V matrix
|
||||||
|
//struct ggml_tensor * v_cur_t = ggml_transpose(ctx, ggml_reshape_2d(ctx, v_cur, n_embd_v_gqa, n_tokens));
|
||||||
|
struct ggml_tensor * v_cur_t = ggml_transpose(ctx, v_cur); // TODO: reshape above is likely not needed
|
||||||
|
cb(v_cur_t, "v_cur_t", il);
|
||||||
|
|
||||||
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
|
struct ggml_tensor * v_cache_view = ggml_view_2d(ctx, kv.v_l[il], n_tokens, n_embd_v_gqa,
|
||||||
( n_ctx)*ggml_element_size(kv.v_l[il]),
|
( n_ctx)*ggml_element_size(kv.v_l[il]),
|
||||||
(kv_head)*ggml_element_size(kv.v_l[il]));
|
(kv_head)*ggml_element_size(kv.v_l[il]));
|
||||||
cb(v_cache_view, "v_cache_view", il);
|
|
||||||
|
|
||||||
// important: storing RoPE-ed version of K in the KV cache!
|
|
||||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, k_cur, k_cache_view));
|
|
||||||
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
|
ggml_build_forward_expand(graph, ggml_cpy(ctx, v_cur_t, v_cache_view));
|
||||||
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static struct ggml_tensor * llm_build_norm(
|
static struct ggml_tensor * llm_build_norm(
|
||||||
|
@ -4538,6 +4551,28 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
0);
|
0);
|
||||||
cb(k, "k", il);
|
cb(k, "k", il);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
|
||||||
|
#if defined(LLAMA_FLASH_ATTN)
|
||||||
|
// split cached v into n_head heads (not transposed)
|
||||||
|
struct ggml_tensor * v =
|
||||||
|
ggml_view_3d(ctx, kv.v_l[il],
|
||||||
|
n_embd_head_v, n_kv, n_head_kv,
|
||||||
|
ggml_row_size(kv.v_l[il]->type, n_embd_k_gqa),
|
||||||
|
ggml_row_size(kv.v_l[il]->type, n_embd_head_k),
|
||||||
|
0);
|
||||||
|
cb(v, "v", il);
|
||||||
|
|
||||||
|
cur = ggml_flash_attn_ext(ctx, q, k, v, kq_mask, kq_scale);
|
||||||
|
ggml_flash_attn_ext_set_prec(cur, GGML_PREC_DEFAULT);
|
||||||
|
//printf("q: %4d %4d %4d %4d\n", q->ne[0], q->ne[1], q->ne[2], q->ne[3]);
|
||||||
|
//printf("k: %4d %4d %4d %4d\n", k->ne[0], k->ne[1], k->ne[2], k->ne[3]);
|
||||||
|
//printf("v: %4d %4d %4d %4d\n", v->ne[0], v->ne[1], v->ne[2], v->ne[3]);
|
||||||
|
//printf("m: %4d %4d %4d %4d\n", kq_mask->ne[0], kq_mask->ne[1], kq_mask->ne[2], kq_mask->ne[3]);
|
||||||
|
//printf("r: %4d %4d %4d %4d\n", kqv->ne[0], kqv->ne[1], kqv->ne[2], kqv->ne[3]);
|
||||||
|
|
||||||
|
cur = ggml_reshape_2d(ctx, cur, n_embd_head_k*n_head, n_tokens);
|
||||||
|
#else
|
||||||
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
struct ggml_tensor * kq = ggml_mul_mat(ctx, k, q);
|
||||||
cb(kq, "kq", il);
|
cb(kq, "kq", il);
|
||||||
|
|
||||||
|
@ -4570,7 +4605,7 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
cb(kq, "kq_soft_max_ext", il);
|
cb(kq, "kq_soft_max_ext", il);
|
||||||
}
|
}
|
||||||
|
|
||||||
// split cached v into n_head heads
|
// split cached v into n_head heads (transposed)
|
||||||
struct ggml_tensor * v =
|
struct ggml_tensor * v =
|
||||||
ggml_view_3d(ctx, kv.v_l[il],
|
ggml_view_3d(ctx, kv.v_l[il],
|
||||||
n_kv, n_embd_head_v, n_head_kv,
|
n_kv, n_embd_head_v, n_head_kv,
|
||||||
|
@ -4585,8 +4620,9 @@ static struct ggml_tensor * llm_build_kqv(
|
||||||
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
struct ggml_tensor * kqv_merged = ggml_permute(ctx, kqv, 0, 2, 1, 3);
|
||||||
cb(kqv_merged, "kqv_merged", il);
|
cb(kqv_merged, "kqv_merged", il);
|
||||||
|
|
||||||
struct ggml_tensor * cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
|
cur = ggml_cont_2d(ctx, kqv_merged, n_embd_head_k*n_head, n_tokens);
|
||||||
cb(cur, "kqv_merged_cont", il);
|
cb(cur, "kqv_merged_cont", il);
|
||||||
|
#endif
|
||||||
|
|
||||||
ggml_build_forward_expand(graph, cur);
|
ggml_build_forward_expand(graph, cur);
|
||||||
|
|
||||||
|
@ -4758,7 +4794,7 @@ struct llm_build_context {
|
||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
|
@ -4942,7 +4978,7 @@ struct llm_build_context {
|
||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
|
@ -5063,7 +5099,7 @@ struct llm_build_context {
|
||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
|
@ -5185,7 +5221,7 @@ struct llm_build_context {
|
||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
||||||
|
@ -5282,7 +5318,7 @@ struct llm_build_context {
|
||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
if (do_rope_shift) {
|
if (do_rope_shift) {
|
||||||
|
@ -5485,7 +5521,7 @@ struct llm_build_context {
|
||||||
cb(inpL, "inp_embd", -1);
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
@ -5575,7 +5611,7 @@ struct llm_build_context {
|
||||||
cb(inpL, "inp_embd", -1);
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
inpL = llm_build_norm(ctx0, inpL, hparams,
|
inpL = llm_build_norm(ctx0, inpL, hparams,
|
||||||
|
@ -5668,7 +5704,7 @@ struct llm_build_context {
|
||||||
cb(inpL, "inp_embd", -1);
|
cb(inpL, "inp_embd", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
for (int il = 0; il < n_layer; ++il) {
|
for (int il = 0; il < n_layer; ++il) {
|
||||||
|
@ -5768,7 +5804,7 @@ struct llm_build_context {
|
||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
|
@ -5891,7 +5927,7 @@ struct llm_build_context {
|
||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
|
@ -6005,7 +6041,7 @@ struct llm_build_context {
|
||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
|
@ -6126,7 +6162,7 @@ struct llm_build_context {
|
||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
|
@ -6248,7 +6284,7 @@ struct llm_build_context {
|
||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
|
@ -6355,7 +6391,7 @@ struct llm_build_context {
|
||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
pos = ggml_get_rows(ctx0, model.pos_embd, inp_pos);
|
||||||
|
@ -6453,7 +6489,7 @@ struct llm_build_context {
|
||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
|
@ -6561,7 +6597,7 @@ struct llm_build_context {
|
||||||
cb(inp_pos, "inp_pos", -1);
|
cb(inp_pos, "inp_pos", -1);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, n_tokens, n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0);
|
struct ggml_tensor * KQ_mask = ggml_cast(ctx0, ggml_view_2d(ctx0, lctx.inp_KQ_mask, n_kv, GGML_PAD(n_tokens, GGML_KQ_MASK_PAD), n_kv*ggml_type_size(lctx.inp_KQ_mask->type), 0), GGML_TYPE_F16);
|
||||||
cb(KQ_mask, "KQ_mask", -1);
|
cb(KQ_mask, "KQ_mask", -1);
|
||||||
|
|
||||||
// shift the entire K-cache if needed
|
// shift the entire K-cache if needed
|
||||||
|
@ -7042,7 +7078,8 @@ static int llama_decode_internal(
|
||||||
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
// a heuristic, to avoid attending the full cache if it is not yet utilized
|
||||||
// after enough generations, the benefit from this heuristic disappears
|
// after enough generations, the benefit from this heuristic disappears
|
||||||
// if we start defragmenting the cache, the benefit from this will be more important
|
// if we start defragmenting the cache, the benefit from this will be more important
|
||||||
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(32, GGML_PAD(llama_kv_cache_cell_max(kv_self), 32)));
|
// note: we pad the n_kv because certain GPU kernels require it (e.g. ggml_flash_attn_ext)
|
||||||
|
kv_self.n = std::min((int32_t) cparams.n_ctx, std::max(128, GGML_PAD(llama_kv_cache_cell_max(kv_self), 128)));
|
||||||
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
//kv_self.n = llama_kv_cache_cell_max(kv_self);
|
||||||
|
|
||||||
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
//printf("kv_self.n = %5d, kv_self.used = %5d, kv_self.head = %5d\n", kv_self.n, kv_self.used, kv_self.head);
|
||||||
|
@ -10413,7 +10450,10 @@ struct llama_context * llama_new_context_with_model(
|
||||||
const auto & hparams = model->hparams;
|
const auto & hparams = model->hparams;
|
||||||
auto & cparams = ctx->cparams;
|
auto & cparams = ctx->cparams;
|
||||||
|
|
||||||
cparams.n_batch = params.n_batch;
|
// the batch has to be at least GGML_KQ_MASK_PAD because we will be padding the KQ_mask
|
||||||
|
// this is required by GPU kernels in order to avoid out-of-bounds accesses (e.g. ggml_flash_attn_ext)
|
||||||
|
cparams.n_batch = std::max((uint32_t) GGML_KQ_MASK_PAD, params.n_batch);
|
||||||
|
|
||||||
cparams.n_threads = params.n_threads;
|
cparams.n_threads = params.n_threads;
|
||||||
cparams.n_threads_batch = params.n_threads_batch;
|
cparams.n_threads_batch = params.n_threads_batch;
|
||||||
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
cparams.yarn_ext_factor = params.yarn_ext_factor;
|
||||||
|
@ -10539,8 +10579,7 @@ struct llama_context * llama_new_context_with_model(
|
||||||
}
|
}
|
||||||
ctx->backends.push_back(ctx->backend_cpu);
|
ctx->backends.push_back(ctx->backend_cpu);
|
||||||
|
|
||||||
if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v,
|
if (!llama_kv_cache_init(ctx->kv_self, ctx->model, type_k, type_v, cparams.n_ctx, cparams.offload_kqv)) {
|
||||||
cparams.n_ctx, cparams.offload_kqv)) {
|
|
||||||
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
LLAMA_LOG_ERROR("%s: llama_kv_cache_init() failed for self-attention cache\n", __func__);
|
||||||
llama_free(ctx);
|
llama_free(ctx);
|
||||||
return nullptr;
|
return nullptr;
|
||||||
|
@ -10594,6 +10633,9 @@ struct llama_context * llama_new_context_with_model(
|
||||||
|
|
||||||
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
|
ctx->buf_input = ggml_backend_alloc_ctx_tensors_from_buft(ctx->ctx_input, llama_default_buffer_type_cpu(true));
|
||||||
|
|
||||||
|
// zero-out the input buffer to prevent NaNs in padded tensors
|
||||||
|
ggml_backend_buffer_clear(ctx->buf_input, 0);
|
||||||
|
|
||||||
LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,
|
LLAMA_LOG_INFO("%s: %10s input buffer size = %8.2f MiB\n", __func__,
|
||||||
ggml_backend_buffer_name(ctx->buf_input),
|
ggml_backend_buffer_name(ctx->buf_input),
|
||||||
ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0);
|
ggml_backend_buffer_get_size(ctx->buf_input) / 1024.0 / 1024.0);
|
||||||
|
|
|
@ -572,9 +572,19 @@ struct test_case {
|
||||||
// duplicate the op
|
// duplicate the op
|
||||||
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
|
size_t target_size = ggml_backend_is_cpu(backend) ? 1ULL << 33 : 1ULL << 35; // 8 GB CPU, 32 GB GPU
|
||||||
int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
|
int n_runs = std::min((size_t)gf->size - gf->n_nodes, target_size / op_size(out)) + 1;
|
||||||
|
#if 0
|
||||||
for (int i = 1; i < n_runs; i++) {
|
for (int i = 1; i < n_runs; i++) {
|
||||||
gf->nodes[gf->n_nodes++] = out;
|
gf->nodes[gf->n_nodes++] = out;
|
||||||
}
|
}
|
||||||
|
#else
|
||||||
|
int n_nodes = gf->n_nodes;
|
||||||
|
n_runs = 1000;
|
||||||
|
for (int i = 1; i < n_runs; i++) {
|
||||||
|
for (int j = 0; j < n_nodes; j++) {
|
||||||
|
gf->nodes[gf->n_nodes++] = gf->nodes[j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
// calculate memory
|
// calculate memory
|
||||||
size_t mem = n_runs * op_size(out);
|
size_t mem = n_runs * op_size(out);
|
||||||
|
@ -1101,7 +1111,7 @@ struct test_soft_max : public test_case {
|
||||||
ggml_tensor * build_graph(ggml_context * ctx) override {
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data());
|
||||||
ggml_tensor * b = nullptr;
|
ggml_tensor * b = nullptr;
|
||||||
if (mask) { b = ggml_new_tensor_2d(ctx, type, ne[0], ne[1]); }
|
if (mask) { b = ggml_new_tensor_2d(ctx, GGML_TYPE_F16, ne[0], ne[1]); }
|
||||||
ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale);
|
ggml_tensor * out = ggml_soft_max_ext(ctx, a, b, scale);
|
||||||
return out;
|
return out;
|
||||||
}
|
}
|
||||||
|
@ -1450,6 +1460,76 @@ struct test_leaky_relu : public test_case {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// GGML_OP_FLASH_ATTN_EXT
|
||||||
|
struct test_flash_attn_ext : public test_case {
|
||||||
|
const int64_t hs; // head size
|
||||||
|
const int64_t nh; // num heads
|
||||||
|
const int64_t kv; // kv size
|
||||||
|
const int64_t nb; // batch size
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR4(hs, nh, kv, nb);
|
||||||
|
}
|
||||||
|
|
||||||
|
double max_nmse_err() override {
|
||||||
|
return 5e-4;
|
||||||
|
}
|
||||||
|
|
||||||
|
test_flash_attn_ext(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
|
||||||
|
: hs(hs), nh(nh), kv(kv), nb(nb) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
|
||||||
|
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||||
|
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||||
|
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, GGML_PAD(nb, GGML_KQ_MASK_PAD), 1, 1);
|
||||||
|
ggml_tensor * out = ggml_flash_attn_ext(ctx, q, k, v, mask, 1.0f/sqrtf(hs));
|
||||||
|
return out;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Attention
|
||||||
|
struct test_attn : public test_case {
|
||||||
|
const int64_t hs; // head size
|
||||||
|
const int64_t nh; // num heads
|
||||||
|
const int64_t kv; // kv size
|
||||||
|
const int64_t nb; // batch size
|
||||||
|
|
||||||
|
std::string op_desc(ggml_tensor * t) override {
|
||||||
|
return "ATTN";
|
||||||
|
|
||||||
|
GGML_UNUSED(t);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::string vars() override {
|
||||||
|
return VARS_TO_STR4(hs, nh, kv, nb);
|
||||||
|
}
|
||||||
|
|
||||||
|
double max_nmse_err() override {
|
||||||
|
return 5e-4;
|
||||||
|
}
|
||||||
|
|
||||||
|
test_attn(int64_t hs = 128, int64_t nh = 32, int64_t kv = 96, int64_t nb = 8)
|
||||||
|
: hs(hs), nh(nh), kv(kv), nb(nb) {}
|
||||||
|
|
||||||
|
ggml_tensor * build_graph(ggml_context * ctx) override {
|
||||||
|
ggml_tensor * q = ggml_new_tensor_4d(ctx, GGML_TYPE_F32, hs, nb, nh, 1);
|
||||||
|
ggml_tensor * k = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, hs, kv, nh, 1);
|
||||||
|
ggml_tensor * v = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, hs, nh, 1); // transposed
|
||||||
|
ggml_tensor * mask = ggml_new_tensor_4d(ctx, GGML_TYPE_F16, kv, nb, 1, 1);
|
||||||
|
|
||||||
|
struct ggml_tensor * cur;
|
||||||
|
|
||||||
|
cur = ggml_mul_mat (ctx, k, q);
|
||||||
|
cur = ggml_soft_max_ext(ctx, cur, mask, 1.0f/sqrtf(hs));
|
||||||
|
cur = ggml_mul_mat (ctx, v, cur);
|
||||||
|
cur = ggml_permute (ctx, cur, 0, 2, 1, 3);
|
||||||
|
cur = ggml_cont_2d (ctx, cur, hs*nh, nb);
|
||||||
|
|
||||||
|
return cur;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
// Mixtral MOE
|
// Mixtral MOE
|
||||||
struct test_moe : public test_case {
|
struct test_moe : public test_case {
|
||||||
const int n_experts;
|
const int n_experts;
|
||||||
|
@ -1723,7 +1803,7 @@ struct test_llama : public test_llm {
|
||||||
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
|
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
|
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);
|
||||||
|
|
||||||
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||||
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||||
|
@ -1845,7 +1925,7 @@ struct test_falcon : public test_llm {
|
||||||
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
|
struct ggml_tensor * inp_pos = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, hp.n_tokens);
|
||||||
|
|
||||||
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
// KQ_mask (mask for 1 head, it will be broadcasted to all heads)
|
||||||
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F32, hp.n_kv, hp.n_tokens, 1);
|
struct ggml_tensor * KQ_mask = ggml_new_tensor_3d(ctx, GGML_TYPE_F16, hp.n_kv, hp.n_tokens, 1);
|
||||||
|
|
||||||
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
ggml_tensor * k_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||||
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
ggml_tensor * v_l = ggml_new_tensor_1d(ctx, GGML_TYPE_F16, 1638400);
|
||||||
|
@ -2129,6 +2209,30 @@ static bool test_backend(ggml_backend_t backend, test_mode mode, const char * op
|
||||||
test_cases.emplace_back(new test_pad());
|
test_cases.emplace_back(new test_pad());
|
||||||
test_cases.emplace_back(new test_leaky_relu());
|
test_cases.emplace_back(new test_leaky_relu());
|
||||||
|
|
||||||
|
#if 1
|
||||||
|
for (int hs : { 128, 64, 80, }) {
|
||||||
|
for (int nh : { 32, }) {
|
||||||
|
for (int kv : { 512, 1024, 2048, 4096, }) {
|
||||||
|
for (int nb : { 1, 2, 4, 8, 512, 1024, 2048, }) {
|
||||||
|
test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
for (int hs : { 128, }) {
|
||||||
|
for (int nh : { 32, }) {
|
||||||
|
for (int kv : { 512, 1024, }) {
|
||||||
|
for (int nb : { 1, 2, 4, 8, 512 }) {
|
||||||
|
test_cases.emplace_back(new test_attn (hs, nh, kv, nb));
|
||||||
|
test_cases.emplace_back(new test_flash_attn_ext(hs, nh, kv, nb));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#endif
|
||||||
|
|
||||||
#if !defined(__SANITIZE_THREAD__)
|
#if !defined(__SANITIZE_THREAD__)
|
||||||
// FIXME: these tests use too much memory with thread sanitizer
|
// FIXME: these tests use too much memory with thread sanitizer
|
||||||
test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024));
|
test_cases.emplace_back(new test_moe(8, 2, 1, 4096, 8*1024));
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue