From a6a0fa338a8fb390c47ca85e11ce54672ceed38b Mon Sep 17 00:00:00 2001 From: Concedo <39025047+LostRuins@users.noreply.github.com> Date: Thu, 8 Jun 2023 22:40:53 +0800 Subject: [PATCH] cleanup indentation, fixing cublas build --- CMakeLists.txt | 4 +- Makefile | 29 +- expose.cpp | 12 +- gpttype_adapter.cpp | 142 +++---- koboldcpp.py | 16 +- otherarch/ggml_v2-cuda.cu | 810 ++++++++++++++++++++++++++++++++++++++ otherarch/ggml_v2-cuda.h | 21 + otherarch/ggml_v2.c | 20 +- otherarch/llama_v2-util.h | 6 +- 9 files changed, 933 insertions(+), 127 deletions(-) create mode 100644 otherarch/ggml_v2-cuda.cu create mode 100644 otherarch/ggml_v2-cuda.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4b33ea248..ced0c6a43 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -68,6 +68,7 @@ if (LLAMA_CUBLAS) enable_language(CUDA) set(GGML_CUDA_SOURCES ggml-cuda.cu ggml-cuda.h) + set(GGML_V2_CUDA_SOURCES otherarch/ggml_v2-cuda.cu otherarch/ggml_v2-cuda.h) add_compile_definitions(GGML_USE_CUBLAS) @@ -257,7 +258,8 @@ set_target_properties(ggml_v1 PROPERTIES POSITION_INDEPENDENT_CODE ON) add_library(ggml_v2 OBJECT otherarch/ggml_v2.c - otherarch/ggml_v2.h) + otherarch/ggml_v2.h + ${GGML_V2_CUDA_SOURCES}) target_include_directories(ggml_v2 PUBLIC . ./otherarch ./otherarch/tools) target_compile_features(ggml_v2 PUBLIC c_std_11) # don't bump target_link_libraries(ggml_v2 PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) diff --git a/Makefile b/Makefile index 6a952b386..69dee96a7 100644 --- a/Makefile +++ b/Makefile @@ -131,35 +131,8 @@ ifndef LLAMA_NO_ACCELERATE endif endif -ifdef LLAMA_CUBLAS - CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include - CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include - LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib - OBJS += ggml-cuda.o - NVCC = nvcc - NVCCFLAGS = --forward-unknown-to-host-compiler -arch=native -ifdef LLAMA_CUDA_DMMV_X - NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X) -else - NVCCFLAGS += -DGGML_CUDA_DMMV_X=32 -endif # LLAMA_CUDA_DMMV_X -ifdef LLAMA_CUDA_DMMV_Y - NVCCFLAGS += -DGGML_CUDA_DMMV_Y=$(LLAMA_CUDA_DMMV_Y) -else - NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1 -endif # LLAMA_CUDA_DMMV_Y -ggml-cuda.o: ggml-cuda.cu ggml-cuda.h - $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ -endif # LLAMA_CUBLAS +# to ease maintenance burden, please use the CMake file to generate CUDA builds instead. -ifdef LLAMA_GPROF - CFLAGS += -pg - CXXFLAGS += -pg -endif -ifdef LLAMA_PERF - CFLAGS += -DGGML_PERF - CXXFLAGS += -DGGML_PERF -endif ifdef LLAMA_METAL CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG CXXFLAGS += -DGGML_USE_METAL diff --git a/expose.cpp b/expose.cpp index fffa978d0..6388bed88 100644 --- a/expose.cpp +++ b/expose.cpp @@ -83,7 +83,7 @@ extern "C" file_format = FileFormat::GPTJ_3; printf("\n---\nRetrying as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format); lr = gpttype_load_model(inputs, file_format); - } + } //lastly try format 2 if (lr == ModelLoadResult::RETRY_LOAD) @@ -91,8 +91,8 @@ extern "C" file_format = FileFormat::GPTJ_2; printf("\n---\nRetrying as GPT-J model: (ver %d)\nAttempting to Load...\n---\n", file_format); lr = gpttype_load_model(inputs, file_format); - } - } + } + } if (lr == ModelLoadResult::FAIL || lr == ModelLoadResult::RETRY_LOAD) { @@ -131,7 +131,7 @@ extern "C" else if(file_format==FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2) { printf("\n---\nIdentified as RWKV model: (ver %d)\nAttempting to Load...\n---\n", file_format); - ModelLoadResult lr = gpttype_load_model(inputs, file_format); + ModelLoadResult lr = gpttype_load_model(inputs, file_format); if (lr == ModelLoadResult::FAIL || lr == ModelLoadResult::RETRY_LOAD) { return false; @@ -165,7 +165,7 @@ extern "C" file_format = FileFormat::NEOX_1; printf("\n---\nRetrying as GPT-NEO-X model: (ver %d)\nAttempting to Load...\n---\n", file_format); lr = gpttype_load_model(inputs, file_format); - } + } if (lr == ModelLoadResult::FAIL || lr == ModelLoadResult::RETRY_LOAD) { return false; @@ -178,7 +178,7 @@ extern "C" else if(file_format==FileFormat::MPT_1) { printf("\n---\nIdentified as MPT model: (ver %d)\nAttempting to Load...\n---\n", file_format); - ModelLoadResult lr = gpttype_load_model(inputs, file_format); + ModelLoadResult lr = gpttype_load_model(inputs, file_format); if (lr == ModelLoadResult::FAIL || lr == ModelLoadResult::RETRY_LOAD) { return false; diff --git a/gpttype_adapter.cpp b/gpttype_adapter.cpp index db48b2cad..d7c334c00 100644 --- a/gpttype_adapter.cpp +++ b/gpttype_adapter.cpp @@ -97,14 +97,14 @@ inline bool LogitsDuplicated(std::vector & arr1, std::vector & arr } -llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng) +llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng) { llama_sample_softmax(nullptr, candidates); std::vector probs; probs.reserve(candidates->size); top_picks.clear(); for (size_t i = 0; i < candidates->size; ++i) { - probs.push_back(candidates->data[i].p); + probs.push_back(candidates->data[i].p); } std::discrete_distribution<> dist(probs.begin(), probs.end()); @@ -113,21 +113,21 @@ llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng if(debugmode) { top_picks.push_back(candidates->data[idx]); - for (size_t i = 0; (i < candidates->size && i<4); ++i) - { + for (size_t i = 0; (i < candidates->size && i<4); ++i) + { if(i!=idx) { top_picks.push_back(candidates->data[i]); } - } + } } llama_token result = candidates->data[idx].id; return result; } -llama_token sample_token_mirostat(int n_vocab, llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int m, float * mu) -{ +llama_token sample_token_mirostat(int n_vocab, llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int m, float * mu) +{ float N = float(n_vocab); llama_sample_softmax(nullptr, candidates); // Estimate s_hat using the most probable m tokens @@ -157,7 +157,7 @@ llama_token sample_token_mirostat(int n_vocab, llama_token_data_array * candidat return X; } -llama_token sample_token_mirostat_v2(llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float * mu) +llama_token sample_token_mirostat_v2(llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float * mu) { llama_sample_softmax(nullptr, candidates); // Truncate the words with surprise values greater than mu @@ -191,11 +191,11 @@ void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep) // Compute the cumulative probabilities float maxprob = candidates->data[0].p; - + float threshold = a * maxprob * maxprob; //tokens with probs less than this are removed size_t last_idx = candidates->size; - for (size_t i = 0; i < candidates->size; ++i) { + for (size_t i = 0; i < candidates->size; ++i) { // Go until we reach a value under the threshold float checkprob = candidates->data[i].p; if (checkprob < threshold && i >= min_keep) { @@ -223,11 +223,11 @@ int mirostat, float mirostat_tau, float mirostat_eta) llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false }; // Apply penalties - auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx); + auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx); llama_sample_repetition_penalty(nullptr, &candidates_p, last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, last_n_repeat, rep_pen); - + // llama_sample_frequency_and_presence_penalties(nullptr, &candidates_p, // last_n_tokens.data() + last_n_tokens.size() - last_n_repeat, // last_n_repeat, alpha_frequency, alpha_presence); @@ -300,15 +300,15 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in params.memory_f16 = inputs.f16_kv; params.n_ctx = inputs.max_context_length; - neox_ctx_v2.hparams.n_ctx = gptj_ctx_v1.hparams.n_ctx = gptj_ctx_v2.hparams.n_ctx = gpt2_ctx_v1.hparams.n_ctx = gpt2_ctx_v2.hparams.n_ctx + neox_ctx_v2.hparams.n_ctx = gptj_ctx_v1.hparams.n_ctx = gptj_ctx_v2.hparams.n_ctx = gpt2_ctx_v1.hparams.n_ctx = gpt2_ctx_v2.hparams.n_ctx = neox_ctx_v3.hparams.n_ctx = gptj_ctx_v3.hparams.n_ctx = gptj_ctx_v3.hparams.n_ctx = mpt_ctx_v3.hparams.n_ctx = params.n_ctx; printf("System Info: %s\n", llama_print_system_info()); - SetQuantsUnshuffled(false); + SetQuantsUnshuffled(false); if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2) { //newer format has bit unshuffling - SetQuantsUnshuffled(file_format == FileFormat::GGJT_2); + SetQuantsUnshuffled(file_format == FileFormat::GGJT_2); llama_ctx_params_v2 = llama_v2_context_default_params(); llama_ctx_params_v2.n_ctx = inputs.max_context_length; @@ -319,21 +319,21 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in llama_ctx_params_v2.use_mmap = inputs.use_mmap; llama_ctx_params_v2.use_mlock = inputs.use_mlock; llama_ctx_params_v2.n_gpu_layers = inputs.gpulayers; - + llama_ctx_v2 = llama_v2_init_from_file(modelname.c_str(), llama_ctx_params_v2); - + if (llama_ctx_v2 == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, modelname.c_str()); return ModelLoadResult::FAIL; } - + printf("\n---\nWarning: Your model may be an OUTDATED format (ver %d). Please reconvert it for better results!\n---\n", file_format); - + if (lora_filename != "") { printf("\nAttempting to apply LORA adapter: %s\n", lora_filename.c_str()); - + int err = llama_v2_apply_lora_from_file(llama_ctx_v2, lora_filename.c_str(), NULL, @@ -361,9 +361,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in llama_ctx_params.use_mmap = inputs.use_mmap; llama_ctx_params.use_mlock = inputs.use_mlock; llama_ctx_params.n_gpu_layers = inputs.gpulayers; - + llama_ctx_v3 = llama_init_from_file(modelname.c_str(), llama_ctx_params); - + if (llama_ctx_v3 == NULL) { fprintf(stderr, "%s: error: failed to load model '%s'\n", __func__, modelname.c_str()); @@ -372,7 +372,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in if (lora_filename != "") { printf("\nAttempting to apply LORA adapter: %s\n", lora_filename.c_str()); - + int err = llama_apply_lora_from_file(llama_ctx_v3, lora_filename.c_str(), NULL, @@ -479,8 +479,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in printf("\nTensor Transposition Detected! Retrying GPT-2 model loading..."); return res; } - // determine the required inference memory per token: - legacy_gpt2_eval(gpt2_ctx_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format); + // determine the required inference memory per token: + legacy_gpt2_eval(gpt2_ctx_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format); return ModelLoadResult::SUCCESS; } else if (file_format == FileFormat::GPT2_2 || file_format==FileFormat::GPT2_3 || file_format==FileFormat::GPT2_4) @@ -492,34 +492,34 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return res; - } + } else if(res==ModelLoadResult::RETRY_LOAD) { printf("\nTensor Transposition Detected! Retrying GPT-2 model loading..."); return res; } // determine the required inference memory per token: - gpt2_eval(gpt2_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format); + gpt2_eval(gpt2_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format); return ModelLoadResult::SUCCESS; } else { //newer format has bit unshuffling - SetQuantsUnshuffled(file_format == FileFormat::GPT2_3); + SetQuantsUnshuffled(file_format == FileFormat::GPT2_3); ModelLoadResult res = gpt2_v2_model_load(params.model, gpt2_ctx_v2, vocab, file_format, inputs.gpulayers); if(res==ModelLoadResult::FAIL) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return res; - } + } else if(res==ModelLoadResult::RETRY_LOAD) { printf("\nTensor Transposition Detected! Retrying GPT-2 model loading..."); return res; } - // determine the required inference memory per token: - gpt2_v2_eval(gpt2_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format); + // determine the required inference memory per token: + gpt2_v2_eval(gpt2_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format); return ModelLoadResult::SUCCESS; } } @@ -536,9 +536,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in printf("\nTensor Transposition Detected! Retrying GPT-J model loading..."); return res; } - // determine the required inference memory per token: - legacy_gptj_eval(gptj_ctx_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format); - + // determine the required inference memory per token: + legacy_gptj_eval(gptj_ctx_v1, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token, file_format); + //if the logits are NAN or duplicated, it means the model is incompatible if(logits.size()>0 && IsNanCheck(logits[0])) { @@ -565,16 +565,16 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in return loadresult; } - // determine the required inference memory per token: + // determine the required inference memory per token: gptj_eval(gptj_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); - + //if the logits are NAN or duplicated, it means the model is incompatible std::vector oldlogits(logits); //this is another hack because they change the library - we run the eval through the model //twice and compare logits. if they give the same logits for different inputs, model is broken gptj_eval(gptj_ctx_v3, params.n_threads, 0, {4, 5, 6, 7}, logits, mem_per_token); - + if(logits.size()>0 && (IsNanCheck(logits[0]) || LogitsDuplicated(oldlogits,logits))) { printf("\nBad Logits detected! Retrying GPT-J model loading..."); @@ -587,7 +587,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in else { //newer format has bit unshuffling - SetQuantsUnshuffled(file_format == FileFormat::GPTJ_4); + SetQuantsUnshuffled(file_format == FileFormat::GPTJ_4); ModelLoadResult loadresult = gptj_v2_model_load(params.model, gptj_ctx_v2, vocab, inputs.gpulayers); if (loadresult == ModelLoadResult::FAIL) @@ -601,16 +601,16 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in return loadresult; } - // determine the required inference memory per token: + // determine the required inference memory per token: gptj_v2_eval(gptj_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); - + //if the logits are NAN or duplicated, it means the model is incompatible std::vector oldlogits(logits); //this is another hack because they change the library - we run the eval through the model //twice and compare logits. if they give the same logits for different inputs, model is broken gptj_v2_eval(gptj_ctx_v2, params.n_threads, 0, {4, 5, 6, 7}, logits, mem_per_token); - + if(logits.size()>0 && (IsNanCheck(logits[0]) || LogitsDuplicated(oldlogits,logits))) { printf("\nBad Logits detected! Retrying GPT-J model loading..."); @@ -624,8 +624,8 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in else if(file_format==FileFormat::NEOX_1 || file_format==FileFormat::NEOX_2 || file_format==FileFormat::NEOX_3 || file_format==FileFormat::NEOX_4 || file_format==FileFormat::NEOX_5|| file_format==FileFormat::NEOX_6|| file_format==FileFormat::NEOX_7) { if(file_format==FileFormat::NEOX_6|| file_format==FileFormat::NEOX_7) - { - ModelLoadResult res = gpt_neox_model_load(params.model, neox_ctx_v3, vocab, file_format); + { + ModelLoadResult res = gpt_neox_model_load(params.model, neox_ctx_v3, vocab, file_format); if(res==ModelLoadResult::FAIL) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); @@ -637,7 +637,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in return res; } - // determine the required inference memory per token: + // determine the required inference memory per token: gpt_neox_eval(neox_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); return ModelLoadResult::SUCCESS; @@ -645,9 +645,9 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in else { //newer format has bit unshuffling - SetQuantsUnshuffled(file_format==FileFormat::NEOX_4 || file_format==FileFormat::NEOX_5); + SetQuantsUnshuffled(file_format==FileFormat::NEOX_4 || file_format==FileFormat::NEOX_5); - ModelLoadResult res = gpt_neox_v2_model_load(params.model, neox_ctx_v2, vocab, file_format); + ModelLoadResult res = gpt_neox_v2_model_load(params.model, neox_ctx_v2, vocab, file_format); if(res==ModelLoadResult::FAIL) { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); @@ -659,7 +659,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in return res; } - // determine the required inference memory per token: + // determine the required inference memory per token: gpt_neox_v2_eval(neox_ctx_v2, params.n_threads, 0, { 0, 1, 2, 3 }, logits, mem_per_token); if(logits.size()>0 && file_format==FileFormat::NEOX_2 && !IsNanCheck(logits[0])) @@ -669,7 +669,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in auto orig_par_res = neox_ctx_v2.hparams.par_res; neox_ctx_v2.hparams.par_res = 0; //test with residual false gpt_neox_v2_eval(neox_ctx_v2, params.n_threads, 0, test_embd, logits, mem_per_token); - neox_ctx_v2.hparams.par_res = orig_par_res; + neox_ctx_v2.hparams.par_res = orig_par_res; int topid = std::max_element(logits.begin(),logits.end())-logits.begin(); std::string predicted = vocab.id_to_token[topid].c_str(); auto findresult = predicted.find("8"); @@ -683,7 +683,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in return ModelLoadResult::SUCCESS; } - + } else if(file_format==FileFormat::MPT_1) { @@ -692,10 +692,10 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in { fprintf(stderr, "%s: failed to load model from '%s'\n", __func__, params.model.c_str()); return ModelLoadResult::FAIL; - } - + } + // determine the required inference memory per token: - mpt_eval(mpt_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, false, mem_per_token); + mpt_eval(mpt_ctx_v3, params.n_threads, 0, { 0, 1, 2, 3 }, logits, false, mem_per_token); return ModelLoadResult::SUCCESS; } else @@ -703,7 +703,7 @@ ModelLoadResult gpttype_load_model(const load_model_inputs inputs, FileFormat in printf("\nUnknown Model, cannot load.\n"); return ModelLoadResult::FAIL; } - + } @@ -802,10 +802,10 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o //if using BLAS and prompt is big enough, switch to single thread and use a huge batch bool approved_format = !(file_format == FileFormat::BADFORMAT || - file_format == FileFormat::GPT2_1 || + file_format == FileFormat::GPT2_1 || file_format == FileFormat::GPTJ_1 || - file_format == FileFormat::GPTJ_2 || - file_format == FileFormat::RWKV_1 || + file_format == FileFormat::GPTJ_2 || + file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2); bool blasmode = (approved_format && embd_inp.size() >= 32 && ggml_cpu_has_blas() && blasbatchsize!=-1); // bool blasmode = false; @@ -856,7 +856,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o else if (file_format == FileFormat::GPTJ_1 || file_format == FileFormat::GPTJ_2) { n_vocab = gptj_ctx_v1.hparams.n_vocab; - } + } else if(file_format == FileFormat::GPTJ_3 || file_format==FileFormat::GPTJ_4) { n_vocab = gptj_ctx_v2.hparams.n_vocab; @@ -963,7 +963,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o if (!startedsampling) { printf("\rProcessing Prompt%s (%d / %d tokens)", (blasmode ? " [BLAS]" : ""), input_consumed, embd_inp.size()); - } + } fflush(stdout); if (embdsize > 0) @@ -1081,7 +1081,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o } eosID = llama_token_eos(); - + if (!unbanTokens) { // set the logit of the eos token (2) to zero to avoid sampling it @@ -1112,7 +1112,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0); } } - + // set the logit of the eos token (0) to minimum to avoid sampling it if (file_format == FileFormat::RWKV_1 || file_format == FileFormat::RWKV_2 || @@ -1130,13 +1130,13 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o logits[eosID] = (logits[topid] < 0 ? logits[topid] : 0); } } - + } - - id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, + + id = SampleLogits(logitsPtr, nctx, n_vocab, last_n_size, repeat_penalty, top_k, top_a, top_p, typical_p, tfs_z, temp, rng, params.mirostat,params.mirostat_tau,params.mirostat_eta); - + last_n_tokens.erase(last_n_tokens.begin()); last_n_tokens.push_back(id); current_context_tokens.push_back(id); @@ -1151,11 +1151,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o { concat_output += FileFormatTokenizeID(id,file_format); } - + if (startedsampling) - { + { printf("\rGenerating (%d / %d tokens)", (params.n_predict - remaining_tokens), params.n_predict); - } + } if(debugmode && top_picks.size()>0) { printf(" ["); @@ -1163,11 +1163,11 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o for (auto & pick : top_picks) { if (!firstloop) - { + { printf(" "); } firstloop = false; - std::string tokenizedstr = FileFormatTokenizeID(pick.id, file_format); + std::string tokenizedstr = FileFormatTokenizeID(pick.id, file_format); ::utreplace(tokenizedstr, "\n", "\\n"); printf("(%s %.2f%%)", tokenizedstr.c_str(), pick.p*100); } @@ -1178,7 +1178,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o { printf("\n(EOS token triggered!)"); remaining_tokens = 0; - } + } for (const auto &matched : stop_sequence) { @@ -1199,7 +1199,7 @@ generation_outputs gpttype_generate(const generation_inputs inputs, generation_o { embd.push_back(embd_inp[input_consumed]); last_n_tokens.erase(last_n_tokens.begin()); - last_n_tokens.push_back(embd_inp[input_consumed]); + last_n_tokens.push_back(embd_inp[input_consumed]); current_context_tokens.push_back(embd_inp[input_consumed]); ++input_consumed; if ((int)embd.size() >= params.n_batch) diff --git a/koboldcpp.py b/koboldcpp.py index c8566db8b..cf8847bc8 100644 --- a/koboldcpp.py +++ b/koboldcpp.py @@ -115,7 +115,7 @@ def init_library(): if use_blas: libname = lib_openblas_noavx2 else: - libname = lib_failsafe + libname = lib_failsafe else: if use_clblast: libname = lib_clblast @@ -182,7 +182,7 @@ def generate(prompt,max_length=20, max_context_length=512,temperature=0.8,top_k= inputs.mirostat_eta = float(args.usemirostat[2]) else: inputs.mirostat = inputs.mirostat_tau = inputs.mirostat_eta = 0 - inputs.seed = seed + inputs.seed = seed for n in range(0,stop_token_max): if not stop_sequence or n >= len(stop_sequence): inputs.stop_sequence[n] = "".encode("UTF-8") @@ -468,17 +468,17 @@ def show_gui(): if sel==runopts[1] or sel==runopts[2] or sel==runopts[3]: frameC.grid(row=4,column=0,pady=4) else: - frameC.grid_forget() + frameC.grid_forget() frameA = tk.Frame(root) - tk.OptionMenu( frameA , runchoice , command = onDropdownChange ,*runopts ).grid(row=0,column=0) + tk.OptionMenu( frameA , runchoice , command = onDropdownChange ,*runopts ).grid(row=0,column=0) tk.OptionMenu( frameA , blaschoice ,*blasbatchopts ).grid(row=0,column=1) frameA.grid(row=2,column=0) frameB = tk.Frame(root) threads_var=tk.StringVar() threads_var.set(str(default_threads)) - threads_lbl = tk.Label(frameB, text = 'Threads: ', font=('calibre',10, 'bold')) + threads_lbl = tk.Label(frameB, text = 'Threads: ', font=('calibre',10, 'bold')) threads_input = tk.Entry(frameB,textvariable = threads_var, font=('calibre',10,'normal')) threads_lbl.grid(row=0,column=0) threads_input.grid(row=0,column=1) @@ -487,7 +487,7 @@ def show_gui(): frameC = tk.Frame(root) gpu_layers_var=tk.StringVar() gpu_layers_var.set("0") - gpu_lbl = tk.Label(frameC, text = 'GPU Layers (CLBlast only): ', font=('calibre',10, 'bold')) + gpu_lbl = tk.Label(frameC, text = 'GPU Layers (CLBlast only): ', font=('calibre',10, 'bold')) gpu_layers_input = tk.Entry(frameC,textvariable = gpu_layers_var, font=('calibre',10,'normal')) gpu_lbl.grid(row=0,column=0) gpu_layers_input.grid(row=0,column=1) @@ -507,7 +507,7 @@ def show_gui(): tk.Checkbutton(frameD, text='High Priority',variable=highpriority, onvalue=1, offvalue=0).grid(row=1,column=0) tk.Checkbutton(frameD, text='Disable MMAP',variable=disablemmap, onvalue=1, offvalue=0).grid(row=1,column=1) tk.Checkbutton(frameD, text='Unban Tokens',variable=unbantokens, onvalue=1, offvalue=0).grid(row=2,column=0) - tk.Checkbutton(frameD, text='Launch Browser',variable=launchbrowser, onvalue=1, offvalue=0).grid(row=2,column=1) + tk.Checkbutton(frameD, text='Launch Browser',variable=launchbrowser, onvalue=1, offvalue=0).grid(row=2,column=1) frameD.grid(row=5,column=0,pady=4) # Create button, it will change label text @@ -628,7 +628,7 @@ def main(args): print("Error, Could not change process priority: " + str(ex)) if args.contextsize: - global maxctx + global maxctx maxctx = args.contextsize init_library() # Note: if blas does not exist and is enabled, program will crash. diff --git a/otherarch/ggml_v2-cuda.cu b/otherarch/ggml_v2-cuda.cu new file mode 100644 index 000000000..8314adb25 --- /dev/null +++ b/otherarch/ggml_v2-cuda.cu @@ -0,0 +1,810 @@ +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "ggml_v2-cuda.h" +#include "ggml_v2.h" + +static_assert(sizeof(half) == sizeof(ggml_v2_fp16_t), "wrong fp16 size"); + +#define CUDA_CHECK(err) \ + do { \ + cudaError_t err_ = (err); \ + if (err_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ + cudaGetErrorString(err_)); \ + exit(1); \ + } \ + } while (0) + +#define CUBLAS_CHECK(err) \ + do { \ + cublasStatus_t err_ = (err); \ + if (err_ != CUBLAS_STATUS_SUCCESS) { \ + fprintf(stderr, "cuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + +typedef void (*dequantize_kernel_t)(const void * vx, const int ib, const int iqs, float & v0, float & v1); +typedef void (*to_fp32_cuda_t)(const void * x, float * y, int k, cudaStream_t stream); +typedef void (*dequantize_mul_mat_vec_cuda_t)(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream); + +// QK = number of values after dequantization +// QR = QK / number of values before dequantization + +#define QK4_0 32 +#define QR4_0 2 +typedef struct { + float d; // delta + uint8_t qs[QK4_0 / 2]; // nibbles / quants +} block_q4_0; +static_assert(sizeof(block_q4_0) == sizeof(float) + QK4_0 / 2, "wrong q4_0 block size/padding"); + +#define QK4_1 32 +#define QR4_1 2 +typedef struct { + float d; // delta + float m; // min + uint8_t qs[QK4_1 / 2]; // nibbles / quants +} block_q4_1; +static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding"); + +#define QK5_0 32 +#define QR5_0 2 +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; +static_assert(sizeof(block_q5_0) == sizeof(ggml_v2_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding"); + +#define QK5_1 32 +#define QR5_1 2 +typedef struct { + half d; // delta + half m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; +static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_v2_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding"); + +#define QK8_0 32 +#define QR8_0 1 +typedef struct { + float d; // delta + int8_t qs[QK8_0]; // quants +} block_q8_0; +static_assert(sizeof(block_q8_0) == sizeof(float) + QK8_0, "wrong q8_0 block size/padding"); + +#define CUDA_DEQUANTIZE_BLOCK_SIZE 256 +#define CUDA_DMMV_BLOCK_SIZE 32 // dmmv = dequantize_mul_mat_vec + +static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q4_0 * x = (const block_q4_0 *) vx; + + const float d = x[ib].d; + + const uint8_t vui = x[ib].qs[iqs]; + + const int8_t vi0 = vui & 0xF; + const int8_t vi1 = vui >> 4; + + v0 = (vi0 - 8)*d; + v1 = (vi1 - 8)*d; +} + +static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q4_1 * x = (const block_q4_1 *) vx; + + const float d = x[ib].d; + const float m = x[ib].m; + + const uint8_t vui = x[ib].qs[iqs]; + + const int8_t vi0 = vui & 0xF; + const int8_t vi1 = vui >> 4; + + v0 = vi0*d + m; + v1 = vi1*d + m; +} + +static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q5_0 * x = (const block_q5_0 *) vx; + + const float d = x[ib].d; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0) - 16; + const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1) - 16; + + v0 = x0*d; + v1 = x1*d; +} + +static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q5_1 * x = (const block_q5_1 *) vx; + + const float d = x[ib].d; + const float m = x[ib].m; + + uint32_t qh; + memcpy(&qh, x[ib].qh, sizeof(qh)); + + const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10; + const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10; + + const int32_t x0 = ((x[ib].qs[iqs] & 0xf) | xh_0); + const int32_t x1 = ((x[ib].qs[iqs] >> 4) | xh_1); + + v0 = x0*d + m; + v1 = x1*d + m; +} + +static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const block_q8_0 * x = (const block_q8_0 *) vx; + + const float d = x[ib].d; + + const int8_t vi0 = x[ib].qs[iqs + 0]; + const int8_t vi1 = x[ib].qs[iqs + 1]; + + v0 = vi0*d; + v1 = vi1*d; +} + +static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){ + const half * x = (const half *) vx; + + v0 = __half2float(x[ib + 0]); + v1 = __half2float(x[ib + 1]); +} + +template +static __global__ void dequantize_block(const void * vx, float * y, const int k) { + const int i = blockDim.x*blockIdx.x + 2*threadIdx.x; + + if (i >= k) { + return; + } + + const int ib = i/qk; // block index + const int iqs = (i%qk)/qr; // quant index + const int iybs = i - i%qk; // y block start index + const int y_offset = qr == 1 ? 1 : qk/2; + + // dequantize + float & v0 = y[iybs + iqs + 0]; + float & v1 = y[iybs + iqs + y_offset]; + dequantize_kernel(vx, ib, iqs, v0, v1); +} + +template +static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y, float * dst, const int ncols) { + const int row = blockIdx.x; + const int tid = threadIdx.x; + + const int y_offset = qr == 1 ? 1 : qk/2; + + __shared__ float tmp[block_size]; // separate sum for each thread + tmp[tid] = 0; + + for (int i = 0; i < ncols/block_size; i += 2) { + const int col = i*block_size + 2*tid; + const int ib = (row*ncols + col)/qk; // block index + const int iqs = (col%qk)/qr; // quant index + const int iybs = col - col%qk; // y block start index + + // dequantize + float v0, v1; + dequantize_kernel(vx, ib, iqs, v0, v1); + + // matrix multiplication + tmp[tid] += v0 * y[iybs + iqs + 0]; + tmp[tid] += v1 * y[iybs + iqs + y_offset]; + } + + // sum up partial sums and write back result + __syncthreads(); + for (int s=block_size/2; s>0; s>>=1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + __syncthreads(); + } + if (tid == 0) { + dst[row] = tmp[0]; + } +} + +static void dequantize_row_q4_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); +} + +static void dequantize_row_q4_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); +} + +static void dequantize_row_q5_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); +} + +static void dequantize_row_q5_1_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); +} + +static void dequantize_row_q8_0_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<<>>(vx, y, k); +} + +static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_V2_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_V2_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_V2_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_V2_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); +} + +static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_V2_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); +} + +static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { + const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; + dequantize_block<32, 1, convert_f16><<>>(vx, y, k); +} + +static void convert_mul_mat_vec_f16_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { + GGML_V2_ASSERT(ncols % CUDA_DMMV_BLOCK_SIZE == 0); + dequantize_mul_mat_vec + <<>>(vx, y, dst, ncols); +} + +static to_fp32_cuda_t ggml_v2_get_to_fp32_cuda(ggml_v2_type type) { + switch (type) { + case GGML_V2_TYPE_Q4_0: + return dequantize_row_q4_0_cuda; + case GGML_V2_TYPE_Q4_1: + return dequantize_row_q4_1_cuda; + case GGML_V2_TYPE_Q5_0: + return dequantize_row_q5_0_cuda; + case GGML_V2_TYPE_Q5_1: + return dequantize_row_q5_1_cuda; + case GGML_V2_TYPE_Q8_0: + return dequantize_row_q8_0_cuda; + case GGML_V2_TYPE_F16: + return convert_fp16_to_fp32_cuda; + default: + return nullptr; + } +} + +static dequantize_mul_mat_vec_cuda_t ggml_v2_get_dequantize_mul_mat_vec_cuda(ggml_v2_type type) { + switch (type) { + case GGML_V2_TYPE_Q4_0: + return dequantize_mul_mat_vec_q4_0_cuda; + case GGML_V2_TYPE_Q4_1: + return dequantize_mul_mat_vec_q4_1_cuda; + case GGML_V2_TYPE_Q5_0: + return dequantize_mul_mat_vec_q5_0_cuda; + case GGML_V2_TYPE_Q5_1: + return dequantize_mul_mat_vec_q5_1_cuda; + case GGML_V2_TYPE_Q8_0: + return dequantize_mul_mat_vec_q8_0_cuda; + case GGML_V2_TYPE_F16: + return convert_mul_mat_vec_f16_cuda; + default: + return nullptr; + } +} + +// buffer pool for cuda +#define MAX_CUDA_BUFFERS 256 + +struct scoped_spin_lock { + std::atomic_flag& lock; + scoped_spin_lock(std::atomic_flag& lock) : lock(lock) { + while (lock.test_and_set(std::memory_order_acquire)) { + ; // spin + } + } + ~scoped_spin_lock() { + lock.clear(std::memory_order_release); + } + scoped_spin_lock(const scoped_spin_lock&) = delete; + scoped_spin_lock& operator=(const scoped_spin_lock&) = delete; +}; + +struct cuda_buffer { + void * ptr = nullptr; + size_t size = 0; +}; + +static cuda_buffer g_cuda_buffer_pool[MAX_CUDA_BUFFERS]; +static std::atomic_flag g_cuda_pool_lock = ATOMIC_FLAG_INIT; + +static void * ggml_v2_cuda_pool_malloc(size_t size, size_t * actual_size) { + scoped_spin_lock lock(g_cuda_pool_lock); + + for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { + cuda_buffer& b = g_cuda_buffer_pool[i]; + if (b.size >= size && b.ptr != nullptr) { + void * ptr = b.ptr; + *actual_size = b.size; + b.ptr = nullptr; + b.size = 0; + return ptr; + } + } + void * ptr; + CUDA_CHECK(cudaMalloc((void **) &ptr, size)); + *actual_size = size; + return ptr; +} + +static void ggml_v2_cuda_pool_free(void * ptr, size_t size) { + scoped_spin_lock lock(g_cuda_pool_lock); + + for (int i = 0; i < MAX_CUDA_BUFFERS; ++i) { + cuda_buffer& b = g_cuda_buffer_pool[i]; + if (b.ptr == nullptr) { + b.ptr = ptr; + b.size = size; + return; + } + } + fprintf(stderr, "WARNING: cuda buffer pool full, increase MAX_CUDA_BUFFERS\n"); + CUDA_CHECK(cudaFree(ptr)); +} + +#define GGML_V2_CUDA_MAX_STREAMS 8 // Set this to 1 for reproducible matrix multiplication. +#define GGML_V2_CUDA_MAX_EVENTS 64 +static cublasHandle_t g_cublasH = nullptr; +static cudaStream_t g_cudaStreams[GGML_V2_CUDA_MAX_STREAMS] = { nullptr }; +static cudaStream_t g_cudaStreams2[GGML_V2_CUDA_MAX_STREAMS] = { nullptr }; +static cudaEvent_t g_cudaEvents[GGML_V2_CUDA_MAX_EVENTS] = { nullptr }; + +void ggml_v2_init_cublas() { + if (g_cublasH == nullptr) { + // create streams + for (int i = 0; i < GGML_V2_CUDA_MAX_STREAMS; ++i) { + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams[i], cudaStreamNonBlocking)); + CUDA_CHECK(cudaStreamCreateWithFlags(&g_cudaStreams2[i], cudaStreamNonBlocking)); + } + // create events + for (int i = 0; i < GGML_V2_CUDA_MAX_EVENTS; ++i) { + CUDA_CHECK(cudaEventCreateWithFlags(&g_cudaEvents[i], cudaEventDisableTiming)); + } + + // create cublas handle + CUBLAS_CHECK(cublasCreate(&g_cublasH)); + CUBLAS_CHECK(cublasSetMathMode(g_cublasH, CUBLAS_TF32_TENSOR_OP_MATH)); + + // configure logging to stdout + // CUBLAS_CHECK(cublasLoggerConfigure(1, 1, 0, nullptr)); + } +} + +void * ggml_v2_cuda_host_malloc(size_t size) { + if (getenv("GGML_V2_CUDA_NO_PINNED") != nullptr) { + return nullptr; + } + + void * ptr = nullptr; + cudaError_t err = cudaMallocHost((void **) &ptr, size); + if (err != cudaSuccess) { + fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory: %s\n", + size/1024.0/1024.0, cudaGetErrorString(err)); + return nullptr; + } + + return ptr; +} + +void ggml_v2_cuda_host_free(void * ptr) { + CUDA_CHECK(cudaFreeHost(ptr)); +} + +static cudaError_t ggml_v2_cuda_h2d_tensor_2d(void * dst, const struct ggml_v2_tensor * src, uint64_t i3, uint64_t i2, cudaStream_t stream) { + const uint64_t ne0 = src->ne[0]; + const uint64_t ne1 = src->ne[1]; + const uint64_t nb0 = src->nb[0]; + const uint64_t nb1 = src->nb[1]; + const uint64_t nb2 = src->nb[2]; + const uint64_t nb3 = src->nb[3]; + const enum ggml_v2_type type = src->type; + const size_t ts = ggml_v2_type_size(type); + const size_t bs = ggml_v2_blck_size(type); + + const void * x = (const void *) ((const char *) src->data + i2*nb2 + i3*nb3); + if (nb0 == ts && nb1 == ts*ne0/bs) { + return cudaMemcpyAsync(dst, x, ne1*nb1, cudaMemcpyHostToDevice, stream); + } else if (nb0 == ts) { + return cudaMemcpy2DAsync(dst, ts*ne0/bs, x, nb1, ts*ne0/bs, ne1, cudaMemcpyHostToDevice, stream); + } else { + for (uint64_t i1 = 0; i1 < ne1; i1++) { + const void * rx = (const void *) ((const char *) x + i1*nb1); + void * rd = (void *) ((char *) dst + i1*ts*ne0/bs); + // pretend the row is a matrix with cols=1 + cudaError_t r = cudaMemcpy2DAsync(rd, ts/bs, rx, nb0, ts/bs, ne0, cudaMemcpyHostToDevice, stream); + if (r != cudaSuccess) return r; + } + return cudaSuccess; + } +} + +static void ggml_v2_cuda_mul_mat_f32(const ggml_v2_tensor * src0, const ggml_v2_tensor * src1, ggml_v2_tensor * dst) { + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const float alpha = 1.0f; + const float beta = 0.0f; + const int x_ne = ne01 * ne00; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + const int n_mm = ne03 * ne02; + + size_t x_size, y_size, d_size; + float * d_X = (float *) ggml_v2_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); + float * d_Y = (float *) ggml_v2_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size); + float * d_D = (float *) ggml_v2_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + int i = i03*ne02 + i02; + cudaStream_t cudaStream = g_cudaStreams[i % GGML_V2_CUDA_MAX_STREAMS]; + + float * c_X = d_X + i * x_ne; + float * c_Y = d_Y + i * y_ne; + float * c_D = d_D + i * d_ne; + + // copy data to device + CUDA_CHECK(ggml_v2_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream)); + CUDA_CHECK(ggml_v2_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); + + // compute + CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); + CUBLAS_CHECK( + cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, c_X, ne00, + c_Y, ne10, + &beta, c_D, ne01)); + + // copy dst to host + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + } + } + + CUDA_CHECK(cudaDeviceSynchronize()); + ggml_v2_cuda_pool_free(d_X, x_size); + ggml_v2_cuda_pool_free(d_Y, y_size); + ggml_v2_cuda_pool_free(d_D, d_size); +} + +static void ggml_v2_cuda_mul_mat_f16(const ggml_v2_tensor * src0, const ggml_v2_tensor * src1, ggml_v2_tensor * dst, void * wdata, size_t /* wsize */) { + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + + const int nb10 = src1->nb[0]; + const int nb11 = src1->nb[1]; + const int nb12 = src1->nb[2]; + const int nb13 = src1->nb[3]; + + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + + const float alpha = 1.0f; + const float beta = 0.0f; + const int x_ne = ne01 * ne00; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + const int n_mm = ne03 * ne02; + + size_t x_size, y_size, d_size; + half * d_X = (half *) ggml_v2_cuda_pool_malloc(n_mm * sizeof(half) * x_ne, &x_size); + half * d_Y = (half *) ggml_v2_cuda_pool_malloc(n_mm * sizeof(half) * y_ne, &y_size); + float * d_D = (float *) ggml_v2_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size); + + bool src1_cont_rows = nb10 == sizeof(float); + bool src1_cont_cols = (size_t)nb11 == ne11*sizeof(float); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + int i = i03*ne02 + i02; + cudaStream_t cudaStream = g_cudaStreams[i % GGML_V2_CUDA_MAX_STREAMS]; + + half * c_X = d_X + i * x_ne; + half * c_Y = d_Y + i * y_ne; + float * c_D = d_D + i * d_ne; + + // copy src0 to device + CUDA_CHECK(ggml_v2_cuda_h2d_tensor_2d(c_X, src0, i03, i02, cudaStream)); + + // convert src1 to fp16 + // TODO: use multiple threads + ggml_v2_fp16_t * const tmp = (ggml_v2_fp16_t *) wdata + (ne11 * ne10) * (i03 * ne02 + i02); + char * src1i = (char *) src1->data + i03*nb13 + i02*nb12; + if (src1_cont_rows) { + if (src1_cont_cols) { + ggml_v2_fp32_to_fp16_row((float *) src1i, tmp, ne10*ne11); + } + else { + for (int64_t i01 = 0; i01 < ne11; i01++) { + ggml_v2_fp32_to_fp16_row((float *) (src1i + i01*nb11), tmp + i01*ne10, ne10); + } + } + } + else { + for (int64_t i01 = 0; i01 < ne11; i01++) { + for (int64_t i00 = 0; i00 < ne10; i00++) { + // very slow due to no inlining + tmp[i01*ne10 + i00] = ggml_v2_fp32_to_fp16(*(float *) (src1i + i01*nb11 + i00*nb10)); + } + } + } + + // copy src1 to device + CUDA_CHECK(cudaMemcpyAsync(c_Y, tmp, sizeof(half) * y_ne, cudaMemcpyHostToDevice, cudaStream)); + + // compute + CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); + CUBLAS_CHECK( + cublasGemmEx(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, c_X, CUDA_R_16F, ne00, + c_Y, CUDA_R_16F, ne10, + &beta, c_D, CUDA_R_32F, ne01, + CUBLAS_COMPUTE_32F_FAST_16F, + CUBLAS_GEMM_DEFAULT)); + + // copy dst to host + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + } + } + + CUDA_CHECK(cudaDeviceSynchronize()); + ggml_v2_cuda_pool_free(d_X, x_size); + ggml_v2_cuda_pool_free(d_Y, y_size); + ggml_v2_cuda_pool_free(d_D, d_size); +} + +static void ggml_v2_cuda_mul_mat_q_f32(const ggml_v2_tensor * src0, const ggml_v2_tensor * src1, ggml_v2_tensor * dst) { + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + const int64_t ne02 = src0->ne[2]; + const int64_t ne03 = src0->ne[3]; + + const int64_t ne10 = src1->ne[0]; + const int64_t ne11 = src1->ne[1]; + + const int nb2 = dst->nb[2]; + const int nb3 = dst->nb[3]; + const ggml_v2_type type = src0->type; + const bool mul_mat_vec = ne11 == 1; + + const float alpha = 1.0f; + const float beta = 0.0f; + const int x_ne = ne01 * ne00; + const int y_ne = ne11 * ne10; + const int d_ne = ne11 * ne01; + const int n_mm = ne03 * ne02; + const size_t q_sz = ggml_v2_type_size(type) * x_ne / ggml_v2_blck_size(type); + + size_t x_size, y_size, d_size, q_size; + float * d_X = nullptr; + if (!mul_mat_vec) { + d_X = (float *) ggml_v2_cuda_pool_malloc(n_mm * sizeof(float) * x_ne, &x_size); + } + float * d_Y = (float *) ggml_v2_cuda_pool_malloc(n_mm * sizeof(float) * y_ne, &y_size); + float * d_D = (float *) ggml_v2_cuda_pool_malloc(n_mm * sizeof(float) * d_ne, &d_size); + char * d_Q = (char *) ggml_v2_cuda_pool_malloc(n_mm * q_sz, &q_size); + + const to_fp32_cuda_t to_fp32_cuda = ggml_v2_get_to_fp32_cuda(type); + dequantize_mul_mat_vec_cuda_t dmmv = ggml_v2_get_dequantize_mul_mat_vec_cuda(type); + GGML_V2_ASSERT(to_fp32_cuda != nullptr); + + for (int64_t i03 = 0; i03 < ne03; i03++) { + for (int64_t i02 = 0; i02 < ne02; i02++) { + int i = i03*ne02 + i02; + cudaStream_t cudaStream = g_cudaStreams[i % GGML_V2_CUDA_MAX_STREAMS]; + cudaStream_t cudaStream2 = g_cudaStreams2[i % GGML_V2_CUDA_MAX_STREAMS]; + cudaEvent_t cudaEvent = g_cudaEvents[i % GGML_V2_CUDA_MAX_EVENTS]; + + float * c_Y = d_Y + i * y_ne; + float * c_D = d_D + i * d_ne; + char * c_Q = d_Q + i * q_sz; + + // copy src0 to device if necessary + if (src0->backend == GGML_V2_BACKEND_CPU) { + CUDA_CHECK(ggml_v2_cuda_h2d_tensor_2d(c_Q, src0, i03, i02, cudaStream2)); + } else if (src0->backend == GGML_V2_BACKEND_CUDA) { + c_Q = ((char *) src0->data) + i * q_sz; + } else { + GGML_V2_ASSERT(false); + } + if (mul_mat_vec) { // specialized dequantize_mul_mat_vec kernel + CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); + + // copy src1 to device + CUDA_CHECK(ggml_v2_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); + + // wait for data + CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); + + // compute + dmmv(c_Q, c_Y, c_D, ne00, ne01, cudaStream); + CUDA_CHECK(cudaGetLastError()); + + } else { // general dequantization kernel + cuBLAS matrix matrix multiplication + float * c_X = d_X + i * x_ne; + + // convert src0 to fp32 on device + to_fp32_cuda(c_Q, c_X, x_ne, cudaStream2); + CUDA_CHECK(cudaGetLastError()); + CUDA_CHECK(cudaEventRecord(cudaEvent, cudaStream2)); + + // copy src1 to device + CUDA_CHECK(ggml_v2_cuda_h2d_tensor_2d(c_Y, src1, i03, i02, cudaStream)); + + // wait for conversion + CUDA_CHECK(cudaStreamWaitEvent(cudaStream, cudaEvent, 0)); + + // compute + CUBLAS_CHECK(cublasSetStream(g_cublasH, cudaStream)); + CUBLAS_CHECK( + cublasSgemm(g_cublasH, CUBLAS_OP_T, CUBLAS_OP_N, + ne01, ne11, ne10, + &alpha, c_X, ne00, + c_Y, ne10, + &beta, c_D, ne01)); + } + + // copy dst to host + float * d = (float *) ((char *) dst->data + i02*nb2 + i03*nb3); + CUDA_CHECK(cudaMemcpyAsync(d, c_D, sizeof(float) * d_ne, cudaMemcpyDeviceToHost, cudaStream)); + } + } + + CUDA_CHECK(cudaDeviceSynchronize()); + if (!mul_mat_vec) { + ggml_v2_cuda_pool_free(d_X, x_size); + } + ggml_v2_cuda_pool_free(d_Y, y_size); + ggml_v2_cuda_pool_free(d_D, d_size); + ggml_v2_cuda_pool_free(d_Q, q_size); +} + +bool ggml_v2_cuda_can_mul_mat(const struct ggml_v2_tensor * src0, const struct ggml_v2_tensor * src1, struct ggml_v2_tensor * dst) { + const int64_t ne10 = src1->ne[0]; + + const int64_t ne0 = dst->ne[0]; + const int64_t ne1 = dst->ne[1]; + + // TODO: find the optimal values for these + if ((src0->type == GGML_V2_TYPE_F32 || src0->type == GGML_V2_TYPE_F16 || ggml_v2_is_quantized(src0->type)) && + src1->type == GGML_V2_TYPE_F32 && + dst->type == GGML_V2_TYPE_F32 && + ((ne0 >= 32 && ne1 >= 32 && ne10 >= 32) || src0->backend == GGML_V2_BACKEND_CUDA)) { + return true; + } + + return false; +} + +bool ggml_v2_cuda_mul_mat_use_f16(const struct ggml_v2_tensor * src0, const struct ggml_v2_tensor * src1, struct ggml_v2_tensor * /* dst */) { + size_t src0_sz = ggml_v2_nbytes(src0); + size_t src1_sz = ggml_v2_nbytes(src1); + + // mul_mat_q: src0 is converted to fp32 on device + size_t mul_mat_q_transfer = src0_sz + src1_sz; + + // mul_mat_f16: src1 is converted to fp16 on cpu + size_t mul_mat_f16_transfer = src0_sz + sizeof(half) * ggml_v2_nelements(src1); + + // choose the smaller one to transfer to the device + // TODO: this is not always the best choice due to the overhead of converting to fp16 + return mul_mat_f16_transfer < mul_mat_q_transfer; +} + +void ggml_v2_cuda_mul_mat(const ggml_v2_tensor * src0, const ggml_v2_tensor * src1, ggml_v2_tensor * dst, void * wdata, size_t wsize) { + GGML_V2_ASSERT(ggml_v2_cuda_can_mul_mat(src0, src1, dst)); + + if (src0->type == GGML_V2_TYPE_F32) { + ggml_v2_cuda_mul_mat_f32(src0, src1, dst); + } + else if (src0->type == GGML_V2_TYPE_F16) { + if (ggml_v2_cuda_mul_mat_use_f16(src0, src1, dst)) { + ggml_v2_cuda_mul_mat_f16(src0, src1, dst, wdata, wsize); + } + else { + ggml_v2_cuda_mul_mat_q_f32(src0, src1, dst); + } + } + else if (ggml_v2_is_quantized(src0->type)) { + ggml_v2_cuda_mul_mat_q_f32(src0, src1, dst); + } + else { + GGML_V2_ASSERT(false); + } +} + +size_t ggml_v2_cuda_mul_mat_get_wsize(const struct ggml_v2_tensor * src0, const struct ggml_v2_tensor * src1, struct ggml_v2_tensor * dst) { + if (ggml_v2_cuda_mul_mat_use_f16(src0, src1, dst)) { + return ggml_v2_nelements(src1) * sizeof(ggml_v2_fp16_t); + } + else { + return 0; + } +} + +void ggml_v2_cuda_transform_tensor(ggml_v2_tensor * tensor) { + const int64_t ne0 = tensor->ne[0]; + const int64_t ne1 = tensor->ne[1]; + const int64_t ne2 = tensor->ne[2]; + const int64_t ne3 = tensor->ne[3]; + + const ggml_v2_type type = tensor->type; + const size_t q_sz = ggml_v2_type_size(type) * ne0 * ne1 * ne2 * ne3 / ggml_v2_blck_size(type); + + size_t q_size; + char * d_Q = (char *) ggml_v2_cuda_pool_malloc(q_sz, &q_size); + + cudaStream_t cudaStream2 = g_cudaStreams2[0]; + + // copy tensor to device + CUDA_CHECK(ggml_v2_cuda_h2d_tensor_2d(d_Q, tensor, 0, 0, cudaStream2)); + CUDA_CHECK(cudaDeviceSynchronize()); + + tensor->data = d_Q; + tensor->backend = GGML_V2_BACKEND_CUDA; +} \ No newline at end of file diff --git a/otherarch/ggml_v2-cuda.h b/otherarch/ggml_v2-cuda.h new file mode 100644 index 000000000..4f3c75baf --- /dev/null +++ b/otherarch/ggml_v2-cuda.h @@ -0,0 +1,21 @@ +#include "ggml_v2.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void ggml_v2_init_cublas(void); + +bool ggml_v2_cuda_can_mul_mat(const struct ggml_v2_tensor * src0, const struct ggml_v2_tensor * src1, struct ggml_v2_tensor * dst); +size_t ggml_v2_cuda_mul_mat_get_wsize(const struct ggml_v2_tensor * src0, const struct ggml_v2_tensor * src1, struct ggml_v2_tensor * dst); +void ggml_v2_cuda_mul_mat(const struct ggml_v2_tensor * src0, const struct ggml_v2_tensor * src1, struct ggml_v2_tensor * dst, void * wdata, size_t wsize); + +// TODO: export these with GGML_V2_API +void * ggml_v2_cuda_host_malloc(size_t size); +void ggml_v2_cuda_host_free(void * ptr); + +void ggml_v2_cuda_transform_tensor(struct ggml_v2_tensor * tensor); + +#ifdef __cplusplus +} +#endif \ No newline at end of file diff --git a/otherarch/ggml_v2.c b/otherarch/ggml_v2.c index 097742e7a..cb7d5626b 100644 --- a/otherarch/ggml_v2.c +++ b/otherarch/ggml_v2.c @@ -140,7 +140,7 @@ inline static void* ggml_v2_aligned_malloc(size_t size) { #elif defined(GGML_USE_OPENBLAS) #include #elif defined(GGML_USE_CUBLAS) -#include "ggml-cuda.h" +#include "ggml_v2-cuda.h" #endif #if defined(GGML_USE_CLBLAST) #include "ggml_v2-opencl.h" @@ -3895,7 +3895,7 @@ struct ggml_v2_context * ggml_v2_init(struct ggml_v2_init_params params) { } #if defined(GGML_USE_CUBLAS) - ggml_init_cublas(); + ggml_v2_init_cublas(); #elif defined(GGML_USE_CLBLAST) if(quants_unshuffled) { @@ -9449,9 +9449,9 @@ static void ggml_v2_compute_forward_mul_mat_f32( // compute by src0 rows #if defined(GGML_USE_CUBLAS) - if (ggml_cuda_can_mul_mat(src0, src1, dst)) { + if (ggml_v2_cuda_can_mul_mat(src0, src1, dst)) { if (params->ith == 0 && params->type == GGML_V2_TASK_COMPUTE) { - ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); + ggml_v2_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); } return; } @@ -9643,9 +9643,9 @@ static void ggml_v2_compute_forward_mul_mat_f16_f32( // compute by src0 rows #if defined(GGML_USE_CUBLAS) - if (ggml_cuda_can_mul_mat(src0, src1, dst)) { + if (ggml_v2_cuda_can_mul_mat(src0, src1, dst)) { if (params->ith == 0 && params->type == GGML_V2_TASK_COMPUTE) { - ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); + ggml_v2_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); } return; } @@ -9882,9 +9882,9 @@ static void ggml_v2_compute_forward_mul_mat_q_f32( // compute by src0 rows #if defined(GGML_USE_CUBLAS) - if (ggml_cuda_can_mul_mat(src0, src1, dst)) { + if (ggml_v2_cuda_can_mul_mat(src0, src1, dst)) { if (params->ith == 0 && params->type == GGML_V2_TASK_COMPUTE) { - ggml_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); + ggml_v2_cuda_mul_mat(src0, src1, dst, params->wdata, params->wsize); } return; } @@ -14062,10 +14062,10 @@ void ggml_v2_graph_compute(struct ggml_v2_context * ctx, struct ggml_v2_cgraph * size_t cur = 0; #if defined(GGML_USE_CUBLAS) - if (ggml_cuda_can_mul_mat(node->src0, node->src1, node)) { + if (ggml_v2_cuda_can_mul_mat(node->src0, node->src1, node)) { node->n_tasks = 1; // TODO: this actually is doing nothing // the threads are still spinning - cur = ggml_cuda_mul_mat_get_wsize(node->src0, node->src1, node); + cur = ggml_v2_cuda_mul_mat_get_wsize(node->src0, node->src1, node); } else #elif defined(GGML_USE_CLBLAST) diff --git a/otherarch/llama_v2-util.h b/otherarch/llama_v2-util.h index a8c97ee1e..00aedf8e6 100644 --- a/otherarch/llama_v2-util.h +++ b/otherarch/llama_v2-util.h @@ -416,7 +416,7 @@ struct llama_v2_buffer { }; #ifdef GGML_USE_CUBLAS -#include "ggml-cuda.h" +#include "ggml_v2-cuda.h" struct llama_v2_ctx_buffer { uint8_t * addr = NULL; bool is_cuda; @@ -427,7 +427,7 @@ struct llama_v2_ctx_buffer { void resize(size_t size) { free(); - addr = (uint8_t *) ggml_cuda_host_malloc(size); + addr = (uint8_t *) ggml_v2_cuda_host_malloc(size); if (addr) { is_cuda = true; } @@ -442,7 +442,7 @@ struct llama_v2_ctx_buffer { void free() { if (addr) { if (is_cuda) { - ggml_cuda_host_free(addr); + ggml_v2_cuda_host_free(addr); } else { delete[] addr;