ggml : full ALiBi support (#7192)
* ggml : full ALiBi support * ggml : update ggml_soft_max_ext() CUDA, SYCL * ggml : ggml_flash_attn_ext() support ALiBi (CPU) * ggml : ggml_flash_attn_ext() support ALiBi (Metal) * ggml : fix warning * ggml : ggml_flash_attn_ext() support ALiBi (CUDA) ggml-ci * ggml : fix assert message * vulkan : add dev notes * ggml : require mask when using ALiBi ggml-ci * convert : fix convert for refact models
This commit is contained in:
		
							parent
							
								
									e849648888
								
							
						
					
					
						commit
						9cb317f77e
					
				
					 16 changed files with 350 additions and 825 deletions
				
			
		|  | @ -23,6 +23,10 @@ static __global__ void flash_attn_vec_ext_f16( | |||
|         float      * __restrict__ dst, | ||||
|         float2     * __restrict__ dst_meta, | ||||
|         const float scale, | ||||
|         const float max_bias, | ||||
|         const float m0, | ||||
|         const float m1, | ||||
|         const uint32_t n_head_log2, | ||||
|         const int ne00, | ||||
|         const int ne01, | ||||
|         const int ne02, | ||||
|  | @ -58,6 +62,18 @@ static __global__ void flash_attn_vec_ext_f16( | |||
|     const int stride_KV  = nb11 / sizeof(half); | ||||
|     const int stride_KV2 = nb11 / sizeof(half2); | ||||
| 
 | ||||
|     half slopeh = __float2half(1.0f); | ||||
| 
 | ||||
|     // ALiBi | ||||
|     if (max_bias > 0.0f) { | ||||
|         const int h = blockIdx.y; | ||||
| 
 | ||||
|         const float base = h < n_head_log2 ? m0 : m1; | ||||
|         const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; | ||||
| 
 | ||||
|         slopeh = __float2half(powf(base, exph)); | ||||
|     } | ||||
| 
 | ||||
|     static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64."); | ||||
|     constexpr int nwarps = D / WARP_SIZE; | ||||
|     const int tid = WARP_SIZE*threadIdx.y + threadIdx.x; | ||||
|  | @ -141,7 +157,7 @@ static __global__ void flash_attn_vec_ext_f16( | |||
|             for (int j = 0; j < ncols; ++j) { | ||||
|                 sum2[j] = warp_reduce_sum(sum2[j]); | ||||
|                 half sum = __low2half(sum2[j]) + __high2half(sum2[j]); | ||||
|                 sum += mask ? maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); | ||||
|                 sum += mask ? slopeh*maskh[j*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f); | ||||
| 
 | ||||
|                 if (ncols == 1) { | ||||
|                     kqmax_new        = ggml_cuda_hmax(kqmax_new,        sum); | ||||
|  | @ -249,6 +265,10 @@ static __global__ void flash_attn_ext_f16( | |||
|         float      * __restrict__ dst, | ||||
|         float2     * __restrict__ dst_meta, | ||||
|         const float scale, | ||||
|         const float max_bias, | ||||
|         const float m0, | ||||
|         const float m1, | ||||
|         const uint32_t n_head_log2, | ||||
|         const int ne00, | ||||
|         const int ne01, | ||||
|         const int ne02, | ||||
|  | @ -305,6 +325,20 @@ static __global__ void flash_attn_ext_f16( | |||
|     const int stride_Q  = nb01 / sizeof(float); | ||||
|     const int stride_KV = nb11 / sizeof(half); | ||||
| 
 | ||||
|     half  slopeh = __float2half(1.0f); | ||||
|     half2 slope2 = make_half2(1.0f, 1.0f); | ||||
| 
 | ||||
|     // ALiBi | ||||
|     if (max_bias > 0.0f) { | ||||
|         const int h = blockIdx.y; | ||||
| 
 | ||||
|         const float base = h < n_head_log2 ? m0 : m1; | ||||
|         const int   exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; | ||||
| 
 | ||||
|         slopeh = __float2half(powf(base, exph)); | ||||
|         slope2 = make_half2(slopeh, slopeh); | ||||
|     } | ||||
| 
 | ||||
|     frag_b Q_b[D/16][ncols/frag_n]; | ||||
| 
 | ||||
|     // A single buffer for temporarily holding tiles of KQ and VKQ parts: | ||||
|  | @ -421,7 +455,7 @@ static __global__ void flash_attn_ext_f16( | |||
|                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE; k0 += WARP_SIZE) { | ||||
|                     const int k = k0 + threadIdx.x; | ||||
| 
 | ||||
|                     KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; | ||||
|                     KQ_f_tmp[k0/WARP_SIZE] += mask ? __half2float(slopeh*maskh[j*(nb31/sizeof(half)) + k_VKQ_0 + k]) : 0.0f; | ||||
|                     KQ_max_new = max(KQ_max_new, KQ_f_tmp[k0/WARP_SIZE]); | ||||
|                 } | ||||
|                 KQ_max_new = warp_reduce_max(KQ_max_new); | ||||
|  | @ -464,7 +498,7 @@ static __global__ void flash_attn_ext_f16( | |||
|                 for (int k0 = 0; k0 < FATTN_KQ_STRIDE/2; k0 += WARP_SIZE) { | ||||
|                     const int k = k0 + threadIdx.x; | ||||
| 
 | ||||
|                     KQ2_tmp[k0/WARP_SIZE] += mask ? mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); | ||||
|                     KQ2_tmp[k0/WARP_SIZE] += mask ? slope2*mask2[(j*ne11 + k_VKQ_0)/2 + k] : make_half2(0.0f, 0.0f); | ||||
|                     KQ_max_new = ggml_cuda_hmax2(KQ_max_new, KQ2_tmp[k0/WARP_SIZE]); | ||||
|                 } | ||||
|                 KQ_max_new = __half2half2(warp_reduce_max(ggml_cuda_hmax(__low2half(KQ_max_new), __high2half(KQ_max_new)))); | ||||
|  | @ -710,8 +744,17 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_ | |||
|     const     dim3 blocks_num(parallel_blocks*((Q->ne[1] + cols_per_block - 1) / cols_per_block), Q->ne[2], Q->ne[3]); | ||||
|     const     int  shmem = 0; | ||||
| 
 | ||||
|     float scale; | ||||
|     memcpy(&scale, KQV->op_params, sizeof(float)); | ||||
|     float scale    = 1.0f; | ||||
|     float max_bias = 0.0f; | ||||
| 
 | ||||
|     memcpy(&scale,    (float *) KQV->op_params + 0, sizeof(float)); | ||||
|     memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); | ||||
| 
 | ||||
|     const uint32_t n_head      = Q->ne[2]; | ||||
|     const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); | ||||
| 
 | ||||
|     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2); | ||||
|     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); | ||||
| 
 | ||||
|     flash_attn_vec_ext_f16<D, cols_per_block, parallel_blocks> | ||||
|         <<<blocks_num, block_dim, shmem, main_stream>>> ( | ||||
|  | @ -720,7 +763,7 @@ template <int D, int cols_per_block, int parallel_blocks> void launch_fattn_vec_ | |||
|                 (const char *) V->data, | ||||
|                 mask ? ((const char *) mask->data) : nullptr, | ||||
|                 parallel_blocks == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, | ||||
|                 scale, | ||||
|                 scale, max_bias, m0, m1, n_head_log2, | ||||
|                 Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], | ||||
|                 K->ne[0], K->ne[1], K->ne[2], K->ne[3], | ||||
|                 mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0, | ||||
|  | @ -761,8 +804,17 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K | |||
|     const     dim3 blocks_num(parallel_blocks*(Q->ne[1] + cols_per_block - 1) / cols_per_block, Q->ne[2], Q->ne[3]); | ||||
|     const     int  shmem = 0; | ||||
| 
 | ||||
|     float scale; | ||||
|     memcpy(&scale, KQV->op_params, sizeof(float)); | ||||
|     float scale    = 1.0f; | ||||
|     float max_bias = 0.0f; | ||||
| 
 | ||||
|     memcpy(&scale,    (float *) KQV->op_params + 0, sizeof(float)); | ||||
|     memcpy(&max_bias, (float *) KQV->op_params + 1, sizeof(float)); | ||||
| 
 | ||||
|     const uint32_t n_head      = Q->ne[2]; | ||||
|     const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head)); | ||||
| 
 | ||||
|     const float m0 = powf(2.0f, -(max_bias       ) / n_head_log2); | ||||
|     const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); | ||||
| 
 | ||||
|     flash_attn_ext_f16<D, cols_per_block, nwarps, get_VKQ_stride(D, nwarps, frag_m), parallel_blocks, KQ_acc_t> | ||||
|         <<<blocks_num, block_dim, shmem, main_stream>>> ( | ||||
|  | @ -771,7 +823,7 @@ template <int D, int cols_per_block, int nwarps, int parallel_blocks, typename K | |||
|                 (const char *) V->data, | ||||
|                 mask ? ((const char *) mask->data) : nullptr, | ||||
|                 (parallel_blocks) == 1 ? (float *) KQV->data : dst_tmp.ptr, dst_tmp_meta.ptr, | ||||
|                 scale, | ||||
|                 scale, max_bias, m0, m1, n_head_log2, | ||||
|                 Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], | ||||
|                 K->ne[0], K->ne[1], K->ne[2], K->ne[3], | ||||
|                 mask ? mask->ne[1] : 0, mask ?  mask->nb[1] : 0, | ||||
|  | @ -837,7 +889,7 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst | |||
|     const int cc  = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; | ||||
|     const int nsm = ggml_cuda_info().devices[ggml_cuda_get_device()].nsm; | ||||
| 
 | ||||
|     const int32_t precision = KQV->op_params[1]; | ||||
|     const int32_t precision = KQV->op_params[2]; | ||||
| 
 | ||||
|     if (!fp16_mma_available(cc)) { | ||||
|         GGML_ASSERT(precision == GGML_PREC_DEFAULT); | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue