diff --git a/.gitignore b/.gitignore index 52caf2e98..c26d82a74 100644 --- a/.gitignore +++ b/.gitignore @@ -62,6 +62,16 @@ perf-*.txt examples/jeopardy/results.txt + pyproject.toml poetry.lock poetry.toml + +# Test binaries +tests/test-double-float +tests/test-grad0 +tests/test-opt +tests/test-quantize-fns +tests/test-quantize-perf +tests/test-sampling +tests/test-tokenizer-0 \ No newline at end of file diff --git a/CMakeLists.txt b/CMakeLists.txt index 169332767..abc96814d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -186,16 +186,7 @@ if (LLAMA_BLAS) pkg_check_modules(DepBLAS REQUIRED flexiblas_api) elseif (${LLAMA_BLAS_VENDOR} MATCHES "Intel") # all Intel* libraries share the same include path - pkg_check_modules(DepBLAS mkl-sdl) - if (NOT DepBLAS) - if (BUILD_SHARED_LIBS) - set(LINK_METHOD dynamic) - else() - set(LINK_METHOD static) - endif() - string(REGEX REPLACE ".*_" "" DATA_TYPE_MODEL ${LLAMA_BLAS_VENDOR}) - pkg_check_modules(DepBLAS REQUIRED mkl-${LINK_METHOD}-${DATA_TYPE_MODEL}-iomp) - endif() + pkg_check_modules(DepBLAS REQUIRED mkl-sdl) elseif (${LLAMA_BLAS_VENDOR} MATCHES "NVHPC") # this doesn't provide pkg-config # suggest to assign BLAS_INCLUDE_DIRS on your own diff --git a/Makefile b/Makefile index a7bdb244d..1ea3c4562 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,8 @@ # Define the default target now so that it is always the first target -BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple server libembdinput.so embd-input-test +BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple server embd-input-test + +# Binaries only useful for tests +TEST_TARGETS = tests/test-double-float tests/test-grad0 tests/test-opt tests/test-quantize-fns tests/test-quantize-perf tests/test-sampling tests/test-tokenizer-0 default: $(BUILD_TARGETS) @@ -90,6 +93,28 @@ ifeq ($(UNAME_S),Haiku) CXXFLAGS += -pthread endif +# detect Windows +ifneq ($(findstring _NT,$(UNAME_S)),) + _WIN32 := 1 +endif + +# library name prefix +ifneq ($(_WIN32),1) + LIB_PRE := lib +endif + +# Dynamic Shared Object extension +ifneq ($(_WIN32),1) + DSO_EXT := .so +else + DSO_EXT := .dll +endif + +# Windows Sockets 2 (Winsock) for network-capable apps +ifeq ($(_WIN32),1) + LWINSOCK2 := -lws2_32 +endif + ifdef LLAMA_GPROF CFLAGS += -pg CXXFLAGS += -pg @@ -168,8 +193,12 @@ ifdef LLAMA_CUBLAS CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib OBJS += ggml-cuda.o - NVCC = nvcc NVCCFLAGS = --forward-unknown-to-host-compiler +ifdef LLAMA_CUDA_NVCC + NVCC = $(LLAMA_CUDA_NVCC) +else + NVCC = nvcc +endif #LLAMA_CUDA_NVCC ifdef CUDA_DOCKER_ARCH NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH) else @@ -198,7 +227,9 @@ ifdef LLAMA_CUDA_KQUANTS_ITER else NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2 endif - +ifdef LLAMA_CUDA_CCBIN + NVCCFLAGS += -ccbin $(LLAMA_CUDA_CCBIN) +endif ggml-cuda.o: ggml-cuda.cu ggml-cuda.h $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ endif # LLAMA_CUBLAS @@ -294,7 +325,7 @@ libllama.so: llama.o ggml.o $(OBJS) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) clean: - rm -vf *.o *.so main quantize quantize-stats perplexity embedding benchmark-matmult save-load-state server simple vdot train-text-from-scratch embd-input-test build-info.h + rm -vf *.o *.so *.dll main quantize quantize-stats perplexity embedding benchmark-matmult save-load-state server simple vdot train-text-from-scratch embd-input-test build-info.h $(TEST_TARGETS) # # Examples @@ -325,14 +356,14 @@ save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml. $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS) - $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) + $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(LWINSOCK2) -libembdinput.so: examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS) +$(LIB_PRE)embdinput$(DSO_EXT): examples/embd-input/embd-input.h examples/embd-input/embd-input-lib.cpp build-info.h ggml.o llama.o common.o $(OBJS) $(CXX) --shared $(CXXFLAGS) $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) -embd-input-test: libembdinput.so examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o common.o $(OBJS) - $(CXX) $(CXXFLAGS) $(filter-out %.so,$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput +embd-input-test: $(LIB_PRE)embdinput$(DSO_EXT) examples/embd-input/embd-input-test.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %$(DSO_EXT),$(filter-out %.h,$(filter-out %.hpp,$^))) -o $@ $(LDFLAGS) -L. -lembdinput train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp build-info.h ggml.o llama.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) @@ -349,6 +380,8 @@ build-info.h: $(wildcard .git/index) scripts/build-info.sh # Tests # +tests: $(TEST_TARGETS) + benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) ./$@ @@ -356,6 +389,23 @@ benchmark-matmult: examples/benchmark/benchmark-matmult.cpp build-info.h ggml.o vdot: pocs/vdot/vdot.cpp ggml.o $(OBJS) $(CXX) $(CXXFLAGS) $^ -o $@ $(LDFLAGS) -.PHONY: tests clean -tests: - bash ./tests/run-tests.sh +tests/test-double-float: tests/test-double-float.c build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) + +tests/test-grad0: tests/test-grad0.c build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) + +tests/test-opt: tests/test-opt.c build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) + +tests/test-quantize-fns: tests/test-quantize-fns.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) + +tests/test-quantize-perf: tests/test-quantize-perf.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) + +tests/test-sampling: tests/test-sampling.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) + +tests/test-tokenizer-0: tests/test-tokenizer-0.cpp build-info.h ggml.o llama.o common.o $(OBJS) + $(CXX) $(CXXFLAGS) $(filter-out %.txt,$^) -o $@ $(LDFLAGS) diff --git a/README.md b/README.md index 073b621e9..f45e4bf08 100644 --- a/README.md +++ b/README.md @@ -360,7 +360,7 @@ Building the program with BLAS support may lead to some performance improvements ```bash mkdir build cd build - cmake .. -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=Intel10_lp64 -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx + cmake .. -DLLAMA_BLAS=ON -DLLAMA_BLAS_VENDOR=Intel10_64lp -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx cmake --build . --config Release ``` diff --git a/ci/run.sh b/ci/run.sh index c823bc467..87166ba1a 100644 --- a/ci/run.sh +++ b/ci/run.sh @@ -243,7 +243,7 @@ function gg_sum_open_llama_3b_v2 { if [ -z $GG_BUILD_LOW_PERF ]; then rm -rf ${SRC}/models-mnt - mnt_models=$(realpath ${MNT}/models) + mnt_models=${MNT}/models mkdir -p ${mnt_models} ln -sfn ${mnt_models} ${SRC}/models-mnt diff --git a/examples/Miku.sh b/examples/Miku.sh index c44d9ae74..b9174b4e6 100755 --- a/examples/Miku.sh +++ b/examples/Miku.sh @@ -2,21 +2,21 @@ set -e AI_NAME="${AI_NAME:-Miku}" -MODEL="${MODEL:-./models/gpt4all-7B/gpt4all-lora-unfiltered-quantized.bin}" +MODEL="${MODEL:-./models/llama-2-7b-chat.ggmlv3.q4_K_M.bin}" USER_NAME="${USER_NAME:-Anon}" # Uncomment and adjust to the number of CPU cores you want to use. #N_THREAD="${N_THREAD:-4}" +CTX_SIZE="${CTX_SIZE:-4096}" N_PREDICTS="${N_PREDICTS:-4096}" GEN_OPTIONS=(--batch_size 1024 ---ctx_size 2048 +--ctx_size "$CTX_SIZE" --keep -1 --repeat_last_n 256 --repeat_penalty 1.17647 ---temp 0.7 ---top_k 40 ---top_p 0.5) +--temp 0.6 +--mirostat 2) if [ -n "$N_THREAD" ]; then GEN_OPTIONS+=(--threads "$N_THREAD") @@ -24,16 +24,17 @@ fi ./main "${GEN_OPTIONS[@]}" \ --model "$MODEL" \ + --in-prefix " " \ + --in-suffix "${AI_NAME}:" \ --n_predict "$N_PREDICTS" \ --color --interactive \ --reverse-prompt "${USER_NAME}:" \ - --prompt " -This is a transcript of a 1000 page, never ending conversation between ${USER_NAME} and the cute and helpful AI assistant ${AI_NAME}. ${AI_NAME} is a girl who is an AI running on the user's computer. + --prompt "This is a transcript of a 1000 page, never ending conversation between ${USER_NAME} and the cute and helpful AI assistant ${AI_NAME}. ${AI_NAME} is a girl who is an AI running on the user's computer. ${AI_NAME} can think for herself without the user seeing her thoughts by adding a /think prefix to her output. She uses this to reason about the world and to think about what she should say next. ${AI_NAME} is always coherent and makes sense, but if she isn't sure if what she is saying is correct, she will ask the user for help. ${AI_NAME} is a very helpful AI and will help the user with anything they need. She is also very friendly and will try to make the user feel better if they are sad. ${AI_NAME} is also very curious and will ask the user a lot of questions about themselves and their life. She will also try to make the user like her. -The conversation is only between ${USER_NAME} and ${AI_NAME} +The conversation is only between ${USER_NAME} and ${AI_NAME}. The conversation is only through text, so ${AI_NAME} can't see ${USER_NAME}'s face or hear his voice. ${AI_NAME} can only communicate through text, so she can't send images or videos. diff --git a/examples/common.cpp b/examples/common.cpp index fd6dbc0e3..476d56594 100644 --- a/examples/common.cpp +++ b/examples/common.cpp @@ -586,7 +586,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param lparams.n_batch = params.n_batch; lparams.n_gpu_layers = params.n_gpu_layers; lparams.main_gpu = params.main_gpu; - memcpy(lparams.tensor_split, params.tensor_split, LLAMA_MAX_DEVICES*sizeof(float)); + lparams.tensor_split = params.tensor_split; lparams.low_vram = params.low_vram; lparams.seed = params.seed; lparams.f16_kv = params.memory_f16; diff --git a/examples/server/CMakeLists.txt b/examples/server/CMakeLists.txt index 812a24b09..3782f9b80 100644 --- a/examples/server/CMakeLists.txt +++ b/examples/server/CMakeLists.txt @@ -7,6 +7,9 @@ target_compile_definitions(${TARGET} PRIVATE SERVER_VERBOSE=$ ) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) +if (WIN32) + TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32) +endif() target_compile_features(${TARGET} PRIVATE cxx_std_11) if(TARGET BUILD_INFO) add_dependencies(${TARGET} BUILD_INFO) diff --git a/flake.nix b/flake.nix index 5657e8258..7f148f144 100644 --- a/flake.nix +++ b/flake.nix @@ -6,7 +6,7 @@ outputs = { self, nixpkgs, flake-utils }: flake-utils.lib.eachDefaultSystem (system: let - inherit (pkgs.stdenv) isAarch32 isAarch64 isx86_32 isx86_64 isDarwin; + inherit (pkgs.stdenv) isAarch32 isAarch64 isDarwin; osSpecific = with pkgs; [ openmpi ] ++ ( if isAarch64 && isDarwin then @@ -22,14 +22,13 @@ CoreGraphics CoreVideo ] - else if isx86_32 || isx86_64 then - with pkgs; [ mkl ] else with pkgs; [ openblas ] ); pkgs = import nixpkgs { inherit system; }; + nativeBuildInputs = with pkgs; [ cmake pkgconfig ]; llama-python = - pkgs.python310.withPackages (ps: with ps; [ numpy sentencepiece ]); + pkgs.python3.withPackages (ps: with ps; [ numpy sentencepiece ]); in { packages.default = pkgs.stdenv.mkDerivation { name = "llama.cpp"; @@ -37,33 +36,21 @@ postPatch = '' substituteInPlace ./ggml-metal.m \ --replace '[bundle pathForResource:@"ggml-metal" ofType:@"metal"];' "@\"$out/bin/ggml-metal.metal\";" + substituteInPlace ./*.py --replace '/usr/bin/env python' '${llama-python}/bin/python' ''; - nativeBuildInputs = with pkgs; [ cmake pkgconfig ]; + nativeBuildInputs = nativeBuildInputs; buildInputs = osSpecific; cmakeFlags = [ "-DLLAMA_BUILD_SERVER=ON" "-DLLAMA_MPI=ON" "-DBUILD_SHARED_LIBS=ON" "-DCMAKE_SKIP_BUILD_RPATH=ON" ] ++ (if isAarch64 && isDarwin then [ "-DCMAKE_C_FLAGS=-D__ARM_FEATURE_DOTPROD=1" "-DLLAMA_METAL=ON" - ] else if isx86_32 || isx86_64 then [ - "-DLLAMA_BLAS=ON" - "-DLLAMA_BLAS_VENDOR=Intel10_lp64" ] else [ "-DLLAMA_BLAS=ON" "-DLLAMA_BLAS_VENDOR=OpenBLAS" ]); - installPhase = '' - runHook preInstall - - install -D bin/* -t $out/bin - install -Dm644 lib*.so -t $out/lib + postInstall = '' mv $out/bin/main $out/bin/llama mv $out/bin/server $out/bin/llama-server - - echo "#!${llama-python}/bin/python" > $out/bin/convert.py - cat ${./convert.py} >> $out/bin/convert.py - chmod +x $out/bin/convert.py - - runHook postInstall ''; meta.mainProgram = "llama"; }; @@ -81,7 +68,7 @@ }; apps.default = self.apps.${system}.llama; devShells.default = pkgs.mkShell { - packages = with pkgs; [ cmake llama-python ] ++ osSpecific; + packages = nativeBuildInputs ++ osSpecific; }; }); } diff --git a/ggml-cuda.cu b/ggml-cuda.cu index d3054a7fa..6537897b9 100644 --- a/ggml-cuda.cu +++ b/ggml-cuda.cu @@ -2512,6 +2512,9 @@ void ggml_init_cublas() { } void ggml_cuda_set_tensor_split(const float * tensor_split) { + if (tensor_split == nullptr) { + return; + } bool all_zero = true; for (int i = 0; i < g_device_count; ++i) { if (tensor_split[i] != 0.0f) { diff --git a/ggml-metal.m b/ggml-metal.m index ee205bcdf..135bda9fc 100644 --- a/ggml-metal.m +++ b/ggml-metal.m @@ -676,8 +676,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 4; - nth1 = 16; + nth0 = 2; + nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_K_f32]; } break; case GGML_TYPE_Q3_K: @@ -694,8 +694,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 4; - nth1 = 16; + nth0 = 2; + nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_K_f32]; } break; case GGML_TYPE_Q5_K: @@ -703,8 +703,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 4; - nth1 = 16; + nth0 = 2; + nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_K_f32]; } break; case GGML_TYPE_Q6_K: @@ -712,8 +712,8 @@ void ggml_metal_graph_compute( GGML_ASSERT(ne02 == 1); GGML_ASSERT(ne12 == 1); - nth0 = 4; - nth1 = 16; + nth0 = 2; + nth1 = 32; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_K_f32]; } break; default: @@ -739,14 +739,17 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; - if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) { + if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1 || + src0t == GGML_TYPE_Q2_K || src0t == GGML_TYPE_Q4_K) { [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } - else if (src0t == GGML_TYPE_Q2_K || - src0t == GGML_TYPE_Q3_K || - src0t == GGML_TYPE_Q4_K || - src0t == GGML_TYPE_Q5_K || - src0t == GGML_TYPE_Q6_K) { + else if (src0t == GGML_TYPE_Q5_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 3) / 4, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q6_K) { + [encoder dispatchThreadgroups:MTLSizeMake((ne01+1)/2, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; + } + else if (src0t == GGML_TYPE_Q3_K) { [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; } else { @@ -792,7 +795,7 @@ void ggml_metal_graph_compute( const float eps = 1e-6f; - const int nth = 256; + const int nth = 512; [encoder setComputePipelineState:ctx->pipeline_rms_norm]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; @@ -800,7 +803,7 @@ void ggml_metal_graph_compute( [encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; [encoder setBytes:&eps length:sizeof( float) atIndex:4]; - [encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0]; + [encoder setThreadgroupMemoryLength:nth/32*sizeof(float) atIndex:0]; const int64_t nrows = ggml_nrows(src0); diff --git a/ggml-metal.metal b/ggml-metal.metal index 9f9a4fbd7..97f5c10ba 100644 --- a/ggml-metal.metal +++ b/ggml-metal.metal @@ -331,26 +331,33 @@ kernel void kernel_rms_norm( threadgroup float * sum [[threadgroup(0)]], uint tgpig[[threadgroup_position_in_grid]], uint tpitg[[thread_position_in_threadgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]], + uint tiisg[[thread_index_in_simdgroup]], uint ntg[[threads_per_threadgroup]]) { - device const float * x = (device const float *) ((device const char *) src0 + tgpig*nb01); + device const float4 * x = (device const float4 *) ((device const char *) src0 + tgpig*nb01); + device const float * x_scalar = (device const float *) x; + float4 sumf=0; + float all_sum=0; // parallel sum - sum[tpitg] = 0.0f; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { - sum[tpitg] += x[i00] * x[i00]; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { + sumf += x[i00] * x[i00]; + } + all_sum = sumf[0] + sumf[1] + sumf[2] + sumf[3]; + all_sum = simd_sum(all_sum); + if (tiisg == 0) { + sum[sgitg] = all_sum; } - // reduce threadgroup_barrier(mem_flags::mem_threadgroup); - for (uint i = ntg/2; i > 0; i /= 2) { - if (tpitg < i) { - sum[tpitg] += sum[tpitg + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); + // broadcast, simd group number is ntg / 32 + for (int i = ntg / 32 / 2; i > 0; i /= 2) { + if (tpitg < i) { + sum[tpitg] += sum[tpitg + i]; + } } - - // broadcast if (tpitg == 0) { + for (int i = 4 * (ne00 / 4); i < ne00; i++) {sum[0] += x_scalar[i];} sum[0] /= ne00; } @@ -359,16 +366,101 @@ kernel void kernel_rms_norm( const float mean = sum[0]; const float scale = 1.0f/sqrt(mean + eps); - device float * y = dst + tgpig*ne00; - for (int i00 = tpitg; i00 < ne00; i00 += ntg) { + device float4 * y = (device float4 *) (dst + tgpig*ne00); + device float * y_scalar = (device float *) y; + for (int i00 = tpitg; i00 < ne00/4; i00 += ntg) { y[i00] = x[i00] * scale; } + if (tpitg == 0) { + for (int i00 = 4 * (ne00 / 4); i00 < ne00; i00++) {y_scalar[i00] = x_scalar[i00] * scale;} + } +} + +// function for calculate inner product between a q4_0 block and 32 floats (yl), sumy is SUM(yl[i]) +float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl) { + float d = qb_curr->d; + float4 acc = 0.f; + device uint16_t * qs = ((device uint16_t *)qb_curr + 1); + for (int i = 0; i < 16; i+=2) { + acc[0] += yl[i] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); + acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + } + return d * (sumy * -8.f + acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f); +} + +// function for calculate inner product between a q4_1 block and 32 floats (yl), sumy is SUM(yl[i]) +float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl) { + float d = qb_curr->d; + float m = qb_curr->m; + float4 acc = 0.f; + device uint16_t * qs = ((device uint16_t *)qb_curr + 2); + for (int i = 0; i < 16; i+=2) { + acc[0] += yl[i] * (qs[i / 2] & 0x000F); + acc[1] += yl[i + 16] * (qs[i / 2] & 0x00F0); + acc[2] += yl[i + 1] * (qs[i / 2] & 0x0F00); + acc[3] += yl[i + 17] * (qs[i / 2] & 0xF000); + } + return d * (acc[0] + acc[1]/16.f + acc[2]/256.f + acc[3]/4096.f) + sumy * m; } // putting them in the kernel cause a significant performance penalty #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group #define N_SIMDWIDTH 32 // assuming SIMD group size is 32 +template +void mul_vec_q_n_f32(device const void * src0, device const float * src1, device float * dst, + int64_t ne00, int64_t ne10, int64_t ne0, int64_t ne01, + uint2 tgpig, uint tiisg, uint sgitg) { + const int nb = ne00/QK4_0; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + device const block_q_type * x = (device const block_q_type *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; + device const float * y = (device const float *) src1 + r1*ne10; + float4 y_curr[8]; // src1 vector cache + float sumf[N_DST]={0.f}, all_sum; + thread float * yl=(thread float *)y_curr; + + // each thread in a SIMD group deals with 1 block. + for (int column = 0; column < nb / N_SIMDWIDTH; column++) { + float sumy = 0; + for (int i = 0; i < QK4_0 / 4; i++) { + y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0)) + i); + sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; + } + + for (int row = 0; row < N_DST; row++) { + sumf[row] += block_q_n_dot_y(x+(tiisg + row * nb + column * N_SIMDWIDTH), sumy, yl); + } + } + + // from now loads two rows every time and 16 blocks per row + int ir = tiisg / (N_SIMDWIDTH / 2); + int ib = tiisg % (N_SIMDWIDTH / 2); + for (int ind = 0; ind < (nb % N_SIMDWIDTH + N_SIMDWIDTH / 2 - 1)/(N_SIMDWIDTH / 2); ind++) { + int nb_start = (nb / N_SIMDWIDTH) * N_SIMDWIDTH + ind * (N_SIMDWIDTH / 2); //where the left blocks start + float sumy = 0; + for (int i = 0; i < QK4_0 / 4; i++) { + y_curr[i] = *((device float4 *)(y + (nb_start + ib) * QK4_0) + i); + sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; + } + + for (int row = 0; row < N_DST; row+=2) { + if (nb_start + ib < nb) { + sumf[row + ir] += block_q_n_dot_y(x + (nb_start + ib + (row + ir) * nb), sumy, yl); + } + } + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { + dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; + } + } +} + kernel void kernel_mul_mat_q4_0_f32( device const void * src0, device const float * src1, @@ -380,80 +472,7 @@ kernel void kernel_mul_mat_q4_0_f32( uint2 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int nb = ne00/QK4_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; - device const float * y = (device const float *) src1 + r1*ne10; - block_q4_0 qb_curr, qb_next; - float4 y_curr[8]; // src1 vector cache - float sumf[N_DST]={0.f}, all_sum; - thread float * yl=(thread float *)y_curr; - - // bootstrap - qb_curr = x[tiisg]; - // each thread in a SIMD group deals with 1 block. - for (int column = 0; column < nb / N_SIMDWIDTH; column++) { - - float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; - } - sumy *= (-8.f); - - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - float d = qb_curr.d; - float acc = sumy; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - } - sumf[row] += d * acc; - qb_curr = qb_next; - } - } - - if (nb % N_SIMDWIDTH == 0) { - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } - } - } else { - - float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; - } - sumy *= (-8.f); - - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - float d = qb_curr.d; - float acc = sumy; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - } - if (tiisg < nb % N_SIMDWIDTH) { - sumf[row] += d * acc; - } - qb_curr = qb_next; - - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } - } - } + mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_q4_1_f32( @@ -467,80 +486,7 @@ kernel void kernel_mul_mat_q4_1_f32( uint2 tgpig[[threadgroup_position_in_grid]], uint tiisg[[thread_index_in_simdgroup]], uint sgitg[[simdgroup_index_in_threadgroup]]) { - const int nb = ne00/QK4_0; - const int r0 = tgpig.x; - const int r1 = tgpig.y; - device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb; - device const float * y = (device const float *) src1 + r1*ne10; - block_q4_1 qb_curr, qb_next; - float4 y_curr[8]; // src1 vector cache - float sumf[N_DST]={0.f}, all_sum; - thread float * yl=(thread float *)y_curr; - - // bootstrap - qb_curr = x[tiisg]; - // each thread in a SIMD group deals with 1 block. - for (int column = 0; column < nb / N_SIMDWIDTH; column++) { - - float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i)); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; - } - - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - const float d = qb_curr.d; - const float m = qb_curr.m; - float acc = 0.f; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - } - sumf[row] += d * acc + m * sumy; - qb_curr = qb_next; - } - } - - if (nb % N_SIMDWIDTH == 0) { - for (int row = 0; row < N_DST; ++row) { - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } - } - } else { - - float sumy = 0; - for (int i = 0; i < QK4_0 / 4; i++) { - y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i)); - sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3]; - } - - for (int row = 0; row < N_DST; row++) { - // prefetch next x block - qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH]; - - // calculate - const float d = qb_curr.d; - const float m = qb_curr.m; - float acc = 0.f; - for (int i = 0; i < 16; i++) { - acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4); - } - if (tiisg < nb % N_SIMDWIDTH) { - sumf[row] += d * acc + m * sumy; - } - qb_curr = qb_next; - - all_sum = simd_sum(sumf[row]); - if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) { - dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum; - } - } - } + mul_vec_q_n_f32(src0,src1,dst,ne00,ne10,ne0,ne01,tgpig,tiisg,sgitg); } kernel void kernel_mul_mat_f16_f32( @@ -1263,108 +1209,133 @@ kernel void kernel_mul_mat_q2_K_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], + constant int64_t & ne01[[buffer(4)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + device const block_q2_K * x = (device const block_q2_K *) src0 + ib_row; + device const float * y = (device const float *) src1 + r1*ne10; + float yl[32]; + float sumf[N_DST]={0.f}, all_sum; - device const block_q2_K * x = (device const block_q2_K *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; - - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; - - float sumf = 0; + const int step = sizeof(block_q2_K) * nb; #if QK_K == 256 - const int tid = tpitg.y; // 0...16 - const int il = tid/4; // 0...3 - const int ir = tid%4; // 0...3 - const int ip = il/2; // 0 or 1 - const int shift1 = 4*(il%2);// 0 or 4 - const int shift2 = shift1+2;// 2 or 6 - const int n = 8; - const int is = 4*il + (n*ir)/16; + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int im = it/4; // 0 or 1 + const int ir = it%4; // 0...3 + const int is = (8*ir)/16;// 0 or 1 - const int y_offset = 64*il + n*ir; - const int q_offset = 32*ip + n*ir; + device const float * y4 = y + ix * QK_K + 128 * im + 8 * ir; - for (int i = tpitg.x; i < nb; i += tptg.x) { + for (int ib = ix; ib < nb; ib += 4) { - device const uint8_t * q = x[i].qs + q_offset; - device const uint8_t * scales = x[i].scales + is; - - uint8_t d1 = scales[0] & 0xF; - uint8_t d2 = scales[2] & 0xF; - uint8_t m1 = scales[0] >> 4; - uint8_t m2 = scales[2] >> 4; - - device const float * y = yy + i*QK_K + y_offset; - - float2 s = {0.f, 0.f}; - float smin = 0; - for (int l = 0; l < n; ++l) { - s[0] += y[l+ 0] * ((q[l] >> shift1) & 3); - s[1] += y[l+32] * ((q[l] >> shift2) & 3); - smin += y[l+ 0] * m1 + y[l+32] * m2; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+32]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+64]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+96]; sumy[3] += yl[i+24]; } - const float dall = (float)x[i].d; - const float dmin = (float)x[i].dmin; + device const uint8_t * sc = (device const uint8_t *)x[ib].scales + 8*im + is; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; + device const half * dh = &x[ib].d; - sumf += dall * (s[0] * d1 + s[1] * d2) - dmin * smin; + for (int row = 0; row < N_DST; row++) { + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + float dall = dh[0]; + float dmin = dh[1] * 1.f/16.f; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[2] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[4] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[6] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] & 0xF0) + sumy[1] * (sc[2] & 0xF0) + sumy[2] * (sc[4] & 0xF0) + sumy[3] * (sc[6] & 0xF0)); + + qs += step/2; + sc += step; + dh += step/2; + } + + y4 += 4 * QK_K; } #else - const int il = 4 * tpitg.x; + const int ix = tiisg/2; // 0...15 + const int it = tiisg%2; // 0...1 - uint32_t aux[2]; - thread const uint8_t * d = (thread const uint8_t *)aux; - thread const uint8_t * m = (thread const uint8_t *)aux + 4; + device const float * y4 = y + ix * QK_K + 8 * it; - for (int i = tpitg.y; i < nb; i += tptg.y) { + for (int ib = ix; ib < nb; ib += 16) { - device const uint8_t * q = x[i].qs + il; - device const float * y = yy + i*QK_K + il; - - const float dall = (float)x[i].d; - const float dmin = (float)x[i].dmin; - - device const uint32_t * a = (device const uint32_t *)x[i].scales; - aux[0] = a[0] & 0x0f0f0f0f; - aux[1] = (a[0] >> 4) & 0x0f0f0f0f; - - for (int l = 0; l < 4; ++l) { - sumf += y[l+ 0] * (dall * d[0] * ((q[l] >> 0) & 3) - dmin * m[0]) - + y[l+16] * (dall * d[1] * ((q[l] >> 2) & 3) - dmin * m[1]) - + y[l+32] * (dall * d[2] * ((q[l] >> 4) & 3) - dmin * m[2]) - + y[l+48] * (dall * d[3] * ((q[l] >> 6) & 3) - dmin * m[3]); + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+ 0] = y4[i+ 0]; sumy[0] += yl[i+ 0]; + yl[i+ 8] = y4[i+16]; sumy[1] += yl[i+ 8]; + yl[i+16] = y4[i+32]; sumy[2] += yl[i+16]; + yl[i+24] = y4[i+48]; sumy[3] += yl[i+24]; } + + device const uint8_t * sc = (device const uint8_t *)x[ib].scales; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+ 0] * (qs[i/2] & 0x0003); + acc2[0] += yl[i+ 1] * (qs[i/2] & 0x0300); + acc1[1] += yl[i+ 8] * (qs[i/2] & 0x000c); + acc2[1] += yl[i+ 9] * (qs[i/2] & 0x0c00); + acc1[2] += yl[i+16] * (qs[i/2] & 0x0030); + acc2[2] += yl[i+17] * (qs[i/2] & 0x3000); + acc1[3] += yl[i+24] * (qs[i/2] & 0x00c0); + acc2[3] += yl[i+25] * (qs[i/2] & 0xc000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc2[0]) * (sc[0] & 0xF) * 1.f/ 1.f + + (acc1[1] + 1.f/256.f * acc2[1]) * (sc[1] & 0xF) * 1.f/ 4.f + + (acc1[2] + 1.f/256.f * acc2[2]) * (sc[2] & 0xF) * 1.f/16.f + + (acc1[3] + 1.f/256.f * acc2[3]) * (sc[3] & 0xF) * 1.f/64.f) - + dmin * (sumy[0] * (sc[0] >> 4) + sumy[1] * (sc[1] >> 4) + sumy[2] * (sc[2] >> 4) + sumy[3] * (sc[3] >> 4)); + + qs += step/2; + sc += step; + dh += step/2; + } + + y4 += 16 * QK_K; } #endif - sum[ith] = sumf; - - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = all_sum; + } } } @@ -1506,6 +1477,7 @@ kernel void kernel_mul_mat_q3_K_f32( } +#if QK_K == 256 kernel void kernel_mul_mat_q4_K_f32( device const void * src0, device const float * src1, @@ -1513,131 +1485,180 @@ kernel void kernel_mul_mat_q4_K_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], + constant int64_t & ne01[[buffer(4)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { - - const int nb = ne00/QK_K; - - const int64_t r0 = tgpig.x; - const int64_t r1 = tgpig.y; - - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; - - device const block_q4_K * x = (device const block_q4_K *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; - - float sumf = 0; - -#if QK_K == 256 + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int tid = tpitg.y; // 0...16 - const int il = tid/4; // 0...3 - const int ir = tid - 4*il;// 0...3 - const int n = 4; + const int ix = tiisg/8; // 0...3 + const int it = tiisg%8; // 0...7 + const int im = it/4; // 0 or 1 + const int ir = it%4; // 0...3 - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row; + device const float * y = (device const float *) src1 + r1*ne10; + float yl[16]; + float yh[16]; + float sumf[N_DST]={0.f}, all_sum; - const int l0 = n*(2*ir + in); - const int q_offset = 32*im + l0; - const int y_offset = 64*im + l0; + const int step = sizeof(block_q4_K) * nb / 2; - uchar2 sc1, sc2, sc3, sc4; + device const float * y4 = y + ix * QK_K + 64 * im + 8 * ir; - for (int i = tpitg.x; i < nb; i += tptg.x) { + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - device const uint8_t * q1 = (x + i)->qs + q_offset; - device const uint8_t * q2 = q1 + 64; - device const float * y1 = yy + i*QK_K + y_offset; - device const float * y2 = y1 + 128; - - const float dall = (float)((x + i)->d); - const float dmin = (float)((x + i)->dmin); - - device const uint16_t * a = (device const uint16_t *)(x + i)->scales; - sc1 = as_type((uint16_t)(a[im+0] & kmask1)); - sc2 = as_type((uint16_t)(a[im+2] & kmask1)); - sc3 = as_type((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2))); - sc4 = as_type((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2))); - - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < n; ++l) { - - s[0] += y1[l] * (q1[l] & 0xF); s[1] += y1[l+32] * (q1[l] >> 4); - s[2] += y2[l] * (q2[l] & 0xF); s[3] += y2[l+32] * (q2[l] >> 4); - smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1]; + for (int ib = ix; ib < nb; ib += 4) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+ 0]; sumy[0] += yl[i+0]; + yl[i+8] = y4[i+ 32]; sumy[1] += yl[i+8]; + yh[i+0] = y4[i+128]; sumy[2] += yh[i+0]; + yh[i+8] = y4[i+160]; sumy[3] += yh[i+8]; } - sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; + device const uint16_t * sc = (device const uint16_t *)x[ib].scales + im; + device const uint16_t * q1 = (device const uint16_t *)x[ib].qs + 16 * im + 4 * ir; + device const half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + device const uint16_t * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (q1[i/2] & 0x000F); + acc1[1] += yl[i+1] * (q1[i/2] & 0x0F00); + acc1[2] += yl[i+8] * (q1[i/2] & 0x00F0); + acc1[3] += yl[i+9] * (q1[i/2] & 0xF000); + acc2[0] += yh[i+0] * (q2[i/2] & 0x000F); + acc2[1] += yh[i+1] * (q2[i/2] & 0x0F00); + acc2[2] += yh[i+8] * (q2[i/2] & 0x00F0); + acc2[3] += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc8[0] + + (acc1[2] + 1.f/256.f * acc1[3]) * sc8[1] * 1.f/16.f + + (acc2[0] + 1.f/256.f * acc2[1]) * sc8[4] + + (acc2[2] + 1.f/256.f * acc2[3]) * sc8[5] * 1.f/16.f) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + sc += step; + dh += step; + } + + y4 += 4 * QK_K; } -#else - uint16_t aux16[2]; - thread const uint8_t * scales = (thread const uint8_t *)aux16; - const int il = 4*tpitg.x; - - for (int i = tpitg.y; i < nb; i += tptg.y) { - - device const uint8_t * q = x[i].qs + il; - device const float * y = yy + i * QK_K + il; - - const float d = (float)x[i].d[0]; - const float m = (float)x[i].d[1]; - - device const uint16_t * a = (device const uint16_t *)x[i].scales; - aux16[0] = a[0] & 0x0f0f; - aux16[1] = (a[0] >> 4) & 0x0f0f; - - for (int l = 0; l < 4; ++l) { - sumf += d * scales[0] * (y[l+ 0] * (q[l] & 0xF) + y[l+16] * (q[l+16] & 0xF)) - m * scales[2] * (y[l+ 0] + y[l+16]) - + d * scales[1] * (y[l+32] * (q[l] >> 4) + y[l+48] * (q[l+16] >> 4)) - m * scales[3] * (y[l+32] + y[l+48]); + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = all_sum; } } -#endif - - sum[ith] = sumf; - - // - // Accumulate the sum from all threads in the threadgroup - // This version is slightly faster than the commented out one below, - // which I copy-pasted from ggerganov's q4_0 dot product for metal. - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; - } - - //// accumulate the sum from all threads in the threadgroup - //threadgroup_barrier(mem_flags::mem_threadgroup); - //for (uint i = nth/2; i > 0; i /= 2) { - // if (ith < i) { - // sum[ith] += sum[ith + i]; - // } - // threadgroup_barrier(mem_flags::mem_threadgroup); - //} - - //if (ith == 0) { - // dst[r1*ne0 + r0] = sum[0]; - //} } +#else +kernel void kernel_mul_mat_q4_K_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne10, + constant int64_t & ne0, + constant int64_t & ne01[[buffer(4)]], + uint2 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + + const int ix = tiisg/4; // 0...7 + const int it = tiisg%4; // 0...3 + + const int nb = ne00/QK_K; + const int r0 = tgpig.x; + const int r1 = tgpig.y; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * N_DST; + const int ib_row = first_row * nb; + device const block_q4_K * x = (device const block_q4_K *) src0 + ib_row; + device const float * y = (device const float *) src1 + r1*ne10; + float yl[8]; + float yh[8]; + float sumf[N_DST]={0.f}, all_sum; + + const int step = sizeof(block_q4_K) * nb / 2; + + device const float * y4 = y + ix * QK_K + 8 * it; + + uint16_t sc16[4]; + + for (int ib = ix; ib < nb; ib += 8) { + + float2 sumy = {0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i] = y4[i+ 0]; sumy[0] += yl[i]; + yh[i] = y4[i+32]; sumy[1] += yh[i]; + } + + device const uint16_t * sc = (device const uint16_t *)x[ib].scales; + device const uint16_t * qs = (device const uint16_t *)x[ib].qs + 4 * it; + device const half * dh = x[ib].d; + + for (int row = 0; row < N_DST; row++) { + + sc16[0] = sc[0] & 0x000f; + sc16[1] = sc[0] & 0x0f00; + sc16[2] = sc[0] & 0x00f0; + sc16[3] = sc[0] & 0xf000; + + float2 acc1 = {0.f, 0.f}; + float2 acc2 = {0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1[0] += yl[i+0] * (qs[i/2] & 0x000F); + acc1[1] += yl[i+1] * (qs[i/2] & 0x0F00); + acc2[0] += yh[i+0] * (qs[i/2] & 0x00F0); + acc2[1] += yh[i+1] * (qs[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1[0] + 1.f/256.f * acc1[1]) * sc16[0] + + (acc2[0] + 1.f/256.f * acc2[1]) * sc16[1] * 1.f/4096.f) - + dmin * 1.f/16.f * (sumy[0] * sc16[2] + sumy[1] * sc16[3] * 1.f/256.f); + + qs += step; + sc += step; + dh += step; + } + + y4 += 8 * QK_K; + } + + for (int row = 0; row < N_DST; ++row) { + all_sum = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = all_sum; + } + } +} +#endif kernel void kernel_mul_mat_q5_K_f32( device const void * src0, @@ -1646,39 +1667,39 @@ kernel void kernel_mul_mat_q5_K_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const int nb = ne00/QK_K; const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q5_K * x = (device const block_q5_K *) src0 + r0*nb; + const int first_row = (r0 * N_SIMDGROUP + sgitg) * 2; + + device const block_q5_K * x = (device const block_q5_K *) src0 + first_row*nb; device const float * yy = (device const float *) src1 + r1*ne10; - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; + float sumf[2]={0.f}; - float sumf = 0; + const int step = sizeof(block_q5_K) * nb; #if QK_K == 256 +# + float yl[16], yh[16]; const uint16_t kmask1 = 0x3f3f; const uint16_t kmask2 = 0x0f0f; const uint16_t kmask3 = 0xc0c0; - const int tid = tpitg.y; // 0...16 - const int il = tid/4; // 0...3 - const int ir = tid - 4*il;// 0...3 - const int n = 4; + const int tid = tiisg/4; + const int ix = tiisg%4; + const int im = tid/4; + const int ir = tid%4; + const int n = 8; - const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 - const int in = il%2; - - const int l0 = n*(2*ir + in); + const int l0 = n*ir; const int q_offset = 32*im + l0; const int y_offset = 64*im + l0; @@ -1687,78 +1708,114 @@ kernel void kernel_mul_mat_q5_K_f32( const uint8_t hm3 = hm1 << 4; const uint8_t hm4 = hm2 << 4; - uchar2 sc1, sc2, sc3, sc4; + uint16_t sc16[4]; + thread const uint8_t * sc8 = (thread const uint8_t *)sc16; - for (int i = tpitg.x; i < nb; i += tptg.x) { + device const float * y1 = yy + ix*QK_K + y_offset; - device const uint8_t * q1 = (x + i)->qs + q_offset; - device const uint8_t * q2 = q1 + 64; - device const uint8_t * qh = (x + i)->qh + l0; - device const float * y1 = yy + i*QK_K + y_offset; - device const float * y2 = y1 + 128; + for (int i = ix; i < nb; i += 4) { - const float dall = (float)((x + i)->d); - const float dmin = (float)((x + i)->dmin); + device const uint8_t * q1 = x[i].qs + q_offset; + device const uint8_t * qh = x[i].qh + l0; + device const half * dh = &x[i].d; + device const uint16_t * a = (device const uint16_t *)x[i].scales + im; - device const uint16_t * a = (device const uint16_t *)(x + i)->scales; - sc1 = as_type((uint16_t)(a[im+0] & kmask1)); - sc2 = as_type((uint16_t)(a[im+2] & kmask1)); - sc3 = as_type((uint16_t)(((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2))); - sc4 = as_type((uint16_t)(((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2))); + device const float * y2 = y1 + 128; + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 8; ++l) { + yl[l+0] = y1[l+ 0]; sumy[0] += yl[l+0]; + yl[l+8] = y1[l+32]; sumy[1] += yl[l+8]; + yh[l+0] = y2[l+ 0]; sumy[2] += yh[l+0]; + yh[l+8] = y2[l+32]; sumy[3] += yh[l+8]; + } - float4 s = {0.f, 0.f, 0.f, 0.f}; - float smin = 0; - for (int l = 0; l < n; ++l) { + for (int row = 0; row < 2; ++row) { - s[0] += y1[l+ 0] * ((q1[l] & 0xF) + (qh[l] & hm1 ? 16 : 0)); - s[1] += y1[l+32] * ((q1[l] >> 4) + (qh[l] & hm2 ? 16 : 0)); - s[2] += y2[l+ 0] * ((q2[l] & 0xF) + (qh[l] & hm3 ? 16 : 0)); - s[3] += y2[l+32] * ((q2[l] >> 4) + (qh[l] & hm4 ? 16 : 0)); - smin += y1[l] * sc2[0] + y1[l+32] * sc2[1] + y2[l] * sc4[0] + y2[l+32] * sc4[1]; + device const uint8_t * q2 = q1 + 64; + + sc16[0] = a[0] & kmask1; + sc16[1] = a[2] & kmask1; + sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); + sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); + + float4 acc = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < n; ++l) { + uint8_t h = qh[l]; + acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0)); + acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0)); + acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0)); + acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0)); + } + const float dall = dh[0]; + const float dmin = dh[1]; + sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) - + dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); + + q1 += step; + qh += step; + dh += step/2; + a += step/2; } - sumf += dall * (s[0] * sc1[0] + s[1] * sc1[1] + s[2] * sc3[0] + s[3] * sc3[1]) - dmin * smin; + + y1 += 4 * QK_K; } #else - const int il = 4 * tpitg.x; // 0, 4, 8, 12 - const int im = il/8; // 0, 0, 1, 1 - const int in = il%8; // 0, 4, 0, 4 + float yl[8], yh[8]; - for (int i = tpitg.y; i < nb; i += tptg.y) { + const int il = 4 * (tiisg/8); // 0, 4, 8, 12 + const int ix = tiisg%8; + const int im = il/8; // 0, 0, 1, 1 + const int in = il%8; // 0, 4, 0, 4 - const float d = (float)x[i].d; + device const float * y = yy + ix*QK_K + il; + + for (int i = ix; i < nb; i += 8) { + + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int l = 0; l < 4; ++l) { + yl[l+0] = y[l+ 0]; + yl[l+4] = y[l+16]; + yh[l+0] = y[l+32]; + yh[l+4] = y[l+48]; + } + + device const half * dh = &x[i].d; device const uint8_t * q = x[i].qs + il; device const uint8_t * h = x[i].qh + in; device const int8_t * s = x[i].scales; - device const float * y = yy + i*QK_K + il; - for (int l = 0; l < 4; ++l) { - const uint8_t hl = h[l] >> im; - sumf += y[l+ 0] * d * s[0] * ((q[l+ 0] & 0xF) - (hl & 0x01 ? 0 : 16)) - + y[l+16] * d * s[1] * ((q[l+16] & 0xF) - (hl & 0x04 ? 0 : 16)) - + y[l+32] * d * s[2] * ((q[l+ 0] >> 4) - (hl & 0x10 ? 0 : 16)) - + y[l+48] * d * s[3] * ((q[l+16] >> 4) - (hl & 0x40 ? 0 : 16)); + for (int row = 0; row < 2; ++row) { + + const float d = dh[0]; + + float2 acc = {0.f, 0.f}; + for (int l = 0; l < 4; ++l) { + const uint8_t hl = h[l] >> im; + acc[0] += yl[l+0] * s[0] * ((int16_t)(q[l+ 0] & 0x0F) - (hl & 0x01 ? 0 : 16)) + + yl[l+4] * s[1] * ((int16_t)(q[l+16] & 0x0F) - (hl & 0x04 ? 0 : 16)); + acc[1] += yh[l+0] * s[2] * ((int16_t)(q[l+ 0] & 0xF0) - (hl & 0x10 ? 0 : 256)) + + yh[l+4] * s[3] * ((int16_t)(q[l+16] & 0xF0) - (hl & 0x40 ? 0 : 256)); + } + sumf[row] += d * (acc[0] + 1.f/16.f * acc[1]); + + q += step; + h += step; + s += step; + dh += step/2; + } + + y += 8 * QK_K; } #endif - sum[ith] = sumf; - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; + for (int row = 0; row < 2; ++row) { + const float tot = simd_sum(sumf[row]); + if (tiisg == 0) { + dst[r1*ne0 + first_row + row] = tot; + } } } @@ -1770,10 +1827,9 @@ kernel void kernel_mul_mat_q6_K_f32( constant int64_t & ne00, constant int64_t & ne10, constant int64_t & ne0, - threadgroup float * sum [[threadgroup(0)]], uint2 tgpig[[threadgroup_position_in_grid]], - uint2 tpitg[[thread_position_in_threadgroup]], - uint2 tptg[[threads_per_threadgroup]]) { + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { const uint8_t kmask1 = 0x03; const uint8_t kmask2 = 0x0C; @@ -1785,19 +1841,18 @@ kernel void kernel_mul_mat_q6_K_f32( const int64_t r0 = tgpig.x; const int64_t r1 = tgpig.y; - device const block_q6_K * x = (device const block_q6_K *) src0 + r0*nb; - device const float * yy = (device const float *) src1 + r1*ne10; + const int row = 2 * r0 + sgitg; - const int nth = tptg.x*tptg.y; - const int ith = tptg.y*tpitg.x + tpitg.y; + device const block_q6_K * x = (device const block_q6_K *) src0 + row * nb; //r0*nb; + device const float * yy = (device const float *) src1 + r1*ne10; float sumf = 0; #if QK_K == 256 - // Note: we absolutely assume that tptg.y = 16 and QK_K = 256! - const int iqs = 16 * tpitg.y; - const int ip = iqs / 128; // 0 or 1 - const int il = (iqs - 128*ip)/16; // 0...7 + const int tid = tiisg/2; + const int ix = tiisg%2; + const int ip = tid/8; // 0 or 1 + const int il = tid%8; const int n = 4; const int l0 = n*il; const int is = 8*ip + l0/16; @@ -1806,9 +1861,10 @@ kernel void kernel_mul_mat_q6_K_f32( const int q_offset_l = 64*ip + l0; const int q_offset_h = 32*ip + l0; - for (int i = tpitg.x; i < nb; i += tptg.x) { + for (int i = ix; i < nb; i += 2) { - device const uint8_t * ql = x[i].ql + q_offset_l; + device const uint8_t * q1 = x[i].ql + q_offset_l; + device const uint8_t * q2 = q1 + 32; device const uint8_t * qh = x[i].qh + q_offset_h; device const int8_t * sc = x[i].scales + is; @@ -1818,19 +1874,21 @@ kernel void kernel_mul_mat_q6_K_f32( float4 sums = {0.f, 0.f, 0.f, 0.f}; for (int l = 0; l < n; ++l) { - sums[0] += y[l+ 0] * ((int8_t)((ql[l+ 0] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); - sums[1] += y[l+32] * ((int8_t)((ql[l+32] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); - sums[2] += y[l+64] * ((int8_t)((ql[l+ 0] >> 4) | ((qh[l] & kmask3) << 0)) - 32); - sums[3] += y[l+96] * ((int8_t)((ql[l+32] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); + sums[0] += y[l+ 0] * ((int8_t)((q1[l] & 0xF) | ((qh[l] & kmask1) << 4)) - 32); + sums[1] += y[l+32] * ((int8_t)((q2[l] & 0xF) | ((qh[l] & kmask2) << 2)) - 32); + sums[2] += y[l+64] * ((int8_t)((q1[l] >> 4) | ((qh[l] & kmask3) << 0)) - 32); + sums[3] += y[l+96] * ((int8_t)((q2[l] >> 4) | ((qh[l] & kmask4) >> 2)) - 32); } sumf += dall * (sums[0] * sc[0] + sums[1] * sc[2] + sums[2] * sc[4] + sums[3] * sc[6]); } -#else - const int il = 4*tpitg.x; // 0, 4, 8, 12 - for (int i = tpitg.y; i < nb; i += tptg.y) { +#else + const int ix = tiisg/4; + const int il = 4*(tiisg%4); + + for (int i = ix; i < nb; i += 8) { device const float * y = yy + i * QK_K + il; device const uint8_t * ql = x[i].ql + il; device const uint8_t * qh = x[i].qh + il; @@ -1850,23 +1908,8 @@ kernel void kernel_mul_mat_q6_K_f32( #endif - sum[ith] = sumf; - - // - // Accumulate the sum from all threads in the threadgroup - // - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%4 == 0) { - for (int i = 1; i < 4; ++i) sum[ith] += sum[ith + i]; + const float tot = simd_sum(sumf); + if (tiisg == 0) { + dst[r1*ne0 + row] = tot; } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith%16 == 0) { - for (int i = 4; i < 16; i += 4) sum[ith] += sum[ith + i]; - } - threadgroup_barrier(mem_flags::mem_threadgroup); - if (ith == 0) { - for (int i = 16; i < nth; i += 16) sum[0] += sum[i]; - dst[r1*ne0 + r0] = sum[0]; - } - } diff --git a/llama.cpp b/llama.cpp index 3319b7023..23e746d62 100644 --- a/llama.cpp +++ b/llama.cpp @@ -555,7 +555,9 @@ struct llama_file_loader { } // skip to the next multiple of 32 bytes - file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); + if (file_version >= LLAMA_FILE_VERSION_GGJT_V1) { + file.seek(-static_cast(file.tell()) & 31, SEEK_CUR); + } tensor.file_off = file.tell(); tensor.name = name; @@ -847,7 +849,7 @@ struct llama_context_params llama_context_default_params() { /*.n_batch =*/ 512, /*.gpu_layers =*/ 0, /*.main_gpu =*/ 0, - /*.tensor_split =*/ {0}, + /*.tensor_split =*/ nullptr, /*.rope_freq_base =*/ 10000.0f, /*.rope_freq_scale =*/ 1.0f, /*.progress_callback =*/ nullptr, @@ -1287,7 +1289,7 @@ static bool llama_model_load( int n_batch, int n_gpu_layers, int main_gpu, - float * tensor_split, + const float * tensor_split, float rope_freq_base, float rope_freq_scale, bool low_vram, diff --git a/llama.h b/llama.h index b676a383b..c565f6a00 100644 --- a/llama.h +++ b/llama.h @@ -88,7 +88,8 @@ extern "C" { int32_t n_batch; // prompt processing batch size int32_t n_gpu_layers; // number of layers to store in VRAM int32_t main_gpu; // the GPU that is used for scratch and small tensors - float tensor_split[LLAMA_MAX_DEVICES]; // how to split layers across multiple GPUs + + const float * tensor_split; // how to split layers across multiple GPUs (size: LLAMA_MAX_DEVICES) // ref: https://github.com/ggerganov/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency