diff --git a/ggml-cuda/common.cuh b/ggml-cuda/common.cuh index f3447c477..44e67e040 100644 --- a/ggml-cuda/common.cuh +++ b/ggml-cuda/common.cuh @@ -430,7 +430,7 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal GGML_UNUSED(a); GGML_UNUSED(b); NO_DEVICE_CODE; -#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX +#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) } static __device__ __forceinline__ half2 warp_reduce_max(half2 x) { diff --git a/ggml-cuda/fattn.cu b/ggml-cuda/fattn.cu index ca760b0a1..455510684 100644 --- a/ggml-cuda/fattn.cu +++ b/ggml-cuda/fattn.cu @@ -61,6 +61,7 @@ static __global__ void flash_attn_vec_ext_f16( static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); constexpr int nwarps = D / WARP_SIZE; const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; + __builtin_assume(tid < D); __shared__ half KQ[ncols*D]; #pragma unroll @@ -106,7 +107,10 @@ static __global__ void flash_attn_vec_ext_f16( for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) { // Calculate KQ tile and keep track of new maximum KQ values: half kqmax_new[ncols]; - memcpy(kqmax_new, kqmax, sizeof(kqmax)); +#pragma unroll + for (int j = 0; j < ncols; ++j) { + kqmax_new[j] = kqmax[j]; + } #pragma unroll for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) { @@ -123,7 +127,7 @@ static __global__ void flash_attn_vec_ext_f16( const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ]; #pragma unroll - for (int j = 0; j < ncols; ++j) { + for (int j = 0; j < ncols; ++j) { sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE]; } }