partially revert changes
This commit is contained in:
parent
d9bcb92f75
commit
fa81c3a22c
2 changed files with 7 additions and 3 deletions
|
@ -430,7 +430,7 @@ static __device__ __forceinline__ half2 ggml_cuda_hmax2(const half2 a, const hal
|
||||||
GGML_UNUSED(a);
|
GGML_UNUSED(a);
|
||||||
GGML_UNUSED(b);
|
GGML_UNUSED(b);
|
||||||
NO_DEVICE_CODE;
|
NO_DEVICE_CODE;
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && CUDART_VERSION < CUDART_HMAX
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__))
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
static __device__ __forceinline__ half2 warp_reduce_max(half2 x) {
|
||||||
|
|
|
@ -61,6 +61,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
|
||||||
constexpr int nwarps = D / WARP_SIZE;
|
constexpr int nwarps = D / WARP_SIZE;
|
||||||
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
const int tid = WARP_SIZE*threadIdx.y + threadIdx.x;
|
||||||
|
__builtin_assume(tid < D);
|
||||||
|
|
||||||
__shared__ half KQ[ncols*D];
|
__shared__ half KQ[ncols*D];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
|
@ -106,7 +107,10 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
for (int k_VKQ_0 = k_start; k_VKQ_0 < ne11; k_VKQ_0 += parallel_blocks*D) {
|
||||||
// Calculate KQ tile and keep track of new maximum KQ values:
|
// Calculate KQ tile and keep track of new maximum KQ values:
|
||||||
half kqmax_new[ncols];
|
half kqmax_new[ncols];
|
||||||
memcpy(kqmax_new, kqmax, sizeof(kqmax));
|
#pragma unroll
|
||||||
|
for (int j = 0; j < ncols; ++j) {
|
||||||
|
kqmax_new[j] = kqmax[j];
|
||||||
|
}
|
||||||
|
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
for (int i_KQ_0 = 0; i_KQ_0 < D; i_KQ_0 += nwarps) {
|
||||||
|
@ -123,7 +127,7 @@ static __global__ void flash_attn_vec_ext_f16(
|
||||||
|
|
||||||
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
const half2 K_ik = K_h2[(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
|
||||||
#pragma unroll
|
#pragma unroll
|
||||||
for (int j = 0; j < ncols; ++j) {
|
for (int j = 0; j < ncols; ++j) {
|
||||||
sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE];
|
sum2[j] += K_ik * Q_h2[j][k_KQ_0/WARP_SIZE];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue