Merge 21b68f3032
into 3c04bf6da8
This commit is contained in:
commit
602cebba7a
2 changed files with 44 additions and 14 deletions
21
Makefile
21
Makefile
|
@ -466,6 +466,27 @@ ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
|
|||
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
|
||||
endif # LLAMA_HIPBLAS
|
||||
|
||||
ifdef LLAMA_CHIPSTAR
|
||||
CUSPVC ?= cuspvc
|
||||
LLAMA_CUDA_DMMV_X ?= 32
|
||||
LLAMA_CUDA_MMV_Y ?= 1
|
||||
LLAMA_CUDA_KQUANTS_ITER ?= 2
|
||||
MK_CPPFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS -DGGML_USE_CHIPSTAR
|
||||
MK_CPPFLAGS += -I/opt/H4IBLAS/include
|
||||
MK_LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
|
||||
# MK_LDFLAGS += -lhipblas -lamdhip64 -lrocblas
|
||||
MK_LDFLAGS += -L/opt/chipstar/lib -lCHIP -L/opt/H4IBLAS/lib -lhipblas -L/opt/H4IMKL/lib -lMKLShim
|
||||
# HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
|
||||
HIPFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
|
||||
HIPFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
|
||||
HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
|
||||
HIPFLAGS += -DGGML_CUDA_FORCE_DMMV
|
||||
OBJS += ggml-cuda.o
|
||||
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
|
||||
$(CUSPVC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
|
||||
endif # LLAMA_HIPBLAS
|
||||
|
||||
|
||||
ifdef LLAMA_METAL
|
||||
MK_CPPFLAGS += -DGGML_USE_METAL
|
||||
MK_LDFLAGS += -framework Foundation -framework Metal -framework MetalKit
|
||||
|
|
37
ggml-cuda.cu
37
ggml-cuda.cu
|
@ -13,7 +13,7 @@
|
|||
|
||||
#if defined(GGML_USE_HIPBLAS)
|
||||
#include <hip/hip_runtime.h>
|
||||
#include <hipblas/hipblas.h>
|
||||
#include <hipblas.h>
|
||||
#include <hip/hip_fp16.h>
|
||||
#ifdef __HIP_PLATFORM_AMD__
|
||||
// for rocblas_initialize()
|
||||
|
@ -106,7 +106,7 @@
|
|||
// TODO: improve this to be correct for more hardware
|
||||
// for example, currently fails for GeForce GTX 1660 which is TURING arch (> VOLTA) but does not have tensor cores
|
||||
// probably other such cases, and not sure what happens on AMD hardware
|
||||
#if !defined(GGML_CUDA_FORCE_MMQ)
|
||||
#if !defined(GGML_CUDA_FORCE_MMQ) && !defined(GGML_USE_CHIPSTAR)
|
||||
#define CUDA_USE_TENSOR_CORES
|
||||
#endif
|
||||
|
||||
|
@ -5499,23 +5499,23 @@ static void sqr_f32_cuda(const float * x, float * dst, const int k, cudaStream_t
|
|||
|
||||
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
if (ncols < 256) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
const dim3 block_dims(256, 1, 1);
|
||||
norm_f32<256><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
}
|
||||
}
|
||||
|
||||
static void group_norm_f32_cuda(const float * x, float * dst, const int num_groups, const int group_size, const int ne_elements, cudaStream_t stream) {
|
||||
static const float eps = 1e-6f;
|
||||
if (group_size < 1024) {
|
||||
if (group_size < 256) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
group_norm_f32<WARP_SIZE><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
group_norm_f32<1024><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
|
||||
const dim3 block_dims(256, 1, 1);
|
||||
group_norm_f32<256><<<num_groups, block_dims, 0, stream>>>(x, dst, group_size, ne_elements, eps);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -5542,12 +5542,12 @@ static void pad_f32_cuda(const float * x, float * dst,
|
|||
|
||||
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float eps, cudaStream_t stream) {
|
||||
GGML_ASSERT(ncols % WARP_SIZE == 0);
|
||||
if (ncols < 1024) {
|
||||
if (ncols < 256) {
|
||||
const dim3 block_dims(WARP_SIZE, 1, 1);
|
||||
rms_norm_f32<WARP_SIZE><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
} else {
|
||||
const dim3 block_dims(1024, 1, 1);
|
||||
rms_norm_f32<1024><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
const dim3 block_dims(256, 1, 1);
|
||||
rms_norm_f32<256><<<nrows, block_dims, 0, stream>>>(x, dst, ncols, eps);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -7376,6 +7376,7 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|||
|
||||
const int compute_capability = g_compute_capabilities[id];
|
||||
|
||||
#ifndef GGML_USE_CHIPSTAR
|
||||
if (compute_capability >= CC_VOLTA && (src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type)) && ggml_is_contiguous(src0) && row_diff == src0->ne[1]) {
|
||||
// convert src0 and src1 to fp16, multiply as fp16, convert dst to fp32
|
||||
half * src0_as_f16 = nullptr;
|
||||
|
@ -7428,7 +7429,9 @@ inline void ggml_cuda_op_mul_mat_cublas(
|
|||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||
}
|
||||
}
|
||||
else {
|
||||
else
|
||||
#endif
|
||||
{
|
||||
float * src0_ddq_as_f32 = nullptr;
|
||||
size_t src0_as = 0;
|
||||
|
||||
|
@ -8323,6 +8326,7 @@ static __global__ void k_compute_batched_ptrs(
|
|||
ptrs_dst[0*ne23 + i12 + i13*ne12] = ( char *) dst_f16 + i12* nb2/2 + i13* nb3/2;
|
||||
}
|
||||
|
||||
#ifndef GGML_USE_CHIPSTAR
|
||||
static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
GGML_ASSERT(!ggml_is_transposed(src0));
|
||||
GGML_ASSERT(!ggml_is_transposed(src1));
|
||||
|
@ -8470,6 +8474,7 @@ static void ggml_cuda_mul_mat_mat_batched_cublas(const ggml_tensor * src0, const
|
|||
ggml_cuda_pool_free(src1_as_f16, src1_as);
|
||||
ggml_cuda_pool_free(dst_f16, dst_as);
|
||||
}
|
||||
#endif
|
||||
|
||||
static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
|
||||
const bool all_on_device =
|
||||
|
@ -8506,10 +8511,14 @@ static void ggml_cuda_mul_mat(const ggml_tensor * src0, const ggml_tensor * src1
|
|||
} else if (!split && all_on_device && !use_tensor_cores && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) {
|
||||
// KQV single-batch
|
||||
ggml_cuda_mul_mat_vec_nc(src0, src1, dst);
|
||||
} else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
|
||||
}
|
||||
#ifndef GGML_USE_CHIPSTAR
|
||||
else if (!split && all_on_device && use_tensor_cores && src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && !ggml_is_transposed(src0) && !ggml_is_transposed(src1)) {
|
||||
// KQ + KQV multi-batch
|
||||
ggml_cuda_mul_mat_mat_batched_cublas(src0, src1, dst);
|
||||
} else if (src0->type == GGML_TYPE_F32) {
|
||||
}
|
||||
#endif
|
||||
else if (src0->type == GGML_TYPE_F32) {
|
||||
ggml_cuda_op_mul_mat(src0, src1, dst, ggml_cuda_op_mul_mat_cublas, false);
|
||||
} else if (ggml_is_quantized(src0->type) || src0->type == GGML_TYPE_F16) {
|
||||
if (src1->ne[1] == 1 && src0->ne[0] % GGML_CUDA_DMMV_X == 0) {
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue