Merge branch 'master' into concedo_experimental
# Conflicts: # CMakeLists.txt # Makefile # README.md
This commit is contained in:
commit
cf94340dfc
4 changed files with 148 additions and 77 deletions
4
Makefile
4
Makefile
|
@ -200,18 +200,14 @@ ifeq ($(OS),Windows_NT)
|
||||||
CLBLAST_NOAVX2_BUILD = $(CXX) $(CXXFLAGS) $^ lib/OpenCL.lib lib/clblast.lib -shared -o $@.dll $(LDFLAGS)
|
CLBLAST_NOAVX2_BUILD = $(CXX) $(CXXFLAGS) $^ lib/OpenCL.lib lib/clblast.lib -shared -o $@.dll $(LDFLAGS)
|
||||||
else
|
else
|
||||||
DEFAULT_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o $@.so $(LDFLAGS)
|
DEFAULT_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o $@.so $(LDFLAGS)
|
||||||
FAILSAFE_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o $@.so $(LDFLAGS)
|
|
||||||
ifdef LLAMA_OPENBLAS
|
ifdef LLAMA_OPENBLAS
|
||||||
OPENBLAS_BUILD = $(CXX) $(CXXFLAGS) $^ $(ARCH_ADD) -lopenblas -shared -o $@.so $(LDFLAGS)
|
OPENBLAS_BUILD = $(CXX) $(CXXFLAGS) $^ $(ARCH_ADD) -lopenblas -shared -o $@.so $(LDFLAGS)
|
||||||
OPENBLAS_NOAVX2_BUILD = $(CXX) $(CXXFLAGS) $^ $(ARCH_ADD) -lopenblas -shared -o $@.so $(LDFLAGS)
|
|
||||||
endif
|
endif
|
||||||
ifdef LLAMA_CLBLAST
|
ifdef LLAMA_CLBLAST
|
||||||
ifeq ($(UNAME_S),Darwin)
|
ifeq ($(UNAME_S),Darwin)
|
||||||
CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ -lclblast -framework OpenCL $(ARCH_ADD) -lopenblas -shared -o $@.so $(LDFLAGS)
|
CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ -lclblast -framework OpenCL $(ARCH_ADD) -lopenblas -shared -o $@.so $(LDFLAGS)
|
||||||
CLBLAST_NOAVX2_BUILD = $(CXX) $(CXXFLAGS) $^ -lclblast -framework OpenCL $(ARCH_ADD) -lopenblas -shared -o $@.so $(LDFLAGS)
|
|
||||||
else
|
else
|
||||||
CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ -lclblast -lOpenCL $(ARCH_ADD) -lopenblas -shared -o $@.so $(LDFLAGS)
|
CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ -lclblast -lOpenCL $(ARCH_ADD) -lopenblas -shared -o $@.so $(LDFLAGS)
|
||||||
CLBLAST_NOAVX2_BUILD = $(CXX) $(CXXFLAGS) $^ -lclblast -lOpenCL $(ARCH_ADD) -lopenblas -shared -o $@.so $(LDFLAGS)
|
|
||||||
endif
|
endif
|
||||||
endif
|
endif
|
||||||
|
|
||||||
|
|
202
ggml-cuda.cu
202
ggml-cuda.cu
|
@ -50,7 +50,15 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
|
||||||
} while (0)
|
} while (0)
|
||||||
#endif // CUDART_VERSION >= 11
|
#endif // CUDART_VERSION >= 11
|
||||||
|
|
||||||
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1);
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
|
typedef half dfloat; // dequantize float
|
||||||
|
typedef half2 dfloat2;
|
||||||
|
#else
|
||||||
|
typedef float dfloat; // dequantize float
|
||||||
|
typedef float2 dfloat2;
|
||||||
|
#endif //GGML_CUDA_DMMV_F16
|
||||||
|
|
||||||
|
typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, dfloat2 & v);
|
||||||
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream);
|
||||||
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
|
typedef void (*dot_kernel_k_t)(const void * vx, const int ib, const int iqs, const float * y, float & v);
|
||||||
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
typedef void (*cpy_kernel_t)(const char * cx, char * cdst);
|
||||||
|
@ -234,82 +242,106 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
const block_q4_0 * x = (const block_q4_0 *) vx;
|
const block_q4_0 * x = (const block_q4_0 *) vx;
|
||||||
|
|
||||||
const float d = x[ib].d;
|
const dfloat d = x[ib].d;
|
||||||
|
|
||||||
const uint8_t vui = x[ib].qs[iqs];
|
const int vui = x[ib].qs[iqs];
|
||||||
|
|
||||||
const int8_t vi0 = vui & 0xF;
|
v.x = vui & 0xF;
|
||||||
const int8_t vi1 = vui >> 4;
|
v.y = vui >> 4;
|
||||||
|
|
||||||
v0 = (vi0 - 8)*d;
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
v1 = (vi1 - 8)*d;
|
v = __hsub2(v, {8.0f, 8.0f});
|
||||||
|
v = __hmul2(v, {d, d});
|
||||||
|
#else
|
||||||
|
v.x = (v.x - 8.0f) * d;
|
||||||
|
v.y = (v.y - 8.0f) * d;
|
||||||
|
#endif // GGML_CUDA_DMMV_F16
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
const block_q4_1 * x = (const block_q4_1 *) vx;
|
const block_q4_1 * x = (const block_q4_1 *) vx;
|
||||||
|
|
||||||
const float d = x[ib].d;
|
const dfloat d = x[ib].d;
|
||||||
const float m = x[ib].m;
|
const dfloat m = x[ib].m;
|
||||||
|
|
||||||
const uint8_t vui = x[ib].qs[iqs];
|
const int vui = x[ib].qs[iqs];
|
||||||
|
|
||||||
const int8_t vi0 = vui & 0xF;
|
v.x = vui & 0xF;
|
||||||
const int8_t vi1 = vui >> 4;
|
v.y = vui >> 4;
|
||||||
|
|
||||||
v0 = vi0*d + m;
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
v1 = vi1*d + m;
|
v = __hmul2(v, {d, d});
|
||||||
|
v = __hadd2(v, {m, m});
|
||||||
|
#else
|
||||||
|
v.x = (v.x * d) + m;
|
||||||
|
v.y = (v.y * d) + m;
|
||||||
|
#endif // GGML_CUDA_DMMV_F16
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
const block_q5_0 * x = (const block_q5_0 *) vx;
|
const block_q5_0 * x = (const block_q5_0 *) vx;
|
||||||
|
|
||||||
const float d = x[ib].d;
|
const dfloat d = x[ib].d;
|
||||||
|
|
||||||
uint32_t qh;
|
uint32_t qh;
|
||||||
memcpy(&qh, x[ib].qh, sizeof(qh));
|
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||||
|
|
||||||
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
||||||
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
||||||
|
|
||||||
const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16;
|
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
|
||||||
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16;
|
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
|
||||||
|
|
||||||
v0 = x0*d;
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
v1 = x1*d;
|
v = __hsub2(v, {16.0f, 16.0f});
|
||||||
|
v = __hmul2(v, {d, d});
|
||||||
|
#else
|
||||||
|
v.x = (v.x - 16.0f) * d;
|
||||||
|
v.y = (v.y - 16.0f) * d;
|
||||||
|
#endif // GGML_CUDA_DMMV_F16
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
const block_q5_1 * x = (const block_q5_1 *) vx;
|
const block_q5_1 * x = (const block_q5_1 *) vx;
|
||||||
|
|
||||||
const float d = x[ib].d;
|
const dfloat d = x[ib].d;
|
||||||
const float m = x[ib].m;
|
const dfloat m = x[ib].m;
|
||||||
|
|
||||||
uint32_t qh;
|
uint32_t qh;
|
||||||
memcpy(&qh, x[ib].qh, sizeof(qh));
|
memcpy(&qh, x[ib].qh, sizeof(qh));
|
||||||
|
|
||||||
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
const int xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
|
||||||
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
const int xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
|
||||||
|
|
||||||
const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0);
|
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
|
||||||
const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1);
|
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
|
||||||
|
|
||||||
v0 = x0*d + m;
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
v1 = x1*d + m;
|
v = __hmul2(v, {d, d});
|
||||||
|
v = __hadd2(v, {m, m});
|
||||||
|
#else
|
||||||
|
v.x = (v.x * d) + m;
|
||||||
|
v.y = (v.y * d) + m;
|
||||||
|
#endif // GGML_CUDA_DMMV_F16
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
static __device__ __forceinline__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
const block_q8_0 * x = (const block_q8_0 *) vx;
|
const block_q8_0 * x = (const block_q8_0 *) vx;
|
||||||
|
|
||||||
const float d = x[ib].d;
|
const dfloat d = x[ib].d;
|
||||||
|
|
||||||
const int8_t vi0 = x[ib].qs[iqs + 0];
|
v.x = x[ib].qs[iqs + 0];
|
||||||
const int8_t vi1 = x[ib].qs[iqs + 1];
|
v.y = x[ib].qs[iqs + 1];
|
||||||
|
|
||||||
v0 = vi0*d;
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
v1 = vi1*d;
|
v = __hmul2(v, {d, d});
|
||||||
|
#else
|
||||||
|
v.x *= d;
|
||||||
|
v.y *= d;
|
||||||
|
#endif // GGML_CUDA_DMMV_F16
|
||||||
}
|
}
|
||||||
|
|
||||||
//================================== k-quants
|
//================================== k-quants
|
||||||
|
@ -843,11 +875,12 @@ static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
|
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, dfloat2 & v){
|
||||||
const half * x = (const half *) vx;
|
const half * x = (const half *) vx;
|
||||||
|
|
||||||
v0 = __half2float(x[ib + iqs + 0]);
|
// automatic half -> float type cast if dfloat == float
|
||||||
v1 = __half2float(x[ib + iqs + 1]);
|
v.x = x[ib + iqs + 0];
|
||||||
|
v.y = x[ib + iqs + 1];
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
|
@ -864,13 +897,15 @@ static __global__ void dequantize_block(const void * vx, float * y, const int k)
|
||||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||||
|
|
||||||
// dequantize
|
// dequantize
|
||||||
float & v0 = y[iybs + iqs + 0];
|
dfloat2 v;
|
||||||
float & v1 = y[iybs + iqs + y_offset];
|
dequantize_kernel(vx, ib, iqs, v);
|
||||||
dequantize_kernel(vx, ib, iqs, v0, v1);
|
|
||||||
|
y[iybs + iqs + 0] = v.x;
|
||||||
|
y[iybs + iqs + y_offset] = v.y;
|
||||||
}
|
}
|
||||||
|
|
||||||
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
|
||||||
static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
|
static __global__ void dequantize_mul_mat_vec(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows) {
|
||||||
// qk = quantized weights per x block
|
// qk = quantized weights per x block
|
||||||
// qr = number of quantized weights per data value in x block
|
// qr = number of quantized weights per data value in x block
|
||||||
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
const int row = blockIdx.y*blockDim.y + threadIdx.y;
|
||||||
|
@ -885,7 +920,12 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
||||||
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
|
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
|
||||||
const int y_offset = qr == 1 ? 1 : qk/2;
|
const int y_offset = qr == 1 ? 1 : qk/2;
|
||||||
|
|
||||||
float tmp = 0.0f; // partial sum for thread in warp
|
// partial sum for each thread
|
||||||
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
|
half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics
|
||||||
|
#else
|
||||||
|
float tmp = 0.0f;
|
||||||
|
#endif // GGML_CUDA_DMMV_F16
|
||||||
|
|
||||||
for (int i = 0; i < ncols; i += iter_stride) {
|
for (int i = 0; i < ncols; i += iter_stride) {
|
||||||
const int col = i + vals_per_iter*tid;
|
const int col = i + vals_per_iter*tid;
|
||||||
|
@ -899,14 +939,21 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
||||||
// process 2 vals per j iter
|
// process 2 vals per j iter
|
||||||
|
|
||||||
// dequantize
|
// dequantize
|
||||||
float v0, v1;
|
|
||||||
dequantize_kernel(vx, ib, iqs + j/qr, v0, v1);
|
|
||||||
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
|
||||||
|
dfloat2 v;
|
||||||
|
dequantize_kernel(vx, ib, iqs + j/qr, v);
|
||||||
|
|
||||||
// matrix multiplication
|
// matrix multiplication
|
||||||
tmp += v0 * y[iybs + iqs + j/qr + 0];
|
|
||||||
tmp += v1 * y[iybs + iqs + j/qr + y_offset];
|
|
||||||
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
// for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2
|
||||||
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
|
tmp += __hmul2(v, {
|
||||||
|
y[iybs + iqs + j/qr + 0],
|
||||||
|
y[iybs + iqs + j/qr + y_offset]
|
||||||
|
});
|
||||||
|
#else
|
||||||
|
tmp += v.x * y[iybs + iqs + j/qr + 0];
|
||||||
|
tmp += v.y * y[iybs + iqs + j/qr + y_offset];
|
||||||
|
#endif // GGML_CUDA_DMMV_F16
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -918,7 +965,11 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (tid == 0) {
|
if (tid == 0) {
|
||||||
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
|
dst[row] = tmp.x + tmp.y;
|
||||||
|
#else
|
||||||
dst[row] = tmp;
|
dst[row] = tmp;
|
||||||
|
#endif // GGML_CUDA_DMMV_F16
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1213,7 +1264,7 @@ static void dequantize_row_q6_K_cuda(const void * vx, float * y, const int k, cu
|
||||||
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
dequantize_block_q6_K<<<nb, 64, 0, stream>>>(vx, y);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1222,7 +1273,7 @@ static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, f
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1231,7 +1282,7 @@ static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, f
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1240,7 +1291,7 @@ static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, f
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1249,7 +1300,7 @@ static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, f
|
||||||
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1299,7 +1350,7 @@ static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, c
|
||||||
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
|
||||||
}
|
}
|
||||||
|
|
||||||
static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
|
||||||
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0);
|
||||||
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
const int block_num_y = (nrows + GGML_CUDA_DMMV_Y - 1) / GGML_CUDA_DMMV_Y;
|
||||||
const dim3 block_nums(1, block_num_y, 1);
|
const dim3 block_nums(1, block_num_y, 1);
|
||||||
|
@ -1714,21 +1765,40 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
const int64_t ne00 = src0->ne[0];
|
const int64_t ne00 = src0->ne[0];
|
||||||
const int64_t nrows = i01_high - i01_low;
|
const int64_t nrows = i01_high - i01_low;
|
||||||
|
|
||||||
|
// on some GPUs it is faster to convert src1 to half and to use half precision intrinsics
|
||||||
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
|
size_t ash;
|
||||||
|
dfloat * src1_dfloat = nullptr; // dfloat == half
|
||||||
|
|
||||||
|
bool src1_convert_f16 = src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 ||
|
||||||
|
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 ||
|
||||||
|
src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16;
|
||||||
|
|
||||||
|
if (src1_convert_f16) {
|
||||||
|
src1_dfloat = (half *) ggml_cuda_pool_malloc(ne00*sizeof(half), &ash);
|
||||||
|
ggml_cpy_f32_f16_cuda((char *) src1_ddf_i, (char *) src1_dfloat, ne00,
|
||||||
|
ne00, 1, sizeof(float), 0, 0,
|
||||||
|
ne00, 1, sizeof(half), 0, 0, cudaStream_main);
|
||||||
|
}
|
||||||
|
#else
|
||||||
|
dfloat * src1_dfloat = src1_ddf_i; // dfloat == float, no conversion
|
||||||
|
#endif // GGML_CUDA_DMMV_F16
|
||||||
|
|
||||||
switch (src0->type) {
|
switch (src0->type) {
|
||||||
case GGML_TYPE_Q4_0:
|
case GGML_TYPE_Q4_0:
|
||||||
dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q4_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q4_1:
|
case GGML_TYPE_Q4_1:
|
||||||
dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q4_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_0:
|
case GGML_TYPE_Q5_0:
|
||||||
dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q5_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q5_1:
|
case GGML_TYPE_Q5_1:
|
||||||
dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q5_1_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q8_0:
|
case GGML_TYPE_Q8_0:
|
||||||
dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q8_0_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_Q2_K:
|
case GGML_TYPE_Q2_K:
|
||||||
dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q2_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
|
@ -1746,7 +1816,7 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
dequantize_mul_mat_vec_q6_K_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
case GGML_TYPE_F16:
|
case GGML_TYPE_F16:
|
||||||
convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_ddf_i, dst_ddf_i, ne00, nrows, cudaStream_main);
|
convert_mul_mat_vec_f16_cuda(src0_ddq_i, src1_dfloat, dst_ddf_i, ne00, nrows, cudaStream_main);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
GGML_ASSERT(false);
|
GGML_ASSERT(false);
|
||||||
|
@ -1754,6 +1824,12 @@ inline void ggml_cuda_op_dequantize_mul_mat_vec(
|
||||||
}
|
}
|
||||||
CUDA_CHECK(cudaGetLastError());
|
CUDA_CHECK(cudaGetLastError());
|
||||||
|
|
||||||
|
#ifdef GGML_CUDA_DMMV_F16
|
||||||
|
if (src1_convert_f16) {
|
||||||
|
ggml_cuda_pool_free(src1_dfloat, ash);
|
||||||
|
}
|
||||||
|
#endif // GGML_CUDA_DMMV_F16
|
||||||
|
|
||||||
(void) src1;
|
(void) src1;
|
||||||
(void) dst;
|
(void) dst;
|
||||||
(void) src0_ddf_i;
|
(void) src0_ddf_i;
|
||||||
|
|
2
ggml.c
2
ggml.c
|
@ -7918,7 +7918,7 @@ static void ggml_compute_forward_add_q_f32(
|
||||||
|
|
||||||
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03));
|
||||||
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
|
float * src1_row = (float *)((char *) src1->data + (i11*nb11 + i12*nb12 + i13*nb13));
|
||||||
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb0));
|
void * dst_row = (void *) ((char *) dst->data + ( i1*nb1 + i2*nb2 + i3*nb3));
|
||||||
|
|
||||||
assert(ne00 % 32 == 0);
|
assert(ne00 % 32 == 0);
|
||||||
|
|
||||||
|
|
17
llama.cpp
17
llama.cpp
|
@ -1620,7 +1620,7 @@ static bool llama_eval_internal(
|
||||||
model.layers[il].w1,
|
model.layers[il].w1,
|
||||||
cur);
|
cur);
|
||||||
offload_func(cur);
|
offload_func(cur);
|
||||||
ggml_set_name(cur, "result_w2");
|
ggml_set_name(cur, "result_w1");
|
||||||
|
|
||||||
// SILU activation
|
// SILU activation
|
||||||
cur = ggml_silu(ctx0, cur);
|
cur = ggml_silu(ctx0, cur);
|
||||||
|
@ -1657,11 +1657,7 @@ static bool llama_eval_internal(
|
||||||
{
|
{
|
||||||
cur = ggml_rms_norm(ctx0, inpL);
|
cur = ggml_rms_norm(ctx0, inpL);
|
||||||
offload_func_nr(cur);
|
offload_func_nr(cur);
|
||||||
ggml_set_name(cur, "rms_norm_inpL");
|
ggml_set_name(cur, "rms_norm_2");
|
||||||
|
|
||||||
cur = ggml_rms_norm(ctx0, cur);
|
|
||||||
offload_func_nr(cur);
|
|
||||||
ggml_set_name(cur, "rms_norm_after");
|
|
||||||
|
|
||||||
// cur = cur*norm(broadcasted)
|
// cur = cur*norm(broadcasted)
|
||||||
cur = ggml_mul(ctx0, cur, model.norm);
|
cur = ggml_mul(ctx0, cur, model.norm);
|
||||||
|
@ -3471,9 +3467,12 @@ void llama_print_timings(struct llama_context * ctx) {
|
||||||
|
|
||||||
fprintf(stderr, "\n");
|
fprintf(stderr, "\n");
|
||||||
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0);
|
fprintf(stderr, "%s: load time = %8.2f ms\n", __func__, ctx->t_load_us / 1000.0);
|
||||||
fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per token)\n", __func__, 1e-3 * ctx->t_sample_us, n_sample, 1e-3 * ctx->t_sample_us / n_sample);
|
fprintf(stderr, "%s: sample time = %8.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token)\n", __func__, 1e-3 * ctx->t_p_eval_us, n_p_eval, 1e-3 * ctx->t_p_eval_us / n_p_eval);
|
__func__, 1e-3 * ctx->t_sample_us, n_sample, 1e-3 * ctx->t_sample_us / n_sample, 1e6 / ctx->t_sample_us * n_sample);
|
||||||
fprintf(stderr, "%s: eval time = %8.2f ms / %5d runs (%8.2f ms per token)\n", __func__, 1e-3 * ctx->t_eval_us, n_eval, 1e-3 * ctx->t_eval_us / n_eval);
|
fprintf(stderr, "%s: prompt eval time = %8.2f ms / %5d tokens (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
|
__func__, 1e-3 * ctx->t_p_eval_us, n_p_eval, 1e-3 * ctx->t_p_eval_us / n_p_eval, 1e6 / ctx->t_p_eval_us * n_p_eval);
|
||||||
|
fprintf(stderr, "%s: eval time = %8.2f ms / %5d runs (%8.2f ms per token, %8.2f tokens per second)\n",
|
||||||
|
__func__, 1e-3 * ctx->t_eval_us, n_eval, 1e-3 * ctx->t_eval_us / n_eval, 1e6 / ctx->t_eval_us * n_eval);
|
||||||
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0);
|
fprintf(stderr, "%s: total time = %8.2f ms\n", __func__, (t_end_us - ctx->t_start_us)/1000.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue