try CI fix
This commit is contained in:
parent
d962a56baa
commit
87099452ed
2 changed files with 12 additions and 1 deletions
|
@ -331,6 +331,10 @@ static __device__ __forceinline__ half2 __shfl_xor(half2 var, int laneMask, int
|
||||||
#define FP16_AVAILABLE
|
#define FP16_AVAILABLE
|
||||||
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
#endif // (defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) || __CUDA_ARCH__ >= CC_PASCAL
|
||||||
|
|
||||||
|
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
|
||||||
|
#define FAST_FP16_AVAILABLE
|
||||||
|
#endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
|
||||||
|
|
||||||
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
#if !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||||
#define FP16_MMA_AVAILABLE
|
#define FP16_MMA_AVAILABLE
|
||||||
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
#endif // !(defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__)) && __CUDA_ARCH__ >= CC_VOLTA
|
||||||
|
|
|
@ -839,7 +839,14 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
|
||||||
}
|
}
|
||||||
|
|
||||||
const int sc_m = bxi->scales[kqsx];
|
const int sc_m = bxi->scales[kqsx];
|
||||||
x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = bxi->dm * make_half2(sc_m & 0x0F, sc_m >> 4);
|
#ifdef FAST_FP16_AVAILABLE
|
||||||
|
const half2 x_dm_ik = __hmul2(bxi->dm, make_half2(sc_m & 0x0F, sc_m >> 4));
|
||||||
|
#else
|
||||||
|
const float2 bxi_dmf = __half22float2(bxi->dm);
|
||||||
|
const half2 x_dm_ik = make_half2(bxi_dmf.x*(sc_m & 0x0F), bxi_dmf.y*(sc_m >> 4));
|
||||||
|
#endif // FAST_FP16_AVAILABLE
|
||||||
|
|
||||||
|
x_dm[i*(WARP_SIZE + 1) + threadIdx.x] = x_dm_ik;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue