fix AMD
This commit is contained in:
parent
97f8a7a2fa
commit
2bb97fca5e
1 changed files with 5 additions and 5 deletions
10
ggml-cuda.cu
10
ggml-cuda.cu
|
@ -5319,13 +5319,13 @@ static __global__ void mul_mat_vec_q(
|
||||||
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
const void * __restrict__ vx, const void * __restrict__ vy, float * __restrict__ dst,
|
||||||
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
const int ncols_x, const int nrows_x, const int nrows_y, const int nrows_dst) {
|
||||||
|
|
||||||
#if __CUDA_ARCH__ < CC_RDNA2
|
#if defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && (defined(RDNA2) || defined(RDNA3))
|
||||||
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
|
||||||
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
|
||||||
#else
|
|
||||||
constexpr int nwarps = 1;
|
constexpr int nwarps = 1;
|
||||||
constexpr int rows_per_cuda_block = 1;
|
constexpr int rows_per_cuda_block = 1;
|
||||||
#endif // __CUDA_ARCH__ < CC_RDNA2
|
#else
|
||||||
|
constexpr int nwarps = ncols_y <= 4 ? 4 : 2;
|
||||||
|
constexpr int rows_per_cuda_block = ncols_y == 1 ? 1 : 2;
|
||||||
|
#endif // defined(GGML_USE_HIPBLAS) && defined(__HIP_PLATFORM_AMD__) && !defined(RDNA2) && !defined(RDNA3)
|
||||||
|
|
||||||
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
|
constexpr int blocks_per_iter = vdr * nwarps*WARP_SIZE / qi;
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue