Add basic bf16 support to ggml-cuda
This commit is contained in:
parent
152da28ae5
commit
ebd5efeedf
3 changed files with 27 additions and 2 deletions
|
@ -25,10 +25,12 @@
|
||||||
#include <hip/hip_runtime.h>
|
#include <hip/hip_runtime.h>
|
||||||
#include <hipblas/hipblas.h>
|
#include <hipblas/hipblas.h>
|
||||||
#include <hip/hip_fp16.h>
|
#include <hip/hip_fp16.h>
|
||||||
|
#include <hip/hip_bfloat16.h>
|
||||||
#ifdef __HIP_PLATFORM_AMD__
|
#ifdef __HIP_PLATFORM_AMD__
|
||||||
// for rocblas_initialize()
|
// for rocblas_initialize()
|
||||||
#include "rocblas/rocblas.h"
|
#include "rocblas/rocblas.h"
|
||||||
#endif // __HIP_PLATFORM_AMD__
|
#endif // __HIP_PLATFORM_AMD__
|
||||||
|
#define __nv_bfloat16 hip_bfloat16
|
||||||
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
#define CUBLAS_COMPUTE_16F HIPBLAS_R_16F
|
||||||
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
|
||||||
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
|
||||||
|
@ -38,8 +40,8 @@
|
||||||
#define CUBLAS_OP_T HIPBLAS_OP_T
|
#define CUBLAS_OP_T HIPBLAS_OP_T
|
||||||
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
|
||||||
#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
#define CUBLAS_TF32_TENSOR_OP_MATH 0
|
||||||
#define CUDA_R_16F HIPBLAS_R_16F
|
#define CUDA_R_16F HIPBLAS_R_16F
|
||||||
#define CUDA_R_32F HIPBLAS_R_32F
|
#define CUDA_R_32F HIPBLAS_R_32F
|
||||||
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
|
||||||
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
#define cublasComputeType_t hipblasDatatype_t //deprecated, new hipblasComputeType_t not in 5.6
|
||||||
#define cublasCreate hipblasCreate
|
#define cublasCreate hipblasCreate
|
||||||
|
@ -123,6 +125,7 @@
|
||||||
#include <cuda.h>
|
#include <cuda.h>
|
||||||
#include <cublas_v2.h>
|
#include <cublas_v2.h>
|
||||||
#include <cuda_fp16.h>
|
#include <cuda_fp16.h>
|
||||||
|
#include <cuda_bf16.h>
|
||||||
|
|
||||||
#if CUDART_VERSION < 11020
|
#if CUDART_VERSION < 11020
|
||||||
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
#define CU_DEVICE_ATTRIBUTE_VIRTUAL_MEMORY_MANAGEMENT_SUPPORTED CU_DEVICE_ATTRIBUTE_VIRTUAL_ADDRESS_MANAGEMENT_SUPPORTED
|
||||||
|
|
|
@ -680,6 +680,8 @@ to_fp32_cuda_t ggml_get_to_fp32_cuda(ggml_type type) {
|
||||||
return dequantize_row_iq3_s_cuda;
|
return dequantize_row_iq3_s_cuda;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
return convert_unary_cuda<half>;
|
return convert_unary_cuda<half>;
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
|
return convert_unary_cuda<__nv_bfloat16>;
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
|
@ -422,6 +422,14 @@ static __device__ void convert_f16(const void * vx, const int64_t ib, const int
|
||||||
v.y = x[ib + iqs + 1];
|
v.y = x[ib + iqs + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static __device__ void convert_bf16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){
|
||||||
|
const __nv_bfloat16 * x = (const __nv_bfloat16 *) vx;
|
||||||
|
|
||||||
|
// automatic __nv_bfloat16 -> float type cast if dfloat == float
|
||||||
|
v.x = x[ib + iqs + 0];
|
||||||
|
v.y = x[ib + iqs + 1];
|
||||||
|
}
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) {
|
||||||
// qk = quantized weights per x block
|
// qk = quantized weights per x block
|
||||||
|
@ -584,6 +592,15 @@ static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, floa
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static void convert_mul_mat_vec_bf16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
|
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
|
||||||
|
const dim3 block_nums(block_num_y, 1, 1);
|
||||||
|
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
|
||||||
|
dequantize_mul_mat_vec<1, 1, convert_bf16>
|
||||||
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
|
}
|
||||||
|
|
||||||
void ggml_cuda_op_dequantize_mul_mat_vec(
|
void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
ggml_backend_cuda_context & ctx,
|
ggml_backend_cuda_context & ctx,
|
||||||
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
|
||||||
|
@ -649,6 +666,9 @@ void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||||
break;
|
break;
|
||||||
|
case GGML_TYPE_BF16:
|
||||||
|
convert_mul_mat_vec_bf16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream);
|
||||||
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
break;
|
break;
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue