diff --git a/CMakeLists.txt b/CMakeLists.txt index 466ffd04f..a3cb22e0c 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -43,13 +43,14 @@ if (NOT MSVC) endif() # 3rd party libs -option(LLAMA_CUBLAS "llama: use CUDA" ON) +option(LLAMA_CUBLAS "llama: use CUDA" OFF) set(LLAMA_CUDA_MMQ_Y "64" CACHE STRING "llama: y tile size for mmq CUDA kernels") set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") option(LLAMA_CUDA_F16 "llama: use 16 bit floats for dmmv CUDA kernels" OFF) set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") +option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) option(LLAMA_K_QUANTS "llama: use k-quants" ON) @@ -121,6 +122,43 @@ if (LLAMA_CUBLAS) endif() endif() +if (LLAMA_HIPBLAS) + list(APPEND CMAKE_PREFIX_PATH /opt/rocm) + + if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang") + endif() + if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang") + message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++") + endif() + + find_package(hip) + find_package(hipblas) + find_package(rocblas) + + if (${hipblas_FOUND} AND ${hip_FOUND}) + message(STATUS "HIP and hipBLAS found") + add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS) + add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h) + if (LLAMA_CUDA_FORCE_DMMV) + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV) + endif() + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) + target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y}) + target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER}) + target_compile_definitions(ggml-rocm PRIVATE CC_TURING=1000000000) + set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX) + target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas) + + if (LLAMA_STATIC) + message(FATAL_ERROR "Static linking not supported for HIP/ROCm") + endif() + set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm) + else() + message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm") + endif() +endif() + if (LLAMA_ALL_WARNINGS) if (NOT MSVC) set(c_flags diff --git a/Makefile b/Makefile index f7cf21d5c..1b2c6bc7f 100644 --- a/Makefile +++ b/Makefile @@ -20,8 +20,6 @@ ifneq ($(shell grep -e "Arch Linux" -e "ID_LIKE=arch" /etc/os-release 2>/dev/nul ARCH_ADD = -lcblas endif -CCV := $(shell $(CC) --version | head -n 1) -CXXV := $(shell $(CXX) --version | head -n 1) # Mac OS + Arm can report x86_64 # ref: https://github.com/ggerganov/whisper.cpp/issues/66#issuecomment-1282546789 @@ -195,6 +193,45 @@ ggml_v2-cuda-legacy.o: otherarch/ggml_v2-cuda-legacy.cu otherarch/ggml_v2-cuda-l $(NVCC) $(NVCCFLAGS) $(subst -Ofast,-O3,$(CXXFLAGS)) $(CUBLAS_FLAGS) $(CUBLAS_CXXFLAGS) -Wno-pedantic -c $< -o $@ endif # LLAMA_CUBLAS +ifdef LLAMA_HIPBLAS + ROCM_PATH ?= /opt/rocm + CC := $(ROCM_PATH)/llvm/bin/clang + CXX := $(ROCM_PATH)/llvm/bin/clang++ + GPU_TARGETS ?= gfx803 gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 + LLAMA_CUDA_DMMV_X ?= 128 + LLAMA_CUDA_MMV_Y ?= 2 + LLAMA_CUDA_KQUANTS_ITER ?= 1 + HIPFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS $(shell $(ROCM_PATH)/bin/hipconfig -C) +ifdef LLAMA_CUDA_FORCE_DMMV + HIPFLAGS += -DGGML_CUDA_FORCE_DMMV +endif # LLAMA_CUDA_FORCE_DMMV + HIPLDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib -lhipblas -lamdhip64 -lrocblas + HIP_OBJS += ggml-cuda.o ggml_v2-cuda.o ggml_v2-cuda-legacy.o +ggml-cuda.o: HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS)) \ + -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) \ + -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y) \ + -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) \ + -DCC_TURING=1000000000 +ggml_v2-cuda.o: HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS)) \ + -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) \ + -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y) \ + -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) \ + -DCC_TURING=1000000000 +ggml_v2-cuda-legacy.o: HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS)) \ + -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) \ + -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y) \ + -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER) \ + -DCC_TURING=1000000000 # DGGML_CUDA_DMMV_F16 does not currently work with AMD. +ggml-cuda.o: ggml-cuda.cu ggml-cuda.h + $(CXX) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< +ggml_v2-cuda.o: otherarch/ggml_v2-cuda.cu otherarch/ggml_v2-cuda.h + $(CXX) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< +ggml_v2-cuda-legacy.o: otherarch/ggml_v2-cuda-legacy.cu otherarch/ggml_v2-cuda-legacy.h + $(CXX) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $< +endif # LLAMA_HIPBLAS + + + ifdef LLAMA_METAL CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG CXXFLAGS += -DGGML_USE_METAL @@ -224,12 +261,16 @@ ifneq ($(filter armv8%,$(UNAME_M)),) CFLAGS += -mfp16-format=ieee -mno-unaligned-access endif +CCV := $(shell $(CC) --version | head -n 1) +CXXV := $(shell $(CXX) --version | head -n 1) + DEFAULT_BUILD = FAILSAFE_BUILD = OPENBLAS_BUILD = NOAVX2_BUILD = CLBLAST_BUILD = CUBLAS_BUILD = +HIPBLAS_BUILD = ifeq ($(OS),Windows_NT) DEFAULT_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o $@.dll $(LDFLAGS) @@ -238,10 +279,12 @@ ifeq ($(OS),Windows_NT) NOAVX2_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o $@.dll $(LDFLAGS) CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ lib/OpenCL.lib lib/clblast.lib -shared -o $@.dll $(LDFLAGS) -ifdef LLAMA_CUBLAS - CUBLAS_BUILD = $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $^ -shared -o $@.dll $(CUBLASLD_FLAGS) $(LDFLAGS) -endif - + ifdef LLAMA_CUBLAS + CUBLAS_BUILD = $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $^ -shared -o $@.dll $(CUBLASLD_FLAGS) $(LDFLAGS) + endif + ifdef LLAMA_HIPBLAS + HIPBLAS_BUILD = $(CXX) $(CXXFLAGS) $(HIPFLAGS) $^ -shared -o $@.dll $(HIPLDFLAGS) $(LDFLAGS) + endif else DEFAULT_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o $@.so $(LDFLAGS) FAILSAFE_BUILD = $(CXX) $(CXXFLAGS) $^ -shared -o $@.so $(LDFLAGS) @@ -250,24 +293,29 @@ else NOAVX2_BUILD = $(CXX) $(CXXFLAGS) $^ $(ARCH_ADD) -lopenblas -shared -o $@.so $(LDFLAGS) endif ifdef LLAMA_CLBLAST - ifeq ($(UNAME_S),Darwin) - CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ -lclblast -framework OpenCL $(ARCH_ADD) -lopenblas -shared -o $@.so $(LDFLAGS) - else - CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ -lclblast -lOpenCL $(ARCH_ADD) -lopenblas -shared -o $@.so $(LDFLAGS) - endif + ifeq ($(UNAME_S),Darwin) + CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ -lclblast -framework OpenCL $(ARCH_ADD) -lopenblas -shared -o $@.so $(LDFLAGS) + else + CLBLAST_BUILD = $(CXX) $(CXXFLAGS) $^ -lclblast -lOpenCL $(ARCH_ADD) -lopenblas -shared -o $@.so $(LDFLAGS) + endif endif -ifdef LLAMA_CUBLAS - CUBLAS_BUILD = $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $^ -shared -o $@.so $(CUBLASLD_FLAGS) $(LDFLAGS) -endif + ifdef LLAMA_CUBLAS + CUBLAS_BUILD = $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $^ -shared -o $@.so $(CUBLASLD_FLAGS) $(LDFLAGS) + endif + ifdef LLAMA_HIPBLAS + HIPBLAS_BUILD = $(CXX) $(CXXFLAGS) $(HIPFLAGS) $^ -shared -o $@.so $(HIPLDFLAGS) $(LDFLAGS) + endif ifndef LLAMA_OPENBLAS ifndef LLAMA_CLBLAST ifndef LLAMA_CUBLAS + ifndef LLAMA_HIPBLAS OPENBLAS_BUILD = @echo 'Your OS $(OS) does not appear to be Windows. For faster speeds, install and link a BLAS library. Set LLAMA_OPENBLAS=1 to compile with OpenBLAS support or LLAMA_CLBLAST=1 to compile with ClBlast support. This is just a reminder, not an error.' endif endif endif + endif endif @@ -302,7 +350,7 @@ ggml_noavx2.o: ggml.c ggml.h ggml_clblast.o: ggml.c ggml.h $(CC) $(CFLAGS) $(FULLCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@ ggml_cublas.o: ggml.c ggml.h - $(CC) $(CFLAGS) $(FULLCFLAGS) $(CUBLAS_FLAGS) -c $< -o $@ + $(CC) $(CFLAGS) $(FULLCFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@ #quants K k_quants.o: k_quants.c k_quants.h ggml.h ggml-cuda.h @@ -328,7 +376,7 @@ ggml_v2_noavx2.o: otherarch/ggml_v2.c otherarch/ggml_v2.h ggml_v2_clblast.o: otherarch/ggml_v2.c otherarch/ggml_v2.h $(CC) $(CFLAGS) $(FULLCFLAGS) $(CLBLAST_FLAGS) -c $< -o $@ ggml_v2_cublas.o: otherarch/ggml_v2.c otherarch/ggml_v2.h - $(CC) $(CFLAGS) $(FULLCFLAGS) $(CUBLAS_FLAGS) -c $< -o $@ + $(CC) $(CFLAGS) $(FULLCFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@ #extreme old version compat ggml_v1.o: otherarch/ggml_v1.c otherarch/ggml_v1.h @@ -365,7 +413,7 @@ gpttype_adapter.o: $(GPTTYPE_ADAPTER) gpttype_adapter_clblast.o: $(GPTTYPE_ADAPTER) $(CXX) $(CXXFLAGS) $(CLBLAST_FLAGS) -c $< -o $@ gpttype_adapter_cublas.o: $(GPTTYPE_ADAPTER) - $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) -c $< -o $@ + $(CXX) $(CXXFLAGS) $(CUBLAS_FLAGS) $(HIPFLAGS) -c $< -o $@ clean: rm -vf *.o main quantize_llama quantize_gpt2 quantize_gptj quantize_neox quantize_mpt quantize-stats perplexity embedding benchmark-matmult save-load-state gguf gguf.exe main.exe quantize_llama.exe quantize_gptj.exe quantize_gpt2.exe quantize_neox.exe quantize_mpt.exe koboldcpp_default.dll koboldcpp_openblas.dll koboldcpp_failsafe.dll koboldcpp_noavx2.dll koboldcpp_clblast.dll koboldcpp_cublas.dll koboldcpp_default.so koboldcpp_openblas.so koboldcpp_failsafe.so koboldcpp_noavx2.so koboldcpp_clblast.so koboldcpp_cublas.so @@ -390,8 +438,8 @@ koboldcpp_noavx2: ggml_noavx2.o ggml_v2_noavx2.o ggml_v1_failsafe.o expose.o com $(NOAVX2_BUILD) koboldcpp_clblast: ggml_clblast.o ggml_v2_clblast.o ggml_v1.o expose.o common.o gpttype_adapter_clblast.o ggml-opencl.o ggml_v2-opencl.o ggml_v2-opencl-legacy.o k_quants.o ggml-alloc.o $(OBJS) $(CLBLAST_BUILD) -koboldcpp_cublas: ggml_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o common.o gpttype_adapter_cublas.o k_quants.o ggml-alloc.o $(CUBLAS_OBJS) $(OBJS) - $(CUBLAS_BUILD) +koboldcpp_cublas: ggml_cublas.o ggml_v2_cublas.o ggml_v1.o expose.o common.o gpttype_adapter_cublas.o k_quants.o ggml-alloc.o $(CUBLAS_OBJS) $(HIP_OBJS) $(OBJS) + $(CUBLAS_BUILD) $(HIPBLAS_BUILD) quantize_llama: examples/quantize/quantize.cpp ggml.o llama.o k_quants.o ggml-alloc.o $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index ee8a47a26..37392b8c4 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -440,7 +440,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in //this is used for the mem_per_token eval, openblas needs more RAM bool use_scratch = ggml_cpu_has_gpublas(); - int cu_parseinfo_maindevice = inputs.cublas_info<0?0:inputs.cublas_info; + int cu_parseinfo_maindevice = inputs.cublas_info<=0?0:inputs.cublas_info; printf("System Info: %s\n", llama_print_system_info()); #if defined(GGML_USE_CUBLAS) @@ -530,7 +530,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in if(!ts_all_zero) { llama_ctx_params.tensor_split = inputs.tensor_split; - printf("CUBLAS: Applying Custom Tensor Split!\n"); } #endif @@ -600,7 +599,6 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in if(!ts_all_zero) { llama_ctx_params.tensor_split = inputs.tensor_split; - printf("CUBLAS: Applying Custom Tensor Split!\n"); } #endif diff --git a/koboldcpp.py b/koboldcpp.py index 96d422685..6289ffd1b 100755 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -215,13 +215,6 @@ def load_model(model_filename): if args.useclblast: clblastids = 100 + int(args.useclblast[0])*10 + int(args.useclblast[1]) inputs.clblast_info = clblastids - inputs.cublas_info = 0 - if (args.usecublas and "0" in args.usecublas): - os.environ["CUDA_VISIBLE_DEVICES"] = "0" - elif (args.usecublas and "1" in args.usecublas): - os.environ["CUDA_VISIBLE_DEVICES"] = "1" - elif (args.usecublas and "2" in args.usecublas): - os.environ["CUDA_VISIBLE_DEVICES"] = "2" for n in range(tensor_split_max): if args.tensor_split and n < len(args.tensor_split): @@ -229,6 +222,22 @@ def load_model(model_filename): else: inputs.tensor_split[n] = 0 + # we must force an explicit tensor split + # otherwise the default will divide equally and multigpu crap will slow it down badly + inputs.cublas_info = 0 + if (args.usecublas and "0" in args.usecublas): + inputs.cublas_info = 0 + if not args.tensor_split: + inputs.tensor_split[inputs.cublas_info] = 100 + elif (args.usecublas and "1" in args.usecublas): + inputs.cublas_info = 1 + if not args.tensor_split: + inputs.tensor_split[inputs.cublas_info] = 100 + elif (args.usecublas and "2" in args.usecublas): + inputs.cublas_info = 2 + if not args.tensor_split: + inputs.tensor_split[inputs.cublas_info] = 100 + inputs.executable_path = (getdirpath()+"/").encode("UTF-8") inputs.debugmode = args.debugmode banned_tokens = args.bantokens @@ -730,7 +739,7 @@ def show_new_gui(): lib_option_pairs = [ (lib_openblas, "Use OpenBLAS"), (lib_clblast, "Use CLBlast"), - (lib_cublas, "Use CuBLAS"), + (lib_cublas, "Use CuBLAS/hipBLAS"), (lib_default, "Use No BLAS"), (lib_noavx2, "NoAVX2 Mode (Old CPU)"), (lib_failsafe, "Failsafe Mode (Old CPU)")] @@ -895,7 +904,7 @@ def show_new_gui(): def changerunmode(a,b,c): index = runopts_var.get() - if index == "Use CLBlast" or index == "Use CuBLAS": + if index == "Use CLBlast" or index == "Use CuBLAS/hipBLAS": gpu_selector_label.grid(row=3, column=0, padx = 8, pady=1, stick="nw") quick_gpu_selector_label.grid(row=3, column=0, padx = 8, pady=1, stick="nw") if index == "Use CLBlast": @@ -903,7 +912,7 @@ def show_new_gui(): quick_gpu_selector_box.grid(row=3, column=1, padx=8, pady=1, stick="nw") if gpu_choice_var.get()=="All": gpu_choice_var.set("1") - elif index == "Use CuBLAS": + elif index == "Use CuBLAS/hipBLAS": CUDA_gpu_selector_box.grid(row=3, column=1, padx=8, pady=1, stick="nw") CUDA_quick_gpu_selector_box.grid(row=3, column=1, padx=8, pady=1, stick="nw") else: @@ -914,7 +923,7 @@ def show_new_gui(): quick_gpu_selector_box.grid_forget() CUDA_quick_gpu_selector_box.grid_forget() - if index == "Use CuBLAS": + if index == "Use CuBLAS/hipBLAS": lowvram_box.grid(row=4, column=0, padx=8, pady=1, stick="nw") quick_lowvram_box.grid(row=4, column=0, padx=8, pady=1, stick="nw") mmq_box.grid(row=4, column=1, padx=8, pady=1, stick="nw") @@ -925,7 +934,7 @@ def show_new_gui(): mmq_box.grid_forget() quick_mmq_box.grid_forget() - if index == "Use CLBlast" or index == "Use CuBLAS": + if index == "Use CLBlast" or index == "Use CuBLAS/hipBLAS": gpu_layers_label.grid(row=5, column=0, padx = 8, pady=1, stick="nw") gpu_layers_entry.grid(row=5, column=1, padx=8, pady=1, stick="nw") quick_gpu_layers_label.grid(row=5, column=0, padx = 8, pady=1, stick="nw") @@ -1111,7 +1120,7 @@ def show_new_gui(): gpuchoiceidx = int(gpu_choice_var.get())-1 if runopts_var.get() == "Use CLBlast": args.useclblast = [[0,0], [1,0], [0,1]][gpuchoiceidx] - if runopts_var.get() == "Use CuBLAS": + if runopts_var.get() == "Use CuBLAS/hipBLAS": if gpu_choice_var.get()=="All": args.usecublas = ["lowvram"] if lowvram_var.get() == 1 else ["normal"] else: @@ -1337,7 +1346,7 @@ def show_old_gui(): blaschoice = tk.StringVar() blaschoice.set("BLAS = 512") - runopts = ["Use OpenBLAS","Use CLBLast GPU #1","Use CLBLast GPU #2","Use CLBLast GPU #3","Use CuBLAS GPU","Use No BLAS","NoAVX2 Mode (Old CPU)","Failsafe Mode (Old CPU)"] + runopts = ["Use OpenBLAS","Use CLBLast GPU #1","Use CLBLast GPU #2","Use CLBLast GPU #3","Use CuBLAS/hipBLAS GPU","Use No BLAS","NoAVX2 Mode (Old CPU)","Failsafe Mode (Old CPU)"] runchoice = tk.StringVar() runchoice.set("Use OpenBLAS") @@ -1779,8 +1788,8 @@ if __name__ == '__main__': compatgroup = parser.add_mutually_exclusive_group() compatgroup.add_argument("--noblas", help="Do not use OpenBLAS for accelerated prompt ingestion", action='store_true') compatgroup.add_argument("--useclblast", help="Use CLBlast for GPU Acceleration. Must specify exactly 2 arguments, platform ID and device ID (e.g. --useclblast 1 0).", type=int, choices=range(0,9), nargs=2) - compatgroup.add_argument("--usecublas", help="Use CuBLAS for GPU Acceleration. Requires CUDA. Select lowvram to not allocate VRAM scratch buffer. Enter a number afterwards to select and use 1 GPU. Leaving no number will use all GPUs.", nargs='*',metavar=('[lowvram|normal] [main GPU ID] [mmq]'), choices=['normal', 'lowvram', '0', '1', '2', 'mmq']) + compatgroup.add_argument("--usecublas", help="Use CuBLAS/hipBLAS for GPU Acceleration. Requires CUDA. Select lowvram to not allocate VRAM scratch buffer. Enter a number afterwards to select and use 1 GPU. Leaving no number will use all GPUs.", nargs='*',metavar=('[lowvram|normal] [main GPU ID] [mmq]'), choices=['normal', 'lowvram', '0', '1', '2', 'mmq']) parser.add_argument("--gpulayers", help="Set number of layers to offload to GPU when using GPU. Requires GPU.",metavar=('[GPU layers]'), type=int, default=0) parser.add_argument("--tensor_split", help="For CUDA with ALL GPU set only, ratio to split tensors across multiple GPUs, space-separated list of proportions, e.g. 7 3", metavar=('[Ratios]'), type=float, nargs='+') - main(parser.parse_args(),start_server=True) + main(parser.parse_args(),start_server=True) \ No newline at end of file diff --git a/llama.cpp b/llama.cpp index dae8c1484..9ee84345e 100644 --- a/llama.cpp +++ b/llama.cpp @@ -2055,7 +2055,11 @@ static void llm_load_tensors( #ifdef GGML_USE_CUBLAS const int max_backend_supported_layers = hparams.n_layer + 3; +#if defined(GGML_USE_HIPBLAS) + const int max_offloadable_layers = low_vram ? hparams.n_layer + 3 : hparams.n_layer + 3; +#else const int max_offloadable_layers = low_vram ? hparams.n_layer + 1 : hparams.n_layer + 3; +#endif if (n_gpu_layers > (int) hparams.n_layer + 1) { if (low_vram) { LLAMA_LOG_INFO("%s: cannot offload v cache to GPU due to low VRAM option\n", __func__); diff --git a/otherarch/ggml_v2-cuda-legacy.cu b/otherarch/ggml_v2-cuda-legacy.cu index d3220a786..e7053b764 100644 --- a/otherarch/ggml_v2-cuda-legacy.cu +++ b/otherarch/ggml_v2-cuda-legacy.cu @@ -4,9 +4,64 @@ #include #include +#if defined(GGML_USE_HIPBLAS) +#include +#include +#include +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define CUDA_R_16F HIPBLAS_R_16F +#define CUDA_R_32F HIPBLAS_R_32F +#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +#define cublasCreate hipblasCreate +#define cublasGemmEx hipblasGemmEx +#define cublasHandle_t hipblasHandle_t +#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetStream hipblasSetStream +#define cublasSgemm hipblasSgemm +#define cublasStatus_t hipblasStatus_t +#define cudaDeviceProp hipDeviceProp_t +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaError_t hipError_t +#define cudaEventCreateWithFlags hipEventCreateWithFlags +#define cudaEventDisableTiming hipEventDisableTiming +#define cudaEventRecord hipEventRecord +#define cudaEvent_t hipEvent_t +#define cudaFree hipFree +#define cudaFreeHost hipHostFree +#define cudaGetDevice hipGetDevice +#define cudaGetDeviceCount hipGetDeviceCount +#define cudaGetDeviceProperties hipGetDeviceProperties +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaMalloc hipMalloc +#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) +#define cudaMemcpy hipMemcpy +#define cudaMemcpy2DAsync hipMemcpy2DAsync +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyKind hipMemcpyKind +#define cudaMemset hipMemset +#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize +#define cudaSetDevice hipSetDevice +#define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamNonBlocking hipStreamNonBlocking +#define cudaStreamSynchronize hipStreamSynchronize +#define cudaStreamWaitEvent hipStreamWaitEvent +#define cudaStream_t hipStream_t +#define cudaSuccess hipSuccess +#else #include #include #include +#endif #include "ggml_v2-cuda-legacy.h" #include "ggml_v2-cuda.h" diff --git a/otherarch/ggml_v2-cuda.cu b/otherarch/ggml_v2-cuda.cu index 8314adb25..b4502df00 100644 --- a/otherarch/ggml_v2-cuda.cu +++ b/otherarch/ggml_v2-cuda.cu @@ -4,10 +4,66 @@ #include #include +#if defined(GGML_USE_HIPBLAS) +#include +#include +#include +#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F +#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F +#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT +#define CUBLAS_OP_N HIPBLAS_OP_N +#define CUBLAS_OP_T HIPBLAS_OP_T +#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS +#define CUBLAS_TF32_TENSOR_OP_MATH 0 +#define CUDA_R_16F HIPBLAS_R_16F +#define CUDA_R_32F HIPBLAS_R_32F +#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width) +#define cublasCreate hipblasCreate +#define cublasGemmEx hipblasGemmEx +#define cublasHandle_t hipblasHandle_t +#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS +#define cublasSetStream hipblasSetStream +#define cublasSgemm hipblasSgemm +#define cublasStatus_t hipblasStatus_t +#define cudaDeviceProp hipDeviceProp_t +#define cudaDeviceSynchronize hipDeviceSynchronize +#define cudaError_t hipError_t +#define cudaEventCreateWithFlags hipEventCreateWithFlags +#define cudaEventDisableTiming hipEventDisableTiming +#define cudaEventRecord hipEventRecord +#define cudaEvent_t hipEvent_t +#define cudaFree hipFree +#define cudaFreeHost hipHostFree +#define cudaGetDevice hipGetDevice +#define cudaGetDeviceCount hipGetDeviceCount +#define cudaGetDeviceProperties hipGetDeviceProperties +#define cudaGetErrorString hipGetErrorString +#define cudaGetLastError hipGetLastError +#define cudaMalloc hipMalloc +#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) +#define cudaMemcpy hipMemcpy +#define cudaMemcpy2DAsync hipMemcpy2DAsync +#define cudaMemcpyAsync hipMemcpyAsync +#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice +#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost +#define cudaMemcpyHostToDevice hipMemcpyHostToDevice +#define cudaMemcpyKind hipMemcpyKind +#define cudaMemset hipMemset +#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize +#define cudaSetDevice hipSetDevice +#define cudaStreamCreateWithFlags hipStreamCreateWithFlags +#define cudaStreamNonBlocking hipStreamNonBlocking +#define cudaStreamSynchronize hipStreamSynchronize +#define cudaStreamWaitEvent hipStreamWaitEvent +#define cudaStream_t hipStream_t +#define cudaSuccess hipSuccess +#else #include #include #include +#endif + #include "ggml_v2-cuda.h" #include "ggml_v2.h" @@ -807,4 +863,4 @@ void ggml_v2_cuda_transform_tensor(ggml_v2_tensor * tensor) { tensor->data = d_Q; tensor->backend = GGML_V2_BACKEND_CUDA; -} \ No newline at end of file +} diff --git a/otherarch/gpt2_v3.cpp b/otherarch/gpt2_v3.cpp index 981e01c7c..97b23265f 100644 --- a/otherarch/gpt2_v3.cpp +++ b/otherarch/gpt2_v3.cpp @@ -359,7 +359,11 @@ ModelLoadResult gpt2_model_load(const std::string & fname, gpt2_model & model, g const auto & hparams = model.hparams; size_t vram_total = 0; const int n_gpu = std::min(gpulayers, int(hparams.n_layer)); - fprintf(stderr, "%s: [GPU] offloading %d layers to GPU\n", __func__, n_gpu); + #if defined(GGML_USE_CLBLAST) + fprintf(stderr, "%s: [opencl] offloading %d layers to GPU\n", __func__, n_gpu); + #else + fprintf(stderr, "%s: [CUDA] offloading %d layers to GPU\n", __func__, n_gpu); + #endif for (int i = 0; i < n_gpu; ++i) { const auto & layer = model.layers[i]; layer.c_attn_attn_w->backend = GGML_BACKEND_GPU; @@ -378,7 +382,11 @@ ModelLoadResult gpt2_model_load(const std::string & fname, gpt2_model & model, g ggml_cuda_transform_tensor(layer.c_mlp_proj_w->data,layer.c_mlp_proj_w); vram_total += ggml_nbytes(layer.c_mlp_proj_w); #endif } - fprintf(stderr, "%s: [GPU] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); + #if defined(GGML_USE_CLBLAST) + fprintf(stderr, "%s: [opencl] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); + #else + fprintf(stderr, "%s: [CUDA] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); + #endif } #endif diff --git a/otherarch/gptj_v3.cpp b/otherarch/gptj_v3.cpp index 8f7cc47f1..42512e190 100644 --- a/otherarch/gptj_v3.cpp +++ b/otherarch/gptj_v3.cpp @@ -348,7 +348,11 @@ ModelLoadResult gptj_model_load(const std::string & fname, gptj_model & model, g const auto & hparams = model.hparams; size_t vram_total = 0; const int n_gpu = std::min(gpulayers, int(hparams.n_layer)); - fprintf(stderr, "%s: [GPU] offloading %d layers to GPU\n", __func__, n_gpu); + #if defined(GGML_USE_CLBLAST) + fprintf(stderr, "%s: [opencl] offloading %d layers to GPU\n", __func__, n_gpu); + #else + fprintf(stderr, "%s: [CUDA] offloading %d layers to GPU\n", __func__, n_gpu); + #endif for (int i = 0; i < n_gpu; ++i) { const auto & layer = model.layers[i]; layer.c_attn_q_proj_w->backend = GGML_BACKEND_GPU; @@ -373,7 +377,11 @@ ModelLoadResult gptj_model_load(const std::string & fname, gptj_model & model, g ggml_cuda_transform_tensor(layer.c_mlp_proj_w->data,layer.c_mlp_proj_w); vram_total += ggml_nbytes(layer.c_mlp_proj_w); #endif } - fprintf(stderr, "%s: [GPU] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); + #if defined(GGML_USE_CLBLAST) + fprintf(stderr, "%s: [opencl] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); + #else + fprintf(stderr, "%s: [CUDA] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); + #endif } #endif @@ -644,4 +652,4 @@ bool gptj_eval( ggml_free(ctx0); return true; -} \ No newline at end of file +} diff --git a/otherarch/llama_v2.cpp b/otherarch/llama_v2.cpp index ab9d82f93..01b47697c 100644 --- a/otherarch/llama_v2.cpp +++ b/otherarch/llama_v2.cpp @@ -3101,4 +3101,4 @@ std::vector llama_v2_tokenize(struct llama_v2_context * ctx, const res.resize(n); return res; -} \ No newline at end of file +} diff --git a/otherarch/mpt_v3.cpp b/otherarch/mpt_v3.cpp index 2bf23055c..57ed90888 100644 --- a/otherarch/mpt_v3.cpp +++ b/otherarch/mpt_v3.cpp @@ -301,7 +301,11 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo const auto & hparams = model.hparams; size_t vram_total = 0; const int n_gpu = std::min(gpulayers, int(hparams.n_layers)); - fprintf(stderr, "%s: [GPU] offloading %d layers to GPU\n", __func__, n_gpu); + #if defined(GGML_USE_CLBLAST) + fprintf(stderr, "%s: [opencl] offloading %d layers to GPU\n", __func__, n_gpu); + #else + fprintf(stderr, "%s: [CUDA] offloading %d layers to GPU\n", __func__, n_gpu); + #endif for (int i = 0; i < n_gpu; ++i) { const auto & layer = model.layers[i]; layer.ffn_up_proj->backend = GGML_BACKEND_GPU; @@ -320,7 +324,11 @@ bool mpt_model_load(const std::string & fname, mpt_model & model, gpt_vocab & vo ggml_cuda_transform_tensor(layer.c_attn_out_proj_weight->data,layer.c_attn_out_proj_weight); vram_total += ggml_nbytes(layer.c_attn_out_proj_weight); #endif } - fprintf(stderr, "%s: [GPU] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); + #if defined(GGML_USE_CLBLAST) + fprintf(stderr, "%s: [opencl] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); + #else + fprintf(stderr, "%s: [CUDA] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); + #endif } #endif diff --git a/otherarch/neox_v3.cpp b/otherarch/neox_v3.cpp index 7802cab86..d9fb93b28 100644 --- a/otherarch/neox_v3.cpp +++ b/otherarch/neox_v3.cpp @@ -335,7 +335,11 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model & const auto & hparams = model.hparams; size_t vram_total = 0; const int n_gpu = std::min(gpulayers, int(hparams.n_layer)); - fprintf(stderr, "%s: [GPU] offloading %d layers to GPU\n", __func__, n_gpu); + #if defined(GGML_USE_CLBLAST) + fprintf(stderr, "%s: [opencl] offloading %d layers to GPU\n", __func__, n_gpu); + #else + fprintf(stderr, "%s: [CUDA] offloading %d layers to GPU\n", __func__, n_gpu); + #endif for (int i = 0; i < n_gpu; ++i) { const auto & layer = model.layers[i]; layer.c_attn_attn_w->backend = GGML_BACKEND_GPU; @@ -354,7 +358,11 @@ ModelLoadResult gpt_neox_model_load(const std::string & fname, gpt_neox_model & ggml_cuda_transform_tensor(layer.c_mlp_proj_w->data,layer.c_mlp_proj_w); vram_total += ggml_nbytes(layer.c_mlp_proj_w); #endif } - fprintf(stderr, "%s: [GPU] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); + #if defined(GGML_USE_CLBLAST) + fprintf(stderr, "%s: [opencl] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); + #else + fprintf(stderr, "%s: [CUDA] total VRAM used: %zu MB\n", __func__, vram_total / 1024 / 1024); + #endif } #endif @@ -663,4 +671,4 @@ bool gpt_neox_eval( ggml_free(ctx0); return true; -} \ No newline at end of file +}