diff --git a/.gitignore b/.gitignore index 0b56bcc7a..307c065f7 100644 --- a/.gitignore +++ b/.gitignore @@ -134,3 +134,7 @@ poetry.toml # Test models for lora adapters /lora-tests + +# Local scripts +/run-vim.sh +/run-chat.sh diff --git a/Makefile b/Makefile index 539370e06..95110d4eb 100644 --- a/Makefile +++ b/Makefile @@ -48,7 +48,6 @@ TEST_TARGETS = \ tests/test-backend-ops \ tests/test-chat-template \ tests/test-double-float \ - tests/test-grad0 \ tests/test-grammar-integration \ tests/test-grammar-parser \ tests/test-json-schema-to-grammar \ @@ -636,10 +635,6 @@ else ifndef CUDA_POWER_ARCH MK_NVCCFLAGS += -arch=native endif # CUDA_DOCKER_ARCH -ifdef GGML_CUDA_FORCE_DMMV - MK_NVCCFLAGS += -DGGML_CUDA_FORCE_DMMV -endif # GGML_CUDA_FORCE_DMMV - ifdef GGML_CUDA_FORCE_MMQ MK_NVCCFLAGS += -DGGML_CUDA_FORCE_MMQ endif # GGML_CUDA_FORCE_MMQ @@ -648,20 +643,6 @@ ifdef GGML_CUDA_FORCE_CUBLAS MK_NVCCFLAGS += -DGGML_CUDA_FORCE_CUBLAS endif # GGML_CUDA_FORCE_CUBLAS -ifdef GGML_CUDA_DMMV_X - MK_NVCCFLAGS += -DGGML_CUDA_DMMV_X=$(GGML_CUDA_DMMV_X) -else - MK_NVCCFLAGS += -DGGML_CUDA_DMMV_X=32 -endif # GGML_CUDA_DMMV_X - -ifdef GGML_CUDA_MMV_Y - MK_NVCCFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_MMV_Y) -else ifdef GGML_CUDA_DMMV_Y - MK_NVCCFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_DMMV_Y) # for backwards compatibility -else - MK_NVCCFLAGS += -DGGML_CUDA_MMV_Y=1 -endif # GGML_CUDA_MMV_Y - ifdef GGML_CUDA_F16 MK_NVCCFLAGS += -DGGML_CUDA_F16 endif # GGML_CUDA_F16 @@ -670,12 +651,6 @@ ifdef GGML_CUDA_DMMV_F16 MK_NVCCFLAGS += -DGGML_CUDA_F16 endif # GGML_CUDA_DMMV_F16 -ifdef GGML_CUDA_KQUANTS_ITER - MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER) -else - MK_NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2 -endif - ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE MK_NVCCFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(GGML_CUDA_PEER_MAX_BATCH_SIZE) else @@ -784,10 +759,6 @@ ifdef GGML_HIPBLAS AMDGPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch) endif - GGML_CUDA_DMMV_X ?= 32 - GGML_CUDA_MMV_Y ?= 1 - GGML_CUDA_KQUANTS_ITER ?= 2 - MK_CPPFLAGS += -DGGML_USE_HIP -DGGML_USE_CUDA ifdef GGML_HIP_UMA @@ -801,13 +772,6 @@ endif # GGML_HIP_UMA HIPCC ?= $(CCACHE) $(ROCM_PATH)/bin/hipcc HIPFLAGS += $(addprefix --offload-arch=,$(AMDGPU_TARGETS)) - HIPFLAGS += -DGGML_CUDA_DMMV_X=$(GGML_CUDA_DMMV_X) - HIPFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_MMV_Y) - HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER) - -ifdef GGML_CUDA_FORCE_DMMV - HIPFLAGS += -DGGML_CUDA_FORCE_DMMV -endif # GGML_CUDA_FORCE_DMMV ifdef GGML_CUDA_FORCE_MMQ HIPFLAGS += -DGGML_CUDA_FORCE_MMQ @@ -870,10 +834,6 @@ ifdef GGML_MUSA MUSAFLAGS += $(addprefix --cuda-gpu-arch=, $(MTGPU_TARGETS)) -ifdef GGML_CUDA_FORCE_DMMV - MUSAFLAGS += -DGGML_CUDA_FORCE_DMMV -endif # GGML_CUDA_FORCE_DMMV - ifdef GGML_CUDA_FORCE_MMQ MUSAFLAGS += -DGGML_CUDA_FORCE_MMQ endif # GGML_CUDA_FORCE_MMQ @@ -882,18 +842,6 @@ ifdef GGML_CUDA_FORCE_CUBLAS MUSAFLAGS += -DGGML_CUDA_FORCE_CUBLAS endif # GGML_CUDA_FORCE_CUBLAS -ifdef GGML_CUDA_DMMV_X - MUSAFLAGS += -DGGML_CUDA_DMMV_X=$(GGML_CUDA_DMMV_X) -else - MUSAFLAGS += -DGGML_CUDA_DMMV_X=32 -endif # GGML_CUDA_DMMV_X - -ifdef GGML_CUDA_MMV_Y - MUSAFLAGS += -DGGML_CUDA_MMV_Y=$(GGML_CUDA_MMV_Y) -else - MUSAFLAGS += -DGGML_CUDA_MMV_Y=1 -endif # GGML_CUDA_MMV_Y - ifdef GGML_CUDA_F16 MUSAFLAGS += -DGGML_CUDA_F16 endif # GGML_CUDA_F16 @@ -902,12 +850,6 @@ ifdef GGML_CUDA_DMMV_F16 MUSAFLAGS += -DGGML_CUDA_F16 endif # GGML_CUDA_DMMV_F16 -ifdef GGML_CUDA_KQUANTS_ITER - MUSAFLAGS += -DK_QUANTS_PER_ITERATION=$(GGML_CUDA_KQUANTS_ITER) -else - MUSAFLAGS += -DK_QUANTS_PER_ITERATION=2 -endif - ifdef GGML_CUDA_PEER_MAX_BATCH_SIZE MUSAFLAGS += -DGGML_CUDA_PEER_MAX_BATCH_SIZE=$(GGML_CUDA_PEER_MAX_BATCH_SIZE) else @@ -964,6 +906,7 @@ endif # GGML_METAL ifdef GGML_METAL ggml/src/ggml-metal/ggml-metal.o: \ ggml/src/ggml-metal/ggml-metal.m \ + ggml/src/ggml-metal/ggml-metal-impl.h \ ggml/include/ggml-metal.h \ ggml/include/ggml.h $(CC) $(CFLAGS) -c $< -o $@ @@ -971,9 +914,11 @@ ggml/src/ggml-metal/ggml-metal.o: \ ifdef GGML_METAL_EMBED_LIBRARY ggml/src/ggml-metal-embed.o: \ ggml/src/ggml-metal/ggml-metal.metal \ + ggml/src/ggml-metal/ggml-metal-impl.h \ ggml/src/ggml-common.h @echo "Embedding Metal library" - @sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal + @sed -e '/__embed_ggml-common.h__/r ggml/src/ggml-common.h' -e '/__embed_ggml-common.h__/d' < ggml/src/ggml-metal/ggml-metal.metal > ggml/src/ggml-metal/ggml-metal-embed.metal.tmp + @sed -e '/#include "ggml-metal-impl.h"/r ggml/src/ggml-metal/ggml-metal-impl.h' -e '/#include "ggml-metal-impl.h"/d' < ggml/src/ggml-metal/ggml-metal-embed.metal.tmp > ggml/src/ggml-metal/ggml-metal-embed.metal $(eval TEMP_ASSEMBLY=$(shell mktemp -d)) @echo ".section __DATA, __ggml_metallib" > $(TEMP_ASSEMBLY)/ggml-metal-embed.s @echo ".globl _ggml_metallib_start" >> $(TEMP_ASSEMBLY)/ggml-metal-embed.s @@ -997,6 +942,7 @@ OBJ_GGML = \ $(DIR_GGML)/src/ggml-alloc.o \ $(DIR_GGML)/src/ggml-backend.o \ $(DIR_GGML)/src/ggml-backend-reg.o \ + $(DIR_GGML)/src/ggml-opt.o \ $(DIR_GGML)/src/ggml-quants.o \ $(DIR_GGML)/src/ggml-threading.o \ $(DIR_GGML)/src/ggml-cpu/ggml-cpu.o \ @@ -1499,11 +1445,6 @@ tests/test-json-schema-to-grammar: tests/test-json-schema-to-grammar.cpp \ $(CXX) $(CXXFLAGS) -Iexamples/server -c $< -o $(call GET_OBJ_FILE, $<) $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) -tests/test-grad0: tests/test-grad0.cpp \ - $(OBJ_GGML) - $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) - $(CXX) $(CXXFLAGS) $(filter-out %.h $<,$^) $(call GET_OBJ_FILE, $<) -o $@ $(LDFLAGS) - tests/test-opt: tests/test-opt.cpp \ $(OBJ_GGML) $(CXX) $(CXXFLAGS) -c $< -o $(call GET_OBJ_FILE, $<) diff --git a/README.md b/README.md index 6ab6acf12..5f7933c13 100644 --- a/README.md +++ b/README.md @@ -459,14 +459,14 @@ To learn more how to measure perplexity using llama.cpp, [read this documentatio - Make sure to read this: [Inference at the edge](https://github.com/ggerganov/llama.cpp/discussions/205) - A bit of backstory for those who are interested: [Changelog podcast](https://changelog.com/podcast/532) -## Other documentations +## Other documentation - [main (cli)](./examples/main/README.md) - [server](./examples/server/README.md) - [jeopardy](./examples/jeopardy/README.md) - [GBNF grammars](./grammars/README.md) -**Development documentations** +**Development documentation** - [How to build](./docs/build.md) - [Running on Docker](./docs/docker.md) diff --git a/common/arg.cpp b/common/arg.cpp index 7c5c5e5cd..4115b2f75 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1939,17 +1939,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.simple_io = true; } ).set_examples({LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_INFILL})); - add_opt(common_arg( - {"-ld", "--logdir"}, "LOGDIR", - "path under which to save YAML logs (no logging if unset)", - [](common_params & params, const std::string & value) { - params.logdir = value; - - if (params.logdir.back() != DIRECTORY_SEPARATOR) { - params.logdir += DIRECTORY_SEPARATOR; - } - } - )); add_opt(common_arg( {"--positive-file"}, "FNAME", string_format("positive prompts file, one prompt per line (default: '%s')", params.cvector_positive_file.c_str()), diff --git a/common/common.cpp b/common/common.cpp index ebd16b600..930374621 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1890,213 +1890,3 @@ common_control_vector_data common_control_vector_load(const std::vector & data) { - if (data.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - fprintf(stream, "%s: [", prop_name); - for (size_t i = 0; i < data.size() - 1; ++i) { - fprintf(stream, "%e, ", data[i]); - } - fprintf(stream, "%e]\n", data.back()); -} - -void yaml_dump_vector_int(FILE * stream, const char * prop_name, const std::vector & data) { - if (data.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - fprintf(stream, "%s: [", prop_name); - for (size_t i = 0; i < data.size() - 1; ++i) { - fprintf(stream, "%d, ", data[i]); - } - fprintf(stream, "%d]\n", data.back()); -} - -void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data) { - std::string data_str(data == NULL ? "" : data); - - if (data_str.empty()) { - fprintf(stream, "%s:\n", prop_name); - return; - } - - size_t pos_start = 0; - size_t pos_found = 0; - - if (std::isspace(data_str[0]) || std::isspace(data_str.back())) { - data_str = std::regex_replace(data_str, std::regex("\n"), "\\n"); - data_str = std::regex_replace(data_str, std::regex("\""), "\\\""); - data_str = std::regex_replace(data_str, std::regex(R"(\\[^n"])"), R"(\$&)"); - data_str = "\"" + data_str + "\""; - fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); - return; - } - - if (data_str.find('\n') == std::string::npos) { - fprintf(stream, "%s: %s\n", prop_name, data_str.c_str()); - return; - } - - fprintf(stream, "%s: |\n", prop_name); - while ((pos_found = data_str.find('\n', pos_start)) != std::string::npos) { - fprintf(stream, " %s\n", data_str.substr(pos_start, pos_found-pos_start).c_str()); - pos_start = pos_found + 1; - } -} - -void yaml_dump_non_result_info(FILE * stream, const common_params & params, const llama_context * lctx, - const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc) { - ggml_cpu_init(); // some ARM features are detected at runtime - - const auto & sparams = params.sparams; - - fprintf(stream, "build_commit: %s\n", LLAMA_COMMIT); - fprintf(stream, "build_number: %d\n", LLAMA_BUILD_NUMBER); - fprintf(stream, "cpu_has_arm_fma: %s\n", ggml_cpu_has_arm_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_avx: %s\n", ggml_cpu_has_avx() ? "true" : "false"); - fprintf(stream, "cpu_has_avx_vnni: %s\n", ggml_cpu_has_avx_vnni() ? "true" : "false"); - fprintf(stream, "cpu_has_avx2: %s\n", ggml_cpu_has_avx2() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512: %s\n", ggml_cpu_has_avx512() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512_vbmi: %s\n", ggml_cpu_has_avx512_vbmi() ? "true" : "false"); - fprintf(stream, "cpu_has_avx512_vnni: %s\n", ggml_cpu_has_avx512_vnni() ? "true" : "false"); - fprintf(stream, "cpu_has_fma: %s\n", ggml_cpu_has_fma() ? "true" : "false"); - fprintf(stream, "cpu_has_neon: %s\n", ggml_cpu_has_neon() ? "true" : "false"); - fprintf(stream, "cpu_has_sve: %s\n", ggml_cpu_has_sve() ? "true" : "false"); - fprintf(stream, "cpu_has_f16c: %s\n", ggml_cpu_has_f16c() ? "true" : "false"); - fprintf(stream, "cpu_has_fp16_va: %s\n", ggml_cpu_has_fp16_va() ? "true" : "false"); - fprintf(stream, "cpu_has_riscv_v: %s\n", ggml_cpu_has_riscv_v() ? "true" : "false"); - fprintf(stream, "cpu_has_wasm_simd: %s\n", ggml_cpu_has_wasm_simd() ? "true" : "false"); - fprintf(stream, "cpu_has_sse3: %s\n", ggml_cpu_has_sse3() ? "true" : "false"); - fprintf(stream, "cpu_has_vsx: %s\n", ggml_cpu_has_vsx() ? "true" : "false"); - fprintf(stream, "cpu_has_matmul_int8: %s\n", ggml_cpu_has_matmul_int8() ? "true" : "false"); - -#ifdef NDEBUG - fprintf(stream, "debug: false\n"); -#else - fprintf(stream, "debug: true\n"); -#endif // NDEBUG - - fprintf(stream, "model_desc: %s\n", model_desc); - fprintf(stream, "n_vocab: %d # output size of the final layer, 32001 for some models\n", llama_n_vocab(llama_get_model(lctx))); - -#ifdef __OPTIMIZE__ - fprintf(stream, "optimize: true\n"); -#else - fprintf(stream, "optimize: false\n"); -#endif // __OPTIMIZE__ - - fprintf(stream, "time: %s\n", timestamp.c_str()); - - fprintf(stream, "\n"); - fprintf(stream, "###############\n"); - fprintf(stream, "# User Inputs #\n"); - fprintf(stream, "###############\n"); - fprintf(stream, "\n"); - - fprintf(stream, "alias: %s # default: unknown\n", params.model_alias.c_str()); - fprintf(stream, "batch_size: %d # default: 512\n", params.n_batch); - fprintf(stream, "chunks: %d # default: -1 (unlimited)\n", params.n_chunks); - fprintf(stream, "color: %s # default: false\n", params.use_color ? "true" : "false"); - fprintf(stream, "ctx_size: %d # default: 512\n", params.n_ctx); - fprintf(stream, "dry_allowed_length: %d # default: 2\n", sparams.dry_allowed_length); - fprintf(stream, "dry_base: %.2f # default: 1.75\n", sparams.dry_base); - fprintf(stream, "dry_multiplier: %.1f # default: 0.0\n", sparams.dry_multiplier); - fprintf(stream, "dry_penalty_last_n: %d # default: -1 (0 = disable, -1 = context size)\n", sparams.dry_penalty_last_n); - fprintf(stream, "escape: %s # default: false\n", params.escape ? "true" : "false"); - fprintf(stream, "file: # never logged, see prompt instead. Can still be specified for input.\n"); - fprintf(stream, "frequency_penalty: %f # default: 0.0 \n", sparams.penalty_freq); - yaml_dump_string_multiline(stream, "grammar", sparams.grammar.c_str()); - fprintf(stream, "grammar-file: # never logged, see grammar instead. Can still be specified for input.\n"); - fprintf(stream, "hellaswag: %s # default: false\n", params.hellaswag ? "true" : "false"); - fprintf(stream, "hellaswag_tasks: %zu # default: 400\n", params.hellaswag_tasks); - fprintf(stream, "ignore_eos: %s # default: false\n", sparams.ignore_eos ? "true" : "false"); - - yaml_dump_string_multiline(stream, "in_prefix", params.input_prefix.c_str()); - fprintf(stream, "in_prefix_bos: %s # default: false\n", params.input_prefix_bos ? "true" : "false"); - yaml_dump_string_multiline(stream, "in_suffix", params.input_prefix.c_str()); - fprintf(stream, "interactive: %s # default: false\n", params.interactive ? "true" : "false"); - fprintf(stream, "interactive_first: %s # default: false\n", params.interactive_first ? "true" : "false"); - fprintf(stream, "keep: %d # default: 0\n", params.n_keep); - fprintf(stream, "logdir: %s # default: unset (no logging)\n", params.logdir.c_str()); - - fprintf(stream, "logit_bias:\n"); - for (const auto & logit_bias : sparams.logit_bias) { - fprintf(stream, " %d: %f", logit_bias.token, logit_bias.bias); - } - - fprintf(stream, "lora:\n"); - for (auto & la : params.lora_adapters) { - if (la.scale == 1.0f) { - fprintf(stream, " - %s\n", la.path.c_str()); - } - } - fprintf(stream, "lora_scaled:\n"); - for (auto & la : params.lora_adapters) { - if (la.scale != 1.0f) { - fprintf(stream, " - %s: %f\n", la.path.c_str(), la.scale); - } - } - fprintf(stream, "lora_init_without_apply: %s # default: false\n", params.lora_init_without_apply ? "true" : "false"); - fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); - fprintf(stream, "min_keep: %d # default: 0 (disabled)\n", sparams.min_keep); - fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); - fprintf(stream, "mirostat_ent: %f # default: 5.0\n", sparams.mirostat_tau); - fprintf(stream, "mirostat_lr: %f # default: 0.1\n", sparams.mirostat_eta); - fprintf(stream, "mlock: %s # default: false\n", params.use_mlock ? "true" : "false"); - fprintf(stream, "model: %s # default: %s\n", params.model.c_str(), DEFAULT_MODEL_PATH); - fprintf(stream, "model_draft: %s # default:\n", params.model_draft.c_str()); - fprintf(stream, "multiline_input: %s # default: false\n", params.multiline_input ? "true" : "false"); - fprintf(stream, "n_gpu_layers: %d # default: -1\n", params.n_gpu_layers); - fprintf(stream, "n_predict: %d # default: -1 (unlimited)\n", params.n_predict); - fprintf(stream, "n_probs: %d # only used by server binary, default: 0\n", sparams.n_probs); - fprintf(stream, "no_mmap: %s # default: false\n", !params.use_mmap ? "true" : "false"); - fprintf(stream, "penalize_nl: %s # default: false\n", sparams.penalize_nl ? "true" : "false"); - fprintf(stream, "ppl_output_type: %d # default: 0\n", params.ppl_output_type); - fprintf(stream, "ppl_stride: %d # default: 0\n", params.ppl_stride); - fprintf(stream, "presence_penalty: %f # default: 0.0\n", sparams.penalty_present); - yaml_dump_string_multiline(stream, "prompt", params.prompt.c_str()); - fprintf(stream, "prompt_cache: %s\n", params.path_prompt_cache.c_str()); - fprintf(stream, "prompt_cache_all: %s # default: false\n", params.prompt_cache_all ? "true" : "false"); - fprintf(stream, "prompt_cache_ro: %s # default: false\n", params.prompt_cache_ro ? "true" : "false"); - yaml_dump_vector_int(stream, "prompt_tokens", prompt_tokens); - fprintf(stream, "repeat_penalty: %f # default: 1.1\n", sparams.penalty_repeat); - - fprintf(stream, "reverse_prompt:\n"); - for (std::string ap : params.antiprompt) { - size_t pos = 0; - while ((pos = ap.find('\n', pos)) != std::string::npos) { - ap.replace(pos, 1, "\\n"); - pos += 1; - } - - fprintf(stream, " - %s\n", ap.c_str()); - } - - fprintf(stream, "rope_freq_base: %f # default: 10000.0\n", params.rope_freq_base); - fprintf(stream, "rope_freq_scale: %f # default: 1.0\n", params.rope_freq_scale); - fprintf(stream, "simple_io: %s # default: false\n", params.simple_io ? "true" : "false"); - fprintf(stream, "cont_batching: %s # default: false\n", params.cont_batching ? "true" : "false"); - fprintf(stream, "flash_attn: %s # default: false\n", params.flash_attn ? "true" : "false"); - fprintf(stream, "temp: %f # default: 0.8\n", sparams.temp); - - const std::vector tensor_split_vector(params.tensor_split, params.tensor_split + llama_max_devices()); - yaml_dump_vector_float(stream, "tensor_split", tensor_split_vector); - - fprintf(stream, "threads: %d # default: %u\n", params.cpuparams.n_threads, std::thread::hardware_concurrency()); - fprintf(stream, "top_k: %d # default: 40\n", sparams.top_k); - fprintf(stream, "top_p: %f # default: 0.95\n", sparams.top_p); - fprintf(stream, "min_p: %f # default: 0.0\n", sparams.min_p); - fprintf(stream, "xtc_probability: %f # default: 0.0\n", sparams.xtc_probability); - fprintf(stream, "xtc_threshold: %f # default: 0.1\n", sparams.xtc_threshold); - fprintf(stream, "typ_p: %f # default: 1.0\n", sparams.typ_p); - fprintf(stream, "verbose_prompt: %s # default: false\n", params.verbose_prompt ? "true" : "false"); - fprintf(stream, "display_prompt: %s # default: true\n", params.display_prompt ? "true" : "false"); -} diff --git a/common/common.h b/common/common.h index 6289feaeb..7977cc7a9 100644 --- a/common/common.h +++ b/common/common.h @@ -209,7 +209,6 @@ struct common_params { std::string path_prompt_cache = ""; // path to file for saving/loading prompt eval state // NOLINT std::string input_prefix = ""; // string to prefix user inputs with // NOLINT std::string input_suffix = ""; // string to suffix user inputs with // NOLINT - std::string logdir = ""; // directory in which to save YAML log files // NOLINT std::string lookup_cache_static = ""; // path of static ngram cache file for lookup decoding // NOLINT std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT @@ -584,15 +583,3 @@ common_control_vector_data common_control_vector_load(const std::vector & data); -void yaml_dump_vector_int (FILE * stream, const char * prop_name, const std::vector & data); -void yaml_dump_string_multiline(FILE * stream, const char * prop_name, const char * data); - -void yaml_dump_non_result_info( - FILE * stream, const common_params & params, const llama_context * lctx, - const std::string & timestamp, const std::vector & prompt_tokens, const char * model_desc); diff --git a/docs/build.md b/docs/build.md index 52de2b4e2..359952b30 100644 --- a/docs/build.md +++ b/docs/build.md @@ -186,13 +186,9 @@ The following compilation options are also available to tweak performance: | Option | Legal values | Default | Description | |-------------------------------|------------------------|---------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| GGML_CUDA_FORCE_DMMV | Boolean | false | Force the use of dequantization + matrix vector multiplication kernels instead of using kernels that do matrix vector multiplication on quantized data. By default the decision is made based on compute capability (MMVQ for 6.1/Pascal/GTX 1000 or higher). Does not affect k-quants. | -| GGML_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the CUDA dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | -| GGML_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the CUDA mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. | | GGML_CUDA_FORCE_MMQ | Boolean | false | Force the use of custom matrix multiplication kernels for quantized models instead of FP16 cuBLAS even if there is no int8 tensor core implementation available (affects V100, RDNA3). MMQ kernels are enabled by default on GPUs with int8 tensor core support. With MMQ force enabled, speed for large batch sizes will be worse but VRAM consumption will be lower. | | GGML_CUDA_FORCE_CUBLAS | Boolean | false | Force the use of FP16 cuBLAS instead of custom matrix multiplication kernels for quantized models | | GGML_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | -| GGML_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | GGML_CUDA_PEER_MAX_BATCH_SIZE | Positive integer | 128 | Maximum batch size for which to enable peer access between multiple GPUs. Peer access requires either Linux or NVLink. When using NVLink enabling peer access for larger batch sizes is potentially beneficial. | | GGML_CUDA_FA_ALL_QUANTS | Boolean | false | Compile support for all KV cache quantization type (combinations) for the FlashAttention CUDA kernels. More fine-grained control over KV cache size but compilation takes much longer. | @@ -268,13 +264,6 @@ You can download it from your Linux distro's package manager or from here: [ROCm The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used. If your GPU is not officially supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 (e.g. gfx1030, gfx1031, or gfx1035) or 11.0.0 on RDNA3. -The following compilation options are also available to tweak performance (yes, they refer to CUDA, not HIP, because it uses the same code as the cuBLAS version above): - -| Option | Legal values | Default | Description | -|------------------------|------------------------|---------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| GGML_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the HIP dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. | -| GGML_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the HIP mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. | -| GGML_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per HIP thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | ### Vulkan @@ -282,9 +271,9 @@ The following compilation options are also available to tweak performance (yes, #### w64devkit -Download and extract [w64devkit](https://github.com/skeeto/w64devkit/releases). +Download and extract [`w64devkit`](https://github.com/skeeto/w64devkit/releases). -Download and install the [Vulkan SDK](https://vulkan.lunarg.com/sdk/home#windows). When selecting components, only the Vulkan SDK Core is required. +Download and install the [`Vulkan SDK`](https://vulkan.lunarg.com/sdk/home#windows) with the default settings. Launch `w64devkit.exe` and run the following commands to copy Vulkan dependencies: ```sh @@ -302,6 +291,29 @@ EOF ``` Switch into the `llama.cpp` directory and run `make GGML_VULKAN=1`. +#### Git Bash MINGW64 + +Download and install [`Git-SCM`](https://git-scm.com/downloads/win) with the default settings + +Download and install [`Visual Studio Community Edition`](https://visualstudio.microsoft.com/) and make sure you select `C++` + +Download and install [`CMake`](https://cmake.org/download/) with the default settings + +Download and install the [`Vulkan SDK`](https://vulkan.lunarg.com/sdk/home#windows) with the default settings. + +Go into your `llama.cpp` directory and right click, select `Open Git Bash Here` and then run the following commands + +``` +cmake -B build -DGGML_VULKAN=ON +cmake --build build --config Release +``` + +Now you can load the model in conversation mode using `Vulkan` + +``` +build/bin/release/llama-cli -m "[PATH TO MODEL]" -ngl 100 -c 16384 -t 10 -n -2 -cnv +``` + #### MSYS2 Install [MSYS2](https://www.msys2.org/) and then run the following commands in a UCRT terminal to install dependencies. ```sh diff --git a/examples/infill/infill.cpp b/examples/infill/infill.cpp index f18362c91..15b358dc4 100644 --- a/examples/infill/infill.cpp +++ b/examples/infill/infill.cpp @@ -43,50 +43,6 @@ static std::vector * g_output_tokens; static bool is_interacting = false; -static void write_logfile( - const llama_context * ctx, const common_params & params, const llama_model * model, - const std::vector & input_tokens, const std::string & output, - const std::vector & output_tokens -) { - if (params.logdir.empty()) { - return; - } - - const std::string timestamp = string_get_sortable_timestamp(); - - const bool success = fs_create_directory_with_parents(params.logdir); - if (!success) { - LOG_ERR("%s: warning: failed to create logdir %s, cannot write logfile\n", - __func__, params.logdir.c_str()); - return; - } - - const std::string logfile_path = params.logdir + timestamp + ".yml"; - FILE * logfile = fopen(logfile_path.c_str(), "w"); - - if (logfile == NULL) { - LOG_ERR("%s: failed to open logfile %s\n", __func__, logfile_path.c_str()); - return; - } - - fprintf(logfile, "binary: infill\n"); - char model_desc[128]; - llama_model_desc(model, model_desc, sizeof(model_desc)); - yaml_dump_non_result_info(logfile, params, ctx, timestamp, input_tokens, model_desc); - - fprintf(logfile, "\n"); - fprintf(logfile, "######################\n"); - fprintf(logfile, "# Generation Results #\n"); - fprintf(logfile, "######################\n"); - fprintf(logfile, "\n"); - - yaml_dump_string_multiline(logfile, "output", output.c_str()); - yaml_dump_vector_int(logfile, "output_tokens", output_tokens); - - llama_perf_dump_yaml(logfile, ctx); - fclose(logfile); -} - #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) static void sigint_handler(int signo) { if (signo == SIGINT) { @@ -96,7 +52,6 @@ static void sigint_handler(int signo) { console::cleanup(); LOG("\n"); common_perf_print(*g_ctx, *g_smpl); - write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); // make sure all logs are flushed LOG("Interrupted by user\n"); @@ -625,7 +580,6 @@ int main(int argc, char ** argv) { LOG("\n"); common_perf_print(ctx, smpl); - write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); llama_free(ctx); llama_free_model(model); diff --git a/examples/main/main.cpp b/examples/main/main.cpp index 374ed47ad..7c4ce4be2 100644 --- a/examples/main/main.cpp +++ b/examples/main/main.cpp @@ -62,49 +62,6 @@ static bool file_is_empty(const std::string & path) { return f.tellg() == 0; } -static void write_logfile( - const llama_context * ctx, const common_params & params, const llama_model * model, - const std::vector & input_tokens, const std::string & output, - const std::vector & output_tokens -) { - if (params.logdir.empty()) { - return; - } - - const std::string timestamp = string_get_sortable_timestamp(); - - const bool success = fs_create_directory_with_parents(params.logdir); - if (!success) { - LOG_ERR("%s: failed to create logdir %s, cannot write logfile\n", __func__, params.logdir.c_str()); - return; - } - - const std::string logfile_path = params.logdir + timestamp + ".yml"; - FILE * logfile = fopen(logfile_path.c_str(), "w"); - - if (logfile == NULL) { - LOG_ERR("%s: failed to open logfile %s\n", __func__, logfile_path.c_str()); - return; - } - - fprintf(logfile, "binary: main\n"); - char model_desc[128]; - llama_model_desc(model, model_desc, sizeof(model_desc)); - yaml_dump_non_result_info(logfile, params, ctx, timestamp, input_tokens, model_desc); - - fprintf(logfile, "\n"); - fprintf(logfile, "######################\n"); - fprintf(logfile, "# Generation Results #\n"); - fprintf(logfile, "######################\n"); - fprintf(logfile, "\n"); - - yaml_dump_string_multiline(logfile, "output", output.c_str()); - yaml_dump_vector_int(logfile, "output_tokens", output_tokens); - - llama_perf_dump_yaml(logfile, ctx); - fclose(logfile); -} - #if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__)) || defined (_WIN32) static void sigint_handler(int signo) { if (signo == SIGINT) { @@ -115,7 +72,6 @@ static void sigint_handler(int signo) { console::cleanup(); LOG("\n"); common_perf_print(*g_ctx, *g_smpl); - write_logfile(*g_ctx, *g_params, *g_model, *g_input_tokens, g_output_ss->str(), *g_output_tokens); // make sure all logs are flushed LOG("Interrupted by user\n"); @@ -926,7 +882,6 @@ int main(int argc, char ** argv) { LOG("\n\n"); common_perf_print(ctx, smpl); - write_logfile(ctx, params, model, input_tokens, output_ss.str(), output_tokens); common_sampler_free(smpl); diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index e803ff143..64a84607c 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -34,55 +34,6 @@ struct results_log_softmax { float prob; }; -static void write_logfile( - const llama_context * ctx, const common_params & params, const llama_model * model, - const struct results_perplexity & results -) { - if (params.logdir.empty()) { - return; - } - - if (params.hellaswag) { - LOG_WRN("%s: logging results is not implemented for HellaSwag. No files will be written.\n", __func__); - return; - } - - const std::string timestamp = string_get_sortable_timestamp(); - - const bool success = fs_create_directory_with_parents(params.logdir); - if (!success) { - LOG_WRN("%s: failed to create logdir %s, cannot write logfile\n", - __func__, params.logdir.c_str()); - return; - } - - const std::string logfile_path = params.logdir + timestamp + ".yml"; - FILE * logfile = fopen(logfile_path.c_str(), "w"); - - if (logfile == NULL) { - LOG_ERR("%s: failed to open logfile %s\n", __func__, logfile_path.c_str()); - return; - } - - fprintf(logfile, "binary: main\n"); - char model_desc[128]; - llama_model_desc(model, model_desc, sizeof(model_desc)); - yaml_dump_non_result_info(logfile, params, ctx, timestamp, results.tokens, model_desc); - - fprintf(logfile, "\n"); - fprintf(logfile, "######################\n"); - fprintf(logfile, "# Perplexity Results #\n"); - fprintf(logfile, "######################\n"); - fprintf(logfile, "\n"); - - yaml_dump_vector_float(logfile, "logits", results.logits); - fprintf(logfile, "ppl_value: %f\n", results.ppl_value); - yaml_dump_vector_float(logfile, "probs", results.probs); - - llama_perf_dump_yaml(logfile, ctx); - fclose(logfile); -} - static std::vector softmax(const std::vector& logits) { std::vector probs(logits.size()); float max_logit = logits[0]; @@ -2072,8 +2023,6 @@ int main(int argc, char ** argv) { LOG("\n"); llama_perf_context_print(ctx); - write_logfile(ctx, params, model, results); - llama_free(ctx); llama_free_model(model); diff --git a/examples/server/README.md b/examples/server/README.md index 6f72c6bb8..0936e0b7b 100644 --- a/examples/server/README.md +++ b/examples/server/README.md @@ -85,7 +85,6 @@ The project is under active development, and we are [looking for feedback and co | `-hfr, --hf-repo REPO` | Hugging Face model repository (default: unused)
(env: LLAMA_ARG_HF_REPO) | | `-hff, --hf-file FILE` | Hugging Face model file (default: unused)
(env: LLAMA_ARG_HF_FILE) | | `-hft, --hf-token TOKEN` | Hugging Face access token (default: value from HF_TOKEN environment variable)
(env: HF_TOKEN) | -| `-ld, --logdir LOGDIR` | path under which to save YAML logs (no logging if unset) | | `--log-disable` | Log disable | | `--log-file FNAME` | Log to file | | `--log-colors` | Enable colored logging
(env: LLAMA_LOG_COLORS) | diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 4fb78e59f..a82818d60 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -128,14 +128,9 @@ option(GGML_LLAMAFILE "ggml: use LLAMAFILE" option(GGML_CUDA "ggml: use CUDA" OFF) option(GGML_MUSA "ggml: use MUSA" OFF) -option(GGML_CUDA_FORCE_DMMV "ggml: use dmmv instead of mmvq CUDA kernels" OFF) option(GGML_CUDA_FORCE_MMQ "ggml: use mmq kernels instead of cuBLAS" OFF) option(GGML_CUDA_FORCE_CUBLAS "ggml: always use cuBLAS instead of mmq kernels" OFF) -set (GGML_CUDA_DMMV_X "32" CACHE STRING "ggml: x stride for dmmv CUDA kernels") -set (GGML_CUDA_MMV_Y "1" CACHE STRING "ggml: y block size for mmv CUDA kernels") option(GGML_CUDA_F16 "ggml: use 16 bit floats for some calculations" OFF) -set (GGML_CUDA_KQUANTS_ITER "2" CACHE STRING - "ggml: iters./thread per block for Q2_K/Q6_K") set (GGML_CUDA_PEER_MAX_BATCH_SIZE "128" CACHE STRING "ggml: max. batch size for using peer access") option(GGML_CUDA_NO_PEER_COPY "ggml: do not use peer to peer copies" OFF) @@ -228,6 +223,7 @@ set(GGML_PUBLIC_HEADERS include/ggml-cann.h include/ggml-cuda.h include/ggml-kompute.h + include/ggml-opt.h include/ggml-metal.h include/ggml-rpc.h include/ggml-sycl.h diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h index 0a65dbfca..cef164764 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -86,7 +86,7 @@ extern "C" { GGML_API void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); - // "offset" refers to the offset of the tensor data for setting/getting data + // "offset" refers to the offset in tensor->data for setting/getting data GGML_API void ggml_backend_tensor_set( struct ggml_tensor * tensor, const void * data, size_t offset, size_t size); GGML_API void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size); GGML_API void ggml_backend_tensor_memset( struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size); @@ -242,14 +242,20 @@ extern "C" { ggml_backend_sched_reserve(sched, reserve_graph); // compute - graph = build_graph(sched); - ggml_backend_sched_graph_compute(sched, graph); + graph = build_graph(sched); // the graph and its tensors are single-use in terms of allocation, multi-use in terms of computation + for (int i = 0; i < 10; ++i) { + ggml_backend_sched_graph_compute(sched, graph); // on the first iteration the graph is allocated automatically + } // if there are graph inputs: - ggml_backend_sched_reset(sched); - ggml_backend_sched_alloc_graph(sched, graph); - ggml_backend_tensor_set(input_tensor, ...); - ggml_backend_sched_graph_compute(sched, graph); + graph = build_graph(sched); // get a new graph that is not allocated (the metadata for the old graph is freed once ggml_free is called) + ggml_backend_sched_reset(sched); // clear the allocation of the previous graph + ggml_backend_sched_alloc_graph(sched, graph); // explicitly allocate the new graph but do not execute it + ggml_backend_tensor_set(input_tensor, ...); // copy data to the newly allocated graph tensors + ggml_backend_sched_graph_compute(sched, graph); // execute the graph + + // as an alternative to the above it is also possible to assign the inputs to a dedicated context and + // allocate them statically via ggml_backend_alloc_ctx_tensors } */ @@ -264,7 +270,7 @@ extern "C" { // typedef bool (*ggml_backend_sched_eval_callback)(struct ggml_tensor * t, bool ask, void * user_data); - // Initialize a backend scheduler + // Initialize a backend scheduler, backends with low index are given priority over backends with high index GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel); GGML_API void ggml_backend_sched_free(ggml_backend_sched_t sched); @@ -289,7 +295,9 @@ extern "C" { GGML_API enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph); GGML_API void ggml_backend_sched_synchronize(ggml_backend_sched_t sched); - // Reset all assignments and allocators - must be called before changing the node backends + // Reset all assignments and allocators - must be called before changing the node backends or allocating a new graph. + // This in effect deallocates all tensors that were previously allocated and leaves them with dangling pointers. + // The correct way to use this API is to discard the deallocated tensors and create new ones. GGML_API void ggml_backend_sched_reset(ggml_backend_sched_t sched); // Set a callback to be called for each resulting node during graph compute diff --git a/ggml/include/ggml-opt.h b/ggml/include/ggml-opt.h new file mode 100644 index 000000000..eb5eab9de --- /dev/null +++ b/ggml/include/ggml-opt.h @@ -0,0 +1,216 @@ +// This file contains functionality for training models using GGML. +// It is not strictly needed vs. just vanilla GGML but it provides a more high-level interface for common needs such as datasets. +// At the bottom of this file especially there are relatively high-level functions that are suitable use or adaptation in user code. +// +// Module maintainer: Johannes Gäßler (@JohannesGaessler, johannesg@5d6.de) + +#pragma once + +#include "ggml.h" +#include "ggml-backend.h" + +#include + +#ifdef __cplusplus +extern "C" { +#endif + + struct ggml_opt_dataset; + struct ggml_opt_context; + struct ggml_opt_result; + + typedef struct ggml_opt_dataset * ggml_opt_dataset_t; + typedef struct ggml_opt_context * ggml_opt_context_t; + typedef struct ggml_opt_result * ggml_opt_result_t; + + // ====== Loss ====== + + // built-in loss types, i.e. the built-in quantities minimized by the optimizer + // custom loss types can be defined via mean or sum which simply reduce the outputs for all datapoints to a single value + enum ggml_opt_loss_type { + GGML_OPT_LOSS_TYPE_MEAN, + GGML_OPT_LOSS_TYPE_SUM, + GGML_OPT_LOSS_TYPE_CROSS_ENTROPY, + GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR, + }; + + // ====== Dataset ====== + + GGML_API ggml_opt_dataset_t ggml_opt_dataset_init( + int64_t ne_datapoint, // number of elements per datapoint + int64_t ne_label, // number of elements per label + int64_t ndata, // total number of datapoints/labels + int64_t ndata_shard); // number of datapoints/labels per shard (unit at which the dataset is shuffled/copied) + GGML_API void ggml_opt_dataset_free(ggml_opt_dataset_t dataset); + + // get underlying tensors that store the data + GGML_API struct ggml_tensor * ggml_opt_dataset_data (ggml_opt_dataset_t dataset); // shape = [ne_datapoint, ndata] + GGML_API struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset); // shape = [nd_label, ndata] + + // shuffle idata first datapoints from dataset with RNG from opt_ctx, shuffle all datapoints if idata is negative + GGML_API void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata); + + // get batch at position ibatch from dataset and copy the data to data_batch and labels_batch + GGML_API void ggml_opt_dataset_get_batch( + ggml_opt_dataset_t dataset, + struct ggml_tensor * data_batch, // shape = [ne_datapoint, ndata_batch] + struct ggml_tensor * labels_batch, // shape = [ne_label, ndata_batch] + int64_t ibatch); + + // ====== Model / Context ====== + + enum ggml_opt_build_type { + GGML_OPT_BUILD_TYPE_FORWARD, + GGML_OPT_BUILD_TYPE_GRAD, + GGML_OPT_BUILD_TYPE_OPT, + }; + + // parameters that control which optimizer is used and how said optimizer tries to find the minimal loss + struct ggml_opt_optimizer_params { + // AdamW optimizer parameters + struct { + float alpha; // learning rate + float beta1; + float beta2; + float eps; // epsilon for numerical stability + float wd; // weight decay for AdamW, use 0.0f to disable + } adamw; + }; + + // callback to calculate optimizer parameters prior to a backward pass + // userdata can be used to pass arbitrary data + typedef struct ggml_opt_optimizer_params (*ggml_opt_get_optimizer_params)(void * userdata); + + // returns the default optimizer params (constant) + // userdata is not used + GGML_API struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata); + + // parameters for initializing a new optimization context + struct ggml_opt_params { + ggml_backend_sched_t backend_sched; // defines which backends are used to construct the compute graphs + + struct ggml_context * ctx_compute; // created in user code, holds non-static tensors + + // the forward graph is defined by inputs and outputs + // those tensors and all tensors inbetween are not intended to be reusable between multiple optimization contexts + struct ggml_tensor * inputs; + struct ggml_tensor * outputs; + + enum ggml_opt_loss_type loss_type; + enum ggml_opt_build_type build_type; + + int32_t opt_period; // after how many gradient accumulation steps an optimizer step should be done + + ggml_opt_get_optimizer_params get_opt_pars; // callback for calculating optimizer parameters + void * get_opt_pars_ud; // userdata for calculating optimizer parameters + }; + + // get parameters for an optimization context with defaults set where possible + // parameters for which no sensible defaults exist are supplied as arguments to this function + GGML_API ggml_opt_params ggml_opt_default_params( + ggml_backend_sched_t backend_sched, + struct ggml_context * ctx_compute, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs, + enum ggml_opt_loss_type loss_type); + + GGML_API ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params); + GGML_API void ggml_opt_free(ggml_opt_context_t opt_ctx); + + // set gradients to zero, initilize loss, and optionally reset the optimizer + GGML_API void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer); + + // get underlying tensors that store data + GGML_API struct ggml_tensor * ggml_opt_inputs( ggml_opt_context_t opt_ctx); // forward graph input tensor + GGML_API struct ggml_tensor * ggml_opt_outputs( ggml_opt_context_t opt_ctx); // forward graph output tensor + GGML_API struct ggml_tensor * ggml_opt_labels( ggml_opt_context_t opt_ctx); // labels to compare outputs against + GGML_API struct ggml_tensor * ggml_opt_loss( ggml_opt_context_t opt_ctx); // scalar tensor that contains the loss + GGML_API struct ggml_tensor * ggml_opt_pred( ggml_opt_context_t opt_ctx); // predictions made by outputs + GGML_API struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx); // number of matching predictions between outputs and labels + + GGML_API struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node); + + // ====== Optimization Result ====== + + GGML_API ggml_opt_result_t ggml_opt_result_init(); + GGML_API void ggml_opt_result_free(ggml_opt_result_t result); + GGML_API void ggml_opt_result_reset(ggml_opt_result_t result); + + // get data from result, uncertainties are optional and can be ignored by passing NULL + GGML_API void ggml_opt_result_ndata( ggml_opt_result_t result, int64_t * ndata); // writes 1 value, number of datapoints + GGML_API void ggml_opt_result_loss( ggml_opt_result_t result, double * loss, double * unc); // writes 1 value + GGML_API void ggml_opt_result_pred( ggml_opt_result_t result, int32_t * pred); // writes ndata values + GGML_API void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc); // writes 1 value + + // ====== Computation ====== + + // do forward pass, increment result if not NULL + GGML_API void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + + // do forward pass, increment result if not NULL, do backward pass + GGML_API void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result_t result); + + // ############################################################################ + // ## The high-level functions start here. They do not depend on any private ## + // ## functions or structs and can be copied to and adapted for user code. ## + // ############################################################################ + + // ====== Intended Usage ====== + // + // 1. Select the appropriate loss for your problem. + // 2. Create a dataset and set the data for the "data" tensor. Also set the "labels" tensor if your loss needs them. + // Setting the shard size to 1 will be fine, it's the granularity with which data is shuffled/loaded (bigger values are faster). + // 3. Create a GGML graph for your model with no_alloc == true. Use two separate contexts for the tensors. + // The first context should contain the model parameters and inputs and be allocated statically in user code. + // The second context should contain all other tensors and will be (re)allocated automatically. + // Due to this automated allocation the data of the second context is not defined when accessed in user code. + // Note that the second dimension of the inputs/outputs are interpreted as the number of datapoints in those tensors. + // 4. Call ggml_opt_fit. If you need more control you can use ggml_opt_epoch instead. + + // signature for a callback while evaluating opt_ctx on dataset, called after an evaluation + typedef void (*ggml_opt_epoch_callback)( + bool train, // true after training evaluation, false after validation evaluation + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, // result associated with the dataset subsection + int64_t ibatch, // number of batches that have been evaluated so far + int64_t ibatch_max, // total number of batches in this dataset subsection + int64_t t_start_us); // time at which the evaluation on the dataset subsection was started + + // do training on front of dataset, do evaluation only on back of dataset + GGML_API void ggml_opt_epoch( + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, // result to increment during training, ignored if NULL + ggml_opt_result_t result_eval, // result to increment during evaluation, ignored if NULL + int64_t idata_split, // data index at which to split training and evaluation + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval); + + // callback that prints a progress bar on stderr + GGML_API void ggml_opt_epoch_callback_progress_bar( + bool train, + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + int64_t ibatch, + int64_t ibatch_max, + int64_t t_start_us); + + // fit model defined by inputs and outputs to dataset + GGML_API void ggml_opt_fit( + ggml_backend_sched_t backend_sched, // backend scheduler for constructing the compute graphs + ggml_context * ctx_compute, // context with temporarily allocated tensors to calculate the outputs + ggml_tensor * inputs, // input tensor with shape [ne_datapoint, ndata_batch] + ggml_tensor * outputs, // output tensor, must have shape [ne_label, ndata_batch] if labels are used + ggml_opt_dataset_t dataset, // dataset with data and optionally also labels + enum ggml_opt_loss_type loss_type, // loss to minimize + ggml_opt_get_optimizer_params get_opt_pars, // callback to get optimizer params, userdata is pointer to epoch (of type int64_t) + int64_t nepoch, // how many times the dataset should be iterated over + int64_t nbatch_logical, // datapoints optimizer step, must be a multiple of ndata_batch in inputs/outputs + float val_split, // fraction of the dataset to use for validation, must be in [0.0f, 1.0f) + bool silent); // whether or not info prints to stderr should be suppressed + +#ifdef __cplusplus +} +#endif diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 3b3f6798a..69e6a2434 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -602,7 +602,6 @@ extern "C" { int32_t flags; - struct ggml_tensor * grad; struct ggml_tensor * src[GGML_MAX_SRC]; // source tensor and offset for views @@ -615,7 +614,7 @@ extern "C" { void * extra; // extra things e.g. for ggml-cuda.cu - // char padding[4]; + char padding[8]; }; static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); @@ -1985,28 +1984,20 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * grad, - float alpha, - float beta1, - float beta2, - float eps, - float wd); // weight decay + struct ggml_tensor * m, + struct ggml_tensor * v, + struct ggml_tensor * adamw_params); // parameters such a the learning rate // // automatic differentiation // - GGML_API void ggml_build_forward_expand (struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); - GGML_API void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate); - - GGML_API void ggml_build_opt_adamw( - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - float alpha, - float beta1, - float beta2, - float eps, - float wd); // weight decay + GGML_API void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * tensor); + GGML_API void ggml_build_backward_expand( + struct ggml_context * ctx_static, // context for static gradients (loss + gradient accumulation) + struct ggml_context * ctx_compute, // context for gradient computation + struct ggml_cgraph * cgraph, + bool accumulate); // whether or not gradients should be accumulated, requires static allocation of tensors in ctx_static // graph allocation in a context GGML_API struct ggml_cgraph * ggml_new_graph (struct ggml_context * ctx); // size = GGML_DEFAULT_GRAPH_SIZE, grads = false @@ -2026,7 +2017,9 @@ extern "C" { GGML_API size_t ggml_graph_overhead(void); GGML_API size_t ggml_graph_overhead_custom(size_t size, bool grads); - GGML_API struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name); + GGML_API struct ggml_tensor * ggml_graph_get_tensor (const struct ggml_cgraph * cgraph, const char * name); + GGML_API struct ggml_tensor * ggml_graph_get_grad (const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); + GGML_API struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node); GGML_API void ggml_graph_export(const struct ggml_cgraph * cgraph, const char * fname); GGML_API struct ggml_cgraph * ggml_graph_import(const char * fname, struct ggml_context ** ctx_data, struct ggml_context ** ctx_eval); @@ -2037,198 +2030,15 @@ extern "C" { // dump the graph into a file using the dot format GGML_API void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph * gf, const char * filename); - // build gradient checkpointing backward graph gb for gf using provided checkpoints - // gb_tmp will contain original backward graph with rewritten backward process nodes, - // but without the second forward pass nodes. - GGML_API void ggml_build_backward_gradient_checkpointing( - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - struct ggml_cgraph * gb_tmp, - struct ggml_tensor * * checkpoints, - int n_checkpoints); - // - // optimization - // - - // optimization methods - enum ggml_opt_type { - GGML_OPT_TYPE_ADAM, - GGML_OPT_TYPE_LBFGS, - }; - - // linesearch methods - enum ggml_linesearch { - GGML_LINESEARCH_DEFAULT = 1, - - GGML_LINESEARCH_BACKTRACKING_ARMIJO = 0, - GGML_LINESEARCH_BACKTRACKING_WOLFE = 1, - GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE = 2, - }; - - // optimization return values - enum ggml_opt_result { - GGML_OPT_RESULT_OK = 0, - GGML_OPT_RESULT_DID_NOT_CONVERGE, - GGML_OPT_RESULT_NO_CONTEXT, - GGML_OPT_RESULT_INVALID_WOLFE, - GGML_OPT_RESULT_FAIL, - GGML_OPT_RESULT_CANCEL, - - GGML_LINESEARCH_FAIL = -128, - GGML_LINESEARCH_MINIMUM_STEP, - GGML_LINESEARCH_MAXIMUM_STEP, - GGML_LINESEARCH_MAXIMUM_ITERATIONS, - GGML_LINESEARCH_INVALID_PARAMETERS, - }; - - typedef void (*ggml_opt_callback)(void * data, int accum_step, float * sched, bool * cancel); + // TODO these functions were sandwiched in the old optimization interface, is there a better place for them? typedef void (*ggml_log_callback)(enum ggml_log_level level, const char * text, void * user_data); // Set callback for all future logging events. // If this is not called, or NULL is supplied, everything is output on stderr. GGML_API void ggml_log_set(ggml_log_callback log_callback, void * user_data); - // optimization parameters - // - // see ggml.c (ggml_opt_default_params) for default values - // - struct ggml_opt_params { - enum ggml_opt_type type; - - size_t graph_size; - - int n_threads; - - // delta-based convergence test - // - // if past == 0 - disabled - // if past > 0: - // stop if |f(x) - f(x_past)| < delta * max(1, |f(x)|) - // - int past; - float delta; - - // maximum number of iterations without improvement - // - // if 0 - disabled - // if > 0: - // assume convergence if no cost improvement in this number of iterations - // - int max_no_improvement; - - bool print_forward_graph; - bool print_backward_graph; - - int n_gradient_accumulation; - - // ADAM parameters - struct { - int n_iter; - - float sched; // schedule multiplier (fixed, decay or warmup) - float decay; // weight decay for AdamW, use 0.0f to disable - int decay_min_ndim; // minimum number of tensor dimension to apply weight decay - float alpha; // learning rate - float beta1; - float beta2; - float eps; // epsilon for numerical stability - float eps_f; // epsilon for convergence test - float eps_g; // epsilon for convergence test - float gclip; // gradient clipping - } adam; - - // LBFGS parameters - struct { - int m; // number of corrections to approximate the inv. Hessian - int n_iter; - int max_linesearch; - - float eps; // convergence tolerance - float ftol; // line search tolerance - float wolfe; - float min_step; - float max_step; - - enum ggml_linesearch linesearch; - } lbfgs; - }; - - struct ggml_opt_context { - struct ggml_context * ctx; - struct ggml_opt_params params; - - int iter; - int64_t nx; // number of parameter elements - - bool just_initialized; - - float loss_before; - float loss_after; - - struct { - struct ggml_tensor * g; // current gradient - struct ggml_tensor * m; // first moment - struct ggml_tensor * v; // second moment - struct ggml_tensor * pf; // past function values - float fx_best; - float fx_prev; - int n_no_improvement; - } adam; - - struct { - struct ggml_tensor * x; // current parameters - struct ggml_tensor * xp; // previous parameters - struct ggml_tensor * g; // current gradient - struct ggml_tensor * gp; // previous gradient - struct ggml_tensor * d; // search direction - struct ggml_tensor * pf; // past function values - struct ggml_tensor * lmal; // the L-BFGS memory alpha - struct ggml_tensor * lmys; // the L-BFGS memory ys - struct ggml_tensor * lms; // the L-BFGS memory s - struct ggml_tensor * lmy; // the L-BFGS memory y - float fx_best; - float step; - int j; - int k; - int end; - int n_no_improvement; - } lbfgs; - }; - GGML_API struct ggml_tensor * ggml_set_zero(struct ggml_tensor * tensor); - GGML_API struct ggml_opt_params ggml_opt_default_params(enum ggml_opt_type type); - - // optimize the function defined by the tensor f - GGML_API enum ggml_opt_result ggml_opt( - struct ggml_context * ctx, - struct ggml_opt_params params, - struct ggml_tensor * f); - - // initialize optimizer context - GGML_API void ggml_opt_init( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_opt_params params, - int64_t nx); - - // continue optimizing the function defined by the tensor f - GGML_API enum ggml_opt_result ggml_opt_resume( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_tensor * f); - - // continue optimizing the function defined by the tensor f - GGML_API enum ggml_opt_result ggml_opt_resume_g( - struct ggml_context * ctx, - struct ggml_opt_context * opt, - struct ggml_tensor * f, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - ggml_opt_callback callback, - void * callback_data); - // // quantization // diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt index 71934c679..ae7d3abc8 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt @@ -207,9 +207,11 @@ add_library(ggml-base ../include/ggml-alloc.h ../include/ggml-backend.h ../include/ggml-cpp.h + ../include/ggml-opt.h ggml.c ggml-alloc.c ggml-backend.cpp + ggml-opt.cpp ggml-threading.cpp ggml-threading.h ggml-quants.c diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 041de9e3e..2b2240be8 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -466,18 +466,12 @@ static bool ggml_gallocr_is_own(ggml_gallocr_t galloc, struct ggml_tensor * t) { return ggml_gallocr_hash_get(galloc, t)->allocated; } -static void ggml_gallocr_set_node_offset(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id, size_t offset) { - struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); - hn->buffer_id = buffer_id; - hn->offset = offset; - hn->allocated = true; -} - static bool ggml_gallocr_is_allocated(ggml_gallocr_t galloc, struct ggml_tensor * t) { return t->data != NULL || ggml_gallocr_hash_get(galloc, t)->allocated; } static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor * node, int buffer_id) { + GGML_ASSERT(buffer_id >= 0); struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) { @@ -816,7 +810,11 @@ static void ggml_gallocr_init_tensor(ggml_gallocr_t galloc, struct ggml_tensor * } static bool ggml_gallocr_node_needs_realloc(ggml_gallocr_t galloc, struct ggml_tensor * node, struct tensor_alloc * talloc) { - size_t node_size = (node->data || node->view_src) ? 0 : ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node); + size_t node_size = 0; + if (!node->data && !node->view_src) { + GGML_ASSERT(talloc->buffer_id >= 0); // prevent segfault when misusing the API + node_size = ggml_backend_buft_get_alloc_size(galloc->bufts[talloc->buffer_id], node); + } return talloc->size_max >= node_size; } diff --git a/ggml/src/ggml-amx/ggml-amx.cpp b/ggml/src/ggml-amx/ggml-amx.cpp index 37da98539..8568e7965 100644 --- a/ggml/src/ggml-amx/ggml-amx.cpp +++ b/ggml/src/ggml-amx/ggml-amx.cpp @@ -317,8 +317,6 @@ static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const st const enum ggml_type type = src0->type; const int64_t ne0 = op->ne[0]; - bool is_training = src0->grad || src1->grad; - // amx kernels enables for Q4_0, Q4_1, Q8_0, F16 // Q4_K, Q5_K, Q6_K, IQ4_XS enabled for QK_K = 256 bool has_amx_kernels = qtype_has_amx_kernels(type) || (type == GGML_TYPE_F16); @@ -326,7 +324,6 @@ static bool ggml_backend_amx_device_supports_op(ggml_backend_dev_t dev, const st bool can_use_amx = is_contiguous_2d(src0) && // src0 must be contiguous is_contiguous_2d(src1) && // src1 must be contiguous - !is_training && // inference only src1->type == GGML_TYPE_F32 && // src1 must be float32 has_amx_kernels && // with amx kernel impls ne0 % (TILE_N * 2) == 0; // out_features is 32x diff --git a/ggml/src/ggml-backend.cpp b/ggml/src/ggml-backend.cpp index e48877ba8..9dcde8d11 100644 --- a/ggml/src/ggml-backend.cpp +++ b/ggml/src/ggml-backend.cpp @@ -279,7 +279,7 @@ void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, siz buf->iface.get_tensor(buf, tensor, data, offset, size); } -GGML_API void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { +void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; if (size == 0) { @@ -689,7 +689,7 @@ static int ggml_backend_sched_backend_id(ggml_backend_sched_t sched, ggml_backen } static int ggml_backend_sched_backend_from_buffer(ggml_backend_sched_t sched, const struct ggml_tensor * tensor, const struct ggml_tensor * op) { - ggml_backend_buffer_t buffer = tensor->buffer; + ggml_backend_buffer_t buffer = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; if (buffer == NULL) { return -1; } @@ -722,8 +722,6 @@ static char causes[GGML_DEFAULT_GRAPH_SIZE*16 + GGML_SCHED_MAX_SPLITS_DEBUG*GGML // returns the backend that should be used for the node based on the current locations static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, struct ggml_tensor * tensor) { - // TODO: use supports_op to check if the backend supports the op - // assign pre-allocated nodes to their backend int cur_backend_id = ggml_backend_sched_backend_from_buffer(sched, tensor, tensor); if (cur_backend_id != -1) { @@ -742,7 +740,7 @@ static int ggml_backend_sched_backend_id_from_cur(ggml_backend_sched_t sched, st if (tensor->buffer || (tensor->view_src && tensor->view_src->buffer)) { // since the tensor is pre-allocated, it cannot be moved to another backend - GGML_ABORT("pre-allocated tensor in a backend that cannot run the operation"); + GGML_ABORT("pre-allocated tensor (%s) in a backend that cannot run the operation", tensor->name); } // graph input @@ -886,6 +884,9 @@ static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct gg for (int i = 0; i < graph->n_nodes; i++) { struct ggml_tensor * node = graph->nodes[i]; int * node_backend_id = &tensor_backend_id(node); + if (ggml_is_view_op(node->op)) { + continue; + } // do not overwrite user assignments if (*node_backend_id == -1) { *node_backend_id = ggml_backend_sched_backend_id_from_cur(sched, node); @@ -1538,12 +1539,13 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * ggml_backend_sched_split_graph(sched, measure_graph); + ggml_backend_sched_synchronize(sched); + if (!ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids)) { return false; } ggml_backend_sched_reset(sched); - ggml_backend_sched_synchronize(sched); return true; } diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index 61f53cd01..0d23669c2 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2369,7 +2369,7 @@ void ggml_numa_init(enum ggml_numa_strategy numa_flag) { // figure out which node we're on uint current_cpu; int getcpu_ret = 0; -#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 28) || defined(__COSMOPOLITAN__) +#if __GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ > 33) || defined(__COSMOPOLITAN__) getcpu_ret = getcpu(¤t_cpu, &g_state.numa.current_node); #else // old glibc doesn't have a wrapper for this call. Fall back on direct syscall @@ -12216,11 +12216,16 @@ static void ggml_compute_forward_opt_step_adamw_f32( const struct ggml_compute_params * params, struct ggml_tensor * dst) { - const struct ggml_tensor * src0 = dst->src[0]; - const struct ggml_tensor * src0_grad = dst->src[1]; - const struct ggml_tensor * src0_grad_m = dst->src[2]; - const struct ggml_tensor * src0_grad_v = dst->src[3]; + const struct ggml_tensor * src0 = dst->src[0]; + const struct ggml_tensor * src0_grad = dst->src[1]; + const struct ggml_tensor * src0_grad_m = dst->src[2]; + const struct ggml_tensor * src0_grad_v = dst->src[3]; + const struct ggml_tensor * adamw_params = dst->src[4]; + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m)); + GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v)); + GGML_ASSERT(ggml_nelements(adamw_params) == 7); const int ith = params->ith; const int nth = params->nth; @@ -12237,16 +12242,14 @@ static void ggml_compute_forward_opt_step_adamw_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - /* const float gnorm = 1.0f; */ - int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t)); - const float alpha = ggml_get_op_params_f32(dst, 2); - const float beta1 = ggml_get_op_params_f32(dst, 3); - const float beta2 = ggml_get_op_params_f32(dst, 4); - const float eps = ggml_get_op_params_f32(dst, 5); - const float wd = ggml_get_op_params_f32(dst, 6); - - const float beta1h = alpha/(1.0f - powf(beta1, iter)); - const float beta2h = 1.0f/(1.0f - powf(beta2, iter)); + const float * adamw_params_ptr = ggml_get_data_f32(adamw_params); + const float alpha = adamw_params_ptr[0]; + const float beta1 = adamw_params_ptr[1]; + const float beta2 = adamw_params_ptr[2]; + const float eps = adamw_params_ptr[3]; + const float wd = adamw_params_ptr[4]; + const float beta1h = adamw_params_ptr[5]; + const float beta2h = adamw_params_ptr[6]; for (int ir = ir0; ir < ir1; ++ir) { const int64_t i03 = ir/(ne02*ne01); @@ -12270,17 +12273,9 @@ static void ggml_compute_forward_opt_step_adamw_f32( // The weight decay is applied independently of the Adam momenta m and v. // This is NOT equivalent to l2 regularization that adds w[i00]*w[i00] to the loss. // See: https://arxiv.org/pdf/1711.05101v3.pdf - w[i00] = w[i00]*(1.0f - alpha*wd) - mh/vh; + w[i00] = w[i00]*(1.0f - alpha*wd) - alpha*mh/vh; } } - - ggml_barrier(params->threadpool); - if (ith != 0) { - return; - } - - iter++; - memcpy(&dst->op_params[0], &iter, sizeof(int64_t)); } static void ggml_compute_forward_opt_step_adamw( diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index 40ed2bdf3..e592f7989 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -6,15 +6,18 @@ if (CUDAToolkit_FOUND) message(STATUS "CUDA Toolkit found") if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) - # 52 == lowest CUDA 12 standard - # 60 == FP16 CUDA intrinsics - # 61 == integer CUDA intrinsics - # 70 == compute capability at which unrolling a loop in mul_mat_q kernels is faster - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) + # native == GPUs available at build time + # 52 == Maxwell, lowest CUDA 12 standard + # 60 == P100, FP16 CUDA intrinsics + # 61 == Pascal, __dp4a instruction (per-byte integer dot product) + # 70 == V100, FP16 tensor cores + # 75 == Turing, int8 tensor cores + if (GGML_NATIVE AND CUDAToolkit_VERSION VERSION_GREATER_EQUAL "11.6") + set(CMAKE_CUDA_ARCHITECTURES "native") + elseif(GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) set(CMAKE_CUDA_ARCHITECTURES "60;61;70;75") else() set(CMAKE_CUDA_ARCHITECTURES "52;61;70;75") - #set(CMAKE_CUDA_ARCHITECTURES "OFF") # use this to compile much faster, but only F16 models work endif() endif() message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}") @@ -51,21 +54,12 @@ if (CUDAToolkit_FOUND) target_link_libraries(ggml-cuda PRIVATE ggml-base) target_include_directories(ggml-cuda PRIVATE . ..) - # TODO: change the definitions to this target only - - add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) - add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) if (GGML_CUDA_GRAPHS) add_compile_definitions(GGML_CUDA_USE_GRAPHS) endif() - if (GGML_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) - endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() @@ -78,10 +72,6 @@ if (CUDAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_VMM) endif() - if (DEFINED GGML_CUDA_DMMV_Y) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility - endif() - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) add_compile_definitions(GGML_CUDA_F16) endif() diff --git a/ggml/src/ggml-cuda/dmmv.cu b/ggml/src/ggml-cuda/dmmv.cu deleted file mode 100644 index 00e21b5d7..000000000 --- a/ggml/src/ggml-cuda/dmmv.cu +++ /dev/null @@ -1,699 +0,0 @@ -#include "dmmv.cuh" -#include "dequantize.cuh" -#include "convert.cuh" - -#ifndef K_QUANTS_PER_ITERATION -#define K_QUANTS_PER_ITERATION 2 -#else -static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2"); -#endif - -static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - - const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q2_K * x = (const block_q2_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...15 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int step = 16/K_QUANTS_PER_ITERATION; - - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...15 or 0...7 - - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 or 0...14 in steps of 2 - const int q_offset = 32*im + l0; - const int s_offset = 8*im; - const int y_offset = 128*im + l0; - - uint32_t aux[4]; - const uint8_t * d = (const uint8_t *)aux; - const uint8_t * m = (const uint8_t *)(aux + 2); - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * q = x[i].qs + q_offset; - - const float dall = __low2half(x[i].dm); - const float dmin = __high2half(x[i].dm); - - const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); - aux[0] = a[0] & 0x0f0f0f0f; - aux[1] = a[1] & 0x0f0f0f0f; - aux[2] = (a[0] >> 4) & 0x0f0f0f0f; - aux[3] = (a[1] >> 4) & 0x0f0f0f0f; - - float sum1 = 0, sum2 = 0; - for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) { - sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3) - + y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3) - + y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3) - + y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3) - + y[l+16] * d[1] * ((q[l+16] >> 0) & 3) - + y[l+48] * d[3] * ((q[l+16] >> 2) & 3) - + y[l+80] * d[5] * ((q[l+16] >> 4) & 3) - +y[l+112] * d[7] * ((q[l+16] >> 6) & 3); - sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6] - + y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7]; - - } - tmp += dall * sum1 - dmin * sum2; - - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q3_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q3_K * x = (const block_q3_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - - const uint16_t kmask1 = 0x0303; - const uint16_t kmask2 = 0x0f0f; - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int n = K_QUANTS_PER_ITERATION; // iterations in the inner loop - const int step = 16/K_QUANTS_PER_ITERATION; - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0....15 or 0...7 - - const uint8_t m = 1 << (4*im); - - const int l0 = n*in; // 0...15 or 0...14 in steps of 2 - const int q_offset = 32*im + l0; - const int y_offset = 128*im + l0; - - uint16_t utmp[4]; - const int8_t * s = (const int8_t *)utmp; - - const uint16_t s_shift = 4*im; - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * q = x[i].qs + q_offset; - const uint8_t * h = x[i].hmask + l0; - - const uint16_t * a = (const uint16_t *)x[i].scales; - utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4); - utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4); - utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4); - utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4); - - const float d = x[i].d; - - float sum = 0; - for (int l = 0; l < n; ++l) { - sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4)) - + y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4)) - + y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4)) - + y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4)); - sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4)) - + y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4)) - + y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4)) - + y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4)); - } - tmp += d * sum; - - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q4_K * x = (const block_q4_K *)vx + ib0; - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0,1 - - const int step = 8/K_QUANTS_PER_ITERATION; // 8 or 4 - - const int il = tid/step; // 0...3 - const int ir = tid - step*il; // 0...7 or 0...3 - const int n = 2 * K_QUANTS_PER_ITERATION; // 2 or 4 - - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; - - const int l0 = n*(2*ir + in); - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; - - uint16_t aux[4]; - const uint8_t * sc = (const uint8_t *)aux; - -#if K_QUANTS_PER_ITERATION == 2 - uint32_t q32[4]; - const uint8_t * q4 = (const uint8_t *)q32; -#else - uint16_t q16[4]; - const uint8_t * q4 = (const uint8_t *)q16; -#endif - - float tmp = 0; // partial sum for thread in warp - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y1 = yy + i*QK_K + y_offset; - const float * y2 = y1 + 128; - - const float dall = __low2half(x[i].dm); - const float dmin = __high2half(x[i].dm); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux[0] = a[im+0] & kmask1; - aux[1] = a[im+2] & kmask1; - aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); - aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); - -#if K_QUANTS_PER_ITERATION == 2 - const uint32_t * q1 = (const uint32_t *)(x[i].qs + q_offset); - const uint32_t * q2 = q1 + 16; - - q32[0] = q1[0] & 0x0f0f0f0f; - q32[1] = q1[0] & 0xf0f0f0f0; - q32[2] = q2[0] & 0x0f0f0f0f; - q32[3] = q2[0] & 0xf0f0f0f0; - - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < 4; ++l) { - s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+ 4]; - s.z += y2[l] * q4[l+8]; s.w += y2[l+32] * q4[l+12]; - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; - } - tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; -#else - const uint16_t * q1 = (const uint16_t *)(x[i].qs + q_offset); - const uint16_t * q2 = q1 + 32; - - q16[0] = q1[0] & 0x0f0f; - q16[1] = q1[0] & 0xf0f0; - q16[2] = q2[0] & 0x0f0f; - q16[3] = q2[0] & 0xf0f0; - - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < 2; ++l) { - s.x += y1[l] * q4[l+0]; s.y += y1[l+32] * q4[l+2]; - s.z += y2[l] * q4[l+4]; s.w += y2[l+32] * q4[l+6]; - smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7]; - } - tmp += dall * (s.x * sc[0] + s.y * sc[1] * 1.f/16.f + s.z * sc[4] + s.w * sc[5] * 1.f/16.f) - dmin * smin; -#endif - - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (tid == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols) { - - const int row = blockIdx.x; - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q5_K * x = (const block_q5_K *)vx + ib0; - - float tmp = 0; // partial sum for thread in warp - - const uint16_t kmask1 = 0x3f3f; - const uint16_t kmask2 = 0x0f0f; - const uint16_t kmask3 = 0xc0c0; - - const int tid = threadIdx.x/2; // 0...15 - const int ix = threadIdx.x%2; - - const int il = tid/4; // 0...3 - const int ir = tid - 4*il;// 0...3 - const int n = 2; - - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; - - const int l0 = n*(2*ir + in); - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; - - const uint8_t hm1 = 1 << (2*im); - const uint8_t hm2 = hm1 << 4; - - uint16_t aux[4]; - const uint8_t * sc = (const uint8_t *)aux; - - uint16_t q16[8]; - const uint8_t * q4 = (const uint8_t *)q16; - - for (int i = ix; i < num_blocks_per_row; i += 2) { - - const uint8_t * ql1 = x[i].qs + q_offset; - const uint8_t * qh = x[i].qh + l0; - const float * y1 = yy + i*QK_K + y_offset; - const float * y2 = y1 + 128; - - const float dall = __low2half(x[i].dm); - const float dmin = __high2half(x[i].dm); - - const uint16_t * a = (const uint16_t *)x[i].scales; - aux[0] = a[im+0] & kmask1; - aux[1] = a[im+2] & kmask1; - aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2); - aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2); - - float4 sum = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - const uint16_t * q1 = (const uint16_t *)ql1; - const uint16_t * q2 = q1 + 32; - q16[0] = q1[0] & 0x0f0f; - q16[1] = q1[8] & 0x0f0f; - q16[2] = (q1[0] >> 4) & 0x0f0f; - q16[3] = (q1[8] >> 4) & 0x0f0f; - q16[4] = q2[0] & 0x0f0f; - q16[5] = q2[8] & 0x0f0f; - q16[6] = (q2[0] >> 4) & 0x0f0f; - q16[7] = (q2[8] >> 4) & 0x0f0f; - for (int l = 0; l < n; ++l) { - sum.x += y1[l+ 0] * (q4[l +0] + (qh[l+ 0] & (hm1 << 0) ? 16 : 0)) - + y1[l+16] * (q4[l +2] + (qh[l+16] & (hm1 << 0) ? 16 : 0)); - sum.y += y1[l+32] * (q4[l +4] + (qh[l+ 0] & (hm1 << 1) ? 16 : 0)) - + y1[l+48] * (q4[l +6] + (qh[l+16] & (hm1 << 1) ? 16 : 0)); - sum.z += y2[l+ 0] * (q4[l +8] + (qh[l+ 0] & (hm2 << 0) ? 16 : 0)) - + y2[l+16] * (q4[l+10] + (qh[l+16] & (hm2 << 0) ? 16 : 0)); - sum.w += y2[l+32] * (q4[l+12] + (qh[l+ 0] & (hm2 << 1) ? 16 : 0)) - + y2[l+48] * (q4[l+14] + (qh[l+16] & (hm2 << 1) ? 16 : 0)); - smin += (y1[l] + y1[l+16]) * sc[2] + (y1[l+32] + y1[l+48]) * sc[3] - + (y2[l] + y2[l+16]) * sc[6] + (y2[l+32] + y2[l+48]) * sc[7]; - } - tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin; - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[row] = tmp; - } -} - -static __global__ void dequantize_mul_mat_vec_q6_k(const void * __restrict__ vx, const float * __restrict__ yy, float * __restrict__ dst, const int ncols, int nrows) { - - static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION"); - - const int row = blockIdx.x*blockDim.y + threadIdx.y; - if (row > nrows) return; - - const int num_blocks_per_row = ncols / QK_K; - const int ib0 = row*num_blocks_per_row; - - const block_q6_K * x = (const block_q6_K *)vx + ib0; - - const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16 - const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1 - - const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8 - - const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128... - const int in = tid - step*im; // 0...15 or 0...7 - -#if K_QUANTS_PER_ITERATION == 1 - const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15 - const int is = 0; -#else - const int l0 = 4 * in; // 0, 4, 8, ..., 28 - const int is = in / 4; -#endif - const int ql_offset = 64*im + l0; - const int qh_offset = 32*im + l0; - const int s_offset = 8*im + is; - const int y_offset = 128*im + l0; - - float tmp = 0; // partial sum for thread in warp - - for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) { - - const float * y = yy + i * QK_K + y_offset; - const uint8_t * ql = x[i].ql + ql_offset; - const uint8_t * qh = x[i].qh + qh_offset; - const int8_t * s = x[i].scales + s_offset; - - const float d = x[i].d; - -#if K_QUANTS_PER_ITERATION == 1 - float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32) - + y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32) - + y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32) - + y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32) - + y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32) - + y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32) - + y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32) - +y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32); - tmp += sum; -#else - float sum = 0; - for (int l = 0; l < 4; ++l) { - sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32) - + y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32) - + y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32) - + y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32); - } - tmp += sum; -#endif - - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (tid == 0) { - dst[row] = tmp; - } -} - -static __device__ void convert_f16(const void * vx, const int64_t ib, const int iqs, dfloat2 & v){ - const half * x = (const half *) vx; - // load 2 halfs into register in a single instruction - const half2 x_reg = *((half2 *) &(x[ib + iqs])); - // automatic half -> float type cast if dfloat == float - v.x = __low2float(x_reg); - v.y = __high2float(x_reg); -} - -static constexpr __device__ dequantize_kernel_t get_dequantize_kernel(ggml_type type) { - return type == GGML_TYPE_Q4_0 ? dequantize_q4_0 : - type == GGML_TYPE_Q4_1 ? dequantize_q4_1 : - type == GGML_TYPE_Q5_0 ? dequantize_q5_0 : - type == GGML_TYPE_Q5_1 ? dequantize_q5_1 : - type == GGML_TYPE_Q8_0 ? dequantize_q8_0 : - type == GGML_TYPE_F16 ? convert_f16 : - nullptr; -} - -template -static __global__ void dequantize_mul_mat_vec(const void * __restrict__ vx, const dfloat * __restrict__ y, float * __restrict__ dst, const int ncols, const int nrows) { - constexpr int qk = ggml_cuda_type_traits::qk; // quantized weights per x block - constexpr int qr = ggml_cuda_type_traits::qr; // number of quantized weights per data value in x block - constexpr dequantize_kernel_t dequantize_kernel = get_dequantize_kernel(type); - - const int64_t row = (int64_t)blockIdx.x*blockDim.y + threadIdx.y; - - if (row >= nrows) { - return; - } - - const int tid = threadIdx.x; - - const int iter_stride = 2*GGML_CUDA_DMMV_X; - const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter - const int y_offset = qr == 1 ? 1 : qk/2; - -// partial sum for each thread -#ifdef GGML_CUDA_F16 - half2 tmp = {0.0f, 0.0f}; // two sums for f16 to take advantage of half2 intrinsics -#else - float tmp = 0.0f; -#endif // GGML_CUDA_F16 - - for (int i = 0; i < ncols; i += iter_stride) { - const int col = i + vals_per_iter*tid; - const int64_t ib = ((int64_t)row*ncols + col)/qk; // x block index - const int iqs = (col%qk)/qr; // x quant index - const int iybs = col - col%qk; // y block start index - -// processing >2 values per i iter is faster for fast GPUs -#pragma unroll - for (int j = 0; j < vals_per_iter; j += 2) { - // process 2 vals per j iter - - // dequantize - // for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val - dfloat2 v; - dequantize_kernel(vx, ib, iqs + j/qr, v); - - // matrix multiplication - // for qr = 2 the y index needs to increase by 1 per j iter because of y_offset = qk/2 -#ifdef GGML_CUDA_F16 - if ( y_offset == 1 ) { - // load 2 dfloats into register in a single instruction - const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr])); - tmp += __hmul2(v, y_reg); - } - else { - tmp += __hmul2(v, { - y[iybs + iqs + j/qr + 0], - y[iybs + iqs + j/qr + y_offset] - }); - } -#else - if ( y_offset == 1 ) { - // load 2 dfloats into register in a single instruction - const dfloat2 y_reg = *((dfloat2 *) &(y[iybs + iqs + j/qr])); - tmp += v.x * y_reg.x; - tmp += v.y * y_reg.y; - } - else { - tmp += v.x * y[iybs + iqs + j/qr + 0]; - tmp += v.y * y[iybs + iqs + j/qr + y_offset]; - } -#endif // GGML_CUDA_F16 - } - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (tid == 0) { -#ifdef GGML_CUDA_F16 - dst[row] = tmp.x + tmp.y; -#else - dst[row] = tmp; -#endif // GGML_CUDA_F16 - } -} - -static void dequantize_mul_mat_vec_q4_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - // the number of rows may exceed maximum grid size in the y or z dimensions, use the x dimension instead - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q4_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q5_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q5_1_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q8_0_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2; // very slightly faster than 1 even when K_QUANTS_PER_ITERATION = 2 - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q2_k<<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q3_k<<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q4_k<<>>(vx, y, dst, ncols, nrows); -} - -static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const dim3 block_dims(32, 1, 1); - dequantize_mul_mat_vec_q5_k<<>>(vx, y, dst, ncols); -} - -static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % QK_K == 0); - const int ny = 2 / K_QUANTS_PER_ITERATION; - const int block_num_y = (nrows + ny - 1) / ny; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(32, ny, 1); - dequantize_mul_mat_vec_q6_k<<>>(vx, y, dst, ncols, nrows); -} - -static void convert_mul_mat_vec_f16_cuda(const void * vx, const dfloat * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { - GGML_ASSERT(ncols % (GGML_CUDA_DMMV_X*2) == 0); - const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; - const dim3 block_nums(block_num_y, 1, 1); - const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); - dequantize_mul_mat_vec - <<>>(vx, y, dst, ncols, nrows); -} - -void ggml_cuda_op_dequantize_mul_mat_vec( - ggml_backend_cuda_context & ctx, - const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, - const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, - const int64_t src1_padded_row_size, cudaStream_t stream) { - GGML_UNUSED(ctx); - const int64_t ne00 = src0->ne[0]; - const int64_t row_diff = row_high - row_low; - - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - // on some GPUs it is faster to convert src1 to half and to use half precision intrinsics -#ifdef GGML_CUDA_F16 - ggml_cuda_pool_alloc src1_dfloat_a(ctx.pool()); - half * src1_dfloat = nullptr; // dfloat == half - - bool src1_convert_f16 = - src0->type == GGML_TYPE_Q4_0 || src0->type == GGML_TYPE_Q4_1 || - src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_1 || - src0->type == GGML_TYPE_Q8_0 || src0->type == GGML_TYPE_F16; - - if (src1_convert_f16) { - src1_dfloat = src1_dfloat_a.alloc(ne00); - const to_fp16_cuda_t to_fp16_cuda = ggml_get_to_fp16_cuda(src1->type); - GGML_ASSERT(to_fp16_cuda != nullptr); - to_fp16_cuda(src1_ddf_i, src1_dfloat, ne00, stream); - } -#else - const dfloat * src1_dfloat = (const dfloat *) src1_ddf_i; // dfloat == float, no conversion -#endif // GGML_CUDA_F16 - - switch (src0->type) { - case GGML_TYPE_Q4_0: - dequantize_mul_mat_vec_q4_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_1: - dequantize_mul_mat_vec_q4_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_0: - dequantize_mul_mat_vec_q5_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_1: - dequantize_mul_mat_vec_q5_1_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q8_0: - dequantize_mul_mat_vec_q8_0_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q2_K: - dequantize_mul_mat_vec_q2_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q3_K: - dequantize_mul_mat_vec_q3_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q4_K: - dequantize_mul_mat_vec_q4_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q5_K: - dequantize_mul_mat_vec_q5_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_Q6_K: - dequantize_mul_mat_vec_q6_K_cuda(src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stream); - break; - case GGML_TYPE_F16: - convert_mul_mat_vec_f16_cuda(src0_dd_i, src1_dfloat, dst_dd_i, ne00, row_diff, stream); - break; - default: - GGML_ABORT("fatal error"); - break; - } - - GGML_UNUSED(src1); - GGML_UNUSED(dst); - GGML_UNUSED(src1_ddq_i); - GGML_UNUSED(src1_ncols); - GGML_UNUSED(src1_padded_row_size); -} - -bool ggml_cuda_dmmv_type_supported(ggml_type src0_type) { - return src0_type == GGML_TYPE_Q4_0 || src0_type == GGML_TYPE_Q4_1 || - src0_type == GGML_TYPE_Q5_0 || src0_type == GGML_TYPE_Q5_1 || - src0_type == GGML_TYPE_Q8_0 || src0_type == GGML_TYPE_Q2_K || - src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q4_K || - src0_type == GGML_TYPE_Q5_K || src0_type == GGML_TYPE_Q6_K || - src0_type == GGML_TYPE_F16; -} diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 07f043328..ef56e944d 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -16,11 +16,11 @@ #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/diagmask.cuh" -#include "ggml-cuda/dmmv.cuh" #include "ggml-cuda/fattn.cuh" #include "ggml-cuda/getrows.cuh" #include "ggml-cuda/im2col.cuh" #include "ggml-cuda/mmq.cuh" +#include "ggml-cuda/mmv.cuh" #include "ggml-cuda/mmvq.cuh" #include "ggml-cuda/norm.cuh" #include "ggml-cuda/opt-step-adamw.cuh" @@ -1020,114 +1020,6 @@ typedef void (*ggml_cuda_op_mul_mat_t)( #define MUL_MAT_SRC1_COL_STRIDE 128 -static __global__ void mul_mat_p021_f16_f32( - const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, - const int ncols_x, const int nrows_x, const int nchannels_x, const int nchannels_y) { - - const half * x = (const half *) vx; - - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; - const int channel_x = channel / (nchannels_y / nchannels_x); - - const int nrows_y = ncols_x; - const int nrows_dst = nrows_x; - const int row_dst = row_x; - - float tmp = 0.0f; - - for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { - const int col_x = col_x0 + threadIdx.x; - - if (col_x >= ncols_x) { - break; - } - - // x is transposed and permuted - const int ix = row_x*nchannels_x*ncols_x + channel_x*ncols_x + col_x; - const float xi = __half2float(x[ix]); - - const int row_y = col_x; - - // y is not transposed but permuted - const int iy = channel*nrows_y + row_y; - - tmp += xi * y[iy]; - } - - // dst is not transposed and not permuted - const int idst = channel*nrows_dst + row_dst; - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[idst] = tmp; - } -} - -static __global__ void mul_mat_vec_nc_f16_f32( // nc == non-contiguous - const void * __restrict__ vx, const float * __restrict__ y, float * __restrict__ dst, const int ncols_x, const int nrows_x, - const int row_stride_x, const int channel_stride_x, const int channel_x_divisor) { - - const half * x = (const half *) vx; - - const int row_x = blockDim.y*blockIdx.y + threadIdx.y; - const int channel = blockDim.z*blockIdx.z + threadIdx.z; - const int channel_x = channel / channel_x_divisor; - - const int nrows_y = ncols_x; - const int nrows_dst = nrows_x; - const int row_dst = row_x; - - const int idst = channel*nrows_dst + row_dst; - - float tmp = 0.0f; - - for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) { - const int col_x = col_x0 + threadIdx.x; - - if (col_x >= ncols_x) { - break; - } - - const int row_y = col_x; - - const int ix = channel_x*channel_stride_x + row_x*row_stride_x + col_x; - const int iy = channel*nrows_y + row_y; - - const float xi = __half2float(x[ix]); - - tmp += xi * y[iy]; - } - - // sum up partial sums and write back result - tmp = warp_reduce_sum(tmp); - - if (threadIdx.x == 0) { - dst[idst] = tmp; - } -} - -static void ggml_mul_mat_p021_f16_f32_cuda( - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, - const int nchannels_x, const int nchannels_y, cudaStream_t stream) { - - const dim3 block_nums(1, nrows_x, nchannels_y); - const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_p021_f16_f32<<>>(vx, y, dst, ncols_x, nrows_x, nchannels_x, nchannels_y); -} - -static void ggml_mul_mat_vec_nc_f16_f32_cuda( - const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int row_stride_x, - const int nchannels_x, const int nchannels_y, const int channel_stride_x, cudaStream_t stream) { - - const dim3 block_nums(1, nrows_x, nchannels_y); - const dim3 block_dims(WARP_SIZE, 1, 1); - mul_mat_vec_nc_f16_f32<<>> - (vx, y, dst, ncols_x, nrows_x, row_stride_x, channel_stride_x, nchannels_y/nchannels_x); -} - static cudaError_t ggml_cuda_cpy_tensor_2d( void * dst, const struct ggml_tensor * src, int64_t i3, int64_t i2, int64_t i1_low, int64_t i1_high, cudaStream_t stream) { @@ -1654,58 +1546,6 @@ static void ggml_cuda_op_mul_mat( } } -static void ggml_cuda_mul_mat_vec_p021(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); - GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // 0213 permutation - GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // 0213 permutation - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - - const int64_t ne12 = src1->ne[2]; - - cudaStream_t main_stream = ctx.stream(); - - void * src0_ddq = src0->data; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; - - ggml_mul_mat_p021_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, ne02, ne12, main_stream); -} - -static void ggml_cuda_mul_mat_vec_nc(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { - GGML_ASSERT(!ggml_is_transposed(src0)); - GGML_ASSERT(!ggml_is_transposed(src1)); - GGML_ASSERT(!ggml_is_permuted(src0)); - GGML_ASSERT(ggml_backend_buffer_is_cuda(src0->buffer)); - GGML_ASSERT(src0->type == GGML_TYPE_F16); - GGML_ASSERT(src1->type == GGML_TYPE_F32); - - const int64_t ne00 = src0->ne[0]; - const int64_t ne01 = src0->ne[1]; - const int64_t ne02 = src0->ne[2]; - - const int64_t nb01 = src0->nb[1]; - const int64_t nb02 = src0->nb[2]; - - const int64_t ne12 = src1->ne[2]; - - cudaStream_t main_stream = ctx.stream(); - - void * src0_ddq = src0->data; - float * src1_ddf = (float *) src1->data; - float * dst_ddf = (float *) dst->data; - - const int64_t row_stride_x = nb01 / sizeof(half); - const int64_t channel_stride_x = nb02 / sizeof(half); - - ggml_mul_mat_vec_nc_f16_f32_cuda(src0_ddq, src1_ddf, dst_ddf, ne00, ne01, row_stride_x, ne02, ne12, channel_stride_x, main_stream); -} - static __global__ void k_compute_batched_ptrs( const half * src0_as_f16, const half * src1_as_f16, char * dst, const void ** ptrs_src, void ** ptrs_dst, @@ -1879,21 +1719,17 @@ static void ggml_cuda_mul_mat_batched_cublas(ggml_backend_cuda_context & ctx, co static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { const bool split = ggml_backend_buft_is_cuda_split(src0->buffer->buft); - bool use_dequantize_mul_mat_vec = ggml_cuda_dmmv_type_supported(src0->type) + bool use_mul_mat_vec = src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 - && src0->ne[0] % (GGML_CUDA_DMMV_X*2) == 0 && src1->ne[1] == 1; - bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) + && src0->ne[0] % 2 == 0 && src1->ne[1] == 1; + bool use_mul_mat_vec_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && src1->ne[1] <= MMVQ_MAX_BATCH_SIZE; - bool use_mul_mat_q = ggml_is_quantized(src0->type) + bool use_mul_mat_q = ggml_is_quantized(src0->type) && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32; - // if mmvq is available it's a better choice than dmmv: -#ifndef GGML_CUDA_FORCE_DMMV - use_dequantize_mul_mat_vec = use_dequantize_mul_mat_vec && !use_mul_mat_vec_q; -#endif // GGML_CUDA_FORCE_DMMV - - bool any_gpus_with_slow_fp16 = false; + bool any_gpus_with_slow_fp16 = false; + bool any_gpus_without_fp16_mma = false; if (split) { ggml_backend_cuda_split_buffer_type_context * buft_ctx = (ggml_backend_cuda_split_buffer_type_context *) src0->buffer->buft->context; @@ -1904,14 +1740,16 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor continue; } - const int cc = ggml_cuda_info().devices[id].cc; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + const int cc = ggml_cuda_info().devices[id].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc); } } else { - const int cc = ggml_cuda_info().devices[ctx.device].cc; - use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); - any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + const int cc = ggml_cuda_info().devices[ctx.device].cc; + use_mul_mat_q = use_mul_mat_q && ggml_cuda_should_use_mmq(src0->type, cc, src1->ne[1]); + any_gpus_with_slow_fp16 = any_gpus_with_slow_fp16 || !fast_fp16_available(cc); + any_gpus_without_fp16_mma = any_gpus_without_fp16_mma || !fp16_mma_available(cc); } // debug helpers @@ -1922,18 +1760,14 @@ static void ggml_cuda_mul_mat(ggml_backend_cuda_context & ctx, const ggml_tensor //printf("src0 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src0), ggml_is_transposed(src0), ggml_type_name(src0->type), src0->name); //printf("src1 is contiguous %d, transposed %d, type = %s, name = %s\n", ggml_is_contiguous(src1), ggml_is_transposed(src1), ggml_type_name(src1->type), src1->name); - if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && src1->ne[1] == 1) { - // FP32 precision KQ single-batch for batch size 1 without FlashAttention - ggml_cuda_mul_mat_vec_p021(ctx, src0, src1, dst); - } else if (!split && any_gpus_with_slow_fp16 && src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && src1->ne[1] == 1) { - // FP32 precision KQV single-batch for batch size 1 without FlashAttention - ggml_cuda_mul_mat_vec_nc(ctx, src0, src1, dst); + if (!split && src0->type == GGML_TYPE_F16 && src1->ne[1] == 1 && dst->ne[3] == 1 && (src0->ne[1] < MMV_MAX_ROWS || any_gpus_without_fp16_mma)) { + ggml_cuda_mul_mat_vec(ctx, src0, src1, dst); } else if (!split && src0->type == GGML_TYPE_F16 && (src1->type == GGML_TYPE_F16 || !any_gpus_with_slow_fp16) && !ggml_is_transposed(src0) && !ggml_is_transposed(src1) && src1->ne[2]*src1->ne[3] > 1) { // KQ + KQV multi-batch without FlashAttention ggml_cuda_mul_mat_batched_cublas(ctx, src0, src1, dst); - } else if (use_dequantize_mul_mat_vec) { - ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_dequantize_mul_mat_vec, nullptr); + } else if (use_mul_mat_vec) { + ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec, nullptr); } else if (use_mul_mat_vec_q) { ggml_cuda_op_mul_mat(ctx, src0, src1, dst, ggml_cuda_op_mul_mat_vec_q, quantize_row_q8_1_cuda); } else if (use_mul_mat_q) { diff --git a/ggml/src/ggml-cuda/mmv.cu b/ggml/src/ggml-cuda/mmv.cu new file mode 100644 index 000000000..cfe91f428 --- /dev/null +++ b/ggml/src/ggml-cuda/mmv.cu @@ -0,0 +1,223 @@ +#include "common.cuh" +#include "mmv.cuh" + +template +static __global__ void mul_mat_vec( + const half * __restrict__ x, const float * __restrict__ y, float * __restrict__ dst, const int64_t ncols2, const int64_t stride_row, + const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst) { + const int64_t row = blockIdx.x; + const int64_t channel = blockIdx.z; + const int tid = threadIdx.x; + + x += (channel/channel_ratio)*stride_channel_x + row*stride_row; + y += channel *stride_channel_y; + dst += channel *stride_channel_dst; + + const half2 * x2 = (const half2 *) x; + const float2 * y2 = (const float2 *) y; + + extern __shared__ char data_mmv[]; + float * buf_iw = (float *) data_mmv; + + if (block_size > WARP_SIZE) { + if (tid < WARP_SIZE) { + buf_iw[tid] = 0.0f; + } + __syncthreads(); + } + + float sumf; + + if (std::is_same::value) { + sumf = 0.0f; + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmpx = __half22float2(x2[col2]); + const float2 tmpy = y2[col2]; + sumf += tmpx.x * tmpy.x; + sumf += tmpx.y * tmpy.y; + } + } else { +#ifdef FP16_AVAILABLE + half2 sumh2 = make_half2(0.0f, 0.0f); + + for (int64_t col2 = tid; col2 < ncols2; col2 += block_size) { + const float2 tmp = y2[col2]; + sumh2 += x2[col2] * make_half2(tmp.x, tmp.y); + } + + sumf = __low2float(sumh2) + __high2float(sumh2); +#else + NO_DEVICE_CODE; +#endif // FP16_AVAILABLE + } + + sumf = warp_reduce_sum(sumf); + + if (block_size > WARP_SIZE) { + buf_iw[tid/WARP_SIZE] = sumf; + __syncthreads(); + if (tid > WARP_SIZE) { + return; + } + sumf = buf_iw[tid]; + sumf = warp_reduce_sum(sumf); + } + + if (tid != 0) { + return; + } + + dst[row] = sumf; +} + +template +static void launch_mul_mat_vec_cuda( + const half * x, const float * y, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + cudaStream_t stream) { + GGML_ASSERT(ncols % 2 == 0); + GGML_ASSERT(stride_row % 2 == 0); + GGML_ASSERT(nchannels_y % nchannels_x == 0); + const int64_t channel_ratio = nchannels_y / nchannels_x; + + int64_t block_size_best = WARP_SIZE; + int64_t niter_best = (ncols + 2*WARP_SIZE - 1) / (2*WARP_SIZE); + for (int64_t block_size = 2*WARP_SIZE; block_size <= 256; block_size += WARP_SIZE) { + const int64_t niter = (ncols + 2*block_size - 1) / (2*block_size); + if (niter < niter_best) { + niter_best = niter; + block_size_best = block_size; + } + } + + const int smem = WARP_SIZE*sizeof(float); + const dim3 block_nums(nrows, 1, nchannels_y); + const dim3 block_dims(block_size_best, 1, 1); + switch (block_size_best) { + case 32: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 64: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 96: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 128: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 160: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 192: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 224: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + case 256: { + mul_mat_vec<<>> + (x, y, dst, ncols/2, stride_row, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst); + } break; + default: { + GGML_ABORT("fatal error"); + } break; + } +} + +static void mul_mat_vec_cuda( + const half * x, const float * y, float * dst, + const int64_t ncols, const int64_t nrows, const int64_t stride_row, const int64_t nchannels_x, const int64_t nchannels_y, + const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, + enum ggml_prec prec, cudaStream_t stream) { + switch (prec) { + case GGML_PREC_DEFAULT: { + launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, + stride_channel_x, stride_channel_y, stride_channel_dst, stream); + } break; + case GGML_PREC_F32: { + launch_mul_mat_vec_cuda(x, y, dst, ncols, nrows, stride_row, nchannels_x, nchannels_y, + stride_channel_x, stride_channel_y, stride_channel_dst, stream); + } break; + } +} + +void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t ne01 = src0->ne[1]; + + GGML_ASSERT(src1->ne[1] == 1); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + const half * src0_d = (const half *) src0->data; + const float * src1_d = (const float *) src1->data; + float * dst_d = (float *) dst->data; + + const int64_t ne02 = src0->ne[2]; + const int64_t ne12 = src1->ne[2]; + GGML_ASSERT(dst->ne[2] == ne12); + + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(src1->ne[3] == 1); + GGML_ASSERT( dst->ne[3] == 1); + + const int64_t stride_row = src0->nb[1] / ggml_type_size(src0->type); + const int64_t channel_stride_x = src0->nb[2] / ggml_type_size(src0->type); + const int64_t channel_stride_y = src1->nb[2] / ggml_type_size(src1->type); + const int64_t channel_stride_dst = dst->nb[2] / ggml_type_size( dst->type); + + mul_mat_vec_cuda(src0_d, src1_d, dst_d, ne00, ne01, stride_row, ne02, ne12, channel_stride_x, channel_stride_y, channel_stride_dst, prec, ctx.stream()); +} + +void ggml_cuda_op_mul_mat_vec( + ggml_backend_cuda_context & ctx, + const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, + const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, + const int64_t src1_padded_row_size, cudaStream_t stream) { + + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + const int64_t ne00 = src0->ne[0]; + const int64_t row_diff = row_high - row_low; + + GGML_ASSERT(src1_ncols == 1); + + const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + const enum ggml_prec prec = fast_fp16_available(cc) ? ggml_prec(dst->op_params[0]) : GGML_PREC_F32; + + + // ggml_cuda_op provides single, contiguous matrices + const int64_t stride_row = ne00; + const int64_t nchannels_x = 1; + const int64_t nchannels_y = 1; + const int64_t channel_stride_x = 0; + const int64_t channel_stride_y = 0; + const int64_t channel_stride_dst = 0; + + mul_mat_vec_cuda((const half *) src0_dd_i, src1_ddf_i, dst_dd_i, ne00, row_diff, stride_row, + nchannels_x, nchannels_y, channel_stride_x, channel_stride_y, channel_stride_dst, prec, stream); + + GGML_UNUSED(ctx); + GGML_UNUSED(src1); + GGML_UNUSED(dst); + GGML_UNUSED(src1_ddq_i); + GGML_UNUSED(src1_ncols); + GGML_UNUSED(src1_padded_row_size); +} diff --git a/ggml/src/ggml-cuda/dmmv.cuh b/ggml/src/ggml-cuda/mmv.cuh similarity index 55% rename from ggml/src/ggml-cuda/dmmv.cuh rename to ggml/src/ggml-cuda/mmv.cuh index e727eb97f..78a1cd4a6 100644 --- a/ggml/src/ggml-cuda/dmmv.cuh +++ b/ggml/src/ggml-cuda/mmv.cuh @@ -1,20 +1,12 @@ #include "common.cuh" -// dmmv = dequantize_mul_mat_vec +// maximum number of src0 rows with which to use mul_mat_vec over cuBLAS if FP16 tensor cores are available +#define MMV_MAX_ROWS 512 -// TODO: remove this? -#ifndef GGML_CUDA_DMMV_X -#define GGML_CUDA_DMMV_X 32 -#endif +void ggml_cuda_mul_mat_vec(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); -#ifndef GGML_CUDA_MMV_Y -#define GGML_CUDA_MMV_Y 1 -#endif - -void ggml_cuda_op_dequantize_mul_mat_vec( +void ggml_cuda_op_mul_mat_vec( ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const char * src1_ddq_i, float * dst_dd_i, const int64_t row_low, const int64_t row_high, const int64_t src1_ncols, const int64_t src1_padded_row_size, cudaStream_t stream); - -bool ggml_cuda_dmmv_type_supported(ggml_type src0_type); diff --git a/ggml/src/ggml-cuda/opt-step-adamw.cu b/ggml/src/ggml-cuda/opt-step-adamw.cu index d6f13a9c6..35154f299 100644 --- a/ggml/src/ggml-cuda/opt-step-adamw.cu +++ b/ggml/src/ggml-cuda/opt-step-adamw.cu @@ -1,11 +1,11 @@ +#include "ggml-impl.h" #include "opt-step-adamw.cuh" #include static __global__ void opt_step_adamw_f32( - float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, const int64_t k, - const float alpha, const float beta1, const float beta2, const float eps, const float wd, - const float beta1h, const float beta2h) { + float * __restrict__ x, const float * __restrict__ g, float * __restrict__ g_m, float * __restrict__ g_v, + const float * __restrict__ pars, const int64_t k) { const int64_t i = (int64_t) blockIdx.x*blockDim.x + threadIdx.x; @@ -13,6 +13,14 @@ static __global__ void opt_step_adamw_f32( return; } + const float alpha = pars[0]; + const float beta1 = pars[1]; + const float beta2 = pars[2]; + const float eps = pars[3]; + const float wd = pars[4]; + const float beta1h = pars[5]; + const float beta2h = pars[6]; + const float gi = g[i]; const float gmi = g_m[i]*beta1 + gi*(1.0f - beta1); const float gvi = g_v[i]*beta2 + gi*gi*(1.0f - beta2); @@ -23,58 +31,48 @@ static __global__ void opt_step_adamw_f32( const float mh = gmi*beta1h; const float vh = sqrtf(gvi*beta2h) + eps; - x[i] = x[i]*(1.0f - alpha*wd) - mh/vh; + x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh; } static void opt_step_adamw_f32_cuda( - float * x, const float * g, float * g_m, float * g_v, const int64_t k, - const float alpha, const float beta1, const float beta2, const float eps, const float wd, - const float beta1h, const float beta2h, cudaStream_t stream) { + float * x, const float * g, float * g_m, float * g_v, const float * pars, const int64_t k, cudaStream_t stream) { const dim3 block_dims(CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1); const dim3 block_nums((k + CUDA_OPT_STEP_ADAMW_BLOCK_SIZE - 1) / CUDA_OPT_STEP_ADAMW_BLOCK_SIZE, 1, 1); - opt_step_adamw_f32<<>>(x, g, g_m, g_v, k, alpha, beta1, beta2, eps, wd, beta1h, beta2h); + opt_step_adamw_f32<<>>(x, g, g_m, g_v, pars, k); } void ggml_cuda_opt_step_adamw(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * src0 = dst->src[0]; - const ggml_tensor * src0_grad = dst->src[1]; - const ggml_tensor * src0_grad_m = dst->src[2]; - const ggml_tensor * src0_grad_v = dst->src[3]; + const ggml_tensor * src0 = dst->src[0]; + const ggml_tensor * src0_grad = dst->src[1]; + const ggml_tensor * src0_grad_m = dst->src[2]; + const ggml_tensor * src0_grad_v = dst->src[3]; + const ggml_tensor * adamw_params = dst->src[4]; - GGML_ASSERT(src0->type == GGML_TYPE_F32); - GGML_ASSERT(src0_grad->type == GGML_TYPE_F32); - GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32); - GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad_m->type == GGML_TYPE_F32); + GGML_ASSERT(src0_grad_v->type == GGML_TYPE_F32); + GGML_ASSERT(adamw_params->type == GGML_TYPE_F32); GGML_ASSERT(ggml_is_contiguous(src0)); GGML_ASSERT(ggml_is_contiguous(src0_grad)); GGML_ASSERT(ggml_is_contiguous(src0_grad_m)); GGML_ASSERT(ggml_is_contiguous(src0_grad_v)); + GGML_ASSERT(ggml_is_contiguous(adamw_params)); GGML_ASSERT(ggml_are_same_shape(src0, src0_grad)); GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_m)); GGML_ASSERT(ggml_are_same_shape(src0, src0_grad_v)); + GGML_ASSERT(ggml_nelements(adamw_params) == 7); - float * src0_d = (float *) src0->data; - const float * src0_grad_d = (const float *) src0_grad->data; - float * src0_grad_m_d = (float *) src0_grad_m->data; - float * src0_grad_v_d = (float *) src0_grad_v->data; + float * src0_d = (float *) src0->data; + const float * src0_grad_d = (const float *) src0_grad->data; + float * src0_grad_m_d = (float *) src0_grad_m->data; + float * src0_grad_v_d = (float *) src0_grad_v->data; + const float * adamw_params_d = (const float *) adamw_params->data; cudaStream_t stream = ctx.stream(); const int64_t ne = ggml_nelements(src0); - int64_t iter; memcpy(&iter, &dst->op_params[0], sizeof(int64_t)); - float alpha; memcpy(&alpha, &dst->op_params[2], sizeof(float)); - float beta1; memcpy(&beta1, &dst->op_params[3], sizeof(float)); - float beta2; memcpy(&beta2, &dst->op_params[4], sizeof(float)); - float eps; memcpy(&eps, &dst->op_params[5], sizeof(float)); - float wd; memcpy(&wd, &dst->op_params[6], sizeof(float)); - - const float beta1h = alpha/(1.0f - powf(beta1, iter)); - const float beta2h = 1.0f/(1.0f - powf(beta2, iter)); - - opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, ne, alpha, beta1, beta2, eps, wd, beta1h, beta2h, stream); - - iter++; - memcpy(&dst->op_params[0], &iter, sizeof(int64_t)); + opt_step_adamw_f32_cuda(src0_d, src0_grad_d, src0_grad_m_d, src0_grad_v_d, adamw_params_d, ne, stream); } diff --git a/ggml/src/ggml-hip/CMakeLists.txt b/ggml/src/ggml-hip/CMakeLists.txt index 5ed186ded..fccf8eb84 100644 --- a/ggml/src/ggml-hip/CMakeLists.txt +++ b/ggml/src/ggml-hip/CMakeLists.txt @@ -75,18 +75,11 @@ target_include_directories(ggml-hip PRIVATE . ..) target_compile_definitions(ggml PUBLIC GGML_USE_CUDA) add_compile_definitions(GGML_USE_HIP) -add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) -add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) -add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) if (GGML_HIP_UMA) add_compile_definitions(GGML_HIP_UMA) endif() -if (GGML_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) -endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index aa4d2b85d..92a64fe5a 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -196,7 +196,7 @@ void ggml_hash_set_reset(struct ggml_hash_set * hash_set); static bool ggml_hash_contains(const struct ggml_hash_set * hash_set, struct ggml_tensor * key); // returns GGML_HASHSET_FULL if table is full, otherwise the current index of the key or where it should be inserted -static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key); +static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key); // returns GGML_HASHSET_ALREADY_EXISTS if key already exists, index otherwise, asserts if table is full static size_t ggml_hash_insert(struct ggml_hash_set * hash_set, struct ggml_tensor * key); @@ -210,7 +210,7 @@ static inline size_t ggml_hash(const struct ggml_tensor * p) { return (size_t)(uintptr_t)p >> 4; } -static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, struct ggml_tensor * key) { +static size_t ggml_hash_find(const struct ggml_hash_set * hash_set, const struct ggml_tensor * key) { size_t h = ggml_hash(key) % hash_set->size; // linear probing @@ -281,13 +281,14 @@ enum ggml_cgraph_eval_order { }; struct ggml_cgraph { - int size; - int n_nodes; - int n_leafs; + int size; // maximum number of nodes/leafs/grads/grad_accs + int n_nodes; // number of nodes currently in use + int n_leafs; // number of leafs currently in use - struct ggml_tensor ** nodes; - struct ggml_tensor ** grads; - struct ggml_tensor ** leafs; + struct ggml_tensor ** nodes; // tensors with data that can change if the graph is evaluated + struct ggml_tensor ** grads; // the outputs of these tensors are the gradients of the nodes + struct ggml_tensor ** grad_accs; // accumulators for node gradients + struct ggml_tensor ** leafs; // tensors with constant data struct ggml_hash_set visited_hash_set; diff --git a/ggml/src/ggml-metal/CMakeLists.txt b/ggml/src/ggml-metal/CMakeLists.txt index e0992c744..b237d79f4 100644 --- a/ggml/src/ggml-metal/CMakeLists.txt +++ b/ggml/src/ggml-metal/CMakeLists.txt @@ -25,9 +25,10 @@ if (GGML_METAL_USE_BF16) add_compile_definitions(GGML_METAL_USE_BF16) endif() -# copy ggml-common.h and ggml-metal.metal to bin directory -configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) -configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) +# copy metal files to bin directory +configure_file(../ggml-common.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-common.h COPYONLY) +configure_file(ggml-metal.metal ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal.metal COPYONLY) +configure_file(ggml-metal-impl.h ${CMAKE_RUNTIME_OUTPUT_DIRECTORY}/ggml-metal-impl.h COPYONLY) if (GGML_METAL_EMBED_LIBRARY) enable_language(ASM) @@ -36,24 +37,27 @@ if (GGML_METAL_EMBED_LIBRARY) set(METALLIB_COMMON "${CMAKE_CURRENT_SOURCE_DIR}/../ggml-common.h") set(METALLIB_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal.metal") + set(METALLIB_IMPL "${CMAKE_CURRENT_SOURCE_DIR}/ggml-metal-impl.h") file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/autogenerated") # merge ggml-common.h and ggml-metal.metal into a single file - set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") - set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_EMBED_ASM "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.s") + set(METALLIB_SOURCE_EMBED "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal") + set(METALLIB_SOURCE_EMBED_TMP "${CMAKE_BINARY_DIR}/autogenerated/ggml-metal-embed.metal.tmp") add_custom_command( OUTPUT ${METALLIB_EMBED_ASM} COMMAND echo "Embedding Metal library" - COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED} + COMMAND sed -e '/__embed_ggml-common.h__/r ${METALLIB_COMMON}' -e '/__embed_ggml-common.h__/d' < ${METALLIB_SOURCE} > ${METALLIB_SOURCE_EMBED_TMP} + COMMAND sed -e '/\#include \"ggml-metal-impl.h\"/r ${METALLIB_IMPL}' -e '/\#include \"ggml-metal-impl.h\"/d' < ${METALLIB_SOURCE_EMBED_TMP} > ${METALLIB_SOURCE_EMBED} COMMAND echo ".section __DATA,__ggml_metallib" > ${METALLIB_EMBED_ASM} COMMAND echo ".globl _ggml_metallib_start" >> ${METALLIB_EMBED_ASM} COMMAND echo "_ggml_metallib_start:" >> ${METALLIB_EMBED_ASM} COMMAND echo ".incbin \\\"${METALLIB_SOURCE_EMBED}\\\"" >> ${METALLIB_EMBED_ASM} COMMAND echo ".globl _ggml_metallib_end" >> ${METALLIB_EMBED_ASM} COMMAND echo "_ggml_metallib_end:" >> ${METALLIB_EMBED_ASM} - DEPENDS ggml-metal.metal ../ggml-common.h + DEPENDS ../ggml-common.h ggml-metal.metal ggml-metal-impl.h COMMENT "Generate assembly for embedded Metal library" ) diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h new file mode 100644 index 000000000..53c135496 --- /dev/null +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -0,0 +1,249 @@ +#ifndef GGML_METAL_IMPL +#define GGML_METAL_IMPL + +// kernel argument structs +// +// - element counters (e.g. ne00) typically use int32_t to reduce register usage +// however, be careful from int overflows when using those in the kernel implementation +// +// - strides (e.g. nb00) use uint64_t + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t dim; +} ggml_metal_kargs_concat; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + uint64_t offs; +} ggml_metal_kargs_bin; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_repeat; + +typedef struct { + int64_t ne00; + int64_t ne01; + int64_t ne02; + int64_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int64_t ne0; + int64_t ne1; + int64_t ne2; + int64_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; +} ggml_metal_kargs_cpy; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + int32_t n_past; + int32_t n_dims; + int32_t n_ctx_orig; + float freq_base; + float freq_scale; + float ext_factor; + float attn_factor; + float beta_fast; + float beta_slow; +} ggml_metal_kargs_rope; + +typedef struct { + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb_12_1; + uint64_t nb_12_2; + uint64_t nb_12_3; + uint64_t nb31; + int32_t ne1; + int32_t ne2; + float scale; + float max_bias; + float m0; + float m1; + uint16_t n_head_log2; + float logit_softcap; +} ggml_metal_kargs_flash_attn_ext; + +typedef struct { + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mm; + +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne10; + int32_t ne11; + int32_t ne12; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + int32_t ne0; + int32_t ne1; + int16_t r2; + int16_t r3; +} ggml_metal_kargs_mul_mv; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne02; + uint64_t nb01; + uint64_t nb02; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; +} ggml_metal_kargs_mul_mm_id; + +typedef struct { + int32_t nei0; + int32_t nei1; + uint64_t nbi1; + int32_t ne00; + int32_t ne01; + int32_t ne02; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + int32_t ne10; + int32_t ne11; + int32_t ne12; + int32_t ne13; + uint64_t nb10; + uint64_t nb11; + uint64_t nb12; + int32_t ne0; + int32_t ne1; + uint64_t nb1; +} ggml_metal_kargs_mul_mv_id; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_norm; + +typedef struct { + int32_t ne00; + int32_t ne00_4; + uint64_t nb01; + float eps; +} ggml_metal_kargs_rms_norm; + +#endif // GGML_METAL_IMPL diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m index b4b5cfd26..58fee4bfd 100644 --- a/ggml/src/ggml-metal/ggml-metal.m +++ b/ggml/src/ggml-metal/ggml-metal.m @@ -2,6 +2,7 @@ #import "ggml-impl.h" #import "ggml-backend-impl.h" +#import "ggml-metal-impl.h" #import @@ -1193,35 +1194,39 @@ static void ggml_metal_encode_node( const int32_t dim = ((const int32_t *) dst->op_params)[0]; + ggml_metal_kargs_concat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.dim =*/ dim, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&dim length:sizeof(dim) atIndex:27]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; const int nth = MIN(1024, ne0); @@ -1239,8 +1244,6 @@ static void ggml_metal_encode_node( bool bcast_row = false; - int64_t nb = ne00; // used by the "row" kernels - id pipeline = nil; if (ggml_nelements(src1) == ne10 && ggml_is_contiguous(src1) && ne00 % 4 == 0 && ne10 % 4 == 0) { @@ -1249,7 +1252,6 @@ static void ggml_metal_encode_node( // src1 is a row GGML_ASSERT(ne11 == 1); - nb = ne00 / 4; switch (dst->op) { case GGML_OP_ADD: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD_ROW].pipeline; break; case GGML_OP_SUB: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUB_ROW].pipeline; break; @@ -1269,36 +1271,39 @@ static void ggml_metal_encode_node( } } + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.offs =*/ offs, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:24]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:25]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; - [encoder setBytes:&nb length:sizeof(nb) atIndex:28]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; if (bcast_row) { const int64_t n = ggml_nelements(dst)/4; @@ -1322,25 +1327,29 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); } + ggml_metal_kargs_repeat args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(nb2) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(nb3) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne0); @@ -1369,25 +1378,29 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; + ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); @@ -1396,35 +1409,39 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ADD].pipeline; + ggml_metal_kargs_bin args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ pnb1, + /*.nb02 =*/ pnb2, + /*.nb03 =*/ pnb3, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + /*.offs =*/ offs, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&ne03 length:sizeof(ne03) atIndex:6]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:7]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:8]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:9]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:10]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:11]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:12]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:13]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:14]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:15]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:16]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:17]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:18]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:19]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:20]; - [encoder setBytes:&ne2 length:sizeof(ne2) atIndex:21]; - [encoder setBytes:&ne3 length:sizeof(ne3) atIndex:22]; - [encoder setBytes:&nb0 length:sizeof(nb0) atIndex:23]; - [encoder setBytes:&pnb1 length:sizeof(pnb1) atIndex:24]; - [encoder setBytes:&pnb2 length:sizeof(pnb2) atIndex:25]; - [encoder setBytes:&pnb3 length:sizeof(pnb3) atIndex:26]; - [encoder setBytes:&offs length:sizeof(offs) atIndex:27]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; const int nth = MIN((int) pipeline.maxTotalThreadsPerThreadgroup, ne00); @@ -1465,10 +1482,10 @@ static void ggml_metal_encode_node( memcpy(&max, ((const int32_t *) dst->op_params) + 1, sizeof(float)); [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&min length:sizeof(min) atIndex:2]; - [encoder setBytes:&max length:sizeof(max) atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; + [encoder setBytes:&min length:sizeof(min) atIndex:2]; + [encoder setBytes:&max length:sizeof(max) atIndex:3]; const int64_t n = ggml_nelements(dst); @@ -1640,6 +1657,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SUM_ROWS].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1715,6 +1733,8 @@ static void ggml_metal_encode_node( const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + // TODO: add ggml_metal_kargs struct + // TODO: optimize (see https://github.com/ggerganov/llama.cpp/pull/10238/commits/7941b6b9ec29a2866fec6fa6c51612515ca509f6) [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; if (id_src1) { @@ -1731,6 +1751,7 @@ static void ggml_metal_encode_node( [encoder setBytes:&m0 length:sizeof(m0) atIndex:8]; [encoder setBytes:&m1 length:sizeof(m1) atIndex:9]; [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:10]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01*ne02*ne03, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; @@ -1747,6 +1768,7 @@ static void ggml_metal_encode_node( pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_DIAG_MASK_INF].pipeline; } + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -1771,6 +1793,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_CONV_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1841,6 +1864,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_SSM_SCAN_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -1959,24 +1983,29 @@ static void ggml_metal_encode_node( default: GGML_ABORT("MUL MAT-MAT not implemented"); } + ggml_metal_kargs_mul_mm args = { + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:5]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:6]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:7]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:8]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:9]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:10]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:11]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:12]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:15]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:16]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setThreadgroupMemoryLength:8192 atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake( (ne11 + 31)/32, (ne01 + 63)/64, ne12*ne13) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)]; } else { @@ -2154,28 +2183,32 @@ static void ggml_metal_encode_node( } }; + ggml_metal_kargs_mul_mv args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb13 =*/ nb13, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.r2 =*/ r2, + /*.r3 =*/ r3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(nb03) atIndex:9]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(nb13) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&r2 length:sizeof(r2) atIndex:19]; - [encoder setBytes:&r3 length:sizeof(r3) atIndex:20]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || src0t == GGML_TYPE_Q5_0 || src0t == GGML_TYPE_Q5_1 || src0t == GGML_TYPE_Q8_0 || src0t == GGML_TYPE_Q2_K || @@ -2288,27 +2321,30 @@ static void ggml_metal_encode_node( default: GGML_ABORT("MUL_MAT_ID not implemented"); } + ggml_metal_kargs_mul_mm_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne02 =*/ ne02, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:13]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:14]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:15]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:16]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:18]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:19]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; [encoder setThreadgroupMemoryLength:GGML_PAD(8192 + dst_rows*4/*sizeof(ushort2)*/, 16) atIndex:0]; @@ -2467,30 +2503,34 @@ static void ggml_metal_encode_node( GGML_ASSERT(ne00 >= nth0*nth1); } + ggml_metal_kargs_mul_mv_id args = { + /*.nei0 =*/ ne20, + /*.nei1 =*/ ne21, + /*.nbi1 =*/ nb21, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.ne10 =*/ ne10, + /*.ne11 =*/ ne11, + /*.ne12 =*/ ne12, + /*.ne13 =*/ ne13, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.nb1 =*/ nb1, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; - [encoder setBytes:&ne20 length:sizeof(ne20) atIndex:4]; - [encoder setBytes:&ne21 length:sizeof(ne21) atIndex:5]; - [encoder setBytes:&nb21 length:sizeof(nb21) atIndex:6]; - [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:7]; - [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:8]; - [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:9]; - [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:10]; - [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:11]; - [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:12]; - [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:13]; - [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:14]; - [encoder setBytes:&ne12 length:sizeof(ne12) atIndex:15]; - [encoder setBytes:&ne13 length:sizeof(ne13) atIndex:16]; - [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:17]; - [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:18]; - [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:19]; - [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:20]; - [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:21]; - [encoder setBytes:&nb1 length:sizeof(nb1) atIndex:22]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:4]; const int64_t _ne1 = 1; const int tgz = dst_rows; @@ -2563,6 +2603,7 @@ static void ggml_metal_encode_node( default: GGML_ABORT("not implemented"); } + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; @@ -2586,20 +2627,28 @@ static void ggml_metal_encode_node( float eps; memcpy(&eps, dst->op_params, sizeof(float)); + id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + int nth = 32; // SIMD width - while (nth < ne00/4 && nth < 1024) { + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { nth *= 2; } - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_RMS_NORM].pipeline; + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_rms_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); @@ -2624,6 +2673,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_GROUP_NORM].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2641,22 +2691,35 @@ static void ggml_metal_encode_node( } break; case GGML_OP_NORM: { + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(ggml_is_contiguous_1(src0)); float eps; memcpy(&eps, dst->op_params, sizeof(float)); - const int nth = MIN(256, ne00); - id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_NORM].pipeline; + int nth = 32; // SIMD width + + while (nth < ne00/4 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) { + nth *= 2; + } + + nth = MIN(nth, ne00/4); + + ggml_metal_kargs_norm args = { + /*.ne00 =*/ ne00, + /*.ne00_4 =*/ ne00/4, + /*.nb01 =*/ nb01, + /*.eps =*/ eps, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; - [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:GGML_PAD(nth*sizeof(float), 16) atIndex:0]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; + + [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); @@ -2706,40 +2769,44 @@ static void ggml_metal_encode_node( }; } + ggml_metal_kargs_rope args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.n_past =*/ n_past, + /*.n_dims =*/ n_dims, + /*.n_ctx_orig =*/ n_ctx_orig, + /*.freq_base =*/ freq_base, + /*.freq_scale =*/ freq_scale, + /*.ext_factor =*/ ext_factor, + /*.attn_factor =*/ attn_factor, + /*.beta_fast =*/ beta_fast, + /*.beta_slow =*/ beta_slow, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; if (id_src2 != nil) { - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:2]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:3]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:11]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:14]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:15]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:18]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:19]; - [encoder setBytes:&n_past length:sizeof( int) atIndex:20]; - [encoder setBytes:&n_dims length:sizeof( int) atIndex:21]; - [encoder setBytes:&n_ctx_orig length:sizeof( int) atIndex:22]; - [encoder setBytes:&freq_base length:sizeof( float) atIndex:23]; - [encoder setBytes:&freq_scale length:sizeof( float) atIndex:24]; - [encoder setBytes:&ext_factor length:sizeof( float) atIndex:25]; - [encoder setBytes:&attn_factor length:sizeof( float) atIndex:26]; - [encoder setBytes:&beta_fast length:sizeof( float) atIndex:27]; - [encoder setBytes:&beta_slow length:sizeof( float) atIndex:28]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -2796,6 +2863,7 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); }; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2836,6 +2904,7 @@ static void ggml_metal_encode_node( const id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_UPSCALE_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2870,6 +2939,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_PAD_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2906,6 +2976,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ARANGE_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_dst offset:offs_dst atIndex:0]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:1]; @@ -2927,6 +2998,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_TIMESTEP_EMBEDDING_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2965,6 +3037,7 @@ static void ggml_metal_encode_node( default: GGML_ABORT("fatal error"); }; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -2983,6 +3056,7 @@ static void ggml_metal_encode_node( id pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_LEAKY_RELU_F32].pipeline; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -3224,37 +3298,41 @@ static void ggml_metal_encode_node( } } + ggml_metal_kargs_flash_attn_ext args = { + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne11 =*/ ne11, + /*.ne_12_2 =*/ ne12, + /*.ne_12_3 =*/ ne13, + /*.nb_12_1 =*/ nb11, + /*.nb_12_2 =*/ nb12, + /*.nb_12_3 =*/ nb13, + /*.nb31 =*/ nb31, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.scale =*/ scale, + /*.max_bias =*/ max_bias, + /*.m0 =*/ m0, + /*.m1 =*/ m1, + /*.n_head_log2 =*/ n_head_log2, + /*.logit_softcap =*/ logit_softcap, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; - [encoder setBuffer:id_src2 offset:offs_src2 atIndex:2]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_src1 offset:offs_src1 atIndex:2]; + [encoder setBuffer:id_src2 offset:offs_src2 atIndex:3]; if (id_src3) { - [encoder setBuffer:id_src3 offset:offs_src3 atIndex:3]; + [encoder setBuffer:id_src3 offset:offs_src3 atIndex:4]; } else { - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:3]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:4]; } - [encoder setBuffer:id_dst offset:offs_dst atIndex:4]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:6]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:7]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:10]; - [encoder setBytes:&ne11 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne12 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne13 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb11 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb12 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb13 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb31 length:sizeof(uint64_t) atIndex:17]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:18]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:19]; - [encoder setBytes:&scale length:sizeof( float) atIndex:20]; - [encoder setBytes:&max_bias length:sizeof( float) atIndex:21]; - [encoder setBytes:&m0 length:sizeof(m0) atIndex:22]; - [encoder setBytes:&m1 length:sizeof(m1) atIndex:23]; - [encoder setBytes:&n_head_log2 length:sizeof(n_head_log2) atIndex:24]; - [encoder setBytes:&logit_softcap length:sizeof(logit_softcap) atIndex:25]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:5]; if (!use_vec_kernel) { // half8x8 kernel @@ -3389,25 +3467,29 @@ static void ggml_metal_encode_node( default: GGML_ABORT("not implemented"); } + ggml_metal_kargs_cpy args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + [encoder setComputePipelineState:pipeline]; - [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; - [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; - [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; - [encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; - [encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; - [encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; - [encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; - [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; - [encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; - [encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; - [encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; - [encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; - [encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; - [encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; - [encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; - [encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15]; - [encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16]; - [encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17]; + [encoder setBytes:&args length:sizeof(args) atIndex:0]; + [encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; + [encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; } break; @@ -3452,6 +3534,7 @@ static void ggml_metal_encode_node( const int64_t n_threads = MIN((int64_t)[pipeline maxTotalThreadsPerThreadgroup], parallel_elements); const int64_t n_tg = (parallel_elements + n_threads - 1) / n_threads; + // TODO: add ggml_metal_kargs struct [encoder setComputePipelineState:pipeline]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1]; @@ -3639,6 +3722,12 @@ static void * ggml_backend_metal_buffer_get_base(ggml_backend_buffer_t buffer) { return ctx->all_data; } +static void ggml_backend_metal_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + memset((char *)tensor->data + offset, value, size); + + UNUSED(buffer); +} + static void ggml_backend_metal_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { memcpy((char *)tensor->data + offset, data, size); @@ -3671,7 +3760,7 @@ static struct ggml_backend_buffer_i ggml_backend_metal_buffer_i = { /* .free_buffer = */ ggml_backend_metal_buffer_free_buffer, /* .get_base = */ ggml_backend_metal_buffer_get_base, /* .init_tensor = */ NULL, - /* .memset_tensor = */ NULL, + /* .memset_tensor = */ ggml_backend_metal_buffer_memset_tensor, /* .set_tensor = */ ggml_backend_metal_buffer_set_tensor, /* .get_tensor = */ ggml_backend_metal_buffer_get_tensor, /* .cpy_tensor = */ ggml_backend_metal_buffer_cpy_tensor, diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 8c7fcb113..86fdf1c18 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -6,6 +6,7 @@ __embed_ggml-common.h__ // TODO: this should not be a relative path, but can't figure out how to set Metal include paths in Package.swift #include "../ggml-common.h" #endif +#include "ggml-metal-impl.h" #include @@ -497,240 +498,131 @@ enum ggml_sort_order { // pros: works for non-contiguous tensors, supports broadcast across all dims // cons: not very efficient kernel void kernel_add( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) + *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) + *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_sub( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int64_t & offs, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01 + offs; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1 + offs; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) - *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_mul( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) * *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); } } kernel void kernel_div( + constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig.z; - const int64_t i02 = tgpig.y; - const int64_t i01 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; - const int64_t i13 = i03 % ne13; - const int64_t i12 = i02 % ne12; - const int64_t i11 = i01 % ne11; + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device const char * src1_ptr = src1 + i13*nb13 + i12*nb12 + i11*nb11; - device char * dst_ptr = dst + i03*nb3 + i02*nb2 + i01*nb1; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11; + device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i10 = i0 % ne10; - *((device float *)(dst_ptr + i0*nb0)) = *((device float *)(src0_ptr + i0*nb00)) / *((device float *)(src1_ptr + i10*nb10)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); } } template kernel void kernel_repeat( + constant ggml_metal_kargs_repeat & args, device const char * src0, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; - const int64_t i03 = i3 % ne03; - const int64_t i02 = i2 % ne02; - const int64_t i01 = i1 % ne01; + const int i03 = i3%args.ne03; + const int i02 = i2%args.ne02; + const int i01 = i1%args.ne01; - device const char * src0_ptr = src0 + i03*nb03 + i02*nb02 + i01*nb01; - device char * dst_ptr = dst + i3*nb3 + i2*nb2 + i1*nb1 ; + device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01; + device char * dst_ptr = dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - const int i00 = i0 % ne00; - *((device T *)(dst_ptr + i0*nb0)) = *((device T *)(src0_ptr + i00*nb00)); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i00 = i0%args.ne00; + *((device T *)(dst_ptr + i0*args.nb0)) = *((device T *)(src0_ptr + i00*args.nb00)); } } @@ -744,38 +636,42 @@ template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat // assumption: src1 is a row // broadcast src1 into src0 kernel void kernel_add_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] + src1[tpig % nb]; } kernel void kernel_sub_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] - src1[tpig % nb]; } kernel void kernel_mul_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] * src1[tpig % nb]; } kernel void kernel_div_row( + constant ggml_metal_kargs_bin & args, device const float4 * src0, device const float4 * src1, device float4 * dst, - constant uint64_t & nb [[buffer(28)]], uint tpig[[thread_position_in_grid]]) { + const uint nb = args.ne00/4; dst[tpig] = src0[tpig] / src1[tpig % nb]; } @@ -1345,102 +1241,112 @@ kernel void kernel_ssm_scan_f32( } kernel void kernel_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * sum [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); - // MEAN - // parallel sum - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - sum[tpitg] += x[i00]; + constant ggml_metal_kargs_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; } - // reduce - threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float mean = sum[0] / ne00; - // recenter and VARIANCE + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float4 sumf4(0.0f); + + float sumf = 0.0f; + + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf4 += x[i00]; + } + sumf = sumf4[0] + sumf4[1] + sumf4[2] + sumf4[3]; + sumf = simd_sum(sumf); + threadgroup_barrier(mem_flags::mem_threadgroup); - device float * y = dst + tgpig*ne00; - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float mean = sumf/args.ne00; + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + + sumf = 0.0f; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { y[i00] = x[i00] - mean; - sum[tpitg] += y[i00] * y[i00]; + sumf += dot(y[i00], y[i00]); } + sumf = simd_sum(sumf); - // reduce threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - } - const float variance = sum[0] / ne00; - const float scale = 1.0f/sqrt(variance + eps); - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; + } + + threadgroup_barrier(mem_flags::mem_threadgroup); + + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float variance = sumf/args.ne00; + + const float scale = 1.0f/sqrt(variance + args.eps); + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { y[i00] = y[i00] * scale; } } kernel void kernel_rms_norm( - device const void * src0, - device float * dst, - constant int64_t & ne00, - constant uint64_t & nb01, - constant float & eps, - threadgroup float * buf [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - uint tpitg[[thread_position_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint ntg[[threads_per_threadgroup]]) { - device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + constant ggml_metal_kargs_rms_norm & args, + device const char * src0, + device char * dst, + threadgroup float * shmem_f32 [[threadgroup(0)]], + uint tgpig[[threadgroup_position_in_grid]], + ushort tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort ntg[[threads_per_threadgroup]]) { + if (sgitg == 0) { + shmem_f32[tiisg] = 0.0f; + } - float4 sumf = 0; - float all_sum = 0; + device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + + float sumf = 0.0f; // parallel sum - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { - sumf += x[i00] * x[i00]; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + sumf += dot(x[i00], x[i00]); } - all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; - all_sum = simd_sum(all_sum); - if (ntg > N_SIMDWIDTH) { - if (sgitg == 0) { - buf[tiisg] = 0.0f; - } + sumf = simd_sum(sumf); - threadgroup_barrier(mem_flags::mem_threadgroup); + threadgroup_barrier(mem_flags::mem_threadgroup); - if (tiisg == 0) { - buf[sgitg] = all_sum; - } - - threadgroup_barrier(mem_flags::mem_threadgroup); - - all_sum = buf[tiisg]; - all_sum = simd_sum(all_sum); + if (tiisg == 0) { + shmem_f32[sgitg] = sumf; } - const float mean = all_sum/ne00; - const float scale = 1.0f/sqrt(mean + eps); + threadgroup_barrier(mem_flags::mem_threadgroup); - device float4 * y = (device float4 *) (dst + tgpig*ne00); - for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + sumf = shmem_f32[tiisg]; + sumf = simd_sum(sumf); + + const float mean = sumf/args.ne00; + const float scale = 1.0f/sqrt(mean + args.eps); + + device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; + for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { y[i00] = x[i00] * scale; } } @@ -1628,31 +1534,17 @@ inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thre // quantizations where the block size is 32. It also does not // guard against the number of rows not being divisible by // N_DST, so this is another explicit assumption of the implementation. -template +template void mul_vec_q_n_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { - const int nb = ne00/QK4_0; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + const int nb = args.ne00/QK4_0; const int r0 = tgpig.x; const int r1 = tgpig.y; @@ -1660,19 +1552,19 @@ void mul_vec_q_n_f32_impl( const int first_row = (r0 * nsg + sgitg) * nr; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - //device const block_q_type * x = (device const block_q_type *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + //device const block_q_type * x = (device const block_q_type *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows device const block_q_type * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q_type *) ((device char *) src0 + offset0); } @@ -1680,10 +1572,10 @@ void mul_vec_q_n_f32_impl( float yl[16]; // src1 vector cache float sumf[nr] = {0.f}; - const int ix = (tiisg/2); - const int il = (tiisg%2)*8; + const short ix = (tiisg/2); + const short il = (tiisg%2)*8; - device const float * yb = y + ix * QK4_0 + il; + device const float * yb = y + ix*QK4_0 + il; // each thread in a SIMD group deals with half a block. for (int ib = ix; ib < nb; ib += nw/2) { @@ -1708,324 +1600,216 @@ void mul_vec_q_n_f32_impl( yb += QK4_0 * 16; } + device float * dst_f32 = (device float *) dst + im*args.ne0*args.ne1 + r1*args.ne0; + for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[im*ne0*ne1 + r1*ne0 + first_row + row] = tot; + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; } } } kernel void kernel_mul_mv_q4_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q4_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } kernel void kernel_mul_mv_q5_1_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - mul_vec_q_n_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } - #define NB_Q8_0 8 +template void kernel_mul_mv_q8_0_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const int nr = N_DST; const int nsg = N_SIMDGROUP; const int nw = N_SIMDWIDTH; - const int nb = ne00/QK8_0; + const int nb = args.ne00/QK8_0; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * nsg + sgitg) * nr; + const int first_row = (r0*nsg + sgitg)*nr; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - //const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + //const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - //device const block_q8_0 * x = (device const block_q8_0 *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + //device const block_q8_0 * x = (device const block_q8_0 *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); // pointers to src0 rows device const block_q8_0 * ax[nr]; for (int row = 0; row < nr; ++row) { - const uint offset0 = (first_row + row)*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint64_t offset0 = (first_row + row)*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; ax[row] = (device const block_q8_0 *) ((device char *) src0 + offset0); } float yl[NB_Q8_0]; - float sumf[nr]={0.f}; + float sumf[nr] = { 0.f }; - const int ix = tiisg/4; - const int il = tiisg%4; + const short ix = tiisg/4; + const short il = tiisg%4; - device const float * yb = y + ix * QK8_0 + NB_Q8_0*il; + device const float * yb = y + ix*QK8_0 + il*NB_Q8_0; // each thread in a SIMD group deals with NB_Q8_0 quants at a time for (int ib = ix; ib < nb; ib += nw/4) { - for (int i = 0; i < NB_Q8_0; ++i) { + for (short i = 0; i < NB_Q8_0; ++i) { yl[i] = yb[i]; } for (int row = 0; row < nr; row++) { - device const int8_t * qs = ax[row][ib].qs + NB_Q8_0*il; + device const int8_t * qs = ax[row][ib].qs + il*NB_Q8_0; float sumq = 0.f; - for (int iq = 0; iq < NB_Q8_0; ++iq) { + for (short iq = 0; iq < NB_Q8_0; ++iq) { sumq += qs[iq] * yl[iq]; } sumf[row] += sumq*ax[row][ib].d; } - yb += NB_Q8_0 * nw; + yb += nw*NB_Q8_0; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < nr; ++row) { const float tot = simd_sum(sumf[row]); - if (tiisg == 0 && first_row + row < ne01) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + + if (tiisg == 0 && first_row + row < args.ne01) { + dst_f32[first_row + row] = tot; } } } [[host_name("kernel_mul_mv_q8_0_f32")]] kernel void kernel_mul_mv_q8_0_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q8_0_f32_impl(src0,src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,nullptr,tgpig,tiisg,sgitg); + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_q8_0_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } #define N_MV_T_T 4 -template +template void kernel_mul_mv_impl( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg) { - const int64_t r0 = tgpig.x; - const int64_t rb = tgpig.y*N_MV_T_T; - const int64_t im = tgpig.z; + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig, + ushort tiisg) { + const int r0 = tgpig.x; + const int rb = tgpig.y*N_MV_T_T; + const int im = tgpig.z; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T0 * x = (device const T0 *) (src0 + offset0); - if (ne00 < 128) { + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + + if (args.ne00 < 128) { for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; - if (r1 >= ne11) { + if (r1 >= args.ne11) { break; } - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); float sumf = 0; - for (int i = tiisg; i < ne00; i += 32) { + for (int i = tiisg; i < args.ne00; i += 32) { sumf += (T0) x[i] * (T1) y[i]; } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } else { device const T04 * x4 = (device const T04 *) x; for (int row = 0; row < N_MV_T_T; ++row) { int r1 = rb + row; - if (r1 >= ne11) { + if (r1 >= args.ne11) { break; } - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T1 * y = (device const T1 *) (src1 + offset1); device const T14 * y4 = (device const T14 *) y; float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], (float4) y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } @@ -2033,51 +1817,17 @@ void kernel_mul_mv_impl( template kernel void kernel_mul_mv( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_impl( + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + kernel_mul_mv_impl( + args, src0, src1, dst, - ne00, - ne01, - ne02, - nb00, - nb01, - nb02, - nb03, - ne10, - ne11, - ne12, - nb10, - nb11, - nb12, - nb13, - ne0, - ne1, - r2, - r3, tgpig, tiisg); } @@ -2094,65 +1844,50 @@ template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t kernel_mul_mv< template kernel void kernel_mul_mv_1row( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const T * x = (device const T *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + float sumf = 0; - if (ne00 < 128) { - for (int i = tiisg; i < ne00; i += 32) { + if (args.ne00 < 128) { + for (int i = tiisg; i < args.ne00; i += 32) { sumf += (float) x[i] * (float) y[i]; } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + dst_f32[r0] = all_sum; } } else { device const T4 * x4 = (device const T4 *) x; device const float4 * y4 = (device const float4 *) y; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - for (int i = 4*(ne00/4); i < ne00; ++i) all_sum += (float) (x[i] * y[i]); - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + for (int i = 4*(args.ne00/4); i < args.ne00; ++i) all_sum += (float) (x[i] * y[i]); + dst_f32[r0] = all_sum; } } } @@ -2167,54 +1902,39 @@ template [[host_name("kernel_mul_mv_bf16_f32_1row")]] kernel mul_mv_1row_t kerne // Assumes row size (ne00) is a multiple of 4 template kernel void kernel_mul_mv_l4( - device const char * src0, - device const char * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { - const int nrows = ne11; - const int64_t r0 = tgpig.x; - const int64_t im = tgpig.z; + const int nrows = args.ne11; + const int r0 = tgpig.x; + const int im = tgpig.z; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = r0*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; device const T4 * x4 = (device const T4 *) (src0 + offset0); + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; + for (int r1 = 0; r1 < nrows; ++r1) { - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; device const float4 * y4 = (device const float4 *) (src1 + offset1); float sumf = 0; - for (int i = tiisg; i < ne00/4; i += 32) { - for (int k = 0; k < 4; ++k) sumf += (float) (x4[i][k] * y4[i][k]); + for (int i = tiisg; i < args.ne00/4; i += 32) { + sumf += dot((float4) x4[i], y4[i]); } float all_sum = simd_sum(sumf); if (tiisg == 0) { - dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum; + dst_f32[(uint64_t)r1*args.ne0 + r0] = all_sum; } } } @@ -2234,7 +1954,7 @@ static float rope_yarn_ramp(const float low, const float high, const int i0) { // YaRN algorithm based on LlamaYaRNScaledRotaryEmbedding.py from https://github.com/jquesnelle/yarn // MIT licensed. Copyright (c) 2023 Jeffrey Quesnelle and Bowen Peng. static void rope_yarn( - float theta_extrap, float freq_scale, float corr_dims[2], int64_t i0, float ext_factor, float mscale, + float theta_extrap, float freq_scale, float corr_dims[2], int i0, float ext_factor, float mscale, thread float * cos_theta, thread float * sin_theta) { // Get n-d rotational scaling corrected for extrapolation float theta_interp = freq_scale * theta_extrap; @@ -2266,65 +1986,41 @@ static void rope_yarn_corr_dims( template kernel void kernel_rope_norm( - device const void * src0, - device const int32_t * src1, - device const float * src2, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & n_ctx_orig, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); - device const int32_t * pos = src1; + device const int32_t * pos = (device const int32_t *) src1; const float theta_base = (float) pos[i2]; - const float inv_ndims = -1.f/n_dims; + const float inv_ndims = -1.f/args.n_dims; float cos_theta; float sin_theta; - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - if (i0 < n_dims) { - const int64_t ic = i0/2; + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; - const float theta = theta_base * pow(freq_base, inv_ndims*i0); + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; - rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); const float x0 = src[0]; const float x1 = src[1]; @@ -2332,8 +2028,8 @@ kernel void kernel_rope_norm( dst_data[0] = x0*cos_theta - x1*sin_theta; dst_data[1] = x0*sin_theta + x1*cos_theta; } else { - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); dst_data[0] = src[0]; dst_data[1] = src[1]; @@ -2343,74 +2039,50 @@ kernel void kernel_rope_norm( template kernel void kernel_rope_neox( - device const void * src0, - device const int32_t * src1, - device const float * src2, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int & n_past, - constant int & n_dims, - constant int & n_ctx_orig, - constant float & freq_base, - constant float & freq_scale, - constant float & ext_factor, - constant float & attn_factor, - constant float & beta_fast, - constant float & beta_slow, - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg[[threads_per_threadgroup]], - uint3 tgpig[[threadgroup_position_in_grid]]) { - const int64_t i3 = tgpig[2]; - const int64_t i2 = tgpig[1]; - const int64_t i1 = tgpig[0]; + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; float corr_dims[2]; - rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); - device const int32_t * pos = src1; + device const int32_t * pos = (device const int32_t *) src1; const float theta_base = (float) pos[i2]; - const float inv_ndims = -1.f/n_dims; + const float inv_ndims = -1.f/args.n_dims; float cos_theta; float sin_theta; - for (int64_t i0 = 2*tiitg; i0 < ne0; i0 += 2*tptg.x) { - if (i0 < n_dims) { - const int64_t ic = i0/2; + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; - const float theta = theta_base * pow(freq_base, inv_ndims*i0); + const float theta = theta_base * pow(args.freq_base, inv_ndims*i0); - const float freq_factor = src2 != src0 ? src2[ic] : 1.0f; + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; - rope_yarn(theta/freq_factor, freq_scale, corr_dims, i0, ext_factor, attn_factor, &cos_theta, &sin_theta); + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + ic*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + ic*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); const float x0 = src[0]; - const float x1 = src[n_dims/2]; + const float x1 = src[args.n_dims/2]; - dst_data[0] = x0*cos_theta - x1*sin_theta; - dst_data[n_dims/2] = x0*sin_theta + x1*cos_theta; + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; } else { - device const T * const src = (device T *)((device char *) src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); - device T * dst_data = (device T *)((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); dst_data[0] = src[0]; dst_data[1] = src[1]; @@ -2808,37 +2480,17 @@ template< short KV = 8, // key/value processed per each simdgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int32_t & ne01, - constant int32_t & ne02, - constant int32_t & ne03, - constant uint32_t & nb01, - constant uint32_t & nb02, - constant uint32_t & nb03, - constant int32_t & ne11, - constant int32_t & ne_12_2, // assume K and V are same shape - constant int32_t & ne_12_3, - constant uint32_t & nb_12_1, - constant uint32_t & nb_12_2, - constant uint32_t & nb_12_3, - constant uint32_t & nb31, - constant int32_t & ne1, - constant int32_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint16_t & n_head_log2, - constant float & logit_softcap, - threadgroup half * shared [[threadgroup(0)]], - ushort3 tgpig[[threadgroup_position_in_grid]], - ushort3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups const int iq3 = tgpig[2]; @@ -2854,27 +2506,27 @@ kernel void kernel_flash_attn_ext( const short TS = nsg*SH; // shared memory size per query in (s_t == float) const short T = D + 2*TS; // shared memory size per query in (half) - threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t - threadgroup o_t * so = (threadgroup o_t *) (shared + 0*D); // reuse query data for accumulation - threadgroup o4_t * so4 = (threadgroup o4_t *) (shared + 0*D); // same as above but in o4_t - threadgroup s_t * ss = (threadgroup s_t *) (shared + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix + threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t + threadgroup o_t * so = (threadgroup o_t *) (shmem_f16 + 0*D); // reuse query data for accumulation + threadgroup o4_t * so4 = (threadgroup o4_t *) (shmem_f16 + 0*D); // same as above but in o4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + 2*sgitg*SH + Q*D); // scratch buffer for attention, mask and diagonal matrix - threadgroup k_t * sk = (threadgroup k_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory - threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t + threadgroup k_t * sk = (threadgroup k_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load K in shared memory + threadgroup k4x4_t * sk4x4 = (threadgroup k4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in k4x4_t - threadgroup v_t * sv = (threadgroup v_t *) (shared + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory - threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shared + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t + threadgroup v_t * sv = (threadgroup v_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // scratch buffer to load V in shared memory + threadgroup v4x4_t * sv4x4 = (threadgroup v4x4_t *) (shmem_f16 + sgitg*(4*16*KV) + Q*T); // same as above but in v4x4_t // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) o8x8_t lo[D8]; // load heads from Q to shared memory for (short j = sgitg; j < Q; j += nsg) { - device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*nb01 + iq2*nb02 + iq3*nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q + ((iq1 + j)*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); for (short i = tiisg; i < D4; i += NW) { - if (iq1 + j < ne01) { + if (iq1 + j < args.ne01) { sq4[j*D4 + i] = (q4_t) q4[i]; } else { sq4[j*D4 + i] = (q4_t) 0.0f; @@ -2907,11 +2559,11 @@ kernel void kernel_flash_attn_ext( const short ty = tiisg/4; // broadcast kv - //const short rk2 = ne02/ne12; - //const short rk3 = ne03/ne13; + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; - const short ikv2 = iq2/(ne02/ne_12_2); - const short ikv3 = iq3/(ne03/ne_12_3); + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); // load the queries from shared memory into local memory q8x8_t mq[D8]; @@ -2925,20 +2577,20 @@ kernel void kernel_flash_attn_ext( half slope = 1.0f; // ALiBi - if (max_bias > 0.0f) { + if (args.max_bias > 0.0f) { const short h = iq2; - const half base = h < n_head_log2 ? m0 : m1; - const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exph); } // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { const int ic = ic0 + C*sgitg; - if (ic >= ne11) { + if (ic >= args.ne11) { break; } @@ -2949,7 +2601,7 @@ kernel void kernel_flash_attn_ext( // load the mask in shared memory #pragma unroll(Q) for (short j = 0; j < Q; ++j) { - device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*nb31); + device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31); const half m = pm[ic + tiisg]; @@ -2972,18 +2624,18 @@ kernel void kernel_flash_attn_ext( // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const k_t * pk = (device const k_t *) ((device const char *) k + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); #pragma unroll(D8) for (short i = 0; i < D8; ++i) { k8x8_t mk; - simdgroup_load(mk, pk + i*8, nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 + simdgroup_load(mk, pk + i*8, args.nb_12_1/sizeof(k_t), 0, true); // transpose // TODO: use ne10 simdgroup_multiply_accumulate(mqk, mq[i], mk, mqk); } } else { for (short ii = 0; ii < D16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); if (D16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks @@ -3042,10 +2694,10 @@ kernel void kernel_flash_attn_ext( const half m = M[j]; // scale and apply the logitcap / mask - half s = ss[j*TS + tiisg]*scale; + half s = ss[j*TS + tiisg]*args.scale; - if (logit_softcap != 0.0f) { - s = logit_softcap*precise::tanh(s); + if (args.logit_softcap != 0.0f) { + s = args.logit_softcap*precise::tanh(s); } // mqk = mqk + mask*slope @@ -3087,18 +2739,18 @@ kernel void kernel_flash_attn_ext( if (is_same::value) { // we can read directly from global memory - device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const v_t * pv = (device const v_t *) ((device const char *) v + ((ic + 8*cc)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); #pragma unroll(D8) for (short i = 0; i < D8; ++i) { v8x8_t mv; - simdgroup_load(mv, pv + i*8, nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 + simdgroup_load(mv, pv + i*8, args.nb_12_1/sizeof(v_t), 0, false); // TODO: use ne20 simdgroup_multiply_accumulate(lo[i], ms, mv, lo[i]); } } else { for (short ii = 0; ii < D16; ii += 4) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); if (D16%4 == 0) { // no need for bound checks @@ -3227,11 +2879,11 @@ kernel void kernel_flash_attn_ext( // final rescale with 1/S and store to global memory if (sgitg == 0) { - for (short j = 0; j < Q && iq1 + j < ne01; ++j) { + for (short j = 0; j < Q && iq1 + j < args.ne01; ++j) { const float S = ss[j*TS + 0]; for (short i = tiisg; i < D4; i += NW) { - dst4[((int64_t)iq3*ne2*ne1 + iq2 + (iq1 + j)*ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; + dst4[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*D4 + i] = (float4) so4[j*D4 + i]/S; } } } @@ -3323,38 +2975,17 @@ template< short Q = 1, // queries per threadgroup short C = 32> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( - device const char * q, - device const char * k, - device const char * v, - device const char * mask, - device float * dst, - constant int32_t & ne01, - constant int32_t & ne02, - constant int32_t & ne03, - constant uint32_t & nb01, - constant uint32_t & nb02, - constant uint32_t & nb03, - constant int32_t & ne11, - constant int32_t & ne_12_2, // assume K and V are same shape - constant int32_t & ne_12_3, - constant uint32_t & nb_12_1, - constant uint32_t & nb_12_2, - constant uint32_t & nb_12_3, - constant uint32_t & nb31, - constant int32_t & ne1, - constant int32_t & ne2, - constant float & scale, - constant float & max_bias, - constant float & m0, - constant float & m1, - constant uint16_t & n_head_log2, - constant float & logit_softcap, - threadgroup half * shared [[threadgroup(0)]], - ushort3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_flash_attn_ext & args, + device const char * q, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + threadgroup half * shmem_f16 [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 ntg[[threads_per_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const short nsg = ntg.y; // number of simdgroups const int iq3 = tgpig[2]; @@ -3369,22 +3000,22 @@ kernel void kernel_flash_attn_ext_vec( const short T = D + nsg*SH; // shared memory size per query in (half) - //threadgroup q_t * sq = (threadgroup q_t *) (shared + 0*D); // holds the query data - threadgroup q4_t * sq4 = (threadgroup q4_t *) (shared + 0*D); // same as above but in q4_t - threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shared + 0*D); // same as above but in q4x4_t - threadgroup s_t * ss = (threadgroup s_t *) (shared + sgitg*SH + Q*D); // scratch buffer for attention - threadgroup s4_t * ss4 = (threadgroup s4_t *) (shared + sgitg*SH + Q*D); // same as above but in s4_t - threadgroup half * sm = (threadgroup half *) (shared + sgitg*SH + C + Q*D); // scratch buffer for mask - threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shared + sgitg*D + Q*T); // scratch buffer for the results + //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*D); // holds the query data + threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*D); // same as above but in q4_t + threadgroup q4x4_t * sq4x4 = (threadgroup q4x4_t *) (shmem_f16 + 0*D); // same as above but in q4x4_t + threadgroup s_t * ss = (threadgroup s_t *) (shmem_f16 + sgitg*SH + Q*D); // scratch buffer for attention + threadgroup s4_t * ss4 = (threadgroup s4_t *) (shmem_f16 + sgitg*SH + Q*D); // same as above but in s4_t + threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*D); // scratch buffer for mask + threadgroup o4x4_t * sr4x4 = (threadgroup o4x4_t *) (shmem_f16 + sgitg*D + Q*T); // scratch buffer for the results // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper) o4x4_t lo[D16/NL]; // load heads from Q to shared memory - device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*nb01 + iq2*nb02 + iq3*nb03)); + device const float4 * q4 = (device const float4 *) ((device const char *) q + (iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03)); for (short i = tiisg; i < D4; i += NW) { - if (iq1 < ne01) { + if (iq1 < args.ne01) { sq4[i] = (q4_t) q4[i]; } else { sq4[i] = (q4_t) 0.0f; @@ -3412,11 +3043,11 @@ kernel void kernel_flash_attn_ext_vec( const short ty = tiisg/NL; // broadcast kv - //const short rk2 = ne02/ne12; - //const short rk3 = ne03/ne13; + //const short rk2 = args.ne02/args.ne12; + //const short rk3 = args.ne03/args.ne13; - const short ikv2 = iq2/(ne02/ne_12_2); - const short ikv3 = iq3/(ne03/ne_12_3); + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); // load the queries from shared memory into local memory q4x4_t mq[D16/NL]; @@ -3429,25 +3060,25 @@ kernel void kernel_flash_attn_ext_vec( const bool has_mask = mask != q; // pointer to the mask - device const half * pm = (device const half *) (mask + iq1*nb31); + device const half * pm = (device const half *) (mask + iq1*args.nb31); half slope = 1.0f; // ALiBi - if (max_bias > 0.0f) { + if (args.max_bias > 0.0f) { const short h = iq2; - const half base = h < n_head_log2 ? m0 : m1; - const short exph = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1; + const half base = h < args.n_head_log2 ? args.m0 : args.m1; + const short exph = h < args.n_head_log2 ? h + 1 : 2*(h - args.n_head_log2) + 1; slope = pow(base, exph); } // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = 0; ic0 < ne11; ic0 += C*nsg) { + for (int ic0 = 0; ic0 < args.ne11; ic0 += C*nsg) { const int ic = ic0 + C*sgitg; - if (ic >= ne11) { + if (ic >= args.ne11) { break; } @@ -3461,7 +3092,7 @@ kernel void kernel_flash_attn_ext_vec( for (short cc = 0; cc < C/4; ++cc) { qk_t mqka[4] = { 0.0, 0.0, 0.0, 0.0 }; - device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const kd4x4_t * pk = (device const kd4x4_t *) ((device const char *) k + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); #pragma unroll(D16/NL) for (short ii = 0; ii < D16; ii += NL) { @@ -3497,10 +3128,10 @@ kernel void kernel_flash_attn_ext_vec( // mqk = mqk*scale + mask*slope if (tx == 0) { - mqk *= scale; + mqk *= args.scale; - if (logit_softcap != 0.0f) { - mqk = logit_softcap*precise::tanh(mqk); + if (args.logit_softcap != 0.0f) { + mqk = args.logit_softcap*precise::tanh(mqk); } mqk += sm[4*cc + ty]*slope; @@ -3539,7 +3170,7 @@ kernel void kernel_flash_attn_ext_vec( // O = O + (Q*K^T)*V { for (short cc = 0; cc < C/4; ++cc) { - device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*nb_12_1 + ikv2*nb_12_2 + ikv3*nb_12_3)); + device const vd4x4_t * pv4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 4*cc + ty)*args.nb_12_1 + ikv2*args.nb_12_2 + ikv3*args.nb_12_3)); const s4x4_t ms(ss[4*cc + ty]); @@ -3644,7 +3275,7 @@ kernel void kernel_flash_attn_ext_vec( const float S = ss[0]; for (short i = tiisg; i < D16; i += NW) { - dst44[((int64_t)iq3*ne2*ne1 + iq2 + (iq1)*ne1)*D16 + i] = (float4x4) sr4x4[i]/S; + dst44[((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)iq1*args.ne1)*D16 + i] = (float4x4) sr4x4[i]/S; } } } @@ -3686,42 +3317,27 @@ template [[host_name("kernel_flash_attn_ext_vec_q8_0_h256")]] kernel flash_attn_ template kernel void kernel_cpy( - device const void * src0, - device void * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { - const int64_t i03 = tgpig[2]; - const int64_t i02 = tgpig[1]; - const int64_t i01 = tgpig[0]; + constant ggml_metal_kargs_cpy & args, + device const char * src0, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig[2]; + const int i02 = tgpig[1]; + const int i01 = tgpig[0]; - const int64_t n = i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; + const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - const int64_t i3 = n / (ne2*ne1*ne0); - const int64_t i2 = (n - i3*ne2*ne1*ne0) / (ne1*ne0); - const int64_t i1 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0) / ne0; - const int64_t i0 = (n - i3*ne2*ne1*ne0 - i2*ne1*ne0 - i1*ne0); + const int64_t i3 = n/(args.ne2*args.ne1*args.ne0); + const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0)/(args.ne1*args.ne0); + const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0)/args.ne0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0); - device T1 * dst_data = (device T1 *) ((device char *) dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x; i00 < ne00; i00 += ntg.x) { - device const T0 * src = (device T0 *)((device char *) src0 + i03*nb03 + i02*nb02 + i01*nb01 + i00*nb00); + for (int64_t i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { + device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; } } @@ -3741,42 +3357,27 @@ template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy 0 ? sumqx/sumq2 : d; - } } kernel void kernel_concat( + constant ggml_metal_kargs_concat & args, device const char * src0, device const char * src1, device char * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant int64_t & ne03, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant int64_t & ne2, - constant int64_t & ne3, - constant uint64_t & nb0, - constant uint64_t & nb1, - constant uint64_t & nb2, - constant uint64_t & nb3, - constant int32_t & dim, - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - uint3 ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { - const int64_t i3 = tgpig.z; - const int64_t i2 = tgpig.y; - const int64_t i1 = tgpig.x; + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; - int64_t o[4] = {0, 0, 0, 0}; - o[dim] = dim == 0 ? ne00 : (dim == 1 ? ne01 : (dim == 2 ? ne02 : ne03)); + int o[4] = {0, 0, 0, 0}; + o[args.dim] = args.dim == 0 ? args.ne00 : (args.dim == 1 ? args.ne01 : (args.dim == 2 ? args.ne02 : args.ne03)); device const float * x; - for (int i0 = tpitg.x; i0 < ne0; i0 += ntg.x) { - if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { - x = (device const float *)(src0 + (i3 )*nb03 + (i2 )*nb02 + (i1 )*nb01 + (i0 )*nb00); + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + if (i0 < args.ne00 && i1 < args.ne01 && i2 < args.ne02 && i3 < args.ne03) { + x = (device const float *)(src0 + (i3 )*args.nb03 + (i2 )*args.nb02 + (i1 )*args.nb01 + (i0 )*args.nb00); } else { - x = (device const float *)(src1 + (i3 - o[3])*nb13 + (i2 - o[2])*nb12 + (i1 - o[1])*nb11 + (i0 - o[0])*nb10); + x = (device const float *)(src1 + (i3 - o[3])*args.nb13 + (i2 - o[2])*args.nb12 + (i1 - o[1])*args.nb11 + (i0 - o[0])*args.nb10); } - device float * y = (device float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); + device float * y = (device float *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); *y = *x; } } +template void kernel_mul_mv_q2_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q2_K * x = (device const block_q2_K *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_q2_K * x = (device const block_q2_K *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -4305,92 +3793,64 @@ void kernel_mul_mv_q2_K_f32_impl( (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); - qs += nb01/2; - sc += nb01; - dh += nb01/2; + qs += args.nb01/2; + sc += args.nb01; + dh += args.nb01/2; } y4 += 4 * QK_K; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } [[host_name("kernel_mul_mv_q2_K_f32")]] kernel void kernel_mul_mv_q2_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q2_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q2_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q3_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int64_t im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q3_K * x = (device const block_q3_K *) ((device char *) src0 + offset0); - device const float * yy = (device const float *) ((device char *) src1 + offset1); + device const block_q3_K * x = (device const block_q3_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); float yl[32]; @@ -4420,9 +3880,10 @@ void kernel_mul_mv_q3_K_f32_impl( const ushort4 hm = mm[2*ip + il/2]; - const int shift = 2*il; - const float v1 = il == 0 ? 4.f : 64.f; - const float v2 = 4.f * v1; + const short shift = 2*il; + + const float v1 = il == 0 ? 4.f : 64.f; + const float v2 = 4.f * v1; const uint16_t s_shift1 = 4*ip; const uint16_t s_shift2 = s_shift1 + il; @@ -4491,10 +3952,10 @@ void kernel_mul_mv_q3_K_f32_impl( sumf1[row] += d1 * (scales[1] - 32); sumf2[row] += d2 * (scales[3] - 32); - q += nb01/2; - h += nb01/2; - a += nb01/2; - dh += nb01/2; + q += args.nb01/2; + h += args.nb01/2; + a += args.nb01/2; + dh += args.nb01/2; } y1 += 4 * QK_K; @@ -4504,66 +3965,39 @@ void kernel_mul_mv_q3_K_f32_impl( const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift); sumf1[row] = simd_sum(sumf); } + + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + if (tiisg == 0) { for (int row = 0; row < 2; ++row) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = sumf1[row]; + dst_f32[first_row + row] = sumf1[row]; } } } [[host_name("kernel_mul_mv_q3_K_f32")]] kernel void kernel_mul_mv_q3_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q3_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q3_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q4_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; @@ -4574,21 +4008,21 @@ void kernel_mul_mv_q4_K_f32_impl( const int iq = it/4; // 0 or 1 const int ir = it%4; // 0...3 - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; //const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; const int first_row = r0 * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q4_K * x = (device const block_q4_K *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_q4_K * x = (device const block_q4_K *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[16]; float yh[16]; @@ -4641,92 +4075,64 @@ void kernel_mul_mv_q4_K_f32_impl( (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += nb01/2; - sc += nb01/2; - dh += nb01/2; + q1 += args.nb01/2; + sc += args.nb01/2; + dh += args.nb01/2; } y4 += 4 * QK_K; } + device float * dst_f32 = (device float *) dst + (int64_t)im*args.ne0*args.ne1 + (int64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } [[host_name("kernel_mul_mv_q4_K_f32")]] kernel void kernel_mul_mv_q4_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q4_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q4_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q5_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; + const int r0 = tgpig.x; + const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q5_K * x = (device const block_q5_K *) ((device char *) src0 + offset0); - device const float * yy = (device const float *) ((device char *) src1 + offset1); + device const block_q5_K * x = (device const block_q5_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); float sumf[2]={0.f}; @@ -4800,98 +4206,70 @@ void kernel_mul_mv_q5_K_f32_impl( sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) - dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); - q1 += nb01; - qh += nb01; - dh += nb01/2; - a += nb01/2; + q1 += args.nb01; + qh += args.nb01; + dh += args.nb01/2; + a += args.nb01/2; } y1 += 4 * QK_K; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < 2; ++row) { const float tot = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = tot; + dst_f32[first_row + row] = tot; } } } [[host_name("kernel_mul_mv_q5_K_f32")]] kernel void kernel_mul_mv_q5_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q5_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q5_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_q6_K_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; const uint8_t kmask3 = 0x30; const uint8_t kmask4 = 0xC0; - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - const int im = tgpig.z; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int im = tgpig.z; - const int row = 2 * r0 + sgitg; + const int row = 2*r0 + sgitg; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_q6_K * x = (device const block_q6_K *) ((device char *) src0 + offset0); - device const float * yy = (device const float *) ((device char *) src1 + offset1); + device const block_q6_K * x = (device const block_q6_K *) (src0 + offset0); + device const float * yy = (device const float *) (src1 + offset1); float sumf = 0; @@ -4908,7 +4286,6 @@ void kernel_mul_mv_q6_K_f32_impl( const int q_offset_h = 32*ip + l0; for (int i = ix; i < nb; i += 2) { - device const uint8_t * q1 = x[i].ql + q_offset_l; device const uint8_t * q2 = q1 + 32; device const uint8_t * qh = x[i].qh + q_offset_h; @@ -4930,98 +4307,70 @@ void kernel_mul_mv_q6_K_f32_impl( } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + const float tot = simd_sum(sumf); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + row] = tot; + dst_f32[row] = tot; } } [[host_name("kernel_mul_mv_q6_K_f32")]] kernel void kernel_mul_mv_q6_K_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_q6_K_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_q6_K_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } // ======================= "True" 2-bit +template void kernel_mul_mv_iq2_xxs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq2_xxs * x = (device const block_iq2_xxs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq2_xxs * x = (device const block_iq2_xxs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; const int nb32 = nb * (QK_K / 32); - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); { int nval = 4; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xxs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xxs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -5051,114 +4400,85 @@ void kernel_mul_mv_iq2_xxs_f32_impl( float sum = 0; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + aux8[l]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + aux8[l]); + const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; for (int j = 0; j < 8; ++j) { sum += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } sumf[row] += d * sum; - dh += nb01/2; - q2 += nb01/2; + dh += args.nb01/2; + q2 += args.nb01/2; } y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = all_sum * 0.25f; } } } [[host_name("kernel_mul_mv_iq2_xxs_f32")]] kernel void kernel_mul_mv_iq2_xxs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - - kernel_mul_mv_iq2_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + kernel_mul_mv_iq2_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq2_xs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq2_xs * x = (device const block_iq2_xs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq2_xs * x = (device const block_iq2_xs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; const int nb32 = nb * (QK_K / 32); - threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 512); + threadgroup uint64_t * svalues = (threadgroup uint64_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 512); { int nval = 8; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq2xs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2xs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -5190,122 +4510,94 @@ void kernel_mul_mv_iq2_xs_f32_impl( float sum1 = 0, sum2 = 0; for (int l = 0; l < 2; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); + const uint8_t signs = ssigns[(q2[l] >> 9)]; for (int j = 0; j < 8; ++j) { sum1 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } for (int l = 2; l < 4; ++l) { - const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(values + (q2[l] & 511)); - const uint8_t signs = shared_signs[(q2[l] >> 9)]; + const threadgroup uint8_t * grid = (const threadgroup uint8_t *)(svalues + (q2[l] & 511)); + const uint8_t signs = ssigns[(q2[l] >> 9)]; for (int j = 0; j < 8; ++j) { sum2 += yl[8*l + j] * grid[j] * (signs & kmask_iq2xs[j] ? -1.f : 1.f); } } sumf[row] += d1 * sum1 + d2 * sum2; - dh += nb01/2; - q2 += nb01/2; - sc += nb01; + dh += args.nb01/2; + q2 += args.nb01/2; + sc += args.nb01; } y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = all_sum * 0.25f; } } } [[host_name("kernel_mul_mv_iq2_xs_f32")]] kernel void kernel_mul_mv_iq2_xs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq3_xxs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq3_xxs * x = (device const block_iq3_xxs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq3_xxs * x = (device const block_iq3_xxs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; const int nb32 = nb * (QK_K / 32); - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; - threadgroup uint8_t * shared_signs = (threadgroup uint8_t *)(values + 256); + threadgroup uint32_t * svalues = (threadgroup uint32_t *)(shmem); + threadgroup uint8_t * ssigns = (threadgroup uint8_t *)(svalues + 256); { int nval = 4; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3xxs_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3xxs_grid[pos + i]; nval = 2; pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) shared_signs[pos+i] = ksigns_iq2xs[pos+i]; + for (int i = 0; i < nval; ++i) ssigns[pos+i] = ksigns_iq2xs[pos+i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -5314,7 +4606,6 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const float * y4 = y + 32 * ix; for (int ib32 = ix; ib32 < nb32; ib32 += 32) { - for (int i = 0; i < 32; ++i) { yl[i] = y4[i]; } @@ -5328,16 +4619,15 @@ void kernel_mul_mv_iq3_xxs_f32_impl( device const half * dh = &xr->d; for (int row = 0; row < N_DST; row++) { - const float db = dh[0]; const uint32_t aux32 = gas[0] | (gas[1] << 16); const float d = db * (0.5f + (aux32 >> 28)); float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + q3[2*l+0]); - const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + q3[2*l+1]); - const uint8_t signs = shared_signs[(aux32 >> 7*l) & 127]; + const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + q3[2*l+0]); + const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + q3[2*l+1]); + const uint8_t signs = ssigns[(aux32 >> 7*l) & 127]; for (int j = 0; j < 4; ++j) { sum[0] += yl[8*l + j + 0] * grid1[j] * (signs & kmask_iq2xs[j+0] ? -1.f : 1.f); sum[1] += yl[8*l + j + 4] * grid2[j] * (signs & kmask_iq2xs[j+4] ? -1.f : 1.f); @@ -5345,103 +4635,75 @@ void kernel_mul_mv_iq3_xxs_f32_impl( } sumf[row] += d * (sum[0] + sum[1]); - dh += nb01/2; - q3 += nb01; - gas += nb01/2; + dh += args.nb01/2; + q3 += args.nb01; + gas += args.nb01/2; } y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.5f; + dst_f32[first_row + row] = all_sum * 0.5f; } } } [[host_name("kernel_mul_mv_iq3_xxs_f32")]] kernel void kernel_mul_mv_iq3_xxs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_xxs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_xxs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq3_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq3_s * x = (device const block_iq3_s *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq3_s * x = (device const block_iq3_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; const int nb32 = nb * (QK_K / 32); - threadgroup uint32_t * values = (threadgroup uint32_t *)shared_values; + threadgroup uint32_t * svalues = (threadgroup uint32_t *) shmem; { int nval = 8; int pos = (32*sgitg + tiisg)*nval; - for (int i = 0; i < nval; ++i) values[pos + i] = iq3s_grid[pos + i]; + for (int i = 0; i < nval; ++i) svalues[pos + i] = iq3s_grid[pos + i]; threadgroup_barrier(mem_flags::mem_threadgroup); } @@ -5472,8 +4734,8 @@ void kernel_mul_mv_iq3_s_f32_impl( float2 sum = {0}; for (int l = 0; l < 4; ++l) { - const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? values + 256 : values; - const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? values + 256 : values; + const threadgroup uint32_t * table1 = qh[0] & kmask_iq2xs[2*l+0] ? svalues + 256 : svalues; + const threadgroup uint32_t * table2 = qh[0] & kmask_iq2xs[2*l+1] ? svalues + 256 : svalues; const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(table1 + qs[2*l+0]); const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(table2 + qs[2*l+1]); for (int j = 0; j < 4; ++j) { @@ -5483,105 +4745,77 @@ void kernel_mul_mv_iq3_s_f32_impl( } sumf[row] += d * (sum[0] + sum[1]); - dh += nb01/2; - qs += nb01; - qh += nb01; - sc += nb01; - signs += nb01; + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01; + sc += args.nb01; + signs += args.nb01; } y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } [[host_name("kernel_mul_mv_iq3_s_f32")]] kernel void kernel_mul_mv_iq3_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq3_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq3_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq2_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq2_s * x = (device const block_iq2_s *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq2_s * x = (device const block_iq2_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; const int nb32 = nb * (QK_K / 32); - //threadgroup uint64_t * values = (threadgroup uint64_t *)shared_values; + //threadgroup uint64_t * svalues = (threadgroup uint64_t *) shmem; //{ // int nval = 32; // int pos = (32*sgitg + tiisg)*nval; - // for (int i = 0; i < nval; ++i) values[pos + i] = iq2s_grid[pos + i]; + // for (int i = 0; i < nval; ++i) svalues[pos + i] = iq2s_grid[pos + i]; // threadgroup_barrier(mem_flags::mem_threadgroup); //} @@ -5613,8 +4847,8 @@ void kernel_mul_mv_iq2_s_f32_impl( float2 sum = {0}; for (int l = 0; l < 2; ++l) { - //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(values + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); - //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(values + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); + //const threadgroup uint8_t * grid1 = (const threadgroup uint8_t *)(svalues + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); + //const threadgroup uint8_t * grid2 = (const threadgroup uint8_t *)(svalues + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); constant uint8_t * grid1 = (constant uint8_t *)(iq2s_grid + (qs[l+0] | ((qh[0] << (8-2*l)) & 0x300))); constant uint8_t * grid2 = (constant uint8_t *)(iq2s_grid + (qs[l+2] | ((qh[0] << (4-2*l)) & 0x300))); for (int j = 0; j < 8; ++j) { @@ -5624,94 +4858,66 @@ void kernel_mul_mv_iq2_s_f32_impl( } sumf[row] += d1 * sum[0] + d2 * sum[1]; - dh += nb01/2; - qs += nb01; - qh += nb01; - sc += nb01; - signs += nb01; + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01; + sc += args.nb01; + signs += args.nb01; } y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum * 0.25f; + dst_f32[first_row + row] = all_sum * 0.25f; } } } [[host_name("kernel_mul_mv_iq2_s_f32")]] kernel void kernel_mul_mv_iq2_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq2_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq2_s_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } +template void kernel_mul_mv_iq1_s_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq1_s * x = (device const block_iq1_s *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq1_s * x = (device const block_iq1_s *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -5754,61 +4960,50 @@ void kernel_mul_mv_iq1_s_f32_impl( } sumf[row] += (float)dh[0] * (sum + sumy * (qh[0] & 0x8000 ? -1 - IQ1S_DELTA : -1 + IQ1S_DELTA)) * (2*((qh[0] >> 12) & 7) + 1); - dh += nb01/2; - qs += nb01; - qh += nb01/2; + dh += args.nb01/2; + qs += args.nb01; + qh += args.nb01/2; } y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } +template void kernel_mul_mv_iq1_m_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_value, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { - const int nb = ne00/QK_K; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq1_m * x = (device const block_iq1_m *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq1_m * x = (device const block_iq1_m *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); float yl[32]; float sumf[N_DST]={0.f}, all_sum; @@ -5860,66 +5055,55 @@ void kernel_mul_mv_iq1_m_f32_impl( sumf[row] += (float)scale.f16 * ((sum[0] + delta1) * (2*((sc[ib/2] >> (6*(ib%2)+0)) & 7) + 1) + (sum[1] + delta2) * (2*((sc[ib/2] >> (6*(ib%2)+3)) & 7) + 1)); - sc += nb01/2; - qs += nb01; - qh += nb01; + sc += args.nb01/2; + qs += args.nb01; + qh += args.nb01; } y4 += 32 * 32; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < N_DST; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } +template void kernel_mul_mv_iq4_nl_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values_i8, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK4_NL; + threadgroup float * shmem_f32 = (threadgroup float *) shmem; + const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * 2 + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq4_nl * x = (device const block_iq4_nl *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); const int ix = tiisg/2; // 0...15 const int it = tiisg%2; // 0 or 1 - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; @@ -5937,7 +5121,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { device const block_iq4_nl & xb = x[row*nb + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); @@ -5947,16 +5131,16 @@ void kernel_mul_mv_iq4_nl_f32_impl( aux32[0] = q4[0] | (q4[1] << 16); aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; aux32[0] = q4[2] | (q4[3] << 16); aux32[1] = (aux32[0] >> 4) & 0x0f0f0f0f; aux32[0] &= 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[2] * qf1; acc2 += yl[3] * qf2; @@ -5969,60 +5153,49 @@ void kernel_mul_mv_iq4_nl_f32_impl( yb += 16 * QK4_NL; } - for (int row = 0; row < 2 && first_row + row < ne01; ++row) { + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + + for (int row = 0; row < 2 && first_row + row < args.ne01; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } +template void kernel_mul_mv_iq4_xs_f32_impl( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values_i8, - uint3 tgpig, - uint tiisg, - uint sgitg) { + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { - threadgroup float * shared_values = (threadgroup float *)shared_values_i8; - const int nb = ne00/QK_K; + threadgroup float * shmem_f32 = (threadgroup float *) shmem; + const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; const int first_row = (r0 * 2 + sgitg) * 2; - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const uint i12 = im%args.ne12; + const uint i13 = im/args.ne12; - const uint offset0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; - const uint offset1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + const uint64_t offset0 = first_row*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - device const block_iq4_xs * x = (device const block_iq4_xs *) ((device char *) src0 + offset0); - device const float * y = (device const float *) ((device char *) src1 + offset1); + device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); + device const float * y = (device const float *) (src1 + offset1); const int ix = tiisg/16; // 0 or 1 const int it = tiisg%16; // 0...15 const int ib = it/2; const int il = it%2; - shared_values[tiisg] = kvalues_iq4nl_f[tiisg%16]; + shmem_f32[tiisg] = kvalues_iq4nl_f[tiisg%16]; threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; @@ -6036,28 +5209,26 @@ void kernel_mul_mv_iq4_xs_f32_impl( float4 qf1, qf2; for (int ibl = ix; ibl < nb; ibl += 2) { - device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; for (int row = 0; row < 2; ++row) { - device const block_iq4_xs & xb = x[row*nb + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); float4 acc1 = {0.f}, acc2 = {0.f}; - aux32[0] = q4[0] & 0x0f0f0f0f; + aux32[0] = (q4[0] ) & 0x0f0f0f0f; aux32[1] = (q4[0] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[0] * qf1; acc2 += yl[1] * qf2; - aux32[0] = q4[1] & 0x0f0f0f0f; + aux32[0] = (q4[1] ) & 0x0f0f0f0f; aux32[1] = (q4[1] >> 4) & 0x0f0f0f0f; - qf1 = {shared_values[q8[0]], shared_values[q8[1]], shared_values[q8[2]], shared_values[q8[3]]}; - qf2 = {shared_values[q8[4]], shared_values[q8[5]], shared_values[q8[6]], shared_values[q8[7]]}; + qf1 = {shmem_f32[q8[0]], shmem_f32[q8[1]], shmem_f32[q8[2]], shmem_f32[q8[3]]}; + qf2 = {shmem_f32[q8[4]], shmem_f32[q8[5]], shmem_f32[q8[6]], shmem_f32[q8[7]]}; acc1 += yl[2] * qf1; acc2 += yl[3] * qf2; @@ -6071,134 +5242,68 @@ void kernel_mul_mv_iq4_xs_f32_impl( yb += 2 * QK_K; } + device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; + for (int row = 0; row < 2; ++row) { all_sum = simd_sum(sumf[row]); if (tiisg == 0) { - dst[r1*ne0 + im*ne0*ne1 + first_row + row] = all_sum; + dst_f32[first_row + row] = all_sum; } } } [[host_name("kernel_mul_mv_iq1_s_f32")]] kernel void kernel_mul_mv_iq1_s_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_s_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_s_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq1_m_f32")]] kernel void kernel_mul_mv_iq1_m_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq1_m_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, nullptr, tgpig, tiisg, sgitg); + kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_nl_f32")]] kernel void kernel_mul_mv_iq4_nl_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_nl_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } [[host_name("kernel_mul_mv_iq4_xs_f32")]] kernel void kernel_mul_mv_iq4_xs_f32( - device const void * src0, - device const float * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mv & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_iq4_xs_f32_impl(src0, src1, dst, ne00, ne01, ne02, nb01, nb02, nb03, ne10, ne12, nb11, nb12, nb13, ne0, ne1, r2, r3, shared_values, tgpig, tiisg, sgitg); + kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } template @@ -6302,38 +5407,26 @@ kernel void kernel_get_rows_i32( // each block_q contains 16*nl weights template -kernel void kernel_mul_mm(device const uchar * src0, - device const uchar * src1, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant uint64_t & nb03, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant uint64_t & nb13, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint & r2, - constant uint & r3, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { +kernel void kernel_mul_mm( + constant ggml_metal_kargs_mul_mm & args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup T * sa = (threadgroup T *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + threadgroup T * sa = (threadgroup T *)(shmem); + threadgroup float * sb = (threadgroup float *)(shmem + 4096); - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; - const uint im = tgpig.z; + const int r0 = tgpig.y; + const int r1 = tgpig.x; + const int im = tgpig.z; // if this block is of 64x32 shape or smaller - short n_rows = (ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; - short n_cols = (ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; + short n_rows = (args.ne0 - r0*BLOCK_SIZE_M < BLOCK_SIZE_M) ? (args.ne0 - r0*BLOCK_SIZE_M) : BLOCK_SIZE_M; + short n_cols = (args.ne1 - r1*BLOCK_SIZE_N < BLOCK_SIZE_N) ? (args.ne1 - r1*BLOCK_SIZE_N) : BLOCK_SIZE_N; // a thread shouldn't load data outside of the matrix short thread_row = ((short)tiitg/THREAD_PER_ROW) < n_rows ? ((short)tiitg/THREAD_PER_ROW) : n_rows - 1; @@ -6349,20 +5442,20 @@ kernel void kernel_mul_mm(device const uchar * src0, short il = (tiitg % THREAD_PER_ROW); - const uint i12 = im%ne12; - const uint i13 = im/ne12; + const int i12 = im%args.ne12; + const int i13 = im/args.ne12; - uint offset0 = (i12/r2)*nb02 + (i13/r3)*nb03; - ushort offset1 = il/nl; + uint64_t offset0 = (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; + short offset1 = il/nl; - device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*nb01 + offset0) + offset1; + device const block_q * x = (device const block_q *)(src0 + (r0*BLOCK_SIZE_M + thread_row)*args.nb01 + offset0) + offset1; device const float * y = (device const float *)(src1 - + nb13 * i13 - + nb12 * i12 - + nb11 * (r1 * BLOCK_SIZE_N + thread_col) - + nb10 * (BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); + + args.nb13*i13 + + args.nb12*i12 + + args.nb11*(r1 * BLOCK_SIZE_N + thread_col) + + args.nb10*(BLOCK_SIZE_K / THREAD_PER_COL * (tiitg % THREAD_PER_COL))); - for (int loop_k = 0; loop_k < ne00; loop_k += BLOCK_SIZE_K) { + for (int loop_k = 0; loop_k < args.ne00; loop_k += BLOCK_SIZE_K) { // load data and store to threadgroup memory T4x4 temp_a; dequantize_func(x, il, temp_a); @@ -6409,16 +5502,18 @@ kernel void kernel_mul_mm(device const uchar * src0, } } - if ((r0 + 1) * BLOCK_SIZE_M <= ne0 && (r1 + 1) * BLOCK_SIZE_N <= ne1) { - device float * C = dst + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) \ - + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * ne0 + im*ne1*ne0; + if ((r0 + 1) * BLOCK_SIZE_M <= args.ne0 && (r1 + 1) * BLOCK_SIZE_N <= args.ne1) { + device float * C = (device float *) dst + + (BLOCK_SIZE_M * r0 + 32 * (sgitg & 1)) + \ + (BLOCK_SIZE_N * r1 + 16 * (sgitg >> 1)) * args.ne0 + im*args.ne1*args.ne0; + for (short i = 0; i < 8; i++) { - simdgroup_store(mc[i], C + 8 * (i%4) + 8 * ne0 * (i/4), ne0); + simdgroup_store(mc[i], C + 8 * (i%4) + 8 * args.ne0 * (i/4), args.ne0); } } else { // block is smaller than 64x32, we should avoid writing data outside of the matrix threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *) shared_memory) \ + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + 32 * (sgitg&1) + (16 * (sgitg>>1))*BLOCK_SIZE_M; for (short i = 0; i < 8; i++) { simdgroup_store(mc[i], temp_str + 8*(i%4) + 8*BLOCK_SIZE_M*(i/4), BLOCK_SIZE_M); @@ -6428,7 +5523,7 @@ kernel void kernel_mul_mm(device const uchar * src0, if (sgitg == 0) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { - device float * D = dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*ne0 + im*ne1*ne0; + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + (r1*BLOCK_SIZE_N + j)*args.ne0 + im*args.ne1*args.ne0; device float4 * D4 = (device float4 *) D; threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); @@ -6449,36 +5544,37 @@ kernel void kernel_mul_mm(device const uchar * src0, } // same as kernel_mul_mm_impl, but src1 and dst are accessed via indices stored in rowids +// TODO: this kernel needs to be reimplemented from scratch for better performance template void kernel_mul_mm_id_impl( - device const uchar * src0, - device const uchar * src1, + int32_t ne00, + int32_t ne02, + uint64_t nb01, + uint64_t nb02, + int32_t ne11, + int32_t ne12, + uint64_t nb10, + uint64_t nb11, + uint64_t nb12, + int32_t ne0, + int32_t ne1, + int64_t ne0ne1, + device const char * src0, + device const char * src1, threadgroup ushort2 * rowids, - device float * dst, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - int64_t ne1, - int64_t ne0ne1, - threadgroup uchar * shared_memory, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + device char * dst, + threadgroup char * shmem, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { - threadgroup half * sa = (threadgroup half *)(shared_memory); - threadgroup float * sb = (threadgroup float *)(shared_memory + 4096); + threadgroup half * sa = (threadgroup half *)(shmem); + threadgroup float * sb = (threadgroup float *)(shmem + 4096); - const uint r0 = tgpig.y; - const uint r1 = tgpig.x; + const int r0 = tgpig.y; + const int r1 = tgpig.x; - if (r1 * BLOCK_SIZE_N >= ne1) return; + if (r1*BLOCK_SIZE_N >= ne1) return; // if this block is of 64x32 shape or smaller short n_rows = (ne0 - r0 * BLOCK_SIZE_M < BLOCK_SIZE_M) ? (ne0 - r0 * BLOCK_SIZE_M) : BLOCK_SIZE_M; @@ -6490,9 +5586,9 @@ void kernel_mul_mm_id_impl( simdgroup_half8x8 ma[4]; simdgroup_float8x8 mb[2]; - simdgroup_float8x8 c_res[8]; + simdgroup_float8x8 mc[8]; for (int i = 0; i < 8; i++){ - c_res[i] = make_filled_simdgroup_matrix(0.f); + mc[i] = make_filled_simdgroup_matrix(0.f); } short il = (tiitg % THREAD_PER_ROW); @@ -6530,11 +5626,14 @@ void kernel_mul_mm_id_impl( threadgroup half * lsma = (sa + THREAD_MAT_M * SG_MAT_SIZE * (sgitg % 2)); threadgroup float * lsmb = (sb + THREAD_MAT_N * SG_MAT_SIZE * (sgitg / 2)); + #pragma unroll(BLOCK_SIZE_K/8) for (int ik = 0; ik < BLOCK_SIZE_K / 8; ik++) { + #pragma unroll(4) for (int i = 0; i < 4; i++) { simdgroup_load(ma[i], lsma + SG_MAT_SIZE * i); } simdgroup_barrier(mem_flags::mem_none); + #pragma unroll(2) for (int i = 0; i < 2; i++) { simdgroup_load(mb[i], lsmb + SG_MAT_SIZE * i); } @@ -6542,29 +5641,42 @@ void kernel_mul_mm_id_impl( lsma += BLOCK_SIZE_M / SG_MAT_ROW * SG_MAT_SIZE; lsmb += BLOCK_SIZE_N / SG_MAT_ROW * SG_MAT_SIZE; + #pragma unroll(8) for (int i = 0; i < 8; i++){ - simdgroup_multiply_accumulate(c_res[i], mb[i/4], ma[i%4], c_res[i]); + simdgroup_multiply_accumulate(mc[i], mb[i/4], ma[i%4], mc[i]); } } } { threadgroup_barrier(mem_flags::mem_threadgroup); - threadgroup float * temp_str = ((threadgroup float *)shared_memory) \ + threadgroup float * temp_str = ((threadgroup float *) shmem) \ + 32 * (sgitg&1) + (16 * (sgitg>>1)) * BLOCK_SIZE_M; for (int i = 0; i < 8; i++) { - simdgroup_store(c_res[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); + simdgroup_store(mc[i], temp_str + 8 * (i%4) + 8 * BLOCK_SIZE_M * (i/4), BLOCK_SIZE_M); } threadgroup_barrier(mem_flags::mem_threadgroup); - device float * C = dst + (BLOCK_SIZE_M * r0); if (sgitg == 0) { for (int j = tiitg; j < n_cols; j += BLOCK_SIZE_N) { threadgroup const auto & jid = rowids[r1 * BLOCK_SIZE_N + j]; - int joff = jid[0] * ne0 + jid[1] * ne0ne1; - for (int i = 0; i < n_rows; i++) { - *(C + i + joff) = *(temp_str + i + j * BLOCK_SIZE_M); + int64_t joff = jid[0]*ne0 + jid[1]*ne0ne1; + + device float * D = (device float *) dst + (r0*BLOCK_SIZE_M) + joff; + device float4 * D4 = (device float4 *) D; + + threadgroup float * C = temp_str + (j*BLOCK_SIZE_M); + threadgroup float4 * C4 = (threadgroup float4 *) C; + + int i = 0; + for (; i < n_rows/4; i++) { + *(D4 + i) = *(C4 + i); + } + + i *= 4; + for (; i < n_rows; i++) { + *(D + i) = *(C + i); } } } @@ -6573,48 +5685,34 @@ void kernel_mul_mm_id_impl( template kernel void kernel_mul_mm_id( - device const uchar * src0s, - device const uchar * src1, - device float * dst, - device const uchar * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne02, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - threadgroup uchar * shared_memory [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { + constant ggml_metal_kargs_mul_mm_id & args, + device const char * src0s, + device const char * src1, + device char * dst, + device const char * ids, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { const int32_t i02 = tgpig.z; + tgpig.z = 0; - device const uchar * src0 = src0s + i02*nb02; + device const char * src0 = src0s + i02*args.nb02; // row indices - threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shared_memory + 8192); + threadgroup ushort2 * rowids = (threadgroup ushort2 *)(shmem + 8192); // TODO: parallelize this loop - int64_t _ne1 = 0; - for (ushort ii1 = 0; ii1 < nei1; ii1++) { - for (ushort ii0 = 0; ii0 < nei0; ii0++) { - int32_t id = ((device int32_t *) (ids + ii1*nbi1))[ii0]; + int32_t _ne1 = 0; + for (ushort ii1 = 0; ii1 < args.nei1; ii1++) { + for (ushort ii0 = 0; ii0 < args.nei0; ii0++) { + int32_t id = ((device int32_t *) (ids + ii1*args.nbi1))[ii0]; if (id == i02) { - //if (tiitg == 0) { + if (tiitg == 0) { rowids[_ne1] = ushort2(ii0, ii1); - //} + } _ne1++; } } @@ -6623,23 +5721,23 @@ kernel void kernel_mul_mm_id( threadgroup_barrier(mem_flags::mem_threadgroup); kernel_mul_mm_id_impl( + args.ne00, + args.ne02, + args.nb01, + args.nb02, + args.ne11, + args.ne12, + args.nb10, + args.nb11, + args.nb12, + args.ne0, + _ne1, + (int64_t)args.ne0*args.ne1, src0, src1, rowids, dst, - ne00, - ne02, - nb01, - nb02, - ne11, - ne12, - nb10, - nb11, - nb12, - ne0, - _ne1, - ne0*ne1, - shared_memory, + shmem, tgpig, tiitg, sgitg); @@ -6748,194 +5846,110 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mat_mm_id_t kernel // typedef void (kernel_mul_mv_impl_t)( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - uint3 tgpig, - uint tiisg); + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + uint3 tgpig, + ushort tiisg); typedef void (kernel_mul_mv2_impl_t)( - device const void * src0, - device const float * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne12, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiisg, - uint sgitg); + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg); template void mmv_fn( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - int64_t ne13, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint64_t nb1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { - impl_fn(src0,src1,dst,ne00,ne01,ne02,nb00,nb01,nb02,nb03,ne10,ne11,ne12,nb10,nb11,nb12,nb13,ne0,ne1,r2,r3,tgpig,tiisg); + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiitg, + ushort tiisg, + ushort sgitg) { + impl_fn(args, src0, src1, dst, tgpig, tiisg); } template void mmv_fn( - device const char * src0, - device const char * src1, - device float * dst, - int64_t ne00, - int64_t ne01, - int64_t ne02, - uint64_t nb00, - uint64_t nb01, - uint64_t nb02, - uint64_t nb03, - int64_t ne10, - int64_t ne11, - int64_t ne12, - int64_t ne13, - uint64_t nb10, - uint64_t nb11, - uint64_t nb12, - uint64_t nb13, - int64_t ne0, - int64_t ne1, - uint64_t nb1, - uint r2, - uint r3, - threadgroup int8_t * shared_values, - uint3 tgpig, - uint tiitg, - uint tiisg, - uint sgitg) { - impl_fn(src0,(const device float *)src1,dst,ne00,ne01,ne02,nb01,nb02,nb03,ne10,ne12,nb11,nb12,nb13,ne0,ne1,r2,r3,shared_values,tgpig,tiisg,sgitg); + ggml_metal_kargs_mul_mv args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiitg, + ushort tiisg, + ushort sgitg) { + impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(mmv_fn>) mul_mv_impl_fn_t; +typedef decltype(mmv_fn>) mul_mv_impl_fn_t; template kernel void kernel_mul_mv_id( - device const char * src0s, - device const char * src1, - device float * dst, - device const char * ids, - constant int64_t & nei0, - constant int64_t & nei1, - constant uint64_t & nbi1, - constant int64_t & ne00, - constant int64_t & ne01, - constant int64_t & ne02, - constant uint64_t & nb00, - constant uint64_t & nb01, - constant uint64_t & nb02, - constant int64_t & ne10, - constant int64_t & ne11, - constant int64_t & ne12, - constant int64_t & ne13, - constant uint64_t & nb10, - constant uint64_t & nb11, - constant uint64_t & nb12, - constant int64_t & ne0, - constant int64_t & ne1, - constant uint64_t & nb1, - threadgroup int8_t * shared_values [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint tiisg[[thread_index_in_simdgroup]], - uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int iid1 = tgpig.z/nei0; - const int idx = tgpig.z%nei0; + constant ggml_metal_kargs_mul_mv_id & args, + device const char * src0s, + device const char * src1, + device char * dst, + device const char * ids, + threadgroup char * shmem [[threadgroup(0)]], + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]]) { + const int iid1 = tgpig.z/args.nei0; + const int idx = tgpig.z%args.nei0; tgpig.z = 0; - const int32_t i02 = ((device const int32_t *) (ids + iid1*nbi1))[idx]; + const int32_t i02 = ((device const int32_t *) (ids + iid1*args.nbi1))[idx]; - const int64_t i11 = idx % ne11; + const int64_t i11 = idx % args.ne11; const int64_t i12 = iid1; const int64_t i1 = idx; const int64_t i2 = i12; - device const char * src0_cur = src0s + i02*nb02; - device const char * src1_cur = src1 + i11*nb11 + i12*nb12; - device float * dst_cur = dst + i1*ne0 + i2*ne1*ne0; + device const char * src0_cur = src0s + i02*args.nb02; + device const char * src1_cur = src1 + i11*args.nb11 + i12*args.nb12; + + device char * dst_cur = dst + (i1*args.ne0 + i2*args.ne1*args.ne0)*sizeof(float); + + ggml_metal_kargs_mul_mv args0 = { + /*.ne00 =*/ args.ne00, + /*.ne01 =*/ args.ne01, + /*.ne02 =*/ 1, // args.ne02, + /*.nb00 =*/ args.nb00, + /*.nb01 =*/ args.nb01, + /*.nb02 =*/ args.nb02, + /*.nb03 =*/ args.nb02, // args.ne02 == 1 + /*.ne10 =*/ args.ne10, + /*.ne11 =*/ 1, // args.ne11, + /*.ne12 =*/ 1, // args.ne12, + /*.nb10 =*/ args.nb10, + /*.nb11 =*/ args.nb11, + /*.nb12 =*/ args.nb12, + /*.nb13 =*/ args.nb12, // ne12 == 1 + /*.ne0 =*/ args.ne0, + /*.ne1 =*/ 1, // args.ne1, + /*.r2 =*/ 1, + /*.r3 =*/ 1, + }; impl_fn( + args0, /* src0 */ src0_cur, /* src1 */ src1_cur, /* dst */ dst_cur, - /* ne00 */ ne00, - /* ne01 */ ne01, - /* ne02 */ 1, // ne02, - /* nb00 */ nb00, - /* nb01 */ nb01, - /* nb02 */ nb02, - /* nb03 */ nb02, // ne02 == 1 - /* ne10 */ ne10, - /* ne11 */ 1, // ne11, - /* ne12 */ 1, // ne12, - /* ne13 */ 1, // ne13, - /* nb10 */ nb10, - /* nb11 */ nb11, - /* nb12 */ nb12, - /* ne13 */ nb12, // ne12 == 1 - /* ne0 */ ne0, - /* ne1 */ 1, // ne1, - /* nb1 */ nb1, - /* r2 */ 1, - /* r3 */ 1, - shared_values, + shmem, tgpig, tiitg, tiisg, diff --git a/ggml/src/ggml-musa/CMakeLists.txt b/ggml/src/ggml-musa/CMakeLists.txt index 8edc75cc5..f3c013692 100644 --- a/ggml/src/ggml-musa/CMakeLists.txt +++ b/ggml/src/ggml-musa/CMakeLists.txt @@ -58,19 +58,12 @@ if (MUSAToolkit_FOUND) target_compile_definitions(ggml PUBLIC GGML_USE_CUDA) add_compile_definitions(GGML_USE_MUSA) - add_compile_definitions(GGML_CUDA_DMMV_X=${GGML_CUDA_DMMV_X}) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_MMV_Y}) - add_compile_definitions(K_QUANTS_PER_ITERATION=${GGML_CUDA_KQUANTS_ITER}) add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE}) if (GGML_CUDA_GRAPHS) add_compile_definitions(GGML_CUDA_USE_GRAPHS) endif() - if (GGML_CUDA_FORCE_DMMV) - add_compile_definitions(GGML_CUDA_FORCE_DMMV) - endif() - if (GGML_CUDA_FORCE_MMQ) add_compile_definitions(GGML_CUDA_FORCE_MMQ) endif() @@ -83,10 +76,6 @@ if (MUSAToolkit_FOUND) add_compile_definitions(GGML_CUDA_NO_VMM) endif() - if (DEFINED GGML_CUDA_DMMV_Y) - add_compile_definitions(GGML_CUDA_MMV_Y=${GGML_CUDA_DMMV_Y}) # for backwards compatibility - endif() - if (GGML_CUDA_F16 OR GGML_CUDA_DMMV_F16) add_compile_definitions(GGML_CUDA_F16) endif() diff --git a/ggml/src/ggml-opt.cpp b/ggml/src/ggml-opt.cpp new file mode 100644 index 000000000..040205a31 --- /dev/null +++ b/ggml/src/ggml-opt.cpp @@ -0,0 +1,867 @@ +#include "ggml-opt.h" + +#include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml-impl.h" + +#include +#include +#include +#include +#include +#include +#include + +struct ggml_opt_dataset { + struct ggml_context * ctx; + ggml_backend_buffer_t buf; + struct ggml_tensor * data; + struct ggml_tensor * labels; + + int64_t ndata; + int64_t ndata_shard; + size_t nbs_data; + size_t nbs_labels; + + std::vector permutation; +}; + +struct ggml_opt_context { + ggml_backend_sched_t backend_sched; + ggml_cgraph * allocated_graph; + ggml_cgraph * allocated_graph_copy; + struct ggml_context * ctx_static; + struct ggml_context * ctx_static_cpu; + struct ggml_context * ctx_compute; + struct ggml_context * ctx_copy; + ggml_backend_buffer_t buf_static; + ggml_backend_buffer_t buf_static_cpu; + std::mt19937 rng; + + struct ggml_tensor * inputs; + struct ggml_tensor * outputs; + struct ggml_tensor * labels; + + struct ggml_tensor * loss; + struct ggml_tensor * pred; + struct ggml_tensor * ncorrect; + + struct ggml_cgraph * gf; + struct ggml_cgraph * gb_grad; + struct ggml_cgraph * gb_opt; + + int64_t iter; + int32_t opt_period; + int32_t opt_i; + bool loss_per_datapoint; + + ggml_opt_get_optimizer_params get_opt_pars; + void * get_opt_pars_ud; + struct ggml_tensor * adamw_params; +}; + +struct ggml_opt_result { + int64_t ndata = 0; + std::vector loss; + std::vector pred; + int64_t ncorrect = 0; + + bool loss_per_datapoint = false; + int64_t opt_period = -1; +}; + +// ====== Dataset ====== + +ggml_opt_dataset_t ggml_opt_dataset_init(int64_t ne_datapoint, int64_t ne_label, int64_t ndata, int64_t ndata_shard) { + GGML_ASSERT(ne_datapoint > 0); + GGML_ASSERT(ne_label >= 0); + GGML_ASSERT(ndata > 0); + GGML_ASSERT(ndata_shard > 0); + + ggml_opt_dataset_t result = new ggml_opt_dataset; + result->ndata = ndata; + result->ndata_shard = ndata_shard; + + { + struct ggml_init_params params = { + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + result->ctx = ggml_init(params); + } + + result->data = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_datapoint, ndata); + result->nbs_data = ggml_nbytes(result->data) * ndata_shard/ndata; + + if (ne_label > 0) { + result->labels = ggml_new_tensor_2d(result->ctx, GGML_TYPE_F32, ne_label, ndata); + result->nbs_labels = ggml_nbytes(result->labels) * ndata_shard/ndata; + } else { + result->labels = nullptr; + result->nbs_labels = 0; + } + + result->buf = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx, ggml_backend_cpu_buffer_type()); + + const int64_t nshards = ndata/ndata_shard; + result->permutation.resize(nshards); + for (int64_t i = 0; i < nshards; ++i) { + result->permutation[i] = i; + } + return result; +} + +void ggml_opt_dataset_free(ggml_opt_dataset_t dataset) { + ggml_backend_buffer_free(dataset->buf); + ggml_free(dataset->ctx); + delete dataset; +} + +struct ggml_tensor * ggml_opt_dataset_data(ggml_opt_dataset_t dataset) { + return dataset->data; +} + +struct ggml_tensor * ggml_opt_dataset_labels(ggml_opt_dataset_t dataset) { + return dataset->labels; +} + +void ggml_opt_dataset_shuffle(ggml_opt_context_t opt_ctx, ggml_opt_dataset_t dataset, int64_t idata) { + GGML_ASSERT(idata <= dataset->ndata); + + if (idata < 0) { + std::shuffle(dataset->permutation.begin(), dataset->permutation.end(), opt_ctx->rng); + return; + } + + GGML_ASSERT(idata % dataset->ndata_shard == 0); + const int64_t ishard_max = idata / dataset->ndata_shard; + std::shuffle(dataset->permutation.begin(), dataset->permutation.begin() + ishard_max, opt_ctx->rng); +} + +void ggml_opt_dataset_get_batch(ggml_opt_dataset_t dataset, struct ggml_tensor * data_batch, struct ggml_tensor * labels_batch, int64_t ibatch) { + GGML_ASSERT( data_batch && ggml_is_contiguous(data_batch)); + GGML_ASSERT(!labels_batch || ggml_is_contiguous(labels_batch)); + GGML_ASSERT((labels_batch == nullptr) == (dataset->labels == nullptr)); + + const size_t nb_data_batch = ggml_nbytes(data_batch); + GGML_ASSERT(nb_data_batch % dataset->nbs_data == 0); + const int64_t shards_per_batch = nb_data_batch / dataset->nbs_data; + + if (labels_batch) { + const size_t nb_labels_batch = ggml_nbytes(labels_batch); + GGML_ASSERT(nb_labels_batch == shards_per_batch*dataset->nbs_labels); + } + + GGML_ASSERT((ibatch + 1)*shards_per_batch <= int64_t(dataset->permutation.size())); + + for (int64_t ishard_batch = 0; ishard_batch < shards_per_batch; ++ishard_batch) { + const int64_t ishard = dataset->permutation[ibatch*shards_per_batch + ishard_batch]; + + const char * ptr_data = (const char *) dataset->data->data + ishard*dataset->nbs_data; + ggml_backend_tensor_set(data_batch, ptr_data, ishard_batch*dataset->nbs_data, dataset->nbs_data); + + if (!labels_batch) { + continue; + } + + const char * ptr_labels = (const char *) dataset->labels->data + ishard*dataset->nbs_labels; + ggml_backend_tensor_set(labels_batch, ptr_labels, ishard_batch*dataset->nbs_labels, dataset->nbs_labels); + } +} + +// ====== Model / Context ====== + +struct ggml_opt_optimizer_params ggml_opt_get_default_optimizer_params(void * userdata) { + GGML_UNUSED(userdata); + + ggml_opt_optimizer_params result; + + result.adamw.alpha = 0.001f; + result.adamw.beta1 = 0.9f; + result.adamw.beta2 = 0.999f; + result.adamw.eps = 1e-8f; + result.adamw.wd = 0.0f; + + return result; +} + +struct ggml_opt_params ggml_opt_default_params( + ggml_backend_sched_t backend_sched, + struct ggml_context * ctx_compute, + struct ggml_tensor * inputs, + struct ggml_tensor * outputs, + enum ggml_opt_loss_type loss_type) { + return { + /*backend_sched =*/ backend_sched, + /*ctx_compute =*/ ctx_compute, + /*inputs =*/ inputs, + /*logits =*/ outputs, + /*loss_type =*/ loss_type, + /*build_type =*/ GGML_OPT_BUILD_TYPE_OPT, + /*opt_period =*/ 1, + /*get_opt_pars =*/ ggml_opt_get_default_optimizer_params, + /*get_opt_pars_ud =*/ nullptr, + }; +} + +static ggml_tensor * map_tensor(std::map & tensor_map, ggml_context * ctx, ggml_tensor * tensor) { + if (!tensor) { + return nullptr; + } + + if (tensor_map.find(tensor) != tensor_map.end()) { + return tensor_map[tensor]; + } + + ggml_tensor * new_tensor = ggml_dup_tensor(ctx, tensor); + tensor_map[tensor] = new_tensor; + + new_tensor->op = tensor->op; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + new_tensor->nb[i] = tensor->nb[i]; + } + new_tensor->flags = tensor->flags; + memcpy(new_tensor->op_params, tensor->op_params, sizeof(tensor->op_params)); + strcpy(new_tensor->name, tensor->name); + new_tensor->data = tensor->data; + new_tensor->buffer = tensor->buffer; + new_tensor->extra = tensor->extra; + new_tensor->view_offs = tensor->view_offs; + new_tensor->view_src = map_tensor(tensor_map, ctx, tensor->view_src); + for (int i = 0; i < GGML_MAX_SRC; i++) { + new_tensor->src[i] = map_tensor(tensor_map, ctx, tensor->src[i]); + } + + return new_tensor; +} + +static ggml_cgraph * dup_graph(ggml_context * ctx, ggml_cgraph * graph) { + std::map tensor_map; + + ggml_cgraph * new_graph = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); + + for (int i = 0; i < graph->n_leafs; i++) { + ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->leafs[i])); + } + for (int i = 0; i < graph->n_nodes; i++) { + ggml_build_forward_expand(new_graph, map_tensor(tensor_map, ctx, graph->nodes[i])); + } + for (int i = 0; i < graph->n_nodes; ++i) { + const size_t igrad_src = ggml_hash_find(&graph->visited_hash_set, graph->nodes[i]); + const size_t igrad_dst = ggml_hash_find(&new_graph->visited_hash_set, new_graph->nodes[i]); + graph->grads[igrad_dst] = new_graph->grads[igrad_src]; + graph->grad_accs[igrad_dst] = new_graph->grad_accs[igrad_src]; + } + + return new_graph; +} + +static void ggml_opt_alloc_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph) { + GGML_ASSERT(graph); + if (opt_ctx->allocated_graph == graph) { + return; + } + + ggml_backend_sched_reset(opt_ctx->backend_sched); // clear allocation of previous graph + + { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead() * GGML_DEFAULT_GRAPH_SIZE, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ggml_free(opt_ctx->ctx_copy); + opt_ctx->ctx_copy = ggml_init(params); + } + + opt_ctx->allocated_graph_copy = dup_graph(opt_ctx->ctx_copy, graph); + + ggml_backend_sched_alloc_graph(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); + opt_ctx->allocated_graph = graph; +} + +ggml_opt_context_t ggml_opt_init(struct ggml_opt_params params) { + ggml_opt_context_t result = new struct ggml_opt_context; + result->backend_sched = params.backend_sched; + result->allocated_graph = nullptr; + result->allocated_graph_copy = nullptr; + result->ctx_compute = params.ctx_compute; + result->ctx_copy = nullptr; + result->inputs = params.inputs; + result->outputs = params.outputs; + result->iter = 1; + result->opt_period = params.opt_period; + result->opt_i = 0; + result->get_opt_pars = params.get_opt_pars; + result->get_opt_pars_ud = params.get_opt_pars_ud; + + GGML_ASSERT(result->inputs->data && "the inputs must be allocated statically"); + GGML_ASSERT(result->opt_period >= 1); + + const bool accumulate = params.build_type == GGML_OPT_BUILD_TYPE_GRAD || + (params.build_type == GGML_OPT_BUILD_TYPE_OPT && result->opt_period > 1); + + ggml_set_input(result->inputs); + ggml_set_output(result->outputs); + + result->gf = ggml_new_graph_custom(result->ctx_compute, GGML_DEFAULT_GRAPH_SIZE, /*grads =*/ true); // Forward pass. + ggml_build_forward_expand(result->gf, result->outputs); + + int n_param = 0; + for (int i = 0; i < result->gf->n_nodes; ++i) { + if (result->gf->nodes[i]->flags & GGML_TENSOR_FLAG_PARAM) { + n_param++; + } + } + + { + // The static context is used for: + // - gradients (1 tensor per param if using gradient accumulation) + // - optimizer momenta (2 tensors per param) + // - labels + // - loss + its gradient (up to 5 tensors) + // - pred + // - ncorrect (2 tensors). + const size_t tensors_per_param = (accumulate ? 1 : 0) + (params.build_type == GGML_OPT_BUILD_TYPE_OPT ? 2 : 0); + const size_t size_meta = (tensors_per_param*n_param + 9) * ggml_tensor_overhead(); + struct ggml_init_params params = { + /*.mem_size =*/ size_meta, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + result->ctx_static = ggml_init(params); + } + { + // The static cpu context is used for: + // - optimizer parameters (1 for the entire context) + const size_t size_meta = 1 * ggml_tensor_overhead(); + struct ggml_init_params params = { + /*.mem_size =*/ size_meta, + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + result->ctx_static_cpu = ggml_init(params); + } + + + switch (params.loss_type) { + case GGML_OPT_LOSS_TYPE_MEAN: { + result->labels = nullptr; + result->loss = ggml_sum(result->ctx_static, result->outputs); + ggml_set_name(result->loss, "loss_sum"); + const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs)); + result->loss = ggml_scale(result->ctx_static, result->loss, scale); + ggml_set_name(result->loss, "loss_mean"); + result->loss_per_datapoint = true; + break; + } + case GGML_OPT_LOSS_TYPE_SUM: { + result->labels = nullptr; + result->loss = ggml_sum(result->ctx_static, result->outputs); + ggml_set_name(result->loss, "loss_sum"); + result->loss_per_datapoint = false; + break; + } + case GGML_OPT_LOSS_TYPE_CROSS_ENTROPY: { + result->labels = ggml_dup_tensor(result->ctx_static, result->outputs); + ggml_set_input(result->labels); + ggml_set_name(result->labels, "labels"); + result->loss = ggml_cross_entropy_loss(result->ctx_static, result->outputs, result->labels); + ggml_set_name(result->loss, "loss_cross_entropy"); + if (result->opt_period > 1) { + result->loss = ggml_scale(result->ctx_static, result->loss, 1.0f / result->opt_period); + ggml_set_name(result->loss, "loss_cross_entropy_scaled"); + } + result->loss_per_datapoint = true; + break; + } + case GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR: { + result->labels = ggml_dup_tensor(result->ctx_static, result->outputs); + ggml_set_input(result->labels); + ggml_set_name(result->labels, "labels"); + result->loss = ggml_sub(result->ctx_static, result->outputs, result->labels); + ggml_set_name(result->loss, "loss_error"); + result->loss = ggml_sqr(result->ctx_static, result->loss); + ggml_set_name(result->loss, "loss_squared_error"); + result->loss = ggml_sum(result->ctx_static, result->loss); + ggml_set_name(result->loss, "loss_sum_squared_error"); + const float scale = 1.0f / (result->opt_period * ggml_nelements(result->outputs)); + result->loss = ggml_scale(result->ctx_static, result->loss, scale); + ggml_set_name(result->loss, "loss_mean_squared_error"); + result->loss_per_datapoint = true; + break; + } + } + ggml_set_output(result->loss); + ggml_set_loss(result->loss); + ggml_build_forward_expand(result->gf, result->loss); + + result->pred = ggml_argmax(result->ctx_static, result->outputs); + ggml_set_name(result->pred, "pred"); + ggml_set_output(result->pred); + ggml_build_forward_expand(result->gf, result->pred); + + if (result->labels) { + result->ncorrect = ggml_count_equal(result->ctx_static, result->pred, ggml_argmax(result->ctx_static, result->labels)); + ggml_set_name(result->ncorrect, "ncorrect"); + ggml_set_output(result->ncorrect); + ggml_build_forward_expand(result->gf, result->ncorrect); + } else { + result->ncorrect = nullptr; + } + + if (params.build_type == GGML_OPT_BUILD_TYPE_FORWARD) { + result->gb_grad = nullptr; + result->gb_opt = nullptr; + + result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); + result->buf_static_cpu = nullptr; + + ggml_opt_alloc_graph(result, result->gf); + + return result; + } + + // gb_grad == graph backward gradients, forward pass, then backward pass to calculate gradients. + result->gb_grad = ggml_graph_dup(result->ctx_compute, result->gf); + ggml_build_backward_expand(result->ctx_static, result->ctx_compute, result->gb_grad, accumulate); + + if (params.build_type == GGML_OPT_BUILD_TYPE_GRAD) { + result->gb_opt = nullptr; + + result->buf_static = ggml_backend_alloc_ctx_tensors(result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); + result->buf_static_cpu = nullptr; + + ggml_opt_alloc_graph(result, result->gb_grad); + ggml_graph_reset(result->gb_grad); + + return result; + } + + GGML_ASSERT(params.build_type == GGML_OPT_BUILD_TYPE_OPT); + + // gb_opt == graph backward optimize, forward pass, then backward pass to calculate gradients, then optimizer step. + result->gb_opt = ggml_graph_dup(result->ctx_compute, result->gb_grad); + + result->adamw_params = ggml_new_tensor_1d(result->ctx_static_cpu, GGML_TYPE_F32, 7); + ggml_set_input(result->adamw_params); + ggml_set_name(result->adamw_params, "adamw_params"); + + for (int i = result->gf->n_nodes-1; i >= 0; --i) { + struct ggml_tensor * node = result->gb_opt->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(result->gb_opt, node); + + if (node->flags & GGML_TENSOR_FLAG_PARAM) { + struct ggml_tensor * m = ggml_dup_tensor(result->ctx_static, node); + struct ggml_tensor * v = ggml_dup_tensor(result->ctx_static, node); + struct ggml_tensor * opt_step = ggml_opt_step_adamw(result->ctx_compute, node, grad, m, v, result->adamw_params); + ggml_build_forward_expand(result->gb_opt, opt_step); + } + } + + result->buf_static = ggml_backend_alloc_ctx_tensors( + result->ctx_static, ggml_backend_sched_get_backend(result->backend_sched, 0)); + + result->buf_static_cpu = ggml_backend_alloc_ctx_tensors_from_buft(result->ctx_static_cpu, ggml_backend_cpu_buffer_type()); + + ggml_opt_alloc_graph(result, result->gb_opt); + ggml_graph_reset(result->gb_opt); + + return result; +} + +void ggml_opt_free(ggml_opt_context_t opt_ctx) { + if (opt_ctx == nullptr) { + return; + } + ggml_backend_buffer_free(opt_ctx->buf_static); + ggml_backend_buffer_free(opt_ctx->buf_static_cpu); + ggml_free(opt_ctx->ctx_static); + ggml_free(opt_ctx->ctx_static_cpu); + delete opt_ctx; +} + +void ggml_opt_reset(ggml_opt_context_t opt_ctx, bool optimizer) { + if (optimizer) { + ggml_graph_reset(opt_ctx->gb_opt); + opt_ctx->iter = 1; + } else { + ggml_graph_reset(opt_ctx->gb_grad); + } +} + +struct ggml_tensor * ggml_opt_inputs(ggml_opt_context_t opt_ctx) { + return opt_ctx->inputs; +} + +struct ggml_tensor * ggml_opt_outputs(ggml_opt_context_t opt_ctx) { + return opt_ctx->outputs; +} + +struct ggml_tensor * ggml_opt_labels(ggml_opt_context_t opt_ctx) { + return opt_ctx->labels; +} + +struct ggml_tensor * ggml_opt_loss(ggml_opt_context_t opt_ctx) { + return opt_ctx->loss; +} + +struct ggml_tensor * ggml_opt_pred(ggml_opt_context_t opt_ctx) { + return opt_ctx->pred; +} + +struct ggml_tensor * ggml_opt_ncorrect(ggml_opt_context_t opt_ctx) { + return opt_ctx->ncorrect; +} + +struct ggml_tensor * ggml_opt_grad_acc(ggml_opt_context_t opt_ctx, struct ggml_tensor * node) { + return ggml_graph_get_grad_acc(opt_ctx->gb_opt, node); +} + +// ====== Optimization Result ====== + +ggml_opt_result_t ggml_opt_result_init() { + return new ggml_opt_result; +} + +void ggml_opt_result_free(ggml_opt_result_t result) { + delete result; +} + +void ggml_opt_result_reset(ggml_opt_result_t result) { + result->ndata = 0; + result->loss.clear(); + result->pred.clear(); + result->ncorrect = 0; +} + +void ggml_opt_result_ndata(ggml_opt_result_t result, int64_t * ndata) { + *ndata = result->ndata; +} + +void ggml_opt_result_loss(ggml_opt_result_t result, double * loss, double * unc) { + const int64_t nbatches = result->loss.size(); // Number of physical batches. + + if (nbatches == 0) { + *loss = 0.0; + *unc = NAN; + return; + } + + double sum = 0.0; + double sum_squared = 0.0; + + for (const float & loss : result->loss) { + // If the loss is per datapoint it was scaled by 1.0f/opt_period for each physical batch. + const float loss_scaled = result->loss_per_datapoint ? loss*result->opt_period : loss; + sum += loss_scaled; + sum_squared += loss_scaled*loss_scaled; + } + + const double mean = sum/nbatches; + *loss = result->loss_per_datapoint ? mean : sum; + + if (!unc) { + return; + } + + if (nbatches < 2) { + *unc = NAN; + return; + } + + const double var_sum = sum_squared/nbatches - mean*mean; // variance without Bessel's correction, i.e. nbatches/(nbatches-1) + *unc = result->loss_per_datapoint ? sqrt(var_sum / (nbatches - 1)) : sqrt(var_sum * nbatches/(nbatches - 1)); +} + +void ggml_opt_result_pred(ggml_opt_result_t result, int32_t * pred) { + for (size_t i = 0; i < result->pred.size(); ++i) { + pred[i] = result->pred[i]; + } +} + +void ggml_opt_result_accuracy(ggml_opt_result_t result, double * accuracy, double * unc) { + *accuracy = result->ncorrect >= 0 ? double(result->ncorrect) / double(result->ndata) : NAN; + + if (!unc) { + return; + } + + *unc = result->ncorrect >= 0 && result->ndata >= 2 ? + sqrt((*accuracy) * (1.0 - (*accuracy)) / double(result->ndata - 1)) : NAN; +} + +// ====== Computation ====== + +static void ggml_opt_eval_graph(ggml_opt_context_t opt_ctx, ggml_cgraph * graph, ggml_opt_result * result) { + if (graph != opt_ctx->gf) { + struct ggml_opt_optimizer_params opt_pars = opt_ctx->get_opt_pars(opt_ctx->get_opt_pars_ud); + + GGML_ASSERT(opt_pars.adamw.alpha > 0.0f); + GGML_ASSERT(opt_pars.adamw.beta1 >= 0.0f); + GGML_ASSERT(opt_pars.adamw.beta1 <= 1.0f); + GGML_ASSERT(opt_pars.adamw.beta2 >= 0.0f); + GGML_ASSERT(opt_pars.adamw.beta2 <= 1.0f); + GGML_ASSERT(opt_pars.adamw.eps >= 0.0f); + GGML_ASSERT(opt_pars.adamw.wd >= 0.0f); + GGML_ASSERT(opt_pars.adamw.wd <= 1.0f); + + // beta1, beta2 after applying warmup + const float beta1h = 1.0f/(1.0f - powf(opt_pars.adamw.beta1, opt_ctx->iter)); + const float beta2h = 1.0f/(1.0f - powf(opt_pars.adamw.beta2, opt_ctx->iter)); + + float * adamw_par_data = ggml_get_data_f32(opt_ctx->adamw_params); + adamw_par_data[0] = opt_pars.adamw.alpha; + adamw_par_data[1] = opt_pars.adamw.beta1; + adamw_par_data[2] = opt_pars.adamw.beta2; + adamw_par_data[3] = opt_pars.adamw.eps; + adamw_par_data[4] = opt_pars.adamw.wd; + adamw_par_data[5] = beta1h; + adamw_par_data[6] = beta2h; + } + + ggml_opt_alloc_graph(opt_ctx, graph); + ggml_backend_sched_graph_compute(opt_ctx->backend_sched, opt_ctx->allocated_graph_copy); + opt_ctx->iter += opt_ctx->allocated_graph == opt_ctx->gb_opt; + + if (!result) { + return; + } + + if (result->ndata == 0) { + result->loss_per_datapoint = opt_ctx->loss_per_datapoint; + result->opt_period = opt_ctx->opt_period; + } else { + GGML_ASSERT(result->loss_per_datapoint == opt_ctx->loss_per_datapoint); + GGML_ASSERT(result->opt_period == opt_ctx->opt_period); + } + + const int64_t ndata = opt_ctx->outputs->ne[1]; + GGML_ASSERT(result->ndata == ndata*int64_t(result->loss.size()) && "varying batch size not supported"); + result->ndata += ndata; + + GGML_ASSERT(ggml_is_scalar(opt_ctx->loss)); + GGML_ASSERT(opt_ctx->loss->type == GGML_TYPE_F32); + float loss; + ggml_backend_tensor_get(opt_ctx->loss, &loss, 0, ggml_nbytes(opt_ctx->loss)); + result->loss.push_back(loss); + + GGML_ASSERT(opt_ctx->pred->type == GGML_TYPE_I32); + std::vector pred(ndata); + ggml_backend_tensor_get(opt_ctx->pred, pred.data(), 0, ggml_nbytes(opt_ctx->pred)); + result->pred.insert(result->pred.end(), pred.begin(), pred.end()); + + if (!opt_ctx->labels || result->ncorrect < 0) { + result->ncorrect = -1; + return; + } + + GGML_ASSERT(ggml_is_scalar(opt_ctx->ncorrect)); + GGML_ASSERT(opt_ctx->ncorrect->type == GGML_TYPE_I64); + int64_t ncorrect; + ggml_backend_tensor_get(opt_ctx->ncorrect, &ncorrect, 0, ggml_nbytes(opt_ctx->ncorrect)); + result->ncorrect += ncorrect; +} + +void ggml_opt_forward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) { + ggml_opt_eval_graph(opt_ctx, opt_ctx->gf, result); +} + +void ggml_opt_forward_backward(ggml_opt_context_t opt_ctx, ggml_opt_result * result) { + if (opt_ctx->opt_period == 1) { + ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); + return; + } + + const int32_t opt_i_next = (opt_ctx->opt_i + 1) % opt_ctx->opt_period; + if (opt_i_next == 0) { + ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_opt, result); + ggml_opt_reset(opt_ctx, /*optimizer =*/ false); + } else { + ggml_opt_eval_graph(opt_ctx, opt_ctx->gb_grad, result); + } + opt_ctx->opt_i = opt_i_next; +} + +// ====== High-Level Functions ====== + +void ggml_opt_epoch( + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result_train, + ggml_opt_result_t result_eval, + int64_t idata_split, + ggml_opt_epoch_callback callback_train, + ggml_opt_epoch_callback callback_eval) { + struct ggml_tensor * inputs = ggml_opt_inputs(opt_ctx); + struct ggml_tensor * labels = ggml_opt_labels(opt_ctx); + struct ggml_tensor * data = ggml_opt_dataset_data(dataset); + GGML_ASSERT(data->ne[0] == inputs->ne[0]); + + const int64_t ndata = data->ne[1]; + const int64_t ndata_batch = inputs->ne[1]; + + GGML_ASSERT(data->ne[1] % inputs->ne[1] == 0); + const int64_t nbatches = ndata/ndata_batch; + + idata_split = idata_split < 0 ? ndata : idata_split; + GGML_ASSERT(idata_split % ndata_batch == 0); + const int64_t ibatch_split = idata_split / ndata_batch; + + int64_t ibatch = 0; + int64_t t_loop_start = ggml_time_us(); + for (; ibatch < ibatch_split; ++ibatch) { + ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); + ggml_opt_forward_backward(opt_ctx, result_train); + if (callback_train) { + callback_train(true, opt_ctx, dataset, result_train, ibatch+1, ibatch_split, t_loop_start); + } + } + t_loop_start = ggml_time_us(); + for (; ibatch < nbatches; ++ibatch) { + ggml_opt_dataset_get_batch(dataset, inputs, labels, ibatch); + ggml_opt_forward(opt_ctx, result_eval); + if (callback_eval) { + callback_eval(false, opt_ctx, dataset, result_eval, ibatch+1-ibatch_split, nbatches-ibatch_split, t_loop_start); + } + } +} + +void ggml_opt_epoch_callback_progress_bar( + bool train, + ggml_opt_context_t opt_ctx, + ggml_opt_dataset_t dataset, + ggml_opt_result_t result, + int64_t ibatch, + int64_t ibatch_max, + int64_t t_start_us) { + fprintf(stderr, "%s[", train ? "train: " : "val: "); + + constexpr int64_t bar_length = 25; + for (int64_t j = 0; j < bar_length; ++j) { + const int64_t ibatch_j = ibatch_max * j/bar_length; + if (ibatch_j < ibatch) { + fprintf(stderr, "="); + } else if (ibatch_max * (j - 1)/bar_length < ibatch) { + fprintf(stderr, ">"); + } else { + fprintf(stderr, " "); + } + } + + const int64_t batch_size = ggml_opt_inputs(opt_ctx)->ne[1]; + const int64_t idata = ibatch*batch_size; + const int64_t idata_max = ibatch_max*batch_size; + + double loss; + double loss_unc; + ggml_opt_result_loss(result, &loss, &loss_unc); + + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(result, &accuracy, &accuracy_unc); + + const int64_t t_ibatch_us = ggml_time_us() - t_start_us; + int64_t t_ibatch_s = t_ibatch_us / 1000000; + const int64_t t_ibatch_h = t_ibatch_s / 3600; + t_ibatch_s -= t_ibatch_h * 3600; + const int64_t t_ibatch_m = t_ibatch_s / 60; + t_ibatch_s -= t_ibatch_m * 60; + + const int64_t t_eta_us = t_ibatch_us * (ibatch_max - ibatch)/ibatch; + int64_t t_eta_s = t_eta_us / 1000000; + const int64_t t_eta_h = t_eta_s / 3600; + t_eta_s -= t_eta_h * 3600; + const int64_t t_eta_m = t_eta_s / 60; + t_eta_s -= t_eta_m * 60; + + fprintf(stderr, "| data=%06" PRId64 "/%06" PRId64 ", loss=%.6lf+-%.6lf, accuracy=%.2lf+-%.2lf%%, " + "t=%02" PRId64 ":%02" PRId64 ":%02" PRId64 ", ETA=%02" PRId64 ":%02" PRId64 ":%02" PRId64 "]\r", + idata, idata_max, loss, loss_unc, 100.0*accuracy, 100.0*accuracy_unc, + t_ibatch_h, t_ibatch_m, t_ibatch_s, t_eta_h, t_eta_m, t_eta_s); + if (ibatch == ibatch_max) { + fprintf(stderr, "\n"); + } + fflush(stderr); + + GGML_UNUSED(dataset); +} + +void ggml_opt_fit( + ggml_backend_sched_t backend_sched, + ggml_context * ctx_compute, + ggml_tensor * inputs, + ggml_tensor * outputs, + ggml_opt_dataset_t dataset, + enum ggml_opt_loss_type loss_type, + ggml_opt_get_optimizer_params get_opt_pars, + int64_t nepoch, + int64_t nbatch_logical, + float val_split, + bool silent) { + ggml_time_init(); + const int64_t t_start_us = ggml_time_us(); + + const int64_t ndata = ggml_opt_dataset_data(dataset)->ne[1]; + const int64_t nbatch_physical = inputs->ne[1]; + GGML_ASSERT(ndata % nbatch_logical == 0); + GGML_ASSERT(nbatch_logical % nbatch_physical == 0); + + const int64_t opt_period = nbatch_logical / nbatch_physical; + const int64_t nbatches_logical = ndata / nbatch_logical; + + GGML_ASSERT(val_split >= 0.0f); + GGML_ASSERT(val_split < 1.0f); + const int64_t ibatch_split = int64_t(((1.0f - val_split) * nbatches_logical)) * opt_period; // train <-> val split index (physical) + const int64_t idata_split = ibatch_split * nbatch_physical; + + int64_t epoch = 1; + + ggml_opt_params params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type); + params.opt_period = opt_period; + params.get_opt_pars = get_opt_pars; + params.get_opt_pars_ud = &epoch; + ggml_opt_context_t opt_ctx = ggml_opt_init(params); + + // Shuffling the data is generally useful but there is only a point if not all data is used in a single batch. + if (nbatch_logical < ndata) { + ggml_opt_dataset_shuffle(opt_ctx, dataset, -1); // Shuffle all data (train + validation). + } + + ggml_opt_result_t result_train = ggml_opt_result_init(); + ggml_opt_result_t result_val = ggml_opt_result_init(); + + ggml_opt_epoch_callback epoch_callback = silent ? nullptr : ggml_opt_epoch_callback_progress_bar; + + for (; epoch <= nepoch; ++epoch) { + if (nbatch_logical < idata_split) { + ggml_opt_dataset_shuffle(opt_ctx, dataset, idata_split); + } + + ggml_opt_result_reset(result_train); + ggml_opt_result_reset(result_val); + + if (!silent) { + fprintf(stderr, "%s: epoch %04" PRId64 "/%04" PRId64 ":\n", __func__, epoch, nepoch); + } + ggml_opt_epoch(opt_ctx, dataset, result_train, result_val, idata_split, epoch_callback, epoch_callback); + if (!silent) { + fprintf(stderr, "\n"); + } + } + + if (!silent) { + int64_t t_total_s = (ggml_time_us() - t_start_us) / 1000000; + const int64_t t_total_h = t_total_s / 3600; + t_total_s -= t_total_h * 3600; + const int64_t t_total_m = t_total_s / 60; + t_total_s -= t_total_m * 60; + fprintf(stderr, "%s: training took %02" PRId64 ":%02" PRId64 ":%02" PRId64 "\n", __func__, t_total_h, t_total_m, t_total_s); + } + + ggml_opt_free(opt_ctx); + ggml_opt_result_free(result_train); + ggml_opt_result_free(result_val); +} diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 5cdf59f25..ee72a173e 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1592,14 +1592,13 @@ static struct ggml_tensor * ggml_new_tensor_impl( /*.op =*/ GGML_OP_NONE, /*.op_params =*/ { 0 }, /*.flags =*/ 0, - /*.grad =*/ NULL, /*.src =*/ { NULL }, /*.view_src =*/ view_src, /*.view_offs =*/ view_offs, /*.data =*/ obj_alloc_size > 0 ? (void *)(result + 1) : data, /*.name =*/ { 0 }, /*.extra =*/ NULL, - ///*.padding =*/ { 0 }, + /*.padding =*/ { 0 }, }; #ifdef __clang__ @@ -4194,8 +4193,6 @@ struct ggml_tensor * ggml_flash_attn_ext( GGML_ASSERT(mask); } - bool is_node = false; - // permute(0, 2, 1, 3) int64_t ne[4] = { q->ne[0], q->ne[2], q->ne[1], q->ne[3] }; struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); @@ -4203,8 +4200,7 @@ struct ggml_tensor * ggml_flash_attn_ext( float params[] = { scale, max_bias, logit_softcap }; ggml_set_op_params(result, params, sizeof(params)); - result->op = GGML_OP_FLASH_ATTN_EXT; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_FLASH_ATTN_EXT; result->src[0] = q; result->src[1] = k; result->src[2] = v; @@ -4272,14 +4268,6 @@ struct ggml_tensor * ggml_flash_attn_back( GGML_ASSERT(ne2 % kvne2 == 0); - bool is_node = false; - - if (q->grad || k->grad || v->grad) { - // when using this operation (in backwards pass) these grads are set. - // we don't want to create (big) grad of our result, so is_node is false. - is_node = false; - } - // store gradients of q, k and v as continuous tensors concatenated in result. // note: v and gradv are actually transposed, i.e. v->ne[0] != D. const int64_t elem_q = ggml_nelements(q); @@ -4302,8 +4290,7 @@ struct ggml_tensor * ggml_flash_attn_back( int32_t masked_i = masked ? 1 : 0; ggml_set_op_params(result, &masked_i, sizeof(masked_i)); - result->op = GGML_OP_FLASH_ATTN_BACK; - result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; + result->op = GGML_OP_FLASH_ATTN_BACK; result->src[0] = q; result->src[1] = k; result->src[2] = v; @@ -4945,34 +4932,24 @@ struct ggml_tensor * ggml_opt_step_adamw( struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * grad, - float alpha, - float beta1, - float beta2, - float eps, - float wd) { + struct ggml_tensor * m, + struct ggml_tensor * v, + struct ggml_tensor * adamw_params) { GGML_ASSERT(a->flags & GGML_TENSOR_FLAG_PARAM); GGML_ASSERT(ggml_are_same_shape(a, grad)); - GGML_ASSERT(alpha > 0.0f); - GGML_ASSERT(beta1 >= 0.0f && beta1 <= 1.0f); - GGML_ASSERT(beta2 >= 0.0f && beta2 <= 1.0f); - GGML_ASSERT(eps >= 0.0f); - GGML_ASSERT(wd >= 0.0f && wd <= 1.0f); + GGML_ASSERT(ggml_are_same_shape(a, m)); + GGML_ASSERT(ggml_are_same_shape(a, v)); + GGML_ASSERT(adamw_params->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_nelements(adamw_params) == 7); struct ggml_tensor * result = ggml_view_tensor(ctx, a); - const int64_t iter = 1; - memcpy(&result->op_params[0], &iter, sizeof(int64_t)); - ggml_set_op_params_f32(result, 2, alpha); - ggml_set_op_params_f32(result, 3, beta1); - ggml_set_op_params_f32(result, 4, beta2); - ggml_set_op_params_f32(result, 5, eps); - ggml_set_op_params_f32(result, 6, wd); - result->op = GGML_OP_OPT_STEP_ADAMW; result->src[0] = a; result->src[1] = grad; - result->src[2] = ggml_dup_tensor(ctx, grad); - result->src[3] = ggml_dup_tensor(ctx, grad); + result->src[2] = m; + result->src[3] = v; + result->src[4] = adamw_params; return result; } @@ -5041,1112 +5018,514 @@ static void ggml_hash_map_free(struct hash_map * map) { GGML_FREE(map); } -// gradient checkpointing - -static struct ggml_tensor * ggml_recompute_graph_node( - struct ggml_context * ctx, - struct ggml_cgraph * graph, - struct hash_map * replacements, - struct ggml_tensor * node) { - - if (node == NULL) { - return NULL; - } - - if (node->flags & GGML_TENSOR_FLAG_PARAM) { - return node; - } - - if (!ggml_hash_contains(&graph->visited_hash_set, node)) { - return node; - } - - int count_children = 0; - for (int k = 0; k < GGML_MAX_SRC; ++k) { - if (node->src[k]) { - ++count_children; - } - } - - if (count_children == 0) { - return node; - } - - size_t i = ggml_hash_find(&replacements->set, node); - GGML_ASSERT(i != GGML_HASHSET_FULL); // assert that not full - if (replacements->set.keys[i] == node) { - return replacements->vals[i]; - } - - struct ggml_tensor * clone = ggml_new_tensor(ctx, node->type, GGML_MAX_DIMS, node->ne); - - // insert clone into replacements - GGML_ASSERT(replacements->set.keys[i] == NULL); // assert that we don't overwrite - replacements->set.keys[i] = node; - replacements->vals[i] = clone; - - clone->op = node->op; - clone->grad = node->grad; - clone->flags = node->flags; - clone->extra = node->extra; - for (int k = 0; k < GGML_MAX_DIMS; ++k) { - clone->nb[k] = node->nb[k]; - } - for (int k = 0; k < GGML_MAX_SRC; ++k) { - clone->src[k] = ggml_recompute_graph_node(ctx, graph, replacements, node->src[k]); - } - if (node->view_src != NULL) { - clone->data = (node->view_src->data == NULL) - ? NULL // view_src not yet allocated - : (char *) node->view_src->data // view_src already allocated - + node->view_offs; - clone->view_src = node->view_src; - clone->view_offs = node->view_offs; - } - - GGML_ASSERT(sizeof(node->op_params) == sizeof(int32_t) * (GGML_MAX_OP_PARAMS / sizeof(int32_t))); - GGML_ASSERT(sizeof(node->name) == GGML_MAX_NAME); - memcpy(clone->op_params, node->op_params, sizeof(node->op_params)); - ggml_format_name(clone, "%s (clone)", ggml_get_name(node)); - - return clone; -} - -void ggml_build_backward_gradient_checkpointing( - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - struct ggml_cgraph * gb_tmp, - struct ggml_tensor * * checkpoints, - int n_checkpoints) { - ggml_graph_cpy(gf, gb_tmp); - ggml_build_backward_expand(ctx, gf, gb_tmp, false); - - if (n_checkpoints <= 0) { - ggml_graph_cpy(gb_tmp, gb); - return; - } - - struct hash_map * replacements = ggml_new_hash_map(gf->n_nodes + gf->n_leafs + n_checkpoints); - - // insert checkpoints in replacements - for (int i = 0; i < n_checkpoints; ++i) { - size_t k = ggml_hash_find(&replacements->set, checkpoints[i]); - GGML_ASSERT(k != GGML_HASHSET_FULL); // assert that not full - GGML_ASSERT(replacements->set.keys[k] == NULL); // assert that we don't overwrite - replacements->set.keys[k] = checkpoints[i]; - replacements->vals[k] = checkpoints[i]; - } - - ggml_graph_cpy(gf, gb); - // rewrite gb_tmp->nodes[gf->n_nodes:gb_tmp->n_nodes], - // replacing references to gb_tmp->nodes[0:gf->n_nodes] ( == gf->nodes[0:gf->n_nodes]), - // by recomputing them from checkpoints - for (int i = gf->n_nodes; in_nodes; ++i) { - struct ggml_tensor * node = gb_tmp->nodes[i]; - for (int k = 0; k < GGML_MAX_SRC; ++k) { - // insert new tensors recomputing src, reusing already made replacements, - // remember replacements: remember new tensors with mapping from corresponding gf nodes - // recurse for input tensors, - // unless (i.e. terminating when) input tensors are replacements (like checkpoints) - node->src[k] = ggml_recompute_graph_node(ctx, gf, replacements, node->src[k]); - } - // insert rewritten backward node with replacements made into resulting backward graph gb - ggml_build_forward_expand(gb, node); - } - - ggml_hash_map_free(replacements); -} - // utility functions to change gradients // if a is in acc_table, modify gradients in-place and mark result as gradient accumulator // else if a is in zero_table, replace a // else, just add/subtract/etc. the gradients -static struct ggml_tensor * ggml_add_or_set( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_hash_set * zero_table, - struct ggml_hash_set * acc_table) { - if (ggml_hash_contains(acc_table, a)) { - struct ggml_tensor * ret = ggml_add_impl(ctx, a, b, true); - const size_t insert_result = ggml_hash_insert(acc_table, ret); - GGML_ASSERT(insert_result != GGML_HASHSET_FULL); - GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); - return ret; +static void ggml_add_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * tensor) { + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_add_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = tensor; } - if (ggml_hash_contains(zero_table, a)) { - return b; - } - return ggml_add_impl(ctx, a, b, false); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } -static struct ggml_tensor * ggml_acc_or_set( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - const size_t nb1, - const size_t nb2, - const size_t nb3, - const size_t offset, - struct ggml_hash_set * zero_table, - struct ggml_hash_set * acc_table) { - if (ggml_hash_contains(acc_table, a)) { - struct ggml_tensor * ret = ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, true); - const size_t insert_result = ggml_hash_insert(acc_table, ret); - GGML_ASSERT(insert_result != GGML_HASHSET_FULL); - GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); - return ret; +static void ggml_acc_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * src, + struct ggml_tensor * tensor, + const size_t nb1, + const size_t nb2, + const size_t nb3, + const size_t offset) { + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_acc_impl(ctx, cgraph->grads[isrc], tensor, nb1, nb2, nb3, offset, cgraph->grad_accs[isrc]); + } else { + struct ggml_tensor * a_zero = ggml_scale(ctx, src, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN + cgraph->grads[isrc] = ggml_acc_impl(ctx, a_zero, tensor, nb1, nb2, nb3, offset, false); } - if (ggml_hash_contains(zero_table, a)) { - struct ggml_tensor * a_zero = ggml_scale(ctx, a, 0.0f); // FIXME this is going to produce NaN if a contains inf/NaN - return ggml_acc_impl(ctx, a_zero, b, nb1, nb2, nb3, offset, false); - } - return ggml_acc_impl(ctx, a, b, nb1, nb2, nb3, offset, false); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } -static struct ggml_tensor * ggml_add1_or_set( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_hash_set * zero_table, - struct ggml_hash_set * acc_table) { - if (ggml_hash_contains(acc_table, a)) { - struct ggml_tensor * ret = ggml_add1_impl(ctx, a, b, true); - const size_t insert_result = ggml_hash_insert(acc_table, ret); - GGML_ASSERT(insert_result != GGML_HASHSET_FULL); - GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); - return ret; +static void ggml_add1_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * src, + struct ggml_tensor * tensor) { + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_add1_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = ggml_repeat(ctx, tensor, src); } - if (ggml_hash_contains(zero_table, a)) { - return ggml_repeat(ctx, b, a); - } - return ggml_add1_impl(ctx, a, b, false); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } -static struct ggml_tensor * ggml_sub_or_set( - struct ggml_context * ctx, - struct ggml_tensor * a, - struct ggml_tensor * b, - struct ggml_hash_set * zero_table, - struct ggml_hash_set * acc_table) { - if (ggml_hash_contains(acc_table, a)) { - struct ggml_tensor * ret = ggml_sub_impl(ctx, a, b, true); - const size_t insert_result = ggml_hash_insert(acc_table, ret); - GGML_ASSERT(insert_result != GGML_HASHSET_FULL); - GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); - return ret; +static void ggml_sub_or_set( + struct ggml_context * ctx, + struct ggml_cgraph * cgraph, + size_t isrc, + struct ggml_tensor * tensor) { + if (cgraph->grads[isrc]) { + cgraph->grads[isrc] = ggml_sub_impl(ctx, cgraph->grads[isrc], tensor, cgraph->grad_accs[isrc]); + } else { + cgraph->grads[isrc] = ggml_neg(ctx, tensor); } - if (ggml_hash_contains(zero_table, a)) { - return ggml_neg(ctx, b); - } - return ggml_sub_impl(ctx, a, b, false); + ggml_build_forward_expand(cgraph, cgraph->grads[isrc]); } -static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor * tensor, struct ggml_hash_set * zero_table, struct ggml_hash_set * acc_table) { +static void ggml_compute_backward( + struct ggml_context * ctx, struct ggml_cgraph * cgraph, int i, bool * grads_needed) { + struct ggml_tensor * tensor = cgraph->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, tensor); + + if (!grad) { + return; + } + struct ggml_tensor * src0 = tensor->src[0]; struct ggml_tensor * src1 = tensor->src[1]; struct ggml_tensor * src2 = tensor->src[2]; + struct ggml_hash_set * hash_set = &cgraph->visited_hash_set; + const size_t isrc0 = ggml_hash_find(hash_set, src0); + const size_t isrc1 = ggml_hash_find(hash_set, src1); + const size_t isrc2 = ggml_hash_find(hash_set, src2); + const bool src0_needs_grads = isrc0 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc0) && grads_needed[isrc0]; + const bool src1_needs_grads = isrc1 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc1) && grads_needed[isrc1]; + const bool src2_needs_grads = isrc2 != GGML_HASHSET_FULL && ggml_bitset_get(hash_set->used, isrc2) && grads_needed[isrc2]; switch (tensor->op) { - case GGML_OP_DUP: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - } break; - case GGML_OP_ADD: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - if (ggml_are_same_shape(src0, src1)) { - src1->grad = ggml_add_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table); - } else { - src1->grad = ggml_add_or_set(ctx, src1->grad, ggml_repeat_back(ctx, tensor->grad, src1), zero_table, acc_table); - } - } - } break; - case GGML_OP_ADD1: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - src1->grad = ggml_add_or_set(ctx, - src1->grad, - ggml_mean(ctx, tensor->grad), // TODO: should probably be sum instead of mean - zero_table, acc_table); - } - } break; - case GGML_OP_ACC: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - const size_t nb1 = ((int32_t *) tensor->op_params)[0]; - const size_t nb2 = ((int32_t *) tensor->op_params)[1]; - const size_t nb3 = ((int32_t *) tensor->op_params)[2]; - const size_t offset = ((int32_t *) tensor->op_params)[3]; - - struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx, - tensor->grad, - src1->grad->ne[0], - src1->grad->ne[1], - src1->grad->ne[2], - src1->grad->ne[3], - nb1, nb2, nb3, offset); - - src1->grad = - ggml_add_or_set(ctx, - src1->grad, - ggml_reshape(ctx, - ggml_cont(ctx, tensor_grad_view), - src1->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_SUB: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - src1->grad = ggml_sub_or_set(ctx, src1->grad, tensor->grad, zero_table, acc_table); - } - } break; - case GGML_OP_MUL: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, src1, tensor->grad), - zero_table, acc_table); - } - if (src1->grad) { - src1->grad = - ggml_add_or_set(ctx, - src1->grad, - ggml_mul(ctx, src0, tensor->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_DIV: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_div(ctx, tensor->grad, src1), - zero_table, acc_table); - } - if (src1->grad) { - src1->grad = - ggml_sub_or_set(ctx, - src1->grad, - ggml_mul(ctx, - tensor->grad, - ggml_div(ctx, tensor, src1)), - zero_table, acc_table); - } - } break; - case GGML_OP_SQR: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_scale(ctx, - ggml_mul(ctx, src0, tensor->grad), - 2.0f), - zero_table, acc_table); - } - } break; - case GGML_OP_SQRT: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_scale(ctx, - ggml_div(ctx, - tensor->grad, - tensor), - 0.5f), - zero_table, acc_table); - } - } break; - case GGML_OP_LOG: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_div(ctx, - tensor->grad, - src0), - zero_table, acc_table); - } - } break; - case GGML_OP_SIN: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, - tensor->grad, - ggml_cos(ctx, src0)), - zero_table, acc_table); - } - } break; - case GGML_OP_COS: - { - if (src0->grad) { - src0->grad = - ggml_sub_or_set(ctx, - src0->grad, - ggml_mul(ctx, - tensor->grad, - ggml_sin(ctx, src0)), - zero_table, acc_table); - } - } break; - case GGML_OP_SUM: - { - if (src0->grad) { - src0->grad = - ggml_add1_or_set(ctx, - src0->grad, - tensor->grad, - zero_table, acc_table); - } - } break; - case GGML_OP_SUM_ROWS: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_repeat(ctx, - tensor->grad, - src0->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_MEAN: - case GGML_OP_ARGMAX: - case GGML_OP_COUNT_EQUAL: - { - GGML_ABORT("fatal error"); // TODO: implement + case GGML_OP_DUP: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_REPEAT: - { - // necessary for llama - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_repeat_back(ctx, tensor->grad, src0->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_REPEAT_BACK: - { - if (src0->grad) { - // TODO: test this - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_repeat(ctx, tensor->grad, src0->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_CONCAT: - { - GGML_ABORT("fatal error"); // TODO: implement + } break; + case GGML_OP_ADD: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_SILU_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + if (src1_needs_grads) { + struct ggml_tensor * tmp = grad; + if (!ggml_are_same_shape(src0, src1)) { + tmp = ggml_repeat_back(ctx, tmp, src1); + } + ggml_add_or_set(ctx, cgraph, isrc1, tmp); } - case GGML_OP_NORM: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_ADD1: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_RMS_NORM: - { - // necessary for llama - if (src0->grad) { - float eps; - memcpy(&eps, tensor->op_params, sizeof(float)); - - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_rms_norm_back(ctx, src0, tensor->grad, eps), - zero_table, acc_table); - } - } break; - case GGML_OP_RMS_NORM_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, ggml_mean(ctx, grad)); // TODO: should probably be sum instead of mean } - case GGML_OP_GROUP_NORM: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_ACC: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_MUL_MAT: - { - // https://cs231n.github.io/optimization-2/#staged - // # forward pass - // s0 = np.random.randn(5, 10) - // s1 = np.random.randn(10, 3) - // t = s0.dot(s1) + if (src1_needs_grads) { + const size_t nb1 = ((int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((int32_t *) tensor->op_params)[2]; + const size_t offset = ((int32_t *) tensor->op_params)[3]; - // # now suppose we had the gradient on t from above in the circuit - // dt = np.random.randn(*t.shape) # same shape as t - // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix - // ds1 = t.T.dot(dt) + struct ggml_tensor * tensor_grad_view = ggml_view_4d(ctx, + grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + nb1, nb2, nb3, offset); - // tensor.shape [m,p,qq,rr] - // src0.shape [n,m,q1,r1] - // src1.shape [n,p,qq,rr] - - // necessary for llama - if (src0->grad) { - struct ggml_tensor * s1_tg = - ggml_out_prod(ctx, // [n,m,qq,rr] - src1, // [n,p,qq,rr] - tensor->grad); // [m,p,qq,rr] - const int64_t qq = s1_tg->ne[2]; - const int64_t rr = s1_tg->ne[3]; - const int64_t q1 = src0->ne[2]; - const int64_t r1 = src0->ne[3]; - const bool ne2_broadcasted = qq > q1; - const bool ne3_broadcasted = rr > r1; - if (ne2_broadcasted || ne3_broadcasted) { - // sum broadcast repetitions of s1_tg into shape of src0 - s1_tg = ggml_repeat_back(ctx, s1_tg, src0); - } - src0->grad = - ggml_add_or_set(ctx, - src0->grad, // [n,m,q1,r1] - s1_tg, // [n,m,q1,r1] - zero_table, acc_table); - } - if (src1->grad) { - src1->grad = - ggml_add_or_set(ctx, - src1->grad, // [n,p,qq,rr] - // ggml_mul_mat(ctx, // [n,p,qq,rr] - // ggml_cont(ctx, // [m,n,q1,r1] - // ggml_transpose(ctx, src0)), // [m,n,q1,r1] - // tensor->grad), // [m,p,qq,rr] - - // // when src0 is bigger than tensor->grad (this is mostly the case in llama), - // // avoid transpose of src0, rather transpose smaller tensor->grad - // // and then use ggml_out_prod - ggml_out_prod(ctx, // [n,p,qq,rr] - src0, // [n,m,q1,r1] - ggml_transpose(ctx, // [p,m,qq,rr] - tensor->grad)), // [m,p,qq,rr] - zero_table, acc_table); - } - } break; - case GGML_OP_MUL_MAT_ID: - { - GGML_ABORT("fatal error"); // TODO: not implemented + ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1)); } - case GGML_OP_OUT_PROD: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_SUB: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_SCALE: - { - // necessary for llama - if (src0->grad) { - float s; - memcpy(&s, tensor->op_params, sizeof(float)); - - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_scale_impl(ctx, tensor->grad, s, false), - zero_table, acc_table); - } - } break; - case GGML_OP_SET: - { - const size_t nb1 = ((int32_t *) tensor->op_params)[0]; - const size_t nb2 = ((int32_t *) tensor->op_params)[1]; - const size_t nb3 = ((int32_t *) tensor->op_params)[2]; - const size_t offset = ((int32_t *) tensor->op_params)[3]; - - struct ggml_tensor * tensor_grad_view = NULL; - - if (src0->grad || src1->grad) { - GGML_ASSERT(src0->type == tensor->type); - GGML_ASSERT(tensor->grad->type == tensor->type); - GGML_ASSERT(!src1->grad || src1->grad->type == tensor->grad->type); - - tensor_grad_view = ggml_view_4d(ctx, - tensor->grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], - nb1, nb2, nb3, offset); - } - - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_acc_impl(ctx, - tensor->grad, - ggml_neg(ctx, tensor_grad_view), - nb1, nb2, nb3, offset, false), - zero_table, acc_table); - } - - if (src1->grad) { - src1->grad = - ggml_add_or_set(ctx, - src1->grad, - ggml_reshape(ctx, - ggml_cont(ctx, tensor_grad_view), - src1->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_CPY: - { - // necessary for llama - // cpy overwrites value of src1 by src0 and returns view(src1) - // the overwriting is mathematically equivalent to: - // tensor = src0 * 1 + src1 * 0 - if (src0->grad) { - // dsrc0 = dtensor * 1 - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - if (src1->grad) { - // dsrc1 = dtensor * 0 -> noop - } - } break; - case GGML_OP_CONT: - { - // same as cpy - if (src0->grad) { - GGML_ASSERT(ggml_is_contiguous(src0->grad)); - GGML_ASSERT(ggml_is_contiguous(tensor->grad)); - src0->grad = ggml_add_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - } break; - case GGML_OP_RESHAPE: - { - // necessary for llama - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_reshape(ctx, - ggml_is_contiguous(tensor->grad) - ? tensor->grad - : ggml_cont(ctx, tensor->grad), - src0->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_VIEW: - { - // necessary for llama - if (src0->grad) { - size_t offset; - - memcpy(&offset, tensor->op_params, sizeof(offset)); - - size_t nb1 = tensor->nb[1]; - size_t nb2 = tensor->nb[2]; - size_t nb3 = tensor->nb[3]; - - if (src0->type != src0->grad->type) { - // gradient is typically F32, but src0 could be other type - size_t ng = ggml_element_size(src0->grad); - size_t n0 = ggml_element_size(src0); - GGML_ASSERT(offset % n0 == 0); - GGML_ASSERT(nb1 % n0 == 0); - GGML_ASSERT(nb2 % n0 == 0); - GGML_ASSERT(nb3 % n0 == 0); - offset = (offset / n0) * ng; - nb1 = (nb1 / n0) * ng; - nb2 = (nb2 / n0) * ng; - nb3 = (nb3 / n0) * ng; - } - - src0->grad = ggml_acc_or_set(ctx, src0->grad, tensor->grad, nb1, nb2, nb3, offset, zero_table, acc_table); - } - } break; - case GGML_OP_PERMUTE: - { - // necessary for llama - if (src0->grad) { - int32_t * axes = (int32_t *) tensor->op_params; - int axis0 = axes[0] & 0x3; - int axis1 = axes[1] & 0x3; - int axis2 = axes[2] & 0x3; - int axis3 = axes[3] & 0x3; - int axes_backward[4] = {0,0,0,0}; - axes_backward[axis0] = 0; - axes_backward[axis1] = 1; - axes_backward[axis2] = 2; - axes_backward[axis3] = 3; - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_permute(ctx, - tensor->grad, - axes_backward[0], - axes_backward[1], - axes_backward[2], - axes_backward[3]), - zero_table, acc_table); - } - } break; - case GGML_OP_TRANSPOSE: - { - // necessary for llama - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_transpose(ctx, tensor->grad), - zero_table, acc_table); - } - } break; - case GGML_OP_GET_ROWS: - { - // necessary for llama (only for tokenizer) - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - // last ggml_get_rows_back argument src0->grad is only - // necessary to setup correct output shape - ggml_get_rows_back(ctx, tensor->grad, src1, src0->grad), - zero_table, acc_table); - } - if (src1->grad) { - // noop - } - } break; - case GGML_OP_GET_ROWS_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + if (src1_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc1, grad); } - case GGML_OP_DIAG: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_MUL: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, src1, grad)); } - case GGML_OP_DIAG_MASK_INF: - { - // necessary for llama - if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; - src0->grad = - ggml_add_or_set(ctx, src0->grad, - /* ggml_diag_mask_inf_impl() shouldn't be here */ - /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ - ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table, acc_table); + if (src1_needs_grads) { + struct ggml_tensor * tmp = ggml_mul(ctx, src0, grad); + if (!ggml_are_same_shape(src0, src1)) { + tmp = ggml_repeat_back(ctx, tmp, src1); } - } break; - case GGML_OP_DIAG_MASK_ZERO: - { - // necessary for llama - if (src0->grad) { - const int n_past = ((int32_t *) tensor->op_params)[0]; - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_diag_mask_zero_impl(ctx, tensor->grad, n_past, false), - zero_table, acc_table); - } - } break; - case GGML_OP_SOFT_MAX: - { - // necessary for llama - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, src0->grad, - ggml_soft_max_back(ctx, tensor->grad, tensor), - zero_table, acc_table); - } - GGML_ASSERT((!src1 || !src1->grad) && "backward pass for softmax mask not implemented"); - } break; - case GGML_OP_SOFT_MAX_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + ggml_add_or_set(ctx, cgraph, isrc1, tmp); } - case GGML_OP_ROPE: - { - // necessary for llama - if (src0->grad) { - //const int n_past = ((int32_t *) tensor->op_params)[0]; - const int n_dims = ((int32_t *) tensor->op_params)[1]; - const int mode = ((int32_t *) tensor->op_params)[2]; - //const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + } break; + case GGML_OP_DIV: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src1)); + } + if (src1_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc1, ggml_mul(ctx, grad, ggml_div(ctx, tensor, src1))); + } + } break; + case GGML_OP_SQR: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_mul(ctx, src0, grad), 2.0f)); + } + } break; + case GGML_OP_SQRT: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale(ctx, ggml_div(ctx, grad, tensor), 0.5f)); + } + } break; + case GGML_OP_LOG: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_div(ctx, grad, src0)); + } + } break; + case GGML_OP_SIN: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_cos(ctx, src0))); + } + } break; + case GGML_OP_COS: { + if (src0_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, grad, ggml_sin(ctx, src0))); + } + } break; + case GGML_OP_SUM: { + if (src0_needs_grads) { + ggml_add1_or_set(ctx, cgraph, isrc0, src0, grad); + } + } break; + case GGML_OP_SUM_ROWS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0)); + } + } break; + case GGML_OP_MEAN: { + if (src0_needs_grads) { + ggml_add1_or_set(ctx, cgraph, isrc0, src0, ggml_scale_impl(ctx, grad, 1.0f/src0->ne[0], false)); + } + } break; + case GGML_OP_REPEAT: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat_back(ctx, grad, src0)); + } + } break; + case GGML_OP_REPEAT_BACK: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_repeat(ctx, grad, src0)); + } + } break; + case GGML_OP_RMS_NORM: { + if (src0_needs_grads) { + float eps; + memcpy(&eps, tensor->op_params, sizeof(float)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_rms_norm_back(ctx, src0, grad, eps)); + } + } break; + case GGML_OP_MUL_MAT: { + // https://cs231n.github.io/optimization-2/#staged + // # forward pass + // s0 = np.random.randn(5, 10) + // s1 = np.random.randn(10, 3) + // t = s0.dot(s1) - memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); + // # now suppose we had the gradient on t from above in the circuit + // dt = np.random.randn(*t.shape) # same shape as t + // ds0 = dt.dot(s1.T) #.T gives the transpose of the matrix + // ds1 = t.T.dot(dt) - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_rope_back(ctx, - tensor->grad, - src1, - src2, - n_dims, - mode, - n_ctx_orig, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow), - zero_table, acc_table); + // tensor.shape [m,p,qq,rr] + // src0.shape [n,m,q1,r1] + // src1.shape [n,p,qq,rr] + + if (src0_needs_grads) { + struct ggml_tensor * s1_tg = + ggml_out_prod(ctx, // [n,m,qq,rr] + src1, // [n,p,qq,rr] + grad); // [m,p,qq,rr] + const int64_t qq = s1_tg->ne[2]; + const int64_t rr = s1_tg->ne[3]; + const int64_t q1 = src0->ne[2]; + const int64_t r1 = src0->ne[3]; + const bool ne2_broadcasted = qq > q1; + const bool ne3_broadcasted = rr > r1; + if (ne2_broadcasted || ne3_broadcasted) { + // sum broadcast repetitions of s1_tg into shape of src0 + s1_tg = ggml_repeat_back(ctx, s1_tg, src0); } - GGML_ASSERT((!src2 || !src2->grad) && "gradients for freq factors not implemented"); - } break; - case GGML_OP_ROPE_BACK: - { - if (src0->grad) { - //const int n_past = ((int32_t *) tensor->op_params)[0]; - const int n_dims = ((int32_t *) tensor->op_params)[1]; - const int mode = ((int32_t *) tensor->op_params)[2]; - //const int n_ctx = ((int32_t *) tensor->op_params)[3]; - const int n_ctx_orig = ((int32_t *) tensor->op_params)[4]; - float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + ggml_add_or_set(ctx, cgraph, isrc0, s1_tg /*= [n,m,q1,r1]*/); + } + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, + // ggml_mul_mat(ctx, // [n,p,qq,rr] + // ggml_cont(ctx, // [m,n,q1,r1] + // ggml_transpose(ctx, src0)), // [m,n,q1,r1] + // grad), // [m,p,qq,rr] - memcpy(&freq_base, (int32_t *) tensor->op_params + 5, sizeof(float)); - memcpy(&freq_scale, (int32_t *) tensor->op_params + 6, sizeof(float)); - memcpy(&ext_factor, (int32_t *) tensor->op_params + 7, sizeof(float)); - memcpy(&attn_factor, (int32_t *) tensor->op_params + 8, sizeof(float)); - memcpy(&beta_fast, (int32_t *) tensor->op_params + 9, sizeof(float)); - memcpy(&beta_slow, (int32_t *) tensor->op_params + 10, sizeof(float)); + // when src0 is bigger than tensor->grad (this is mostly the case in llama), + // avoid transpose of src0, rather transpose smaller tensor->grad + // and then use ggml_out_prod + ggml_out_prod(ctx, // [n,p,qq,rr] + src0, // [n,m,q1,r1] + ggml_transpose(ctx, // [p,m,qq,rr] + grad))); // [m,p,qq,rr] + } + } break; + case GGML_OP_SCALE: { + if (src0_needs_grads) { + float s; + memcpy(&s, tensor->op_params, sizeof(float)); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_scale_impl(ctx, grad, s, false)); + } + } break; + case GGML_OP_SET: { + const size_t nb1 = ((const int32_t *) tensor->op_params)[0]; + const size_t nb2 = ((const int32_t *) tensor->op_params)[1]; + const size_t nb3 = ((const int32_t *) tensor->op_params)[2]; + const size_t offset = ((const int32_t *) tensor->op_params)[3]; - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_rope_impl(ctx, - tensor->grad, - src1, - src2, - n_dims, - mode, - n_ctx_orig, - freq_base, - freq_scale, - ext_factor, - attn_factor, - beta_fast, - beta_slow, - false), - zero_table, acc_table); - } - } break; - case GGML_OP_CLAMP: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_CONV_TRANSPOSE_1D: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_IM2COL: - { - if (src1->grad) { - const int32_t s0 = ggml_get_op_params_i32(tensor, 0); - const int32_t s1 = ggml_get_op_params_i32(tensor, 1); - const int32_t p0 = ggml_get_op_params_i32(tensor, 2); - const int32_t p1 = ggml_get_op_params_i32(tensor, 3); - const int32_t d0 = ggml_get_op_params_i32(tensor, 4); - const int32_t d1 = ggml_get_op_params_i32(tensor, 5); - const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1; + struct ggml_tensor * tensor_grad_view = NULL; - src1->grad = ggml_add_or_set(ctx, - src1->grad, - ggml_im2col_back(ctx, src0, tensor->grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D), - zero_table, acc_table); - } - } break; - case GGML_OP_IM2COL_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_CONV_TRANSPOSE_2D: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_POOL_1D: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_OP_POOL_2D: - { - if (src0->grad) { - const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0); - const int32_t k0 = ggml_get_op_params_i32(tensor, 1); - const int32_t k1 = ggml_get_op_params_i32(tensor, 2); - const int32_t s0 = ggml_get_op_params_i32(tensor, 3); - const int32_t s1 = ggml_get_op_params_i32(tensor, 4); - const int32_t p0 = ggml_get_op_params_i32(tensor, 5); - const int32_t p1 = ggml_get_op_params_i32(tensor, 6); + if (src0_needs_grads || src1_needs_grads) { + GGML_ASSERT(src0->type == tensor->type); + GGML_ASSERT(!cgraph->grads[isrc0] || cgraph->grads[isrc0]->type == grad->type); + GGML_ASSERT(!cgraph->grads[isrc1] || !src1_needs_grads || cgraph->grads[isrc1]->type == grad->type); - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_pool_2d_back(ctx, tensor->grad, src0, op, k0, k1, s0, s1, p0, p1), - zero_table, acc_table); - } - } break; - case GGML_OP_POOL_2D_BACK: - { - GGML_ABORT("fatal error"); // TODO: not implemented + tensor_grad_view = ggml_view_4d(ctx, + grad, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], + nb1, nb2, nb3, offset); } - case GGML_OP_UPSCALE: - { - GGML_ABORT("fatal error"); // TODO: not implemented + + if (src0_needs_grads) { + struct ggml_tensor * tmp = ggml_neg(ctx, tensor_grad_view); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_acc_impl(ctx, grad, tmp, nb1, nb2, nb3, offset, false)); } - case GGML_OP_PAD: - { - GGML_ABORT("fatal error"); // TODO: not implemented + + if (src1_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc1, ggml_reshape(ctx, ggml_cont(ctx, tensor_grad_view), src1)); } - case GGML_OP_ARANGE: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_CPY: { + // cpy overwrites value of src1 by src0 and returns view(src1) + // the overwriting is mathematically equivalent to: + // tensor = src0 * 1 + src1 * 0 + if (src0_needs_grads) { + // dsrc0 = dtensor * 1 + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_TIMESTEP_EMBEDDING: - { - GGML_ABORT("fatal error"); // TODO: not implemented + if (src1_needs_grads) { + // dsrc1 = dtensor * 0 -> noop } - case GGML_OP_ARGSORT: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_CONT: { + // same as cpy + if (src0_needs_grads) { + GGML_ASSERT(!cgraph->grads[isrc0] || ggml_is_contiguous(cgraph->grads[isrc0])); + GGML_ASSERT(ggml_is_contiguous(grad)); + ggml_add_or_set(ctx, cgraph, isrc0, grad); } - case GGML_OP_LEAKY_RELU: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_RESHAPE: { + if (src0_needs_grads) { + struct ggml_tensor * grad_cont = ggml_is_contiguous(grad) ? grad : ggml_cont(ctx, grad); + ggml_add_or_set(ctx, cgraph, isrc0, ggml_reshape(ctx, grad_cont, src0)); } - case GGML_OP_FLASH_ATTN_EXT: - { - GGML_ABORT("FA backward pass not adapted after rework"); - struct ggml_tensor * flash_grad = NULL; - if (src0->grad || src1->grad || tensor->src[2]->grad) { - int32_t t = ggml_get_op_params_i32(tensor, 0); - GGML_ASSERT(t == 0 || t == 1); - bool masked = t != 0; - flash_grad = - ggml_flash_attn_back(ctx, - src0, - src1, - tensor->src[2], - tensor->grad, - masked); + } break; + case GGML_OP_VIEW: { + if (src0_needs_grads) { + size_t offset; + + memcpy(&offset, tensor->op_params, sizeof(offset)); + + size_t nb1 = tensor->nb[1]; + size_t nb2 = tensor->nb[2]; + size_t nb3 = tensor->nb[3]; + + if (cgraph->grads[isrc0] && src0->type != cgraph->grads[isrc0]->type) { + // gradient is typically F32, but src0 could be other type + size_t ng = ggml_element_size(cgraph->grads[isrc0]); + size_t n0 = ggml_element_size(src0); + GGML_ASSERT(offset % n0 == 0); + GGML_ASSERT(nb1 % n0 == 0); + GGML_ASSERT(nb2 % n0 == 0); + GGML_ASSERT(nb3 % n0 == 0); + offset = (offset / n0) * ng; + nb1 = (nb1 / n0) * ng; + nb2 = (nb2 / n0) * ng; + nb3 = (nb3 / n0) * ng; } - const int64_t elem_q = ggml_nelements(src0); - const int64_t elem_k = ggml_nelements(src1); - const int64_t elem_v = ggml_nelements(src2); - - enum ggml_type result_type = flash_grad->type; - GGML_ASSERT(ggml_blck_size(result_type) == 1); - const size_t tsize = ggml_type_size(result_type); - - const size_t offs_q = 0; - const size_t offs_k = offs_q + GGML_PAD(elem_q * tsize, GGML_MEM_ALIGN); - const size_t offs_v = offs_k + GGML_PAD(elem_k * tsize, GGML_MEM_ALIGN); - - if (src0->grad) { - struct ggml_tensor * view_q = ggml_view_1d(ctx, flash_grad, elem_q, offs_q); - struct ggml_tensor * grad_q = ggml_reshape(ctx, view_q, src0); - src0->grad = ggml_add_or_set(ctx, - src0->grad, - grad_q, - zero_table, acc_table); - } - if (src1->grad) { - struct ggml_tensor * view_k = ggml_view_1d(ctx, flash_grad, elem_k, offs_k); - struct ggml_tensor * grad_k = ggml_reshape(ctx, view_k, src1); - src1->grad = ggml_add_or_set(ctx, - src1->grad, - grad_k, - zero_table, acc_table); - } - if (src2->grad) { - struct ggml_tensor * view_v = ggml_view_1d(ctx, flash_grad, elem_v, offs_v); - struct ggml_tensor * grad_v = ggml_reshape(ctx, view_v, src2); - src2->grad = ggml_add_or_set(ctx, - src2->grad, - grad_v, - zero_table, acc_table); - } - } break; - case GGML_OP_FLASH_ATTN_BACK: - { - GGML_ABORT("fatal error"); // not supported + ggml_acc_or_set(ctx, cgraph, isrc0, src0, grad, nb1, nb2, nb3, offset); } - case GGML_OP_SSM_CONV: - case GGML_OP_SSM_SCAN: - { - GGML_ABORT("fatal error"); // TODO: not implemented + } break; + case GGML_OP_PERMUTE: { + if (src0_needs_grads) { + const int32_t * axes = (const int32_t *) tensor->op_params; + const int axis0 = axes[0] & 0x3; + const int axis1 = axes[1] & 0x3; + const int axis2 = axes[2] & 0x3; + const int axis3 = axes[3] & 0x3; + int axb[4] = {0,0,0,0}; // axes backward + axb[axis0] = 0; + axb[axis1] = 1; + axb[axis2] = 2; + axb[axis3] = 3; + ggml_add_or_set(ctx, cgraph, isrc0, ggml_permute(ctx, grad, axb[0], axb[1], axb[2], axb[3])); } + } break; + case GGML_OP_TRANSPOSE: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_transpose(ctx, grad)); + } + } break; + case GGML_OP_GET_ROWS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_get_rows_back(ctx, grad, src1, src0)); + } + if (src1_needs_grads) { + // noop + } + } break; + case GGML_OP_DIAG_MASK_INF: { + if (src0_needs_grads) { + /* ggml_diag_mask_inf_impl() shouldn't be here */ + /* ref: https://github.com/ggerganov/llama.cpp/pull/4203#discussion_r1412377992 */ + const int n_past = ((const int32_t *) tensor->op_params)[0]; + ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false)); + } + } break; + case GGML_OP_DIAG_MASK_ZERO: { + if (src0_needs_grads) { + const int n_past = ((const int32_t *) tensor->op_params)[0]; + ggml_add_or_set(ctx, cgraph, isrc0, ggml_diag_mask_zero_impl(ctx, grad, n_past, false)); + } + } break; + case GGML_OP_SOFT_MAX: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_soft_max_back(ctx, grad, tensor)); + } + GGML_ASSERT((!src1 || !src1_needs_grads) && "backward pass for softmax mask not implemented"); + } break; + case GGML_OP_ROPE: { + if (src0_needs_grads) { + //const int n_past = ((int32_t *) tensor->op_params)[0]; + const int n_dims = ((const int32_t *) tensor->op_params)[1]; + const int mode = ((const int32_t *) tensor->op_params)[2]; + //const int n_ctx = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig = ((const int32_t *) tensor->op_params)[4]; + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + + memcpy(&freq_base, (const float *) tensor->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (const float *) tensor->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (const float *) tensor->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (const float *) tensor->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (const float *) tensor->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (const float *) tensor->op_params + 10, sizeof(float)); + + ggml_add_or_set(ctx, cgraph, isrc0, + ggml_rope_back(ctx, grad, src1, src2, n_dims, mode, n_ctx_orig, freq_base, + freq_scale, ext_factor, attn_factor, beta_fast, beta_slow)); + } + GGML_ASSERT((!src2 || !src2_needs_grads) && "gradients for freq factors not implemented"); + } break; + case GGML_OP_IM2COL: { + if (src1_needs_grads) { + const int32_t s0 = ggml_get_op_params_i32(tensor, 0); + const int32_t s1 = ggml_get_op_params_i32(tensor, 1); + const int32_t p0 = ggml_get_op_params_i32(tensor, 2); + const int32_t p1 = ggml_get_op_params_i32(tensor, 3); + const int32_t d0 = ggml_get_op_params_i32(tensor, 4); + const int32_t d1 = ggml_get_op_params_i32(tensor, 5); + const bool is_2D = ggml_get_op_params_i32(tensor, 6) == 1; + + ggml_add_or_set(ctx, cgraph, isrc1, ggml_im2col_back(ctx, src0, grad, src1->ne, s0, s1, p0, p1, d0, d1, is_2D)); + } + } break; + case GGML_OP_POOL_2D: { + if (src0_needs_grads) { + const enum ggml_op_pool op = ggml_get_op_params_i32(tensor, 0); + const int32_t k0 = ggml_get_op_params_i32(tensor, 1); + const int32_t k1 = ggml_get_op_params_i32(tensor, 2); + const int32_t s0 = ggml_get_op_params_i32(tensor, 3); + const int32_t s1 = ggml_get_op_params_i32(tensor, 4); + const int32_t p0 = ggml_get_op_params_i32(tensor, 5); + const int32_t p1 = ggml_get_op_params_i32(tensor, 6); + + ggml_add_or_set(ctx, cgraph, isrc0, ggml_pool_2d_back(ctx, grad, src0, op, k0, k1, s0, s1, p0, p1)); + } + } break; case GGML_OP_WIN_PART: case GGML_OP_WIN_UNPART: - case GGML_OP_UNARY: - { - switch (ggml_get_unary_op(tensor)) { - case GGML_UNARY_OP_ABS: - { - if (src0->grad) { - src0->grad = - ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, - ggml_sgn(ctx, src0), - tensor->grad), - zero_table, acc_table); - } - } break; - case GGML_UNARY_OP_SGN: - { - if (src0->grad) { - // noop - } - } break; - case GGML_UNARY_OP_NEG: - { - if (src0->grad) { - src0->grad = ggml_sub_or_set(ctx, src0->grad, tensor->grad, zero_table, acc_table); - } - } break; - case GGML_UNARY_OP_STEP: - { - if (src0->grad) { - // noop - } - } break; - case GGML_UNARY_OP_TANH: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_ELU: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_RELU: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, - ggml_step(ctx, src0), - tensor->grad), - zero_table, acc_table); - } - } break; - case GGML_UNARY_OP_SIGMOID: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_GELU: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_GELU_QUICK: - { - GGML_ABORT("fatal error"); // TODO: not implemented - } - case GGML_UNARY_OP_SILU: - { - // necessary for llama - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_silu_back(ctx, src0, tensor->grad), - zero_table, acc_table); - } - } break; - case GGML_UNARY_OP_EXP: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_mul(ctx, tensor, tensor->grad), - zero_table, acc_table); - } - } break; - default: - GGML_ABORT("fatal error"); - } - } break; - case GGML_OP_GET_REL_POS: - case GGML_OP_ADD_REL_POS: - case GGML_OP_RWKV_WKV6: - case GGML_OP_MAP_UNARY: - case GGML_OP_MAP_BINARY: - case GGML_OP_MAP_CUSTOM1_F32: - case GGML_OP_MAP_CUSTOM2_F32: - case GGML_OP_MAP_CUSTOM3_F32: - case GGML_OP_MAP_CUSTOM1: - case GGML_OP_MAP_CUSTOM2: - case GGML_OP_MAP_CUSTOM3: - { - GGML_ABORT("fatal error"); // not supported + case GGML_OP_UNARY: { + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_ABS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_sgn(ctx, src0), grad)); + } + } break; + case GGML_UNARY_OP_SGN: { + // noop + } break; + case GGML_UNARY_OP_NEG: { + if (src0_needs_grads) { + ggml_sub_or_set(ctx, cgraph, isrc0, grad); + } + } break; + case GGML_UNARY_OP_STEP: { + // noop + } break; + case GGML_UNARY_OP_RELU: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, ggml_step(ctx, src0), grad)); + } + } break; + case GGML_UNARY_OP_SILU: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_silu_back(ctx, src0, grad)); + } + } break; + case GGML_UNARY_OP_EXP: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_mul(ctx, tensor, grad)); + } + } break; + default: { + fprintf(stderr, "%s: unsupported unary op for backward pass: %s\n", + __func__, ggml_unary_op_name(ggml_get_unary_op(tensor))); + GGML_ABORT("fatal error"); + } //break; } - case GGML_OP_CROSS_ENTROPY_LOSS: - { - if (src0->grad) { - src0->grad = ggml_add_or_set(ctx, - src0->grad, - ggml_cross_entropy_loss_back(ctx, - src0, - src1, - tensor->grad), - zero_table, acc_table); - } - GGML_ASSERT(!src1->grad && "backward pass for labels not implemented"); - } break; - case GGML_OP_CROSS_ENTROPY_LOSS_BACK: - { - GGML_ABORT("fatal error"); // not supported + } break; + case GGML_OP_CROSS_ENTROPY_LOSS: { + if (src0_needs_grads) { + ggml_add_or_set(ctx, cgraph, isrc0, ggml_cross_entropy_loss_back(ctx, src0, src1, grad)); } - case GGML_OP_OPT_STEP_ADAMW: - { - GGML_ABORT("fatal error"); // not supported - } - case GGML_OP_NONE: - { - // nop - } break; + GGML_ASSERT(!src1_needs_grads && "backward pass for labels not implemented"); + } break; + case GGML_OP_NONE: { + // noop + } break; case GGML_OP_COUNT: - { - GGML_ABORT("fatal error"); - } + default: { + fprintf(stderr, "%s: unsupported ggml op for backward pass: %s\n", __func__, ggml_op_name(tensor->op)); + GGML_ABORT("fatal error"); + } //break; } - for (int i = 0; i < GGML_MAX_SRC; ++i) { - if (tensor->src[i] && tensor->src[i]->grad) { - GGML_ASSERT(ggml_are_same_shape(tensor->src[i], tensor->src[i]->grad)); - } - } + GGML_ASSERT(!src0_needs_grads || ggml_are_same_shape(src0, cgraph->grads[isrc0])); + GGML_ASSERT(!src1_needs_grads || ggml_are_same_shape(src1, cgraph->grads[isrc1])); + GGML_ASSERT(!src2_needs_grads || ggml_are_same_shape(src2, cgraph->grads[isrc2])); } static void ggml_visit_parents(struct ggml_cgraph * cgraph, struct ggml_tensor * node) { - if (node->grad == NULL) { - // this usually happens when we generate intermediate nodes from constants in the backward pass - // it can also happen during forward pass, if the user performs computations with constants - if (node->op != GGML_OP_NONE) { - //GGML_PRINT_DEBUG("%s: warning: node %p has no grad, but op %d\n", __func__, (void *) node, node->op); - } - } - // check if already visited if (ggml_hash_insert(&cgraph->visited_hash_set, node) == GGML_HASHSET_ALREADY_EXISTS) { return; @@ -6207,18 +5586,42 @@ void ggml_build_forward_expand(struct ggml_cgraph * cgraph, struct ggml_tensor * ggml_build_forward_impl(cgraph, tensor, true); } -void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * gf, struct ggml_cgraph * gb, bool accumulate) { - GGML_ASSERT(gf->n_nodes > 0); - GGML_ASSERT(gf->grads); +void ggml_build_backward_expand( + struct ggml_context * ctx_static, + struct ggml_context * ctx_compute, + struct ggml_cgraph * cgraph, + bool accumulate) { + GGML_ASSERT(cgraph->n_nodes > 0); + GGML_ASSERT(cgraph->grads); + GGML_ASSERT(cgraph->grad_accs); - for (int i = 0; i < gf->n_nodes; ++i) { - struct ggml_tensor * node = gf->nodes[i]; + const int n_nodes_f = cgraph->n_nodes; + + const size_t hash_size = ggml_hash_size(2*cgraph->size); + memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *)); + memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *)); + bool * grads_needed = calloc(hash_size, sizeof(bool)); + + { + bool any_params = false; + bool any_loss = false; + for (int i = 0; i < n_nodes_f; ++i) { + struct ggml_tensor * node = cgraph->nodes[i]; + any_params = any_params || (node->flags & GGML_TENSOR_FLAG_PARAM); + any_loss = any_loss || (node->flags & GGML_TENSOR_FLAG_LOSS); + } + GGML_ASSERT(any_params && "no trainable parameters found, did you forget to call ggml_set_param?"); + GGML_ASSERT(any_loss && "no training loss found, did you forget to call ggml_set_loss?"); + } + + for (int i = 0; i < n_nodes_f; ++i) { + struct ggml_tensor * node = cgraph->nodes[i]; if (node->type == GGML_TYPE_I32) { continue; } - bool needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM; + bool node_needs_grad = node->flags & GGML_TENSOR_FLAG_PARAM; bool ignore_src[GGML_MAX_SRC] = {false}; switch (node->op) { // gradients in node->src[0] for one reason or another have no effect on output gradients @@ -6246,14 +5649,14 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * break; } for (int j = 0; j < GGML_MAX_SRC; ++j) { - if (!node->src[j] || !node->src[j]->grad || ignore_src[j]) { + if (!node->src[j] || ignore_src[j] || !grads_needed[ggml_hash_find(&cgraph->visited_hash_set, node->src[j])]) { continue; } GGML_ASSERT(node->src[j]->type == GGML_TYPE_F32 || node->src[j]->type == GGML_TYPE_F16); - needs_grad = true; + node_needs_grad = true; break; } - if (!needs_grad) { + if (!node_needs_grad) { continue; } @@ -6261,73 +5664,21 @@ void ggml_build_backward_expand(struct ggml_context * ctx, struct ggml_cgraph * GGML_ASSERT(!node->view_src || node->op == GGML_OP_CPY || node->op == GGML_OP_VIEW || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_TRANSPOSE); - // create a new tensor with the same type and shape as the node and set it as grad - node->grad = ggml_dup_tensor(ctx, node); - } - - // keep tables of original gradients for replacement/accumulation logic - struct ggml_hash_set zero_table = ggml_hash_set_new(gf->size); - struct ggml_hash_set acc_table = ggml_hash_set_new(gf->size); - for (int i = 0; i < gf->n_nodes; i++) { - struct ggml_tensor * node = gf->nodes[i]; - - if (node->grad) { - { - const size_t insert_result = ggml_hash_insert(&zero_table, node->grad); - GGML_ASSERT(insert_result != GGML_HASHSET_FULL); - GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); - } - - // only gradients of trainable parameters should be accumulated - if (accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) { - const size_t insert_result = ggml_hash_insert(&acc_table, node->grad); - GGML_ASSERT(insert_result != GGML_HASHSET_FULL); - GGML_ASSERT(insert_result != GGML_HASHSET_ALREADY_EXISTS); - } + const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); + if ((accumulate && (node->flags & GGML_TENSOR_FLAG_PARAM)) || (node->flags & GGML_TENSOR_FLAG_LOSS)) { + cgraph->grads[igrad] = ggml_dup_tensor(ctx_static, node); + cgraph->grad_accs[igrad] = cgraph->grads[igrad]; } + grads_needed[igrad] = true; } - for (int i = gf->n_nodes - 1; i >= 0; i--) { - struct ggml_tensor * node = gf->nodes[i]; - + for (int i = n_nodes_f - 1; i >= 0; --i) { // inplace operations to add gradients are not created by ggml_compute_backward except for gradient accumulation // use allocator to automatically make inplace operations - if (node->grad) { - ggml_compute_backward(ctx, node, &zero_table, &acc_table); - } + ggml_compute_backward(ctx_compute, cgraph, i, grads_needed); } - for (int i = 0; i < gf->n_nodes; i++) { - struct ggml_tensor * node = gf->nodes[i]; - - if (node->flags & GGML_TENSOR_FLAG_PARAM) { - GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - ggml_build_forward_expand(gb, node->grad); - } - } - - ggml_hash_set_free(&zero_table); - ggml_hash_set_free(&acc_table); -} - -void ggml_build_opt_adamw( - struct ggml_context * ctx, - struct ggml_cgraph * gf, - struct ggml_cgraph * gb, - float alpha, - float beta1, - float beta2, - float eps, - float wd) { - for (int i = 0; i < gf->n_nodes; i++) { - struct ggml_tensor * node = gf->nodes[i]; - - if (node->flags & GGML_TENSOR_FLAG_PARAM) { - GGML_PRINT_DEBUG("%s: found root node %p\n", __func__, (void *) node); - struct ggml_tensor * opt_step = ggml_opt_step_adamw(ctx, node, node->grad, alpha, beta1, beta2, eps, wd); - ggml_build_forward_expand(gb, opt_step); - } - } + free(grads_needed); } static void * incr_ptr_aligned(void ** p, size_t size, size_t align) { @@ -6345,7 +5696,8 @@ static size_t ggml_graph_nbytes(size_t size, bool grads) { incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // leafs incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // hash keys if (grads) { - incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads + incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grads + incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); // grad_accs } incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t)); @@ -6371,10 +5723,12 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz void * p = cgraph + 1; - struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); - struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + struct ggml_tensor ** nodes_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** leafs_ptr = incr_ptr_aligned(&p, size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** hash_keys_ptr = incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)); + struct ggml_tensor ** grads_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + struct ggml_tensor ** grad_accs_ptr = grads ? incr_ptr_aligned(&p, hash_size * sizeof(struct ggml_tensor *), sizeof(struct ggml_tensor *)) : NULL; + ggml_bitset_t * hash_used = incr_ptr_aligned(&p, ggml_bitset_size(hash_size) * sizeof(ggml_bitset_t), sizeof(ggml_bitset_t)); // check that we allocated the correct amount of memory @@ -6386,12 +5740,17 @@ struct ggml_cgraph * ggml_new_graph_custom(struct ggml_context * ctx, size_t siz /*.n_leafs =*/ 0, /*.nodes =*/ nodes_ptr, /*.grads =*/ grads_ptr, + /*.grad_accs =*/ grad_accs_ptr, /*.leafs =*/ leafs_ptr, /*.hash_table =*/ { hash_size, hash_used, hash_keys_ptr }, /*.order =*/ GGML_CGRAPH_EVAL_ORDER_LEFT_TO_RIGHT, }; ggml_hash_set_reset(&cgraph->visited_hash_set); + if (grads) { + memset(cgraph->grads, 0, hash_size*sizeof(struct ggml_tensor *)); + memset(cgraph->grad_accs, 0, hash_size*sizeof(struct ggml_tensor *)); + } return cgraph; } @@ -6407,6 +5766,7 @@ struct ggml_cgraph ggml_graph_view(struct ggml_cgraph * cgraph0, int i0, int i1) /*.n_leafs =*/ 0, /*.nodes =*/ cgraph0->nodes + i0, /*.grads =*/ cgraph0->grads ? cgraph0->grads + i0 : NULL, + /*.grad_accs =*/ cgraph0->grad_accs ? cgraph0->grad_accs + i0 : NULL, /*.leafs =*/ NULL, /*.hash_table =*/ { 0, NULL, NULL }, /*.order =*/ cgraph0->order, @@ -6432,19 +5792,23 @@ void ggml_graph_cpy(struct ggml_cgraph * src, struct ggml_cgraph * dst) { dst->nodes[i] = src->nodes[i]; } - if (src->grads) { - GGML_ASSERT(dst->grads != NULL); - for (int i = 0; i < src->n_nodes; ++i) { - dst->grads[i] = src->grads[i]; - } - } - for (size_t i = 0; i < src->visited_hash_set.size; ++i) { // copy all hashset keys (tensors) that are in use if (ggml_bitset_get(src->visited_hash_set.used, i)) { ggml_hash_insert(&dst->visited_hash_set, src->visited_hash_set.keys[i]); } } + + if (src->grads) { + GGML_ASSERT(dst->grads != NULL); + GGML_ASSERT(dst->grad_accs != NULL); + for (int i = 0; i < src->n_nodes; ++i) { + const size_t igrad_src = ggml_hash_find(&src->visited_hash_set, src->nodes[i]); + const size_t igrad_dst = ggml_hash_find(&dst->visited_hash_set, dst->nodes[i]); + dst->grads[igrad_dst] = src->grads[igrad_src]; + dst->grad_accs[igrad_dst] = src->grad_accs[igrad_src]; + } + } } struct ggml_cgraph * ggml_graph_dup(struct ggml_context * ctx, struct ggml_cgraph * cgraph) { @@ -6470,28 +5834,35 @@ void ggml_graph_reset(struct ggml_cgraph * cgraph) { GGML_ASSERT(cgraph->grads != NULL); for (int i = 0; i < cgraph->n_nodes; i++) { - struct ggml_tensor * node = cgraph->nodes[i]; + struct ggml_tensor * node = cgraph->nodes[i]; + struct ggml_tensor * grad_acc = ggml_graph_get_grad_acc(cgraph, node); - // initial gradients of loss should be 1, 0 otherwise - if (node->grad) { - if (node->flags & GGML_TENSOR_FLAG_LOSS) { - GGML_ASSERT(node->grad->buffer); - GGML_ASSERT(node->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_scalar(node)); - - const float onef = 1.0f; - ggml_backend_tensor_set(node->grad, &onef, 0, ggml_nbytes(node->grad)); - } else { - ggml_set_zero(node->grad); + if (node->op == GGML_OP_OPT_STEP_ADAMW) { + // clear momenta + if (node->src[2]->data) { + ggml_set_zero(node->src[2]); + } + if (node->src[3]->data) { + ggml_set_zero(node->src[3]); } } - GGML_ASSERT(node); - if (node->op == GGML_OP_OPT_STEP_ADAMW) { - // set iteration to 1 and clear momenta - ggml_set_op_params_i32(node, 0, 1); - ggml_set_zero(node->src[2]); - ggml_set_zero(node->src[3]); + // initial gradients of loss should be 1, 0 otherwise + if (grad_acc) { + if (node->flags & GGML_TENSOR_FLAG_LOSS) { + GGML_ASSERT(grad_acc->type == GGML_TYPE_F32); + GGML_ASSERT(ggml_is_scalar(grad_acc)); + + const float onef = 1.0f; + if (grad_acc->buffer) { + ggml_backend_tensor_set(grad_acc, &onef, 0, sizeof(float)); + } else { + GGML_ASSERT(grad_acc->data); + *((float *) grad_acc->data) = onef; + } + } else { + ggml_set_zero(grad_acc); + } } } } @@ -6530,7 +5901,7 @@ void ggml_graph_add_node(struct ggml_cgraph * cgraph, struct ggml_tensor * tenso cgraph->n_nodes++; } -struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const char * name) { +struct ggml_tensor * ggml_graph_get_tensor(const struct ggml_cgraph * cgraph, const char * name) { for (int i = 0; i < cgraph->n_leafs; i++) { struct ggml_tensor * leaf = cgraph->leafs[i]; @@ -6550,6 +5921,16 @@ struct ggml_tensor * ggml_graph_get_tensor(struct ggml_cgraph * cgraph, const ch return NULL; } +struct ggml_tensor * ggml_graph_get_grad(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); + return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grads[igrad] : NULL; +} + +struct ggml_tensor * ggml_graph_get_grad_acc(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { + const size_t igrad = ggml_hash_find(&cgraph->visited_hash_set, node); + return igrad != GGML_HASHSET_FULL && ggml_bitset_get(cgraph->visited_hash_set.used, igrad) ? cgraph->grad_accs[igrad] : NULL; +} + void ggml_graph_print(const struct ggml_cgraph * cgraph) { GGML_LOG_INFO("=== GRAPH ===\n"); @@ -6560,7 +5941,8 @@ void ggml_graph_print(const struct ggml_cgraph * cgraph) { GGML_LOG_INFO(" - %3d: [ %5" PRId64 ", %5" PRId64 ", %5" PRId64 "] %16s %s\n", i, node->ne[0], node->ne[1], node->ne[2], - ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : node->grad ? "g" : " "); + ggml_op_name(node->op), (node->flags & GGML_TENSOR_FLAG_PARAM) ? "x" : + ggml_graph_get_grad(cgraph, node) ? "g" : " "); } GGML_LOG_INFO("n_leafs = %d\n", cgraph->n_leafs); @@ -6595,8 +5977,9 @@ static bool ggml_graph_find(const struct ggml_cgraph * cgraph, const struct ggml static struct ggml_tensor * ggml_graph_get_parent(const struct ggml_cgraph * cgraph, const struct ggml_tensor * node) { for (int i = 0; i < cgraph->n_nodes; i++) { struct ggml_tensor * parent = cgraph->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(cgraph, parent); - if (parent->grad == node) { + if (grad == node) { return parent; } } @@ -6636,6 +6019,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph for (int i = 0; i < gb->n_nodes; i++) { struct ggml_tensor * node = gb->nodes[i]; + struct ggml_tensor * grad = ggml_graph_get_grad(gb, node); if (ggml_graph_get_parent(gb, node) != NULL) { continue; @@ -6643,7 +6027,7 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph if (node->flags & GGML_TENSOR_FLAG_PARAM) { snprintf(color, sizeof(color), "yellow"); - } else if (node->grad) { + } else if (grad) { if (ggml_graph_find(gf, node)) { snprintf(color, sizeof(color), "green"); } else { @@ -6670,8 +6054,8 @@ void ggml_graph_dump_dot(const struct ggml_cgraph * gb, const struct ggml_cgraph fprintf(fp, "%d [%" PRId64 ", %" PRId64 ", %" PRId64 "] | %s", i, node->ne[0], node->ne[1], node->ne[2], ggml_op_symbol(node->op)); } - if (node->grad) { - fprintf(fp, " | %s\"; ]\n", ggml_op_symbol(node->grad->op)); + if (grad) { + fprintf(fp, " | %s\"; ]\n", ggml_op_symbol(grad->op)); } else { fprintf(fp, "\"; ]\n"); } diff --git a/include/llama.h b/include/llama.h index 5e742642e..bc268e799 100644 --- a/include/llama.h +++ b/include/llama.h @@ -1244,8 +1244,6 @@ extern "C" { LLAMA_API void llama_perf_sampler_print(const struct llama_sampler * chain); LLAMA_API void llama_perf_sampler_reset( struct llama_sampler * chain); - LLAMA_API void llama_perf_dump_yaml(FILE * stream, const struct llama_context * ctx); - #ifdef __cplusplus } #endif diff --git a/scripts/run-with-preset.py b/scripts/run-with-preset.py deleted file mode 100755 index 8f0bf8ca8..000000000 --- a/scripts/run-with-preset.py +++ /dev/null @@ -1,146 +0,0 @@ -#!/usr/bin/env python3 - -import logging -import argparse -import os -import subprocess -import sys - -import yaml - -logger = logging.getLogger("run-with-preset") - -CLI_ARGS_LLAMA_CLI_PERPLEXITY = [ - "batch-size", "cfg-negative-prompt", "cfg-scale", "chunks", "color", "ctx-size", "escape", - "export", "file", "frequency-penalty", "grammar", "grammar-file", "hellaswag", - "hellaswag-tasks", "ignore-eos", "in-prefix", "in-prefix-bos", "in-suffix", - "interactive", "interactive-first", "keep", "logdir", "logit-bias", "lora", "lora-base", - "low-vram", "main-gpu", "mirostat", "mirostat-ent", "mirostat-lr", "mlock", - "model", "multiline-input", "n-gpu-layers", "n-predict", "no-mmap", "no-mul-mat-q", - "np-penalize-nl", "numa", "ppl-output-type", "ppl-stride", "presence-penalty", "prompt", - "prompt-cache", "prompt-cache-all", "prompt-cache-ro", "repeat-last-n", - "repeat-penalty", "reverse-prompt", "rope-freq-base", "rope-freq-scale", "rope-scale", "seed", - "simple-io", "tensor-split", "threads", "temp", "top-k", "top-p", "typical", - "verbose-prompt" -] - -CLI_ARGS_LLAMA_BENCH = [ - "batch-size", "low-vram", "model", "mul-mat-q", "n-gen", "n-gpu-layers", - "n-prompt", "output", "repetitions", "tensor-split", "threads", "verbose" -] - -CLI_ARGS_LLAMA_SERVER = [ - "alias", "batch-size", "ctx-size", "embedding", "host", "lora", "lora-base", - "low-vram", "main-gpu", "mlock", "model", "n-gpu-layers", "n-probs", "no-mmap", "no-mul-mat-q", - "numa", "path", "port", "rope-freq-base", "timeout", "rope-freq-scale", "tensor-split", - "threads", "verbose" -] - -description = """Run llama.cpp binaries with presets from YAML file(s). -To specify which binary should be run, specify the "binary" property (llama-cli, llama-perplexity, llama-bench, and llama-server are supported). -To get a preset file template, run a llama.cpp binary with the "--logdir" CLI argument. - -Formatting considerations: -- The YAML property names are the same as the CLI argument names of the corresponding binary. -- Properties must use the long name of their corresponding llama.cpp CLI arguments. -- Like the llama.cpp binaries the property names do not differentiate between hyphens and underscores. -- Flags must be defined as ": true" to be effective. -- To define the logit_bias property, the expected format is ": " in the "logit_bias" namespace. -- To define multiple "reverse_prompt" properties simultaneously the expected format is a list of strings. -- To define a tensor split, pass a list of floats. -""" -usage = "run-with-preset.py [-h] [yaml_files ...] [-- ...]" -epilog = (" -- specify additional CLI ars to be passed to the binary (override all preset files). " - "Unknown args will be ignored.") - -parser = argparse.ArgumentParser( - description=description, usage=usage, epilog=epilog, formatter_class=argparse.RawTextHelpFormatter) -parser.add_argument("-bin", "--binary", help="The binary to run.") -parser.add_argument("yaml_files", nargs="*", - help="Arbitrary number of YAML files from which to read preset values. " - "If two files specify the same values the later one will be used.") -parser.add_argument("--verbose", action="store_true", help="increase output verbosity") - -known_args, unknown_args = parser.parse_known_args() - -if not known_args.yaml_files and not unknown_args: - parser.print_help() - sys.exit(0) - -logging.basicConfig(level=logging.DEBUG if known_args.verbose else logging.INFO) - -props = dict() - -for yaml_file in known_args.yaml_files: - with open(yaml_file, "r") as f: - props.update(yaml.load(f, yaml.SafeLoader)) - -props = {prop.replace("_", "-"): val for prop, val in props.items()} - -binary = props.pop("binary", "llama-cli") -if known_args.binary: - binary = known_args.binary - -if os.path.exists(f"./{binary}"): - binary = f"./{binary}" - -if binary.lower().endswith("llama-cli") or binary.lower().endswith("llama-perplexity"): - cli_args = CLI_ARGS_LLAMA_CLI_PERPLEXITY -elif binary.lower().endswith("llama-bench"): - cli_args = CLI_ARGS_LLAMA_BENCH -elif binary.lower().endswith("llama-server"): - cli_args = CLI_ARGS_LLAMA_SERVER -else: - logger.error(f"Unknown binary: {binary}") - sys.exit(1) - -command_list = [binary] - -for cli_arg in cli_args: - value = props.pop(cli_arg, None) - - if not value or value == -1: - continue - - if cli_arg == "logit-bias": - for token, bias in value.items(): - command_list.append("--logit-bias") - command_list.append(f"{token}{bias:+}") - continue - - if cli_arg == "reverse-prompt" and not isinstance(value, str): - for rp in value: - command_list.append("--reverse-prompt") - command_list.append(str(rp)) - continue - - command_list.append(f"--{cli_arg}") - - if cli_arg == "tensor-split": - command_list.append(",".join([str(v) for v in value])) - continue - - value = str(value) - - if value != "True": - command_list.append(str(value)) - -num_unused = len(props) -if num_unused > 10: - logger.info(f"The preset file contained a total of {num_unused} unused properties.") -elif num_unused > 0: - logger.info("The preset file contained the following unused properties:") - for prop, value in props.items(): - logger.info(f" {prop}: {value}") - -command_list += unknown_args - -sp = subprocess.Popen(command_list) - -while sp.returncode is None: - try: - sp.wait() - except KeyboardInterrupt: - pass - -sys.exit(sp.returncode) diff --git a/scripts/sync-ggml-am.sh b/scripts/sync-ggml-am.sh index 74d6c6c8b..d0815cf89 100755 --- a/scripts/sync-ggml-am.sh +++ b/scripts/sync-ggml-am.sh @@ -73,17 +73,20 @@ while read c; do src/ggml*.h \ src/ggml*.c \ src/ggml*.cpp \ - src/ggml*.m \ - src/ggml*.metal \ - src/ggml*.cu \ src/ggml-amx/* \ + src/ggml-blas/* \ src/ggml-cann/* \ + src/ggml-cpu/* \ src/ggml-cuda/* \ + src/ggml-hip/* \ + src/ggml-kompute/* \ + src/ggml-metal/* \ + src/ggml-musa/* \ + src/ggml-rpc/* \ src/ggml-sycl/* \ - src/vulkan-shaders/* \ + src/ggml-vulkan/* \ include/ggml*.h \ tests/test-opt.cpp \ - tests/test-grad0.cpp \ tests/test-quantize-fns.cpp \ tests/test-quantize-perf.cpp \ tests/test-backend-ops.cpp \ @@ -121,21 +124,22 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then # src/ggml*.c -> ggml/src/ggml*.c # src/ggml*.cpp -> ggml/src/ggml*.cpp # src/ggml*.h -> ggml/src/ggml*.h - # src/ggml*.cu -> ggml/src/ggml*.cu - # src/ggml*.m -> ggml/src/ggml*.m - # src/ggml-amx/* -> ggml/src/ggml-amx/ - # src/ggml-cann/* -> ggml/src/ggml-cann/ - # src/ggml-cuda/* -> ggml/src/ggml-cuda/ - # src/ggml-sycl/* -> ggml/src/ggml-sycl/ - # src/vulkan-shaders/* -> ggml/src/vulkan-shaders/ + # src/ggml-amx/* -> ggml/src/ggml-amx/* + # src/ggml-blas/* -> ggml/src/ggml-blas/* + # src/ggml-cann/* -> ggml/src/ggml-cann/* + # src/ggml-cpu/* -> ggml/src/ggml-cpu/* + # src/ggml-cuda/* -> ggml/src/ggml-cuda/* + # src/ggml-hip/* -> ggml/src/ggml-hip/* + # src/ggml-kompute/* -> ggml/src/ggml-kompute/* + # src/ggml-metal/* -> ggml/src/ggml-metal/* + # src/ggml-musa/* -> ggml/src/ggml-musa/* + # src/ggml-rpc/* -> ggml/src/ggml-rpc/* + # src/ggml-sycl/* -> ggml/src/ggml-sycl/* + # src/ggml-vulkan/* -> ggml/src/ggml-vulkan/* # # include/ggml*.h -> ggml/include/ggml*.h # - # tests/test-opt.cpp -> tests/test-opt.cpp - # tests/test-grad0.cpp -> tests/test-grad0.cpp - # tests/test-quantize-fns.cpp -> tests/test-quantize-fns.cpp - # tests/test-quantize-perf.cpp -> tests/test-quantize-perf.cpp - # tests/test-backend-ops.cpp -> tests/test-backend-ops.cpp + # tests/test*.cpp -> tests/ # # LICENSE -> LICENSE # scripts/gen-authors.sh -> scripts/gen-authors.sh @@ -147,18 +151,20 @@ if [ -f $SRC_LLAMA/ggml-src.patch ]; then -e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.c/\1ggml\/src\/ggml\2.c/g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.cpp/\1ggml\/src\/ggml\2.cpp/g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.h/\1ggml\/src\/ggml\2.h/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.cu/\1ggml\/src\/ggml\2.cu/g' \ - -e 's/([[:space:]]|[ab]\/)src\/ggml(.*)\.m/\1ggml\/src\/ggml\2.m/g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-amx\//\1ggml\/src\/ggml-amx\//g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-blas\//\1ggml\/src\/ggml-blas\//g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-cann\//\1ggml\/src\/ggml-cann\//g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-cpu\//\1ggml\/src\/ggml-cpu\//g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-cuda\//\1ggml\/src\/ggml-cuda\//g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-hip\//\1ggml\/src\/ggml-hip\//g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-kompute\//\1ggml\/src\/ggml-kompute\//g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-metal\//\1ggml\/src\/ggml-metal\//g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-musa\//\1ggml\/src\/ggml-musa\//g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-rpc\//\1ggml\/src\/ggml-rpc\//g' \ -e 's/([[:space:]]|[ab]\/)src\/ggml-sycl\//\1ggml\/src\/ggml-sycl\//g' \ - -e 's/([[:space:]]|[ab]\/)src\/vulkan-shaders\//\1ggml\/src\/vulkan-shaders\//g' \ + -e 's/([[:space:]]|[ab]\/)src\/ggml-vulkan\//\1ggml\/src\/ggml-vulkan\//g' \ -e 's/([[:space:]]|[ab]\/)include\/ggml(.*)\.h/\1ggml\/include\/ggml\2.h/g' \ - -e 's/([[:space:]]|[ab]\/)examples\/common\.h/\1examples\/common.h/g' \ - -e 's/([[:space:]]|[ab]\/)examples\/common\.cpp/\1examples\/common.cpp/g' \ - -e 's/([[:space:]]|[ab]\/)examples\/common-ggml\.h/\1examples\/common-ggml.h/g' \ - -e 's/([[:space:]]|[ab]\/)examples\/common-ggml\.cpp/\1examples\/common-ggml.cpp/g' \ + -e 's/([[:space:]]|[ab]\/)tests\/(.*)\.cpp/\1tests\/\2.cpp/g' \ -e 's/([[:space:]]|[ab]\/)LICENSE/\1LICENSE/g' \ -e 's/([[:space:]]|[ab]\/)scripts\/gen-authors\.sh/\1scripts\/gen-authors.sh/g' \ > ggml-src.patch.tmp diff --git a/scripts/sync-ggml.sh b/scripts/sync-ggml.sh index f29554c82..000270afb 100755 --- a/scripts/sync-ggml.sh +++ b/scripts/sync-ggml.sh @@ -4,21 +4,25 @@ cp -rpv ../ggml/CMakeLists.txt ./ggml/CMakeLists.txt cp -rpv ../ggml/src/CMakeLists.txt ./ggml/src/CMakeLists.txt cp -rpv ../ggml/cmake/FindSIMD.cmake ./ggml/cmake/FindSIMD.cmake -cp -rpv ../ggml/src/ggml*.c ./ggml/src/ -cp -rpv ../ggml/src/ggml*.cpp ./ggml/src/ -cp -rpv ../ggml/src/ggml*.h ./ggml/src/ -cp -rpv ../ggml/src/ggml*.cu ./ggml/src/ -cp -rpv ../ggml/src/ggml*.m ./ggml/src/ -cp -rpv ../ggml/src/ggml-amx/* ./ggml/src/ggml-amx/ -cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/ -cp -rpv ../ggml/src/ggml-cuda/* ./ggml/src/ggml-cuda/ -cp -rpv ../ggml/src/ggml-sycl/* ./ggml/src/ggml-sycl/ -cp -rpv ../ggml/src/vulkan-shaders/* ./ggml/src/vulkan-shaders/ +cp -rpv ../ggml/src/ggml*.c ./ggml/src/ +cp -rpv ../ggml/src/ggml*.cpp ./ggml/src/ +cp -rpv ../ggml/src/ggml*.h ./ggml/src/ +cp -rpv ../ggml/src/ggml-amx/* ./ggml/src/ggml-amx/ +cp -rpv ../ggml/src/ggml-blas/* ./ggml/src/ggml-blas/ +cp -rpv ../ggml/src/ggml-cann/* ./ggml/src/ggml-cann/ +cp -rpv ../ggml/src/ggml-cpu/* ./ggml/src/ggml-cpu/ +cp -rpv ../ggml/src/ggml-cuda/* ./ggml/src/ggml-cuda/ +cp -rpv ../ggml/src/ggml-hip/* ./ggml/src/ggml-hip/ +cp -rpv ../ggml/src/ggml-kompute/* ./ggml/src/ggml-kompute/ +cp -rpv ../ggml/src/ggml-metal/* ./ggml/src/ggml-metal/ +cp -rpv ../ggml/src/ggml-musa/* ./ggml/src/ggml-musa/ +cp -rpv ../ggml/src/ggml-rpc/* ./ggml/src/ggml-rpc/ +cp -rpv ../ggml/src/ggml-sycl/* ./ggml/src/ggml-sycl/ +cp -rpv ../ggml/src/ggml-vulkan/* ./ggml/src/ggml-vulkan/ cp -rpv ../ggml/include/ggml*.h ./ggml/include/ cp -rpv ../ggml/tests/test-opt.cpp ./tests/test-opt.cpp -cp -rpv ../ggml/tests/test-grad0.cpp ./tests/test-grad0.cpp cp -rpv ../ggml/tests/test-quantize-fns.cpp ./tests/test-quantize-fns.cpp cp -rpv ../ggml/tests/test-quantize-perf.cpp ./tests/test-quantize-perf.cpp cp -rpv ../ggml/tests/test-backend-ops.cpp ./tests/test-backend-ops.cpp diff --git a/src/llama.cpp b/src/llama.cpp index dc5dfba0c..de96959f2 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -3460,21 +3460,13 @@ static bool llama_kv_cache_init( const uint32_t n_embd_k_gqa = hparams.n_embd_k_gqa(i) + hparams.n_embd_k_s(); const uint32_t n_embd_v_gqa = hparams.n_embd_v_gqa(i) + hparams.n_embd_v_s(); - const llama_model::buft_list_t * buft_list; + ggml_backend_buffer_type_t buft; if (offload) { - buft_list = model.dev_layer.at(i).buft_list; + auto * dev = model.dev_layer.at(i).dev; + buft = ggml_backend_dev_buffer_type(dev); } else { - buft_list = &model.cpu_buft_list; + buft = ggml_backend_cpu_buffer_type(); } - ggml_backend_buffer_type_t buft = select_buft(*buft_list, - [&](ggml_context * ctx) { - ggml_tensor * k = ggml_new_tensor_1d(ctx, type_k, n_embd_k_gqa*kv_size); - if (hparams.rope_type == LLAMA_ROPE_TYPE_NONE) { - return k; - } - ggml_tensor * p = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 1); - return ggml_rope(ctx, k, p, hparams.n_rot, hparams.rope_type); - }); ggml_context * ctx = ctx_for_buft(buft); if (!ctx) { @@ -22075,28 +22067,6 @@ void llama_perf_context_reset(struct llama_context * ctx) { ctx->t_p_eval_us = ctx->n_p_eval = 0; } -void llama_perf_dump_yaml(FILE * stream, const llama_context * ctx) { - fprintf(stream, "\n"); - fprintf(stream, "###########\n"); - fprintf(stream, "# Timings #\n"); - fprintf(stream, "###########\n"); - fprintf(stream, "\n"); - - fprintf(stream, "mst_eval: %.2f # ms / token during generation\n", - 1.0e-3 * ctx->t_eval_us / ctx->n_eval); - fprintf(stream, "mst_p_eval: %.2f # ms / token during prompt processing\n", - 1.0e-3 * ctx->t_p_eval_us / ctx->n_p_eval); - fprintf(stream, "n_eval: %d # number of tokens generated (excluding the first one)\n", ctx->n_eval); - fprintf(stream, "n_p_eval: %d # number of tokens processed in batches at the beginning\n", ctx->n_p_eval); - fprintf(stream, "t_eval_us: %" PRId64 " # total microseconds spent generating tokens\n", ctx->t_eval_us); - fprintf(stream, "t_load_us: %" PRId64 " # total microseconds spent loading the model\n", ctx->t_load_us); - fprintf(stream, "t_p_eval_us: %" PRId64 " # total microseconds spent prompt processing\n", ctx->t_p_eval_us); - fprintf(stream, "ts_eval: %.2f # tokens / second during generation\n", - 1.0e6 * ctx->n_eval / ctx->t_eval_us); - fprintf(stream, "ts_p_eval: %.2f # tokens / second during prompt processing\n", - 1.0e6 * ctx->n_p_eval / ctx->t_p_eval_us); -} - // For internal test use const std::vector> & llama_internal_get_tensor_map( struct llama_context * ctx diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 08ad66b49..b06f122e8 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -116,9 +116,8 @@ llama_target_and_test(test-sampling.cpp) llama_target_and_test(test-chat-template.cpp) llama_target_and_test(test-grammar-parser.cpp) -llama_target_and_test(test-llama-grammar.cpp) llama_target_and_test(test-grammar-integration.cpp) -llama_target_and_test(test-grad0.cpp) +llama_target_and_test(test-llama-grammar.cpp) llama_target_and_test(test-barrier.cpp) # llama_target_and_test(test-opt.cpp) # SLOW llama_target_and_test(test-backend-ops.cpp) diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6618d03d1..f8a59b6df 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -811,11 +811,11 @@ struct test_case { ggml_build_forward_expand(gf, out); ggml_graph_cpy(gf, gb); - ggml_build_backward_expand(ctx, gf, gb, false); + ggml_build_backward_expand(ctx, ctx, gb, false); if (expect.size() != 1 || expect[0] != 0.0f) { GGML_ASSERT(ggml_graph_n_nodes(gb) > ggml_graph_n_nodes(gf)); for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || t->grad->op != GGML_OP_NONE); + GGML_ASSERT(!(t->flags & GGML_TENSOR_FLAG_PARAM) || ggml_graph_get_grad(gb, t)->op != GGML_OP_NONE); } } @@ -862,7 +862,13 @@ struct test_case { const char * bn = ggml_backend_name(backend); const int64_t ne = ggml_nelements(t); - std::vector ga = tensor_to_float(t->grad); + std::vector ga; + struct ggml_tensor * grad = ggml_graph_get_grad(gb, t); + if (grad) { + ga = tensor_to_float(grad); + } else { + ga.resize(ne); // default value is 0.0f + } for (int64_t i = 0; i < ne; ++i) { // gradient algebraic // check for nans @@ -2500,6 +2506,35 @@ struct test_sum_rows : public test_case { } }; +// GGML_OP_MEAN +struct test_mean : public test_case { + const ggml_type type; + const std::array ne; + + std::string vars() override { + return VARS_TO_STR2(type, ne); + } + + test_mean(ggml_type type = GGML_TYPE_F32, + std::array ne = {10, 5, 4, 3}) + : type(type), ne(ne) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); + ggml_set_param(ctx, a); + ggml_set_name(a, "a"); + + ggml_tensor * out = ggml_mean(ctx, a); + ggml_set_name(out, "out"); + + return out; + } + + float grad_eps() override { + return 0.1f * ne[0]*ne[1]*ne[2]*ne[3]; + } +}; + // GGML_OP_UPSCALE struct test_upscale : public test_case { const ggml_type type; @@ -2834,24 +2869,14 @@ struct test_cross_entropy_loss : public test_case { struct test_opt_step_adamw : public test_case { const ggml_type type; const std::array ne; - const float alpha; - const float beta1; - const float beta2; - const float eps; - const float wd; std::string vars() override { - return VARS_TO_STR7(type, ne, alpha, beta1, beta2, eps, wd); + return VARS_TO_STR2(type, ne); } test_opt_step_adamw(ggml_type type = GGML_TYPE_F32, - std::array ne = {10, 5, 4, 3}, - float alpha = 1e-3f, - float beta1 = 0.9f, - float beta2 = 0.999f, - float eps = 1e-8f, - float wd = 0.0f) - : type(type), ne(ne), alpha(alpha), beta1(beta1), beta2(beta2), eps(eps), wd(wd) {} + std::array ne = {10, 5, 4, 3}) + : type(type), ne(ne) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); @@ -2861,7 +2886,16 @@ struct test_opt_step_adamw : public test_case { ggml_tensor * grad = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); ggml_set_name(grad, "grad"); - ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, alpha, beta1, beta2, eps, wd); + ggml_tensor * grad_m = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); + ggml_set_name(grad_m, "grad_m"); + + ggml_tensor * grad_v = ggml_new_tensor_4d(ctx, type, ne[0], ne[1], ne[2], ne[3]); + ggml_set_name(grad_v, "grad_v"); + + ggml_tensor * adamw_params = ggml_new_tensor_1d(ctx, GGML_TYPE_F32, 7); + ggml_set_name(adamw_params, "adamw_params"); + + ggml_tensor * out = ggml_opt_step_adamw(ctx, a, grad, grad_m, grad_v, adamw_params); ggml_set_name(out, "out"); return out; @@ -2869,7 +2903,7 @@ struct test_opt_step_adamw : public test_case { void initialize_tensors(ggml_context * ctx) override { for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { - init_tensor_uniform(t, 0.0f, 1.0f); // grad_v needs non-negative values. + init_tensor_uniform(t, 0.0f, 1.0f); // grad_v and adamw_params need non-negative values. } } @@ -3735,6 +3769,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_sum()); test_cases.emplace_back(new test_sum_rows()); + test_cases.emplace_back(new test_mean()); test_cases.emplace_back(new test_upscale()); test_cases.emplace_back(new test_upscale(GGML_TYPE_F32, { 512, 512, 3, 1 }, 2, true)); test_cases.emplace_back(new test_upscale_ext()); @@ -3766,9 +3801,7 @@ static std::vector> make_test_cases_eval() { } test_cases.emplace_back(new test_cross_entropy_loss()); - for (float wd : {0.0f, 1e-2f}) { - test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3}, 1.0f, 1e-3f, 0.9f, 0.999f, wd)); - } + test_cases.emplace_back(new test_opt_step_adamw(GGML_TYPE_F32, {10, 5, 4, 3})); // these tests are disabled to save execution time, but they can be handy for debugging #if 0 @@ -3938,6 +3971,8 @@ int main(int argc, char ** argv) { ggml_backend_free(backend); } + ggml_quantize_free(); + printf("%zu/%zu backends passed\n", n_ok, ggml_backend_dev_count()); if (n_ok != ggml_backend_dev_count()) { @@ -3945,8 +3980,6 @@ int main(int argc, char ** argv) { return 1; } - ggml_quantize_free(); - printf("\033[1;32mOK\033[0m\n"); return 0; } diff --git a/tests/test-grad0.cpp b/tests/test-grad0.cpp deleted file mode 100644 index c712dba7f..000000000 --- a/tests/test-grad0.cpp +++ /dev/null @@ -1,1684 +0,0 @@ -#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnings on Windows -#include "ggml.h" -#include "ggml-cpu.h" - -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(_MSC_VER) -#pragma warning(disable: 4244 4267) // possible loss of data -#endif - -#if defined(__GNUC__) -#pragma GCC diagnostic ignored "-Wdouble-promotion" -#endif - -#define MAX_NARGS 3 - -#undef MIN -#undef MAX -#define MIN(a, b) ((a) < (b) ? (a) : (b)) -#define MAX(a, b) ((a) > (b) ? (a) : (b)) - -#define GGML_SILU_FP16 - -// -// logging -// - -#if (GGML_DEBUG >= 1) -#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG(...) -#endif - -#if (GGML_DEBUG >= 5) -#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_5(...) -#endif - -#if (GGML_DEBUG >= 10) -#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_10(...) -#endif - -#define GGML_PRINT(...) printf(__VA_ARGS__) - -static float frand(void) { - return (float)rand()/(float)RAND_MAX; -} - -static int irand(int n) { - if (n == 0) return 0; - return rand()%n; -} - -static void get_random_dims(int64_t * dims, int ndims) { - dims[0] = dims[1] = dims[2] = dims[3] = 1; - - for (int i = 0; i < ndims; i++) { - dims[i] = 1 + irand(4); - } -} - -static struct ggml_tensor * get_random_tensor_f32( - struct ggml_context * ctx0, - int ndims, - int64_t ne[], - float fmin, - float fmax) { - struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne); - - switch (ndims) { - case 1: - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin; - } - break; - case 2: - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; - } - } - break; - case 3: - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; - } - } - } - break; - case 4: - for (int i3 = 0; i3 < ne[3]; i3++) { - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; - } - } - } - } - break; - default: - assert(false); - } - - return result; -} - -static struct ggml_tensor * get_random_tensor_f16( - struct ggml_context * ctx0, - int ndims, - int64_t ne[], - float fmin, - float fmax) { - struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F16, ndims, ne); - - switch (ndims) { - case 1: - for (int i0 = 0; i0 < ne[0]; i0++) { - ((ggml_fp16_t *)result->data)[i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin); - } - break; - case 2: - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((ggml_fp16_t *)result->data)[i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin); - } - } - break; - case 3: - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((ggml_fp16_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin); - } - } - } - break; - case 4: - for (int i3 = 0; i3 < ne[3]; i3++) { - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((ggml_fp16_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = ggml_fp32_to_fp16(frand()*(fmax - fmin) + fmin); - } - } - } - } - break; - default: - assert(false); - } - - return result; -} - -static struct ggml_tensor * get_random_tensor_i32( - struct ggml_context * ctx0, - int ndims, - int64_t ne[], - int32_t imin, - int32_t imax) { - struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_I32, ndims, ne); - - switch (ndims) { - case 1: - for (int i0 = 0; i0 < ne[0]; i0++) { - ((int32_t *)result->data)[i0] = irand(imax - imin) + imin; - } - break; - case 2: - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((int32_t *)result->data)[i1*ne[0] + i0] = irand(imax - imin) + imin; - } - } - break; - case 3: - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((int32_t *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin; - } - } - } - break; - case 4: - for (int i3 = 0; i3 < ne[3]; i3++) { - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((int32_t *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = irand(imax - imin) + imin; - } - } - } - } - break; - default: - assert(false); - } - - return result; -} - -static bool check_gradient( - const char * op_name, - struct ggml_context * ctx0, - struct ggml_tensor * x[], - struct ggml_tensor * f, - int ndims, - int nargs, - float eps, - float max_error_abs, - float max_error_rel, - std::vector expected_vals) { - - static int n_threads = -1; - if (n_threads < 0) { - n_threads = GGML_DEFAULT_N_THREADS; - - const char *env = getenv("GGML_N_THREADS"); - if (env) { - n_threads = atoi(env); - } - - printf("GGML_N_THREADS = %d\n", n_threads); - } - - struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); - struct ggml_cgraph * gb = ggml_new_graph_custom(ctx0, GGML_DEFAULT_GRAPH_SIZE, true); - ggml_build_forward_expand(gf, f); - ggml_graph_cpy(gf, gb); - ggml_build_backward_expand(ctx0, gf, gb, false); - - ggml_graph_compute_with_ctx(ctx0, gf, n_threads); - - ggml_graph_reset(gb); - if (f->grad) { - ggml_set_f32(f->grad, 1.0f); - } - - ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - - // ggml_graph_dump_dot(gf, NULL, "test-grad0-forward.dot"); - // ggml_graph_dump_dot(gb, gf, "test-grad0-backward.dot"); - - for (int i = 0; i < nargs; ++i) { - bool all_g0_bad = true; - const int nelements = ggml_nelements(x[i]); - for (int k = 0; k < nelements; ++k) { - // Calculate gradient numerically: - const float x0 = ggml_get_f32_1d(x[i], k); - const float xm = x0 - eps; - const float xp = x0 + eps; - ggml_set_f32_1d(x[i], k, xp); - - ggml_graph_compute_with_ctx(ctx0, gf, n_threads); - - const double f0 = ggml_get_f32_1d(f, 0); - - ggml_set_f32_1d(x[i], k, xm); - - ggml_graph_compute_with_ctx(ctx0, gf, n_threads); - - const double f1 = ggml_get_f32_1d(f, 0); - const double g0 = (f0 - f1)/(2.0*(double) eps); - - // The numerical calculation of the gradient fails around noncontinuities (e.g. 0 for ReLU). - // In such cases, provide a vector of expected values and skip the comparison for failed calculations. - if (!expected_vals.empty()) { - bool matches_any = false; - for (const double & ev : expected_vals) { - const double error_abs = std::fabs(g0 - ev); - if (error_abs > max_error_abs) { - continue; - } - const double error_rel = g0 != 0.0 ? fabs(g0 - ev)/fabs(g0) : 0.0; - if (error_rel > max_error_rel) { - continue; - } - matches_any = true; - break; - } - if (!matches_any) { - continue; - } - } - all_g0_bad = false; - - ggml_set_f32_1d(x[i], k, x0); - - // compute gradient using backward graph - ggml_graph_reset(gb); - if (f->grad) { - ggml_set_f32(f->grad, 1.0f); - } - - ggml_graph_compute_with_ctx(ctx0, gb, n_threads); - - const double g1 = ggml_get_f32_1d(x[i]->grad, k); - - const double error_abs = fabs(g0 - g1); - const double error_rel = g0 != 0.0 ? fabs(g0 - g1)/fabs(g0) : 0.0; - - if (error_abs > max_error_abs || error_rel > max_error_rel) { - printf("%s: ndims=%d, i=%d, k=%d, x0=%f, xm=%f, xp=%f, f0=%f, f1=%f, g0=%f, g1=%f, eps=%f, error_abs=%f, error_rel=%f\n", - op_name, ndims, i, k, x0, xm, xp, f0, f1, g0, g1, eps, error_abs, error_rel); - //assert(false); - return false; - } - } - if (all_g0_bad) { - printf("%s: numerical calculation of the gradient failed for all values\n", op_name); - return false; - } - } - - return true; -} - -// TODO: clean-up this .. -static bool check_mat_mul( - const struct ggml_tensor * y, - const struct ggml_tensor * x0, - const struct ggml_tensor * x1) { - float * dst = (float *) y->data; - float * src0 = (float *) x0->data; - float * src1 = (float *) x1->data; - - const int nc = x0->ne[1]; - const int nr = x1->ne[1]; - const int nk = x0->ne[0]; - - GGML_PRINT_DEBUG("check_mat_mul: nc=%d, nr=%d, nk=%d\n", nc, nr, nk); - - GGML_PRINT_DEBUG("x0:\n"); - for (int j = 0; j < x0->ne[1]; ++j) { - for (int i = 0; i < x0->ne[0]; ++i) { - GGML_PRINT_DEBUG("%6.3f ", src0[j*nk + i]); - } - GGML_PRINT_DEBUG("\n"); - } - GGML_PRINT_DEBUG("\n"); - - GGML_PRINT_DEBUG("x1:\n"); - for (int j = 0; j < x1->ne[1]; ++j) { - for (int i = 0; i < x1->ne[0]; ++i) { - GGML_PRINT_DEBUG("%6.3f ", src1[j*nk + i]); - } - GGML_PRINT_DEBUG("\n"); - } - GGML_PRINT_DEBUG("\n"); - - GGML_PRINT_DEBUG("y: n_dims = %d, (%lld, %lld)\n", y->n_dims, y->ne[0], y->ne[1]); - for (int j = 0; j < y->ne[1]; ++j) { - for (int i = 0; i < y->ne[0]; ++i) { - GGML_PRINT_DEBUG("%6.3f ", dst[j*nr + i]); - } - GGML_PRINT_DEBUG("\n"); - } - - for (int i = 0; i < nr; ++i) { - for (int j = 0; j < nc; ++j) { - float sum = 0.0f; - - for (int k = 0; k < nk; ++k) { - sum += src0[j*nk + k]*src1[i*nk + k]; - } - - if (fabsf(dst[i*nc + j] - sum) > 1e-5f) { - fprintf(stderr, "check_mat_mul: dst[%d] = %f, sum = %f\n", i*nc + j, dst[i*nc + j], sum); - assert(false); - return false; - } - } - } - - return true; -} - -#define NUM_PERMUTATIONS (4*3*2*1) - -int main(int argc, const char ** argv) { - struct ggml_init_params params = { - /* .mem_size = */ 256*1024*1024, - /* .mem_buffer = */ NULL, - /* .no_alloc = */ false, - }; - - int64_t ne[4]; - - int all_permutations[4 * NUM_PERMUTATIONS]; - { - int count = 0; - for (int ax0=0; ax0<4; ++ax0) { - for (int ax1=0; ax1<4; ++ax1) { - if (ax1 == ax0) continue; - for (int ax2=0; ax2<4; ++ax2) { - if (ax2 == ax0) continue; - if (ax2 == ax1) continue; - for (int ax3=0; ax3<4; ++ax3) { - if (ax3 == ax0) continue; - if (ax3 == ax1) continue; - if (ax3 == ax2) continue; - assert(count < NUM_PERMUTATIONS); - all_permutations[count*4+0] = ax0; - all_permutations[count*4+1] = ax1; - all_permutations[count*4+2] = ax2; - all_permutations[count*4+3] = ax3; - ++count; - } - } - } - } - } - - unsigned seed_iter = 1; - - // original loop: 1000 - int niter = 4; - const char *env = getenv("GGML_NLOOP"); - if (env != NULL) { - niter = atoi(env); - } - if (argc > 1) { - niter = atoi(argv[1]); - } - for (int iter = 0; iter < niter; ++iter) { - srand(seed_iter); - seed_iter = rand(); - unsigned seed = rand(); - - printf("test-grad0: iter:%d/%d\n", (iter+1), niter); - struct ggml_context * ctx0 = ggml_init(params); - - get_random_dims(ne, 4); - - struct ggml_tensor * x[MAX_NARGS]; - - // add f32 - { - srand(seed); - const int nargs = 2; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1])); - - check_gradient("add f32", ctx0, x, f, ndims, nargs, 1e-3f, 2e-3f, 2e-3f, {}); - } - } - - // add f16 - { - srand(seed); - const int nargs = 2; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_add(ctx0, x[0], x[1])); - - check_gradient("add f16", ctx0, x, f, ndims, nargs, 1e-1f, 2e-1f, 2e-1f, {}); - } - } - - // sub - { - srand(seed); - const int nargs = 2; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_sub(ctx0, x[0], x[1])); - - check_gradient("sub", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - // mul - { - srand(seed); - const int nargs = 2; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_mul(ctx0, x[0], x[1])); - - check_gradient("mul", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // div - { - srand(seed); - const int nargs = 2; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, 0.5f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_div(ctx0, x[0], x[1])); - - check_gradient("div", ctx0, x, f, ndims, nargs, 1e-3f, 1e-1f, 1e-1f, {}); - } - } - - // sqr - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, x[0])); - - check_gradient("sqr", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // sqrt - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqrt(ctx0, x[0])); - - check_gradient("sqrt", ctx0, x, f, ndims, nargs, 1e-3f, 2e-2f, 1e-1f, {}); - } - } - - // log - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, 2.0f*1e-3f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_log(ctx0, x[0])); - - check_gradient("log", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-1f, {}); - } - } - - // sum - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, x[0]); - - check_gradient("sum", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - - // sum_rows - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sum_rows(ctx0, x[0]))); - - check_gradient("sum_rows", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {}); - } - } - - // mean, not yet fully implemented - if(0) - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_mean(ctx0, x[0])); - - check_gradient("mean", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - // argmax - if (0) - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_argmax(ctx0, x[0])); - - check_gradient("argmax", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - // repeat - { - srand(seed); - int64_t ne2[4]; - get_random_dims(ne2, 4); - - ne2[0] = ne[0] * ne2[0]; - ne2[1] = ne[1] * ne2[1]; - ne2[2] = 1; - ne2[3] = 1; - - const int nargs = 1; - for (int ndims = 1; ndims <= 2; ++ndims) { - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[1], ggml_repeat(ctx0, x[0], x[1])))); - - check_gradient("repeat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {}); - } - } - - // repeat back - { - srand(seed); - int64_t ne2[4]; - get_random_dims(ne2, 4); - - ne2[0] = ne[0] * ne2[0]; - ne2[1] = ne[1] * ne2[1]; - ne2[2] = 1; - ne2[3] = 1; - - const int nargs = 1; - for (int ndims = 1; ndims <= 2; ++ndims) { - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_sqr(ctx0, ggml_sub(ctx0, x[0], ggml_repeat_back(ctx0, x[1], x[0])))); - - check_gradient("repeat back", ctx0, x, f, ndims, nargs, 1e-3f, 1e-2f, INFINITY, {}); - } - } - - // abs - { - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_abs(ctx0, x[0])); - - check_gradient("abs", ctx0, x, f, ndims, nargs, 1e-3f, INFINITY, 1e-3f, {-1.0, 1.0}); - } - } - - // sgn - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor* f = ggml_sum(ctx0, ggml_sgn(ctx0, x[0])); - - check_gradient("sgn", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0}); - } - } - - // neg - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor* f = ggml_sum(ctx0, ggml_neg(ctx0, x[0])); - - check_gradient("neg", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - // step - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor* f = ggml_sum(ctx0, ggml_step(ctx0, x[0])); - - check_gradient("step", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {0.0}); - } - } - - // tanh, not yet fully implemented - if(0) - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor* f = ggml_sum(ctx0, ggml_tanh(ctx0, x[0])); - - check_gradient("tanh", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - // mul_mat - { - srand(seed); - const int nargs = 2; - - for (int ndims = 2; ndims <= 4; ++ndims) { - int max_nrep = (ndims >= 3) ? 2 : 1; - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - for (int nrep2 = 1; nrep2 < max_nrep; ++nrep2) { - for (int nrep3 = 1; nrep3 < max_nrep; ++nrep3) { - { - int64_t ne2[4]; - get_random_dims(ne2, 4); - ne2[0] = ne[0]; - ne2[2] = nrep2 * ne[2]; - ne2[3] = nrep3 * ne[3]; - x[1] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); - } - - ggml_set_param(ctx0, x[0]); - ggml_set_param(ctx0, x[1]); - - struct ggml_tensor * m = ggml_mul_mat(ctx0, x[1], x[0]); - struct ggml_tensor * f = ggml_sum(ctx0, m); - - GGML_PRINT_DEBUG("testing: mul_mat, [%lld, %lld] (%d) * [%lld, %lld] (%d)\n", x[1]->ne[0], x[1]->ne[1], x[1]->n_dims, x[0]->ne[0], x[0]->ne[1], x[0]->n_dims); - - check_gradient("mul_mat", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - if (ndims == 2) { - // check_mat_mul does not support ndims > 2 - check_mat_mul(m, x[1], x[0]); - } - } - } - } - } - - // elu, not yet fully implemented - if(0) - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor* f = ggml_sum(ctx0, ggml_elu(ctx0, x[0])); - - check_gradient("elu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - // relu - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor* f = ggml_sum(ctx0, ggml_relu(ctx0, x[0])); - - check_gradient("relu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {0.0, 1.0}); - } - } - - // gelu, not yet fully implemented - if(0) - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 4; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor* f = ggml_sum(ctx0, ggml_gelu(ctx0, x[0])); - - check_gradient("gelu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, 1e-3f, {}); - } - } - - // silu - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_silu(ctx0, x[0])); - -#ifdef GGML_SILU_FP16 - // due to GGML_SILU_FP16 the finite difference method will be slightly wrong -> increase error bounds. - check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 0.5, INFINITY, {}); -#else - check_gradient("silu", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); -#endif - } - } - - // rms_norm - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rms_norm(ctx0, x[0], 1e-6f)); - - check_gradient("rms_norm", ctx0, x, f, ndims, nargs, 1e-4f, 1.0f, INFINITY, {}); - } - } - - // scale - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - - const float s = -1.0f + 2.0f*frand(); - - ggml_set_param(ctx0, x[0]); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_scale(ctx0, x[0], s)); - - check_gradient("scale", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // cpy f32 - { - srand(seed); - const int nargs = 2; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - // x[1] is overwritten by x[0], so the gradients don't propagate to x[1] - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1])); - - check_gradient("cpy f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // cpy f16 - { - srand(seed); - const int nargs = 2; - - for (int ndims = 1; ndims <= 2; ++ndims) { - for (int i = 0; i < nargs; ++i) { - x[i] = get_random_tensor_f16(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[i]); - } - // x[1] is overwritten by x[0], so the gradients don't propagate to x[1] - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_cpy(ctx0, x[0], x[1])); - - check_gradient("cpy f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {}); - } - } - - // reshape (1d->nd) - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - int64_t ne2[4]; - ne2[0] = 1; - ne2[1] = 1; - ne2[2] = 1; - ne2[3] = 1; - for (int i = 0; i < ndims; ++i) { - ne2[0] *= ne[i]; - } - x[0] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f); - x[1] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1])); - check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // reshape (nd->1d) - { - srand(seed); - const int nargs = 1; - - for (int ndims = 1; ndims <= 2; ++ndims) { - int64_t ne2[4]; - ne2[0] = 1; - ne2[1] = 1; - ne2[2] = 1; - ne2[3] = 1; - for (int i = 0; i < ndims; ++i) { - ne2[0] *= ne[i]; - } - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_reshape(ctx0, x[0], x[1])); - check_gradient("reshape", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // acc 1d - { - srand(seed); - int64_t ne2[4] = { 1, 1, 1, 1 }; - - const int nargs = 2; - for (int ndims = 1; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - get_random_dims(ne2, 1); - while ((ne2[0] > ne[0]) || (ne2[0] > ggml_nelements(x[0]))) { - get_random_dims(ne2, 1); - } - - x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[1]); - - const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1])); - const int offset = irand(max_offset) * ggml_element_size(x[0]); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - - check_gradient("acc 1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // acc 2d - { - srand(seed); - int64_t ne2[4] = { 1, 1, 1, 1 }; - int64_t max_offsets[4] = { 0, 0, 0, 0 }; - int64_t offsets[4] = { 0, 0, 0, 0 }; - - const int nargs = 2; - for (int ndims = 2; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - get_random_dims(ne2, 2); - while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[0]*ne2[1] > ggml_nelements(x[0]))) { - get_random_dims(ne2, 2); - } - - x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[1]); - - max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]); - max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]); - offsets[0] = irand(max_offsets[0]) * x[0]->nb[0]; - offsets[1] = irand(max_offsets[1]) * x[0]->nb[1]; - const int offset = offsets[0] + offsets[1]; - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - - check_gradient("acc 2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // acc 3d - { - srand(seed); - int64_t ne2[4] = { 1, 1, 1, 1 }; - int64_t max_offsets[4] = { 0, 0, 0, 0 }; - int64_t offsets[4] = { 0, 0, 0, 0 }; - - const int nargs = 2; - for (int ndims = 3; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - get_random_dims(ne2, 3); - while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[2] > ne[2]) || (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0]))) { - get_random_dims(ne2, 3); - } - - x[1] = get_random_tensor_f32(ctx0, 3, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[1]); - - max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]); - max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]); - max_offsets[2] = MAX(0, x[0]->ne[2] - x[1]->ne[2]); - offsets[0] = irand(max_offsets[0]) * x[0]->nb[0]; - offsets[1] = irand(max_offsets[1]) * x[0]->nb[1]; - offsets[2] = irand(max_offsets[2]) * x[0]->nb[2]; - const int offset = offsets[0] + offsets[1] + offsets[2]; - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - - check_gradient("acc 3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // acc 4d - { - srand(seed); - int64_t ne2[4] = { 1, 1, 1, 1 }; - int64_t max_offsets[4] = { 0, 0, 0, 0 }; - int64_t offsets[4] = { 0, 0, 0, 0 }; - - const int nargs = 2; - for (int ndims = 4; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - get_random_dims(ne2, 4); - while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[2] > ne[2]) || (ne2[3] > ne[3]) || (ne2[0]*ne2[1]*ne2[2]*ne2[3] > ggml_nelements(x[0]))) { - get_random_dims(ne2, 4); - } - - x[1] = get_random_tensor_f32(ctx0, 4, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[1]); - - max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]); - max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]); - max_offsets[2] = MAX(0, x[0]->ne[2] - x[1]->ne[2]); - max_offsets[3] = MAX(0, x[0]->ne[3] - x[1]->ne[3]); - offsets[0] = irand(max_offsets[0]) * x[0]->nb[0]; - offsets[1] = irand(max_offsets[1]) * x[0]->nb[1]; - offsets[2] = irand(max_offsets[2]) * x[0]->nb[2]; - offsets[3] = irand(max_offsets[3]) * x[0]->nb[3]; - const int offset = offsets[0] + offsets[1] + offsets[2] + offsets[3]; - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_acc(ctx0, x[0], x[1], x[0]->nb[1], x[0]->nb[2], x[0]->nb[3], offset)); - - check_gradient("acc 4d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // set_1d - { - srand(seed); - int64_t ne2[4]; - - const int nargs = 2; - for (int ndims = 1; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - get_random_dims(ne2, 1); - while ((ne2[0] > ne[0]) || (ne2[0] > ggml_nelements(x[0]))) { - get_random_dims(ne2, 1); - } - - x[1] = get_random_tensor_f32(ctx0, 1, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[1]); - - const int max_offset = MAX(0, ggml_nelements(x[0]) - ggml_nelements(x[1])); - const int offset = irand(max_offset) * ggml_element_size(x[0]); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_1d(ctx0, x[0], x[1], offset)); - - check_gradient("set_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // set_2d - { - srand(seed); - int64_t ne2[4]; - int64_t max_offsets[4] = { 0, 0, 0, 0 }; - int64_t offsets[4] = { 0, 0, 0, 0 }; - - const int nargs = 1; - for (int ndims = 2; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - ggml_set_param(ctx0, x[0]); - - get_random_dims(ne2, 2); - while ((ne2[0] > ne[0]) || (ne2[1] > ne[1]) || (ne2[0]*ne2[1] > ggml_nelements(x[0]))) { - get_random_dims(ne2, 2); - } - - x[1] = get_random_tensor_f32(ctx0, 2, ne2, -1.0f, 1.0f); - ggml_set_param(ctx0, x[1]); - - max_offsets[0] = MAX(0, x[0]->ne[0] - x[1]->ne[0]); - max_offsets[1] = MAX(0, x[0]->ne[1] - x[1]->ne[1]); - offsets[0] = irand(max_offsets[0]) * x[0]->nb[0]; - offsets[1] = irand(max_offsets[1]) * x[0]->nb[1]; - const int offset = offsets[0] + offsets[1]; - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_set_2d(ctx0, x[0], x[1], x[1]->nb[1], offset)); - - check_gradient("set_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // view_1d - { - srand(seed); - const int nargs = 1; - for (int ndims = 1; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - - ggml_set_param(ctx0, x[0]); - - const int k0 = irand(ggml_nelements(x[0])); - const int k1 = irand(ggml_nelements(x[0])); - const int i0 = MIN(k0, k1); - const int i1 = MAX(k0, k1); - - const int offset = i0 * sizeof(float); - const int nelem = i1 - i0; - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_1d(ctx0, x[0], nelem, offset)); - - check_gradient("view_1d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // view_2d - { - srand(seed); - int64_t ne2[4]; - int64_t nb2[4]; - - const int nargs = 1; - for (int ndims = 1; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - - get_random_dims(ne2, 2); - while (ne2[0]*ne2[1] > ggml_nelements(x[0])) { - get_random_dims(ne2, 2); - } - const int count = ne2[0]*ne2[1]; - - nb2[0] = sizeof(float); - nb2[1] = nb2[0]*ne2[0]; - - ggml_set_param(ctx0, x[0]); - - const int max_offset = ggml_nelements(x[0]) - count; - const int offset = irand(max_offset+1) * sizeof(float); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_2d(ctx0, x[0], ne2[0], ne2[1], nb2[1], offset)); - - check_gradient("view_2d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // view_3d - { - srand(seed); - int64_t ne2[4] = {1,1,1,1}; - int64_t nb2[4] = {0,0,0,0}; - - const int nargs = 1; - for (int ndims = 1; ndims <= 4; ++ndims) { - - x[0] = get_random_tensor_f32(ctx0, ndims, ne, -1.0f, 1.0f); - - get_random_dims(ne2, 3); - while (ne2[0]*ne2[1]*ne2[2] > ggml_nelements(x[0])) { - get_random_dims(ne2, 3); - } - const int count = ne2[0]*ne2[1]*ne2[2]; - - nb2[0] = sizeof(float); - nb2[1] = nb2[0]*ne2[0]; - nb2[2] = nb2[1]*ne2[1]; - - ggml_set_param(ctx0, x[0]); - - const int max_offset = ggml_nelements(x[0]) - count; - const int offset = irand(max_offset+1) * sizeof(float); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_view_3d(ctx0, x[0], ne2[0], ne2[1], ne2[2], nb2[1], nb2[2], offset)); - - check_gradient("view_3d", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // permute - { - srand(seed); - int64_t ne2[4]; - - const int nargs = 1; - for (int ndims = 1; ndims <= 4; ++ndims) - { - // ggml_permute will set axes of dimensions below n_dims to 1. - // to make ggml_permute work correctly on all axes, - // the input tensor needs maximal n_dim of 4. - for (int i=0; i finite differences should not work - // instead use sum(log(soft_max()*(1-eps)+eps)); use eps to avoid log(0) - struct ggml_tensor * f = ggml_sum(ctx0, - ggml_log(ctx0, - ggml_add1(ctx0, - ggml_scale(ctx0, - ggml_soft_max(ctx0, x[0]), - 1.0f - eps), - ggml_new_f32(ctx0, eps)))); - - check_gradient("softmax", ctx0, x, f, ndims, nargs, 1e-3f, 2e-1f, INFINITY, {}); - // NOTE: softmax forward is computed using f16 table lookup instead of using actual expf, but backward assumes actual expf. - // this may result in different gradients too finite differences. - // when this test reports errors, first try to replace the table lookup with actual expf and test again to see if just that was the cause. - // if only the table lookup causes gradients to differ this is acceptable. - } - } - - // cross_entropy_loss - { - srand(seed); - const int nargs = 1; - - int64_t ne2[4]; - get_random_dims(ne2, 4); - - for (int ndims = 1; ndims <= 4; ++ndims) { - x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); - x[1] = get_random_tensor_f32(ctx0, ndims, ne2, 0.0f, 1.0f); - // the second argument to cross_entropy_loss must sum up to 1 for each row - int nr = ggml_nrows(x[1]); - int nc = ggml_nelements(x[1]) / nr; - for (int ir = 0; ir < nr; ++ir) { - float sum = 0; - for (int ic = 0; ic < nc; ++ic) { - sum += ((float *) x[1]->data)[ic + ir*nc]; - } - for (int ic = 0; ic < nc; ++ic) { - ((float *) x[1]->data)[ic + ir*nc] /= sum; - } - } - ggml_set_param(ctx0, x[0]); - - struct ggml_tensor * f = ggml_cross_entropy_loss(ctx0, x[0], x[1]); - - check_gradient("cross_entropy_loss", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, {}); - } - } - - // rope f32 - { - srand(seed); - const int nargs = 1; - - int64_t ne2[4]; - get_random_dims(ne2, 4); - ne2[0] += ne2[0] % 2; - int n_rot = ne2[0]; - - for (int ndims = 3; ndims <= 4; ++ndims) { - for (int mode = 0; mode < 4; ++mode) { - for (int n_past = 1; n_past < ne2[2]; ++n_past) { - x[0] = get_random_tensor_f32(ctx0, ndims, ne2, -1.0f, 1.0f); - - struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]); - for (int i = 0; i < ne2[2]; ++i) { - ((int32_t *) p->data)[i] = n_past + i; - } - - ggml_set_param(ctx0, x[0]); - - const bool skip_past = (mode & 1); - if (skip_past) { - // we have no past, so this would have to work on uninitialized memory. - // we only test the gradients here; - // skip_past should have no influence on gradient computation. - // so when other modes work, we assume that this does as well. - continue; - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode)); - - GGML_PRINT_DEBUG("rope f32: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); - check_gradient("rope f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {}); - } - } - } - } - - // rope f16 - { - srand(seed); - const int nargs = 1; - - int64_t ne2[4]; - get_random_dims(ne2, 4); - ne2[0] += ne2[0] % 2; - int n_rot = ne2[0]; - - for (int ndims = 3; ndims <= 4; ++ndims) { - for (int mode = 0; mode < 4; ++mode) { - for (int n_past = 1; n_past < ne2[2]; ++n_past) { - x[0] = get_random_tensor_f16(ctx0, ndims, ne2, -1.0f, 1.0f); - - struct ggml_tensor * p = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ne2[2]); - for (int i = 0; i < ne2[2]; ++i) { - ((int32_t *) p->data)[i] = n_past + i; - } - - ggml_set_param(ctx0, x[0]); - - const bool skip_past = (mode & 1); - if (skip_past) { - // we have no past, so this would have to work on uninitialized memory. - // we only test the gradients here; - // skip_past should have no influence on gradient computation. - // so when other modes work, we assume that this does as well. - continue; - } - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_rope(ctx0, x[0], p, n_rot, mode)); - - GGML_PRINT_DEBUG("rope f16: n_past: %d n_rot: %d mode: %d\n", n_past, n_rot, mode); - check_gradient("rope f16", ctx0, x, f, ndims, nargs, 1e-1f, 1e-1f, INFINITY, {}); - } - } - } - } - - // im2col f32 - { - srand(seed); - const int nargs = 1; - const int ndims = 4; - - for (const bool is_2D : {false, true}) { - int64_t ne0[ndims]; - int64_t ne1[ndims]; - get_random_dims(ne0, ndims); - get_random_dims(ne1, ndims); - - // // Ensure that the output is not zero-sized: - ne1[0] += 8; - ne1[1] += 8; - - if (is_2D) { - ne1[2] = ne0[2]; - } else { - ne1[1] = ne0[1]; - ne0[3] = 1; - ne1[3] = 1; - } - - // The order of arguments is swapped because the first tensor is only used for its shape. - x[1] = get_random_tensor_f16(ctx0, ndims, ne0, -1.0f, 1.0f); - x[0] = get_random_tensor_f32(ctx0, ndims, ne1, -1.0f, 1.0f); - - ggml_set_param(ctx0, x[0]); - - const int s0 = 1 + irand(2); - const int s1 = is_2D ? 1 + irand(2) : 0; - const int p0 = 0 + irand(2); - const int p1 = is_2D ? 0 + irand(2) : 0; - const int d0 = 1 + irand(2); - const int d1 = is_2D ? 1 + irand(2) : 0; - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_im2col(ctx0, x[1], x[0], s0, s1, p0, p1, d0, d1, is_2D, GGML_TYPE_F32)); - - GGML_PRINT_DEBUG("im2col f32: is_2D=%s, s0=%d, s1=%d, p0=%d, p1=%d, d0=%d, d1=%d\n", is_2D ? "yes" : "no", s0, s1, p0, p1, d0, d1); - check_gradient("im2col f32", ctx0, x, f, ndims, nargs, 1e-2f, 1e-3f, INFINITY, {}); - } - } - - // pool_2d f32 - { - srand(seed); - const int nargs = 1; - const int ndims = 4; - - for (const enum ggml_op_pool op : {GGML_OP_POOL_AVG, GGML_OP_POOL_MAX}) { - int64_t ne0[ndims]; - get_random_dims(ne0, ndims); - - ne0[0] += 8; - ne0[1] += 8; - - x[0] = get_random_tensor_f32(ctx0, ndims, ne0, -1.0f, 1.0f); - - ggml_set_param(ctx0, x[0]); - - const int k0 = 2 + irand(2); - const int k1 = 2 + irand(2); - const int s0 = 2 + irand(2); - const int s1 = 2 + irand(2); - const int p0 = 0 + irand(2); - const int p1 = 0 + irand(2); - - struct ggml_tensor * f = ggml_sum(ctx0, ggml_pool_2d(ctx0, x[0], op, k0, k1, s0, s1, p0, p1)); - - GGML_PRINT_DEBUG("ggml_pool_2d f32: op=%s k0=%d, k1=%d, s0=%d, s1=%d, p0=%d, p1=%d\n", - op == GGML_OP_POOL_MAX ? "max" : "avg", k0, k1, s0, s1, p0, p1); - std::vector expected_vals; - if (op == GGML_OP_POOL_MAX) { - expected_vals.push_back(0.0); - expected_vals.push_back(1.0); - } - check_gradient("ggml_pool_2d f32", ctx0, x, f, ndims, nargs, 1e-3f, 1e-3f, INFINITY, expected_vals); - } - } - - // flash_attn f32 - // TODO: adapt to ggml_flash_attn_ext() changes - //{ - // srand(seed); - // const int nargs = 3; - - // int64_t ne2[4]; - - // get_random_dims(ne2, 4); - // int64_t D = ne2[0]; - // int64_t N = ne2[1]; - // int64_t M = ne2[2] + N; - // int64_t B = ne2[3]; - - // for (int masked = 0; masked <= 1; ++masked) { - // for (int ndims = 2; ndims <= 4; ++ndims) { - // int max_nrep = (ndims >= 3) ? 2 : 1; - // for (int nrep = 1; nrep < max_nrep; ++nrep) { - // int64_t neq[4] = { D, N, B*nrep, ne[3] }; - // int64_t nek[4] = { D, M, B, ne[3] }; - // int64_t nev[4] = { M, D, B, ne[3] }; - // if (ndims == 2) { - // neq[2] = 1; neq[3] = 1; - // nek[2] = 1; nek[3] = 1; - // nev[2] = 1; nev[3] = 1; - // } else if (ndims == 3) { - // neq[3] = 1; - // nek[3] = 1; - // nev[3] = 1; - // } - // x[0] = get_random_tensor_f32(ctx0, ndims, neq, -0.1250f, 0.1250f); - // x[1] = get_random_tensor_f32(ctx0, ndims, nek, -0.1250f, 0.1250f); - // x[2] = get_random_tensor_f32(ctx0, ndims, nev, -0.1250f, 0.1250f); - // ggml_set_param(ctx0, x[0]); - // ggml_set_param(ctx0, x[1]); - // ggml_set_param(ctx0, x[2]); - - // struct ggml_tensor * f = ggml_sum(ctx0, ggml_flash_attn(ctx0, x[0], x[1], x[2], (masked == 0))); - - // check_gradient("flash_attn f32", ctx0, x, f, ndims, nargs, 1.5e-4f, 1e-3f, INFINITY, {}); - // } - // } - // } - //} - - ggml_free(ctx0); - } - - return 0; -} diff --git a/tests/test-opt.cpp b/tests/test-opt.cpp index 546ca230b..f90c92b4b 100644 --- a/tests/test-opt.cpp +++ b/tests/test-opt.cpp @@ -1,181 +1,892 @@ #include "ggml.h" +#include "ggml-alloc.h" +#include "ggml-backend.h" +#include "ggml-cpu.h" +#include "ggml-opt.h" #include -#include -#include -#include +#include +#include +#include +#include +#include -#define MAX_NARGS 2 - -#if defined(__GNUC__) -#pragma GCC diagnostic ignored "-Wdouble-promotion" -#endif - -// -// logging -// -#define GGML_DEBUG 0 -#if (GGML_DEBUG >= 1) -#define GGML_PRINT_DEBUG(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG(...) -#endif - -#if (GGML_DEBUG >= 5) -#define GGML_PRINT_DEBUG_5(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_5(...) -#endif - -#if (GGML_DEBUG >= 10) -#define GGML_PRINT_DEBUG_10(...) printf(__VA_ARGS__) -#else -#define GGML_PRINT_DEBUG_10(...) -#endif - -#define GGML_PRINT(...) printf(__VA_ARGS__) - - -static float frand(void) { - return (float)rand()/(float)RAND_MAX; +static bool almost_equal(const double a, const double b, const double atol) { + return fabs(a - b) < atol; } -static struct ggml_tensor * get_random_tensor( - struct ggml_context * ctx0, int ndims, int64_t ne[], float fmin, float fmax -) { - struct ggml_tensor * result = ggml_new_tensor(ctx0, GGML_TYPE_F32, ndims, ne); +constexpr int64_t ne_datapoint = 2; +constexpr int64_t ne_label = 1; +constexpr int64_t ndata = 6; - switch (ndims) { - case 1: - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i0] = frand()*(fmax - fmin) + fmin; +struct helper_ctx_data { + std::vector datasets_supervised; + std::vector data_batch; + std::vector labels_batch; + + ggml_opt_dataset_t dataset_unsupervised; + struct ggml_context * ctx_static; + struct ggml_context * ctx_compute; + struct ggml_opt_params opt_params; + ggml_opt_context_t opt_ctx; + struct ggml_tensor * inputs; + struct ggml_tensor * weights; + struct ggml_tensor * outputs; + ggml_backend_buffer_t buf; + ggml_opt_result_t result; + ggml_opt_result_t result2; +}; + +// These default values make it easier to check optimization results vs. expected values. +static ggml_opt_optimizer_params helper_get_test_opt_pars(void * userdata) { + ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata); + result.adamw.alpha = 1.0f; + result.adamw.beta1 = 0.0f; + result.adamw.beta2 = 0.0f; + result.adamw.eps = 0.0f; + return result; +} + +static helper_ctx_data helper_get_ctx_data( + ggml_backend_sched_t backend_sched, + ggml_backend_t backend, + const bool init_opt_ctx = true, + const bool optimizer_defaults = true, + int64_t nbatch_logical = 1, + int64_t nbatch_physical = 1, + enum ggml_opt_loss_type loss_type = GGML_OPT_LOSS_TYPE_SUM) { + std::vector datasets(ndata); + for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) { + ggml_opt_dataset_t dataset = ggml_opt_dataset_init(ne_datapoint, ne_label, ndata, ndata_shard); + + float * data = ggml_get_data_f32(ggml_opt_dataset_data( dataset)); + float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset)); + + for (int64_t idata = 0; idata < ndata; ++idata) { + for (int64_t id = 0; id < ne_datapoint; ++id) { + data[ idata*ne_datapoint + id] = 16*idata + id; } - break; - case 2: - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; - } + for (int64_t il = 0; il < ne_label; ++il) { + labels[idata*ne_label + il] = 16*(16*idata + il); } - break; - case 3: - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; + } + + datasets[ndata_shard-1] = dataset; + } + + ggml_opt_dataset_t dataset_unsupervised = ggml_opt_dataset_init(1, 0, ndata, /*ndata_shard =*/ 1); + + float * data = ggml_get_data_f32(ggml_opt_dataset_data(dataset_unsupervised)); + + for (int64_t idata = 0; idata < ndata; ++idata) { + data[idata] = idata; + } + + struct ggml_context * ctx_static; + struct ggml_context * ctx_compute; + { + struct ggml_init_params params = { + /*.mem_size =*/ (2*ndata + 2)*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ctx_static = ggml_init(params); + } + { + struct ggml_init_params params = { + /*.mem_size =*/ GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + 3*ggml_graph_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ctx_compute = ggml_init(params); + } + + std::vector data_batch(ndata); + std::vector labels_batch(ndata); + for (int64_t ndata_batch = 1; ndata_batch <= ndata; ++ndata_batch) { + data_batch[ndata_batch-1] = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, ndata_batch*ne_datapoint); + labels_batch[ndata_batch-1] = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, ndata_batch*ne_label); + } + + struct ggml_tensor * inputs = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, nbatch_physical); + ggml_set_name(inputs, "inputs"); + + struct ggml_tensor * weights = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1); + ggml_set_name(weights, "weights"); + ggml_set_param(ctx_static, weights); + + struct ggml_tensor * intermediary = ggml_add(ctx_compute, inputs, weights); + + struct ggml_tensor * outputs = ggml_scale(ctx_compute, intermediary, 1.0f); + ggml_set_name(outputs, "outputs"); + + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend); + const float w0 = float(ndata)/2; + ggml_backend_tensor_set(weights, &w0, 0, sizeof(float)); + + GGML_ASSERT(nbatch_logical % nbatch_physical == 0); + const int32_t opt_period = nbatch_logical / nbatch_physical; + + struct ggml_opt_params opt_params = ggml_opt_default_params(backend_sched, ctx_compute, inputs, outputs, loss_type); + opt_params.opt_period = opt_period; + if (!optimizer_defaults) { + opt_params.get_opt_pars = helper_get_test_opt_pars; + } + ggml_opt_context_t opt_ctx = init_opt_ctx ? ggml_opt_init(opt_params) : nullptr; + + ggml_opt_result_t result = ggml_opt_result_init(); + ggml_opt_result_t result2 = ggml_opt_result_init(); + + return {datasets, data_batch, labels_batch, dataset_unsupervised, ctx_static, ctx_compute, opt_params, opt_ctx, inputs, weights, outputs, buf, result, result2}; +} + +static void helper_free_ctx_data(struct helper_ctx_data ctx_data) { + ggml_opt_result_free(ctx_data.result); + ggml_opt_result_free(ctx_data.result2); + ggml_opt_free(ctx_data.opt_ctx); + ggml_backend_buffer_free(ctx_data.buf); + ggml_free(ctx_data.ctx_static); + ggml_free(ctx_data.ctx_compute); + for (ggml_opt_dataset_t dataset : ctx_data.datasets_supervised) { + ggml_opt_dataset_free(dataset); + } + ggml_opt_dataset_free(ctx_data.dataset_unsupervised); +} + +static void helper_after_test( + const char * func, const bool high_level, const std::string options, + const std::string subtest, const bool subtest_ok, int & ntest, int & npass) { + printf(" %s(high_level=%s%s, subtest=%s): ", + func, high_level ? "yes" : "no", options.c_str(), subtest.c_str()); + if (subtest_ok) { + printf("\033[1;32mOK\033[0m\n"); + npass++; + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + ntest++; +} + +static std::pair test_dataset(ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool shuffle) { + int ntest = 0; + int npass = 0; + + struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend); + + for (int64_t ndata_shard = 1; ndata_shard <= ndata; ++ndata_shard) { + ggml_opt_dataset_t dataset = cd.datasets_supervised[ndata_shard-1]; + + if (shuffle) { + ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1); + } + + for (int64_t ndata_batch = 1; ndata_batch <= ndata; ++ndata_batch) { + if (ndata_batch % ndata_shard != 0) { + continue; + } + bool subtest_ok = true; + + struct ggml_tensor * data_batch = cd.data_batch[ndata_batch-1]; + struct ggml_tensor * labels_batch = cd.labels_batch[ndata_batch-1]; + + std::vector data(ggml_nelements( data_batch)); + std::vector labels(ggml_nelements(labels_batch)); + + std::vector idata_shuffled; + const int64_t nbatches = ndata / ndata_batch; + for (int64_t ibatch = 0; ibatch < nbatches; ++ibatch) { + ggml_opt_dataset_get_batch(dataset, data_batch, labels_batch, ibatch); + + ggml_backend_tensor_get( data_batch, data.data(), 0, ggml_nbytes( data_batch)); + ggml_backend_tensor_get(labels_batch, labels.data(), 0, ggml_nbytes(labels_batch)); + + for (int64_t idata_batch = 0; idata_batch < ndata_batch; ++idata_batch) { + const int64_t idata = ibatch*ndata_batch + idata_batch; + const int64_t idata_found = data[idata_batch*ne_datapoint] / 16; + subtest_ok = subtest_ok && (shuffle || idata_found == idata); + idata_shuffled.push_back(idata_found); + + for (int64_t id = 0; id < ne_datapoint; ++id) { + if (data[ idata_batch*ne_datapoint + id] != 16*idata_found + id) { + subtest_ok = false; + } } - } - } - break; - case 4: - for (int i3 = 0; i3 < ne[3]; i3++) { - for (int i2 = 0; i2 < ne[2]; i2++) { - for (int i1 = 0; i1 < ne[1]; i1++) { - for (int i0 = 0; i0 < ne[0]; i0++) { - ((float *)result->data)[i3*ne[2]*ne[1]*ne[0] + i2*ne[1]*ne[0] + i1*ne[0] + i0] = frand()*(fmax - fmin) + fmin; + for (int64_t il = 0; il < ne_label; ++il) { + if (labels[idata_batch*ne_label + il] != 16*(16*idata_found + il)) { + subtest_ok = false; } } } } - break; - default: - assert(false); + + if (!shuffle || ndata % ndata_batch == 0) { + const int ndata_max = (ndata / ndata_batch) * ndata_batch; + + for (int64_t idata = 0; subtest_ok && idata < ndata_max; ++idata) { + int ninstances = 0; + for (int64_t id : idata_shuffled) { + ninstances += id == idata; + } + if (ninstances != 1) { + subtest_ok = false; + } + } + } + + printf(" %s(shuffle=%s, ndata_shard=%" PRId64 ", ndata_batch=%" PRId64 "): ", + __func__, shuffle ? "yes" : "no", ndata_shard, ndata_batch); + if (subtest_ok) { + printf("\033[1;32mOK\033[0m\n"); + npass++; + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + ntest++; + } } + helper_free_ctx_data(cd); + + return std::make_pair(npass, ntest); +} + +static std::pair test_grad(ggml_backend_sched_t backend_sched, ggml_backend_t backend) { + int ntest = 0; + int npass = 0; + + struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, + /*nbatch_logical =*/ 999999, /*nbatch_physical =*/ 1); + + std::vector grad_history(ndata); + for (int64_t idata = 0; idata < ndata; ++idata) { + grad_history[idata] = NAN; + } + + for (int idata = 0; idata < ndata; ++idata) { + const float idataf = idata; + ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); + ggml_opt_forward_backward(cd.opt_ctx, cd.result); + ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, sizeof(float)); + } + + { + bool subtest_ok = true; + for (int idata = 0; idata < ndata; ++idata) { + if (grad_history[idata] != idata + 1) { + subtest_ok = false; + } + } + printf(" %s(): ", __func__); + if (subtest_ok) { + printf("\033[1;32mOK\033[0m\n"); + npass++; + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + ntest++; + } + + helper_free_ctx_data(cd); + + return std::make_pair(npass, ntest); +} + +static void helper_after_test_forward_backward( + const char * func, const bool high_level, const bool shuffle, + const std::string subtest, const bool subtest_ok, int & ntest, int & npass) { + std::string options = ", shuffle="; + options += shuffle ? "yes" : "no"; + helper_after_test(func, high_level, options, subtest, subtest_ok, ntest, npass); +} + +static std::pair test_forward_backward( + ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level, const bool shuffle) { + int ntest = 0; + int npass = 0; + + struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false); + struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx); + + std::vector loss_history(ndata); + for (int64_t idata = 0; idata < ndata; ++idata) { + loss_history[idata] = NAN; + } + + { + int64_t ndata; + ggml_opt_result_ndata(cd.result, &ndata); + double loss; + double loss_unc; + ggml_opt_result_loss(cd.result, &loss, &loss_unc); + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc); + const bool subtest_ok = ndata == 0 && loss == 0.0 && std::isnan(loss_unc) && std::isnan(accuracy) && std::isnan(accuracy_unc); + helper_after_test_forward_backward(__func__, high_level, shuffle, "results_initial", subtest_ok, ntest, npass); + } + + if (high_level) { + ggml_opt_dataset_t dataset = cd.dataset_unsupervised; + if (shuffle) { + ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1); + } + ggml_opt_epoch(cd.opt_ctx, dataset, nullptr, cd.result, 0, nullptr, nullptr); + } else { + for (int idata = 0; idata < ndata; ++idata) { + const float idataf = idata; + ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); + ggml_opt_forward(cd.opt_ctx, cd.result); + ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float)); + } + } + + { + float weights; + ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float)); + const bool subtest_ok = weights == ndata/2; + helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward", subtest_ok, ntest, npass); + } + { + int64_t ndata; + ggml_opt_result_ndata(cd.result, &ndata); + bool subtest_ok = ndata == 6; + + double loss; + double loss_unc; + ggml_opt_result_loss(cd.result, &loss, &loss_unc); + subtest_ok = subtest_ok && loss == 33.0 && almost_equal(loss_unc, sqrt(3.5), 1e-10); + + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc); + subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc); + + helper_after_test_forward_backward(__func__, high_level, shuffle, "results_after_forward", subtest_ok, ntest, npass); + } + + float w0; + ggml_backend_tensor_get(cd.weights, &w0, 0, sizeof(float)); + for (int i = 0; i < 10; ++i) { + ggml_opt_forward_backward(cd.opt_ctx, nullptr); + } + ggml_backend_tensor_set(cd.weights, &w0, 0, sizeof(float)); + + ggml_opt_reset(cd.opt_ctx, /*optimizer =*/ false); + ggml_opt_result_reset(cd.result); + + for (int64_t idata = 0; idata < ndata; ++idata) { + loss_history[idata] = NAN; + } + + if (high_level) { + ggml_opt_dataset_t dataset = cd.dataset_unsupervised; + if (shuffle) { + ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1); + } + ggml_opt_epoch(cd.opt_ctx, dataset, cd.result, nullptr, ndata, nullptr, nullptr); + } else { + for (int idata = 0; idata < ndata; ++idata) { + const float idataf = idata; + ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); + ggml_opt_forward_backward(cd.opt_ctx, cd.result); + ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float)); + } + } + + { + float weights; + ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float)); + const bool subtest_ok = weights == -ndata/2; + helper_after_test_forward_backward(__func__, high_level, shuffle, "weights_after_forward_backward", subtest_ok, ntest, npass); + } + { + int64_t ndata; + ggml_opt_result_ndata(cd.result, &ndata); + bool subtest_ok = ndata == 6; + + double loss; + double loss_unc; + ggml_opt_result_loss(cd.result, &loss, &loss_unc); + subtest_ok = subtest_ok && loss == 18.0 && (shuffle || loss_unc == 0.0); + + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc); + subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc); + + helper_after_test_forward_backward(__func__, high_level, shuffle, "result_after_forward_backward", subtest_ok, ntest, npass); + } + + helper_free_ctx_data(cd); + + return std::make_pair(npass, ntest); +} + +static std::pair test_epoch_vs_fit(ggml_backend_sched_t backend_sched, ggml_backend_t backend) { + int ntest = 0; + int npass = 0; + + float weights_epoch; + float weights_fit; + + { + struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true); + ggml_opt_dataset_t dataset = cd.dataset_unsupervised; + + ggml_opt_dataset_shuffle(cd.opt_ctx, dataset, -1); + ggml_opt_epoch(cd.opt_ctx, dataset, cd.result, nullptr, ndata, nullptr, nullptr); + + ggml_backend_tensor_get(cd.weights, &weights_epoch, 0, ggml_nbytes(cd.weights)); + helper_free_ctx_data(cd); + } + { + struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ false); + ggml_opt_dataset_t dataset = cd.dataset_unsupervised; + + ggml_opt_fit(backend_sched, cd.ctx_compute, cd.inputs, cd.outputs, dataset, + GGML_OPT_LOSS_TYPE_SUM, ggml_opt_get_default_optimizer_params, 1, 1, 0.0f, true); + + ggml_backend_tensor_get(cd.weights, &weights_fit, 0, ggml_nbytes(cd.weights)); + helper_free_ctx_data(cd); + } + + const bool subtest_ok = weights_epoch == weights_fit; + + printf(" %s(): ", __func__); + if (subtest_ok) { + printf("\033[1;32mOK\033[0m\n"); + npass++; + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + ntest++; + + return std::make_pair(npass, ntest); +} + +static void helper_after_test_idata_split( + const char * func, const bool high_level, const int epoch, + const std::string subtest, const bool subtest_ok, int & ntest, int & npass) { + std::string options = ", epoch="; + options += std::to_string(epoch); + helper_after_test(func, high_level, options, subtest, subtest_ok, ntest, npass); +} + +static std::pair test_idata_split(ggml_backend_sched_t backend_sched, ggml_backend_t backend, const bool high_level) { + int ntest = 0; + int npass = 0; + + struct helper_ctx_data cd = helper_get_ctx_data(backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false); + struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx); + const int idata_split = ndata * 2/3; + + std::vector loss_history(ndata); + for (int64_t idata = 0; idata < ndata; ++idata) { + loss_history[idata] = NAN; + } + + for (int epoch = 1; epoch <= 4; ++epoch) { + if (high_level) { + ggml_opt_epoch(cd.opt_ctx, cd.dataset_unsupervised, cd.result, cd.result2, idata_split, nullptr, nullptr); + } else { + int idata = 0; + for (; idata < idata_split; ++idata) { + const float idataf = idata; + ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); + ggml_opt_forward_backward(cd.opt_ctx, cd.result); + ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float)); + } + for (; idata < ndata; ++idata) { + const float idataf = idata; + ggml_backend_tensor_set(cd.inputs, &idataf, 0, ggml_nbytes(cd.inputs)); + ggml_opt_forward(cd.opt_ctx, cd.result2); + ggml_backend_tensor_get(loss, loss_history.data() + idata, 0, sizeof(float)); + } + } + + { + float weights; + ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float)); + const bool subtest_ok = weights == ndata/2 - epoch*idata_split; + helper_after_test_idata_split(__func__, high_level, epoch, "weights", subtest_ok, ntest, npass); + } + { + int64_t ndata_result; + ggml_opt_result_ndata(cd.result, &ndata_result); + bool subtest_ok = ndata_result == idata_split; + + double loss; + double loss_unc; + ggml_opt_result_loss(cd.result, &loss, &loss_unc); + subtest_ok = subtest_ok && loss == 28.0 - epoch*16.0 && loss_unc == 0.0; + + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc); + subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc); + + helper_after_test_idata_split(__func__, high_level, epoch, "results_backward", subtest_ok, ntest, npass); + } + { + int64_t ndata_result; + ggml_opt_result_ndata(cd.result2, &ndata_result); + bool subtest_ok = ndata_result == ndata - idata_split; + + double loss; + double loss_unc; + ggml_opt_result_loss(cd.result2, &loss, &loss_unc); + subtest_ok = subtest_ok && loss == 15.0 - epoch*8 && almost_equal(loss_unc, sqrt(0.5), 1e-10); + + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(cd.result2, &accuracy, &accuracy_unc); + subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc); + + helper_after_test_idata_split(__func__, high_level, epoch, "results_forward", subtest_ok, ntest, npass); + } + + ggml_opt_result_reset(cd.result); + ggml_opt_result_reset(cd.result2); + } + + helper_free_ctx_data(cd); + + return std::make_pair(npass, ntest); +} + +static void helper_after_test_gradient_accumulation( + const char * func, const int nbatch_physical, const enum ggml_opt_loss_type loss_type, const int epoch, + const std::string subtest, const bool subtest_ok, int & ntest, int & npass) { + std::string options = ", nbatch_physical="; + options += std::to_string(nbatch_physical); + options += ", loss_type="; + options += loss_type == GGML_OPT_LOSS_TYPE_MEAN ? "mean" : "sum"; + options += ", epoch="; + options += std::to_string(epoch); + helper_after_test(func, false, options, subtest, subtest_ok, ntest, npass); +} + +static std::pair test_gradient_accumulation( + ggml_backend_sched_t backend_sched, ggml_backend_t backend, const int32_t nbatch_physical, const enum ggml_opt_loss_type loss_type) { + int ntest = 0; + int npass = 0; + + struct helper_ctx_data cd = helper_get_ctx_data( + backend_sched, backend, /*init_opt_ctx =*/ true, /*optimizer_defaults =*/ false, /*nbatch_logical =*/ 6, nbatch_physical, loss_type); + struct ggml_tensor * loss = ggml_opt_loss(cd.opt_ctx); + + std::vector grad_history(ndata); + for (int64_t idata = 0; idata < ndata; ++idata) { + grad_history[idata] = NAN; + } + + for (int epoch = 1; epoch <= 4; ++epoch) { + if (nbatch_physical == 1) { + for (int idata = 0; idata < ndata; ++idata) { + const float idataf = idata; + ggml_backend_tensor_set(cd.inputs, &idataf, 0, 1*sizeof(float)); + ggml_opt_forward_backward(cd.opt_ctx, cd.result); + ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata, 0, 1*sizeof(float)); + } + } else if (nbatch_physical == 2) { + for (int idata = 0; idata < ndata; idata += 2) { + const float idataf[2] = {float(idata + 0), float(idata + 1)}; + ggml_backend_tensor_set(cd.inputs, idataf, 0, 2*sizeof(float)); + ggml_opt_forward_backward(cd.opt_ctx, cd.result); + + grad_history[idata + 0] = 0.0f; + ggml_backend_tensor_get(ggml_opt_grad_acc(cd.opt_ctx, cd.weights), grad_history.data() + idata + 1, 0, 1*sizeof(float)); + } + } else { + GGML_ASSERT(false); + } + + { + GGML_ASSERT(ndata == 6); + constexpr double atol = 1e-6; + bool subtest_ok = true; + if (loss_type == GGML_OPT_LOSS_TYPE_SUM) { + if (nbatch_physical == 1) { + subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[2], 3.0, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[4], 5.0, atol); + } else { + subtest_ok = subtest_ok && almost_equal(grad_history[0], 0.0, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[2], 0.0, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[4], 0.0, atol); + } + subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0, atol); + } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) { + if (nbatch_physical == 1) { + subtest_ok = subtest_ok && almost_equal(grad_history[0], 1.0/ndata, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[2], 3.0/ndata, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[4], 5.0/ndata, atol); + } else { + subtest_ok = subtest_ok && almost_equal(grad_history[0], 0.0/ndata, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[2], 0.0/ndata, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[4], 0.0/ndata, atol); + } + subtest_ok = subtest_ok && almost_equal(grad_history[1], 2.0/ndata, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[3], 4.0/ndata, atol); + subtest_ok = subtest_ok && almost_equal(grad_history[5], 0.0/ndata, atol); + } else { + GGML_ASSERT(false); + } + helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "grads", subtest_ok, ntest, npass); + } + { + float weights; + ggml_backend_tensor_get(cd.weights, &weights, 0, sizeof(float)); + const bool subtest_ok = weights == (ndata/2) - epoch; + helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "weights", subtest_ok, ntest, npass); + } + { + int64_t ndata_result; + ggml_opt_result_ndata(cd.result, &ndata_result); + bool subtest_ok = ndata_result == ndata/nbatch_physical; + + double loss; + ggml_opt_result_loss(cd.result, &loss, /*loss_unc =*/ nullptr); + if (loss_type == GGML_OPT_LOSS_TYPE_SUM) { + subtest_ok = subtest_ok && loss == (39.0 - epoch*6.0); + } else if (loss_type == GGML_OPT_LOSS_TYPE_MEAN) { + subtest_ok = subtest_ok && almost_equal(loss, (39.0 - epoch*6.0) / ndata, 1e-6); + } else { + GGML_ASSERT(false); + } + + double accuracy; + double accuracy_unc; + ggml_opt_result_accuracy(cd.result, &accuracy, &accuracy_unc); + subtest_ok = subtest_ok && std::isnan(accuracy) && std::isnan(accuracy_unc); + + helper_after_test_gradient_accumulation(__func__, nbatch_physical, loss_type, epoch, "results", subtest_ok, ntest, npass); + } + + ggml_opt_result_reset(cd.result); + } + + helper_free_ctx_data(cd); + + return std::make_pair(npass, ntest); +} + +static ggml_opt_optimizer_params helper_get_regression_opt_pars(void * userdata) { + ggml_opt_optimizer_params result = ggml_opt_get_default_optimizer_params(userdata); + result.adamw.alpha = 0.1f; return result; } -int main(void) { - struct ggml_init_params params = { - /* .mem_size = */ 1024*1024*1024, - /* .mem_buffer = */ NULL, - /* .no_alloc = */ false, - }; +static std::pair test_regression(ggml_backend_sched_t backend_sched, ggml_backend_t backend) { + int ntest = 0; + int npass = 0; - struct ggml_context * ctx = ggml_init(params); + // Test for simple regression with f(x) = a*x + b - int64_t ne1[4] = {4, 128, 1, 1}; - int64_t ne2[4] = {4, 256, 1, 1}; - int64_t ne3[4] = {128, 256, 1, 1}; + constexpr int64_t ndata_regression = 201; + constexpr float a_true = 1.2f; + constexpr float b_true = 3.4f; - struct ggml_tensor * a = get_random_tensor(ctx, 2, ne1, -1, +1); - struct ggml_tensor * b = get_random_tensor(ctx, 2, ne2, -1, +1); - ggml_set_param(ctx, a); - ggml_set_param(ctx, b); + std::mt19937 gen(12345); + std::normal_distribution nd{0.0f, 0.1f}; - struct ggml_tensor * c = get_random_tensor(ctx, 2, ne3, -1, +1); + ggml_opt_dataset_t dataset = ggml_opt_dataset_init(1, 1, ndata_regression, ndata_regression); - struct ggml_tensor * ab = ggml_mul_mat(ctx, a, b); - struct ggml_tensor * d = ggml_sub(ctx, c, ab); - struct ggml_tensor * e = ggml_sum(ctx, ggml_sqr(ctx, d)); + float * data = ggml_get_data_f32(ggml_opt_dataset_data( dataset)); + float * labels = ggml_get_data_f32(ggml_opt_dataset_labels(dataset)); - struct ggml_cgraph * ge = ggml_new_graph_custom(ctx, GGML_DEFAULT_GRAPH_SIZE, true); - ggml_build_forward_expand(ge, e); - ggml_graph_reset(ge); + constexpr float x_min = -100.0f; + constexpr float x_max = 100.0f; - ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1); + for (int64_t idata = 0; idata < ndata_regression; ++idata) { + const float x = x_min + (x_max - x_min) * idata/(ndata_regression-1); + const float y = a_true*x + b_true + nd(gen); - const float fe = ggml_get_f32_1d(e, 0); - printf("%s: e = %.4f\n", __func__, fe); + data[idata] = x; + labels[idata] = y; + } - struct ggml_opt_params opt_params = ggml_opt_default_params(GGML_OPT_TYPE_ADAM); + struct ggml_context * ctx_static; + struct ggml_context * ctx_compute; + { + struct ggml_init_params params = { + /*.mem_size =*/ 3*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ctx_static = ggml_init(params); + } + { + struct ggml_init_params params = { + /*.mem_size =*/ GGML_DEFAULT_GRAPH_SIZE*ggml_tensor_overhead() + 3*ggml_graph_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + ctx_compute = ggml_init(params); + } - ggml_opt(ctx, opt_params, e); + // The first dimension is the dimension of the datapoints, the second dimension is the number of datapoints. + struct ggml_tensor * x = ggml_new_tensor_2d(ctx_static, GGML_TYPE_F32, 1, ndata_regression); + ggml_set_name(x, "x"); - ggml_graph_reset(ge); + struct ggml_tensor * a = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1); + ggml_set_name(a, "a"); + ggml_set_param(ctx_static, a); - ggml_graph_compute_with_ctx(ctx, ge, /*n_threads*/ 1); + struct ggml_tensor * b = ggml_new_tensor_1d(ctx_static, GGML_TYPE_F32, 1); + ggml_set_name(b, "b"); + ggml_set_param(ctx_static, b); - const float fe_opt = ggml_get_f32_1d(e, 0); - printf("%s: original e = %.4f\n", __func__, fe); - printf("%s: optimized e = %.4f\n", __func__, fe_opt); + struct ggml_tensor * f = ggml_add(ctx_compute, ggml_mul(ctx_compute, x, a), b); + ggml_set_name(f, "f"); + ggml_set_param(ctx_static, f); - const bool success = (fe_opt <= fe); - assert(success); + ggml_backend_buffer_t buf = ggml_backend_alloc_ctx_tensors(ctx_static, backend); + const float a0 = 1.0f; + const float b0 = 3.0f; + ggml_backend_tensor_set(a, &a0, 0, sizeof(float)); + ggml_backend_tensor_set(b, &b0, 0, sizeof(float)); - ggml_free(ctx); - return success ? 0 : -1; + ggml_opt_fit(backend_sched, ctx_compute, x, f, dataset, GGML_OPT_LOSS_TYPE_MEAN_SQUARED_ERROR, + helper_get_regression_opt_pars, 100, ndata_regression, 0.0f, true); + + { + float a_fit; + ggml_backend_tensor_get(a, &a_fit, 0, sizeof(float)); + float b_fit; + ggml_backend_tensor_get(b, &b_fit, 0, sizeof(float)); + const bool subtest_ok = almost_equal(a_fit, a_true, 1e-2) && almost_equal(b_fit, b_true, 1e-2); + printf(" %s(subtest=weights): ", __func__); + if (subtest_ok) { + printf("\033[1;32mOK\033[0m\n"); + npass++; + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + ntest++; + } + + ggml_backend_buffer_free(buf); + ggml_free(ctx_static); + ggml_opt_dataset_free(dataset); + + return std::make_pair(npass, ntest); } -// int64_t ne1[4] = {4, 128, 1, 1}; -// int64_t ne2[4] = {4, 256, 1, 1};; -// int64_t ne3[4] = {128, 256, 1, 1}; -// main: original e = 25890.9375 -// main: optimized e = 10094.7031 -// int64_t ne1[4] = {8, 128, 1, 1}; -// int64_t ne2[4] = {8, 256, 1, 1};; -// int64_t ne3[4] = {128, 256, 1, 1}; -// main: original e = 39429.5078 -// main: optimized e = 9275.8936 +static std::pair test_backend(ggml_backend_sched_t backend_sched, ggml_backend_t backend) { + int npass = 0; + int ntest = 0; -// int64_t ne1[4] = {16, 128, 1, 1}; -// int64_t ne2[4] = {16, 256, 1, 1};; -// int64_t ne3[4] = {128, 256, 1, 1}; -// main: original e = 68371.1328 -// main: optimized e = 7854.4502 + for (bool shuffle : {false, true}) { + std::pair partial = test_dataset(backend_sched, backend, shuffle); + npass += partial.first; + ntest += partial.second; + } + { + std::pair partial = test_grad(backend_sched, backend); + npass += partial.first; + ntest += partial.second; + } + for (bool high_level : {false, true}){ + for (bool shuffle : {false, true}) { + if (!high_level && shuffle) { + continue; + } + std::pair partial = test_forward_backward(backend_sched, backend, high_level, shuffle); + npass += partial.first; + ntest += partial.second; + } + } + { + std::pair partial = test_epoch_vs_fit(backend_sched, backend); + npass += partial.first; + ntest += partial.second; + } + for (bool high_level : {false, true}){ + std::pair partial = test_idata_split(backend_sched, backend, high_level); + npass += partial.first; + ntest += partial.second; + } + for (int32_t nbatch_physical : {2, 1}) { + for (enum ggml_opt_loss_type loss_type : {GGML_OPT_LOSS_TYPE_SUM, GGML_OPT_LOSS_TYPE_MEAN}) { + std::pair partial = test_gradient_accumulation(backend_sched, backend, nbatch_physical, loss_type); + npass += partial.first; + ntest += partial.second; + } + } + { + std::pair partial = test_regression(backend_sched, backend); + npass += partial.first; + ntest += partial.second; + } -// int64_t ne1[4] = {32, 128, 1, 1}; -// int64_t ne2[4] = {32, 256, 1, 1};; -// int64_t ne3[4] = {128, 256, 1, 1}; -// main: original e = 126061.1953 -// main: optimized e = 5451.0166 + return std::make_pair(npass, ntest); +} -// int64_t ne1[4] = {4, 1024, 1, 1}; -// int64_t ne2[4] = {4, 2048, 1, 1};; -// int64_t ne3[4] = {1024, 2048, 1, 1}; -// main: original e = 1620817.8750 -// main: optimized e = 698387.6875 +int main(void) { + const size_t dev_count = ggml_backend_dev_count(); + printf("Testing %zu devices\n\n", dev_count); + size_t n_ok = 0; -// another run on M1 -// int64_t ne1[4] = {4, 1024, 1, 1}; -// int64_t ne2[4] = {4, 2048, 1, 1};; -// int64_t ne3[4] = {1024, 2048, 1, 1}; -// main: original e = 1629595.6250 -// main: optimized e = 698169.1250 + std::vector devs; + std::vector backends; -// int64_t ne1[4] = {32, 1024, 1, 1}; -// int64_t ne2[4] = {32, 2048, 1, 1};; -// int64_t ne3[4] = {1024, 2048, 1, 1}; -// main: original e = 8146770.5000 -// main: optimized e = 651119.1250 + for (size_t i = 0; i < dev_count; ++i) { + devs.push_back(ggml_backend_dev_get(i)); + + ggml_backend_t backend = ggml_backend_dev_init(devs[i], NULL); + GGML_ASSERT(backend != NULL); + + if (ggml_backend_is_cpu(backend)) { + ggml_backend_cpu_set_n_threads(backend, std::thread::hardware_concurrency() / 2); + } + + backends.push_back(backend); + } + + for (size_t i = 0; i < dev_count; ++i) { + // Put the backend to be tested in front so that it's prioritized: + std::vector backends_modded = {backends[i]}; + backends_modded.insert(backends_modded.end(), backends.begin(), backends.end()); + + ggml_backend_sched_t backend_sched = ggml_backend_sched_new( + backends_modded.data(), nullptr, backends_modded.size(), GGML_DEFAULT_GRAPH_SIZE, false); + + printf("Backend %zu/%zu: %s\n", i + 1, dev_count, ggml_backend_dev_name(devs[i])); + printf(" Device description: %s\n", ggml_backend_dev_description(devs[i])); + size_t free, total; // NOLINT + ggml_backend_dev_memory(devs[i], &free, &total); + printf(" Device memory: %zu MB (%zu MB free)\n", total / 1024 / 1024, free / 1024 / 1024); + printf("\n"); + + std::pair result = test_backend(backend_sched, backends[i]); + + printf(" %d/%d tests passed\n", result.first, result.second); + printf(" Backend %s: ", ggml_backend_name(backends[i])); + if (result.first == result.second) { + printf("\033[1;32mOK\033[0m\n"); + n_ok++; + } else { + printf("\033[1;31mFAIL\033[0m\n"); + } + + printf("\n"); + + ggml_backend_sched_free(backend_sched); + } + + for (ggml_backend_t backend : backends) { + ggml_backend_free(backend); + } + + printf("%zu/%zu backends passed\n", n_ok, dev_count); + if (n_ok != dev_count) { + printf("\033[1;31mFAIL\033[0m\n"); + return 1; + } + printf("\033[1;32mOK\033[0m\n"); + return 0; +} diff --git a/tests/test-quantize-perf.cpp b/tests/test-quantize-perf.cpp index ac0d12714..288288493 100644 --- a/tests/test-quantize-perf.cpp +++ b/tests/test-quantize-perf.cpp @@ -7,7 +7,6 @@ #include #include #include -#include #include #include #include