Merge 'origin/master' into hipblas

This commit is contained in:
Henri Vasserman 2023-06-17 16:53:22 +03:00
commit 6f7c15637a
No known key found for this signature in database
GPG key ID: 2995FC0F58B1A986
42 changed files with 3008 additions and 1732 deletions

2
.flake8 Normal file
View file

@ -0,0 +1,2 @@
[flake8]
max-line-length = 125

4
.gitignore vendored
View file

@ -22,6 +22,7 @@ build-metal/
build-no-accel/ build-no-accel/
build-sanitize-addr/ build-sanitize-addr/
build-sanitize-thread/ build-sanitize-thread/
out/
models/* models/*
*.bin *.bin
@ -32,14 +33,17 @@ models/*
/result /result
/perplexity /perplexity
/embedding /embedding
/train-text-from-scratch
/benchmark-matmult /benchmark-matmult
/vdot /vdot
/server
/Pipfile /Pipfile
/libllama.so /libllama.so
build-info.h build-info.h
arm_neon.h arm_neon.h
compile_commands.json compile_commands.json
CMakeSettings.json
__pycache__ __pycache__

15
.pre-commit-config.yaml Normal file
View file

@ -0,0 +1,15 @@
# See https://pre-commit.com for more information
# See https://pre-commit.com/hooks.html for more hooks
exclude: prompts/.*.txt
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v3.2.0
hooks:
- id: trailing-whitespace
- id: end-of-file-fixer
- id: check-yaml
- id: check-added-large-files
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8

View file

@ -70,6 +70,7 @@ set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF) option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels") set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels") set(LLAMA_CUDA_DMMV_Y "1" CACHE STRING "llama: y block size for dmmv CUDA kernels")
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF) option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
option(LLAMA_METAL "llama: use Metal" OFF) option(LLAMA_METAL "llama: use Metal" OFF)
@ -159,17 +160,64 @@ if (LLAMA_BLAS)
if ($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22) if ($(CMAKE_VERSION) VERSION_GREATER_EQUAL 3.22)
set(BLA_SIZEOF_INTEGER 8) set(BLA_SIZEOF_INTEGER 8)
endif() endif()
set(BLA_VENDOR ${LLAMA_BLAS_VENDOR}) set(BLA_VENDOR ${LLAMA_BLAS_VENDOR})
find_package(BLAS) find_package(BLAS)
if (BLAS_FOUND) if (BLAS_FOUND)
message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}") message(STATUS "BLAS found, Libraries: ${BLAS_LIBRARIES}")
if ("${BLAS_INCLUDE_DIRS}" STREQUAL "")
# BLAS_INCLUDE_DIRS is missing in FindBLAS.cmake.
# see https://gitlab.kitware.com/cmake/cmake/-/issues/20268
find_package(PkgConfig REQUIRED)
if (${LLAMA_BLAS_VENDOR} MATCHES "Generic")
pkg_check_modules(DepBLAS REQUIRED blas)
elseif (${LLAMA_BLAS_VENDOR} MATCHES "OpenBLAS")
pkg_check_modules(DepBLAS REQUIRED openblas)
elseif (${LLAMA_BLAS_VENDOR} MATCHES "FLAME")
pkg_check_modules(DepBLAS REQUIRED blis)
elseif (${LLAMA_BLAS_VENDOR} MATCHES "ATLAS")
pkg_check_modules(DepBLAS REQUIRED blas-atlas)
elseif (${LLAMA_BLAS_VENDOR} MATCHES "FlexiBLAS")
pkg_check_modules(DepBLAS REQUIRED flexiblas_api)
elseif (${LLAMA_BLAS_VENDOR} MATCHES "Intel")
# all Intel* libraries share the same include path
pkg_check_modules(DepBLAS REQUIRED mkl-sdl)
elseif (${LLAMA_BLAS_VENDOR} MATCHES "NVHPC")
# this doesn't provide pkg-config
# suggest to assign BLAS_INCLUDE_DIRS on your own
if ("${NVHPC_VERSION}" STREQUAL "")
message(WARNING "Better to set NVHPC_VERSION")
else()
set(DepBLAS_FOUND ON)
set(DepBLAS_INCLUDE_DIRS "/opt/nvidia/hpc_sdk/${CMAKE_SYSTEM_NAME}_${CMAKE_SYSTEM_PROCESSOR}/${NVHPC_VERSION}/math_libs/include")
endif()
endif()
if (DepBLAS_FOUND)
set(BLAS_INCLUDE_DIRS ${DepBLAS_INCLUDE_DIRS})
else()
message(WARNING "BLAS_INCLUDE_DIRS neither been provided nor been automatically"
" detected by pkgconfig, trying to find cblas.h from possible paths...")
find_path(BLAS_INCLUDE_DIRS
NAMES cblas.h
HINTS
/usr/include
/usr/local/include
/usr/include/openblas
/opt/homebrew/opt/openblas/include
/usr/local/opt/openblas/include
/usr/include/x86_64-linux-gnu/openblas/include
)
endif()
endif()
message(STATUS "BLAS found, Includes: ${BLAS_INCLUDE_DIRS}")
add_compile_options(${BLAS_LINKER_FLAGS}) add_compile_options(${BLAS_LINKER_FLAGS})
add_compile_definitions(GGML_USE_OPENBLAS) add_compile_definitions(GGML_USE_OPENBLAS)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES}) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ${BLAS_LIBRARIES})
set(LLAMA_EXTRA_INCLUDES ${LLAMA_EXTRA_INCLUDES} ${BLAS_INCLUDE_DIRS})
message("${BLAS_LIBRARIES} ${BLAS_INCLUDE_DIRS}")
include_directories(${BLAS_INCLUDE_DIRS})
else() else()
message(WARNING "BLAS not found, please refer to " message(WARNING "BLAS not found, please refer to "
"https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors" "https://cmake.org/cmake/help/latest/module/FindBLAS.html#blas-lapack-vendors"
@ -191,6 +239,7 @@ if (LLAMA_CUBLAS)
add_compile_definitions(GGML_USE_CUBLAS) add_compile_definitions(GGML_USE_CUBLAS)
add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X}) add_compile_definitions(GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y}) add_compile_definitions(GGML_CUDA_DMMV_Y=${LLAMA_CUDA_DMMV_Y})
add_compile_definitions(K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
if (LLAMA_STATIC) if (LLAMA_STATIC)
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
@ -440,12 +489,14 @@ add_library(ggml OBJECT
${GGML_SOURCES_EXTRA} ${GGML_SOURCES_EXTRA}
) )
target_include_directories(ggml PUBLIC .) target_include_directories(ggml PUBLIC . ${LLAMA_EXTRA_INCLUDES})
target_compile_features(ggml PUBLIC c_std_11) # don't bump target_compile_features(ggml PUBLIC c_std_11) # don't bump
target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS}) target_link_libraries(ggml PUBLIC Threads::Threads ${LLAMA_EXTRA_LIBS})
add_library(ggml_static STATIC $<TARGET_OBJECTS:ggml>)
if (BUILD_SHARED_LIBS) if (BUILD_SHARED_LIBS)
set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON) set_target_properties(ggml PROPERTIES POSITION_INDEPENDENT_CODE ON)
add_library(ggml_shared SHARED $<TARGET_OBJECTS:ggml>)
endif() endif()
add_library(llama add_library(llama

View file

@ -1,8 +1,10 @@
# Define the default target now so that it is always the first target # Define the default target now so that it is always the first target
BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot BUILD_TARGETS = main quantize quantize-stats perplexity embedding vdot train-text-from-scratch simple
ifdef LLAMA_BUILD_SERVER ifdef LLAMA_BUILD_SERVER
BUILD_TARGETS += server BUILD_TARGETS += server
LLAMA_SERVER_VERBOSE ?= 1
server: private CXXFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE)
endif endif
default: $(BUILD_TARGETS) default: $(BUILD_TARGETS)
@ -171,6 +173,11 @@ ifdef LLAMA_CUDA_DMMV_Y
else else
NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1 NVCCFLAGS += -DGGML_CUDA_DMMV_Y=1
endif # LLAMA_CUDA_DMMV_Y endif # LLAMA_CUDA_DMMV_Y
ifdef LLAMA_CUDA_KQUANTS_ITER
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
else
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
endif
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@ $(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif # LLAMA_CUBLAS endif # LLAMA_CUBLAS
@ -277,7 +284,7 @@ libllama.so: llama.o ggml.o $(OBJS)
$(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS) $(CXX) $(CXXFLAGS) -shared -fPIC -o $@ $^ $(LDFLAGS)
clean: clean:
rm -vf *.o main quantize quantize-stats perplexity embedding benchmark-matmult save-load-state server vdot build-info.h rm -vf *.o *.so main quantize quantize-stats perplexity embedding benchmark-matmult save-load-state server vdot train-text-from-scratch build-info.h
# #
# Examples # Examples
@ -289,6 +296,12 @@ main: examples/main/main.cpp build-info.h ggml.
@echo '==== Run ./main -h for help. ====' @echo '==== Run ./main -h for help. ===='
@echo @echo
simple: examples/simple/simple.cpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
@echo
@echo '==== Run ./simple -h for help. ===='
@echo
quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS) quantize: examples/quantize/quantize.cpp build-info.h ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
@ -307,6 +320,9 @@ save-load-state: examples/save-load-state/save-load-state.cpp build-info.h ggml.
server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS) server: examples/server/server.cpp examples/server/httplib.h examples/server/json.hpp build-info.h ggml.o llama.o common.o $(OBJS)
$(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS) $(CXX) $(CXXFLAGS) -Iexamples/server $(filter-out %.h,$(filter-out %.hpp,$^)) -o $@ $(LDFLAGS)
train-text-from-scratch: examples/train-text-from-scratch/train-text-from-scratch.cpp build-info.h ggml.o llama.o $(OBJS)
$(CXX) $(CXXFLAGS) $(filter-out %.h,$^) -o $@ $(LDFLAGS)
build-info.h: $(wildcard .git/index) scripts/build-info.sh build-info.h: $(wildcard .git/index) scripts/build-info.sh
@sh scripts/build-info.sh > $@.tmp @sh scripts/build-info.sh > $@.tmp
@if ! cmp -s $@.tmp $@; then \ @if ! cmp -s $@.tmp $@; then \

View file

@ -11,6 +11,7 @@ let package = Package(
.target( .target(
name: "llama", name: "llama",
path: ".", path: ".",
exclude: ["ggml-metal.metal"],
sources: ["ggml.c", "llama.cpp"], sources: ["ggml.c", "llama.cpp"],
publicHeadersPath: "spm-headers", publicHeadersPath: "spm-headers",
cSettings: [.unsafeFlags(["-Wno-shorten-64-to-32"]), .define("GGML_USE_ACCELERATE")], cSettings: [.unsafeFlags(["-Wno-shorten-64-to-32"]), .define("GGML_USE_ACCELERATE")],

View file

@ -616,6 +616,7 @@ And after 4.45 hours, you will have the final perplexity.
### Android ### Android
#### Building the Project using Android NDK
You can easily run `llama.cpp` on Android device with [termux](https://termux.dev/). You can easily run `llama.cpp` on Android device with [termux](https://termux.dev/).
First, obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake: First, obtain the [Android NDK](https://developer.android.com/ndk) and then build with CMake:
``` ```
@ -630,6 +631,46 @@ Finally, copy the `llama` binary and the model files to your device storage. Her
https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4 https://user-images.githubusercontent.com/271616/225014776-1d567049-ad71-4ef2-b050-55b0b3b9274c.mp4
#### Building the Project using Termux (F-Droid)
Termux from F-Droid offers an alternative route to execute the project on an Android device. This method empowers you to construct the project right from within the terminal, negating the requirement for a rooted device or SD Card.
Outlined below are the directives for installing the project using OpenBLAS and CLBlast. This combination is specifically designed to deliver peak performance on recent devices that feature a GPU.
If you opt to utilize OpenBLAS, you'll need to install the corresponding package.
```
apt install libopenblas
```
Subsequently, if you decide to incorporate CLBlast, you'll first need to install the requisite OpenCL packages:
```
apt install ocl-icd opencl-headers opencl-clhpp clinfo
```
In order to compile CLBlast, you'll need to first clone the respective Git repository, which can be found at this URL: https://github.com/CNugteren/CLBlast. Alongside this, clone this repository into your home directory. Once this is done, navigate to the CLBlast folder and execute the commands detailed below:
```
cmake .
make
cp libclblast.so* $PREFIX/lib
cp ./include/clblast.h ../llama.cpp
```
Following the previous steps, navigate to the LlamaCpp directory. To compile it with OpenBLAS and CLBlast, execute the command provided below:
```
cp /data/data/com.termux/files/usr/include/openblas/cblas.h .
cp /data/data/com.termux/files/usr/include/openblas/openblas_config.h .
make LLAMA_CLBLAST=1 //(sometimes you need to run this command twice)
```
Upon completion of the aforementioned steps, you will have successfully compiled the project. To run it using CLBlast, a slight adjustment is required: a command must be issued to direct the operations towards your device's physical GPU, rather than the virtual one. The necessary command is detailed below:
```
GGML_OPENCL_PLATFORM=0
GGML_OPENCL_DEVICE=0
export LD_LIBRARY_PATH=/system/vendor/lib64:$LD_LIBRARY_PATH
./main (...)
```
For easy and swift re-execution, consider documenting this final part in a .sh script file. This will enable you to rerun the process with minimal hassle.
### Docker ### Docker
#### Prerequisites #### Prerequisites

View file

@ -512,7 +512,11 @@ class LazyTensor:
if not isinstance(self.data_type, QuantizedDataType): if not isinstance(self.data_type, QuantizedDataType):
raise Exception(f"Can't turn an unquantized tensor into a quantized type ({data_type})") raise Exception(f"Can't turn an unquantized tensor into a quantized type ({data_type})")
if self.data_type.have_g_idx: if self.data_type.have_g_idx:
sys.stderr.write("Error: Input uses the newer GPTQ-for-LLaMa format (using g_idx), which is not yet natively supported by GGML. For now you can still convert this model by passing `--outtype f16` to dequantize, but that will result in a much larger output file for no quality benefit.\n") sys.stderr.write(
"Error: Input uses the newer GPTQ-for-LLaMa format (using g_idx), "
"which is not yet natively supported by GGML. "
"For now you can still convert this model by passing `--outtype f16` to dequantize, "
"but that will result in a much larger output file for no quality benefit.\n")
sys.exit(1) sys.exit(1)
assert not data_type.have_g_idx and self.data_type.have_addends and data_type.have_addends assert not data_type.have_g_idx and self.data_type.have_addends and data_type.have_addends
@ -694,8 +698,9 @@ class LazyUnpickler(pickle.Unpickler):
description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}' description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}'
return LazyStorage(load=load, kind=pid[1], description=description) return LazyStorage(load=load, kind=pid[1], description=description)
# @staticmethod # @staticmethod
def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any, # pyright: ignore[reportSelfClsParameterName] def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any,
# pyright: ignore[reportSelfClsParameterName]
requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor: requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor:
assert isinstance(storage, LazyStorage) assert isinstance(storage, LazyStorage)
@ -812,7 +817,7 @@ def lazy_load_ggml_file(fp: io.BufferedReader, path: Path) -> ModelPlus:
# Use mmap for the actual data to avoid race conditions with the file offset. # Use mmap for the actual data to avoid race conditions with the file offset.
off = fp.raw.tell() off = fp.raw.tell()
mapped = memoryview(mmap.mmap(fp.fileno(), 0, access=mmap.ACCESS_READ)) mapped = memoryview(mmap.mmap(fp.fileno(), 0, access=mmap.ACCESS_READ))
fp.raw.seek(off) # needed on Windows fp.raw.seek(off) # needed on Windows
def read_tensor() -> None: # this is a function so that variables captured in `load` don't change def read_tensor() -> None: # this is a function so that variables captured in `load` don't change
shape_len, name_len, ftype = struct.unpack("iii", must_read(fp, 12)) shape_len, name_len, ftype = struct.unpack("iii", must_read(fp, 12))
@ -1054,7 +1059,7 @@ def load_some_model(path: Path) -> ModelPlus:
files = list(path.glob("model-00001-of-*.safetensors")) files = list(path.glob("model-00001-of-*.safetensors"))
if not files: if not files:
# Try the PyTorch patterns too, with lower priority # Try the PyTorch patterns too, with lower priority
globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin" ] globs = ["consolidated.00.pth", "pytorch_model-00001-of-*.bin", "*.pt", "pytorch_model.bin"]
files = [file for glob in globs for file in path.glob(glob)] files = [file for glob in globs for file in path.glob(glob)]
if not files: if not files:
# Try GGML too, but with lower priority, since if both a non-GGML # Try GGML too, but with lower priority, since if both a non-GGML
@ -1094,7 +1099,9 @@ def load_vocab(path: Path) -> SentencePieceVocab:
elif path3.exists(): elif path3.exists():
path = path3 path = path3
else: else:
raise FileNotFoundError(f"Could not find tokenizer.model in {path} or its parent; if it's in another directory, pass the directory as --vocab-dir") raise FileNotFoundError(
f"Could not find tokenizer.model in {path} or its parent; "
"if it's in another directory, pass the directory as --vocab-dir")
added_tokens_path = path.parent / "added_tokens.json" added_tokens_path = path.parent / "added_tokens.json"
print(f"Loading vocab file {path}") print(f"Loading vocab file {path}")
return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None) return SentencePieceVocab(path, added_tokens_path if added_tokens_path.exists() else None)
@ -1110,7 +1117,9 @@ def default_outfile(model_paths: List[Path], params: Params) -> Path:
}[params.file_type] }[params.file_type]
ret = model_paths[0].parent / f"ggml-model-{namestr}.bin" ret = model_paths[0].parent / f"ggml-model-{namestr}.bin"
if ret in model_paths: if ret in model_paths:
sys.stderr.write(f"Error: Default output path ({ret}) would overwrite the input. Please explicitly specify a path using --outfile.\n") sys.stderr.write(
f"Error: Default output path ({ret}) would overwrite the input. "
"Please explicitly specify a path using --outfile.\n")
sys.exit(1) sys.exit(1)
return ret return ret
@ -1131,7 +1140,8 @@ def main(args_in: Optional[List[str]] = None) -> None:
parser.add_argument("--outtype", choices=["f32", "f16", "q4_1", "q4_0"], help="output format (default: based on input)") parser.add_argument("--outtype", choices=["f32", "f16", "q4_1", "q4_0"], help="output format (default: based on input)")
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)") parser.add_argument("model", type=Path,
help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
args = parser.parse_args(args_in) args = parser.parse_args(args_in)
vocab: Vocab vocab: Vocab

View file

@ -4,6 +4,10 @@
#include <random> #include <random>
#include <cstring> #include <cstring>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
float frand() { float frand() {
return (float)rand()/(float)RAND_MAX; return (float)rand()/(float)RAND_MAX;
} }
@ -1470,7 +1474,7 @@ struct ggml_tensor * square_error_loss(struct ggml_context * ctx, struct ggml_te
} }
struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) { struct ggml_tensor * cross_entropy_loss(struct ggml_context * ctx, struct ggml_tensor * a, struct ggml_tensor * b) {
const float eps = 1e-3; const float eps = 1e-3f;
return return
ggml_sum(ctx, ggml_sum(ctx,
ggml_neg(ctx, ggml_neg(ctx,

View file

@ -16,6 +16,10 @@
#include <iterator> #include <iterator>
#include <algorithm> #include <algorithm>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
float tensor_sum_elements(const ggml_tensor * tensor) { float tensor_sum_elements(const ggml_tensor * tensor) {
float sum = 0; float sum = 0;
if (tensor->type==GGML_TYPE_F32) { if (tensor->type==GGML_TYPE_F32) {
@ -29,9 +33,9 @@ float tensor_sum_elements(const ggml_tensor * tensor) {
} }
void tensor_dump(const ggml_tensor * tensor, const char * name) { void tensor_dump(const ggml_tensor * tensor, const char * name) {
printf("%15s: type = %i (%5s) ne = %5d x %5d x %5d, nb = (%5li, %5li, %5li) - ", name, printf("%15s: type = %i (%5s) ne = %5" PRIi64 " x %5" PRIi64 " x %5" PRIi64 ", nb = (%5zi, %5zi, %5zi) - ", name,
tensor->type, ggml_type_name(tensor->type), tensor->type, ggml_type_name(tensor->type),
(int) tensor->ne[0], (int) tensor->ne[1], (int) tensor->ne[2], tensor->nb[0], tensor->nb[1], tensor->nb[2]); tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->nb[0], tensor->nb[1], tensor->nb[2]);
float sum = tensor_sum_elements(tensor); float sum = tensor_sum_elements(tensor);
printf("Sum of tensor %s is %6.2f\n", name, sum); printf("Sum of tensor %s is %6.2f\n", name, sum);
} }
@ -120,7 +124,7 @@ int main(int argc, char ** argv) {
ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32); // BLAS ctx_size += sizex*sizey*ggml_type_sizef(GGML_TYPE_F32); // BLAS
ctx_size += 1024*1024*16; ctx_size += 1024*1024*16;
printf("Allocating Memory of size %li bytes, %li MB\n",ctx_size, (ctx_size/1024/1024)); printf("Allocating Memory of size %zi bytes, %zi MB\n",ctx_size, (ctx_size/1024/1024));
struct ggml_init_params params = { struct ggml_init_params params = {
/*.mem_size =*/ ctx_size, /*.mem_size =*/ ctx_size,

41
examples/chat-vicuna.sh Executable file
View file

@ -0,0 +1,41 @@
#!/bin/bash
set -e
cd "$(dirname "$0")/.." || exit
MODEL="${MODEL:-./models/ggml-vic13b-uncensored-q5_0.bin}"
PROMPT_TEMPLATE=${PROMPT_TEMPLATE:-./prompts/chat.txt}
USER_NAME="### Human"
AI_NAME="### Assistant"
# Adjust to the number of CPU cores you want to use.
N_THREAD="${N_THREAD:-8}"
# Number of tokens to predict (made it larger than default because we want a long interaction)
N_PREDICTS="${N_PREDICTS:-2048}"
# Note: you can also override the generation options by specifying them on the command line:
# For example, override the context size by doing: ./chatLLaMa --ctx_size 1024
GEN_OPTIONS="${GEN_OPTIONS:---ctx_size 2048 --temp 0.7 --top_k 40 --top_p 0.5 --repeat_last_n 256 --batch_size 1024 --repeat_penalty 1.17647}"
DATE_TIME=$(date +%H:%M)
DATE_YEAR=$(date +%Y)
PROMPT_FILE=$(mktemp -t llamacpp_prompt.XXXXXXX.txt)
sed -e "s/\[\[USER_NAME\]\]/$USER_NAME/g" \
-e "s/\[\[AI_NAME\]\]/$AI_NAME/g" \
-e "s/\[\[DATE_TIME\]\]/$DATE_TIME/g" \
-e "s/\[\[DATE_YEAR\]\]/$DATE_YEAR/g" \
$PROMPT_TEMPLATE > $PROMPT_FILE
# shellcheck disable=SC2086 # Intended splitting of GEN_OPTIONS
./bin/main $GEN_OPTIONS \
--model "$MODEL" \
--threads "$N_THREAD" \
--n_predict "$N_PREDICTS" \
--color --interactive \
--file ${PROMPT_FILE} \
--reverse-prompt "### Human:" \
--in-prefix ' ' \
"$@"

View file

@ -28,6 +28,10 @@
#include <wchar.h> #include <wchar.h>
#endif #endif
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
int32_t get_num_physical_cores() { int32_t get_num_physical_cores() {
#ifdef __linux__ #ifdef __linux__
// enumerate the set of thread siblings, num entries is num cores // enumerate the set of thread siblings, num entries is num cores
@ -373,7 +377,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
} else { } else {
throw std::exception(); throw std::exception();
} }
} catch (const std::exception &e) { } catch (const std::exception&) {
invalid_param = true; invalid_param = true;
break; break;
} }
@ -412,6 +416,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
gpt_print_usage(argc, argv, default_params); gpt_print_usage(argc, argv, default_params);
exit(1); exit(1);
} }
#ifdef GGML_USE_CUBLAS
if (!params.lora_adapter.empty() && params.n_gpu_layers > 0) {
fprintf(stderr, "%s: error: the simultaneous use of LoRAs and GPU acceleration is not supported", __func__);
exit(1);
}
#endif // GGML_USE_CUBLAS
if (escape_prompt) { if (escape_prompt) {
process_escapes(params.prompt); process_escapes(params.prompt);
} }

View file

@ -4,6 +4,10 @@
#include <ctime> #include <ctime>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
int main(int argc, char ** argv) { int main(int argc, char ** argv) {
gpt_params params; gpt_params params;

View file

@ -1,5 +1,5 @@
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import sys, os import os
import csv import csv
labels = [] labels = []
@ -8,6 +8,7 @@ numEntries = 1
rows = [] rows = []
def bar_chart(numbers, labels, pos): def bar_chart(numbers, labels, pos):
plt.bar(pos, numbers, color='blue') plt.bar(pos, numbers, color='blue')
plt.xticks(ticks=pos, labels=labels) plt.xticks(ticks=pos, labels=labels)
@ -16,6 +17,7 @@ def bar_chart(numbers, labels, pos):
plt.ylabel("Questions Correct") plt.ylabel("Questions Correct")
plt.show() plt.show()
def calculatecorrect(): def calculatecorrect():
directory = os.fsencode("./examples/jeopardy/results/") directory = os.fsencode("./examples/jeopardy/results/")
csv_reader = csv.reader(open("./examples/jeopardy/qasheet.csv", 'rt'), delimiter=',') csv_reader = csv.reader(open("./examples/jeopardy/qasheet.csv", 'rt'), delimiter=',')
@ -38,14 +40,13 @@ def calculatecorrect():
print(line) print(line)
else: else:
print("Correct answer: " + rows[i][2] + "\n") print("Correct answer: " + rows[i][2] + "\n")
i+=1 i += 1
print("Did the AI get the question right? (y/n)") print("Did the AI get the question right? (y/n)")
if input() == "y": if input() == "y":
totalcorrect += 1 totalcorrect += 1
numbers.append(totalcorrect) numbers.append(totalcorrect)
if __name__ == '__main__': if __name__ == '__main__':
calculatecorrect() calculatecorrect()
pos = list(range(numEntries)) pos = list(range(numEntries))

View file

@ -23,11 +23,17 @@
#include <unistd.h> #include <unistd.h>
#elif defined (_WIN32) #elif defined (_WIN32)
#define WIN32_LEAN_AND_MEAN #define WIN32_LEAN_AND_MEAN
#ifndef NOMINMAX
#define NOMINMAX #define NOMINMAX
#endif
#include <windows.h> #include <windows.h>
#include <signal.h> #include <signal.h>
#endif #endif
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
static console_state con_st; static console_state con_st;
static llama_context ** g_ctx; static llama_context ** g_ctx;
@ -348,7 +354,7 @@ int main(int argc, char ** argv) {
if ((int)embd.size() > max_embd_size) { if ((int)embd.size() > max_embd_size) {
auto skipped_tokens = embd.size() - max_embd_size; auto skipped_tokens = embd.size() - max_embd_size;
console_set_color(con_st, CONSOLE_COLOR_ERROR); console_set_color(con_st, CONSOLE_COLOR_ERROR);
printf("<<input too long: skipped %ld token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : ""); printf("<<input too long: skipped %" PRIu64 " token%s>>", skipped_tokens, skipped_tokens != 1 ? "s" : "");
console_set_color(con_st, CONSOLE_COLOR_DEFAULT); console_set_color(con_st, CONSOLE_COLOR_DEFAULT);
fflush(stdout); fflush(stdout);
embd.resize(max_embd_size); embd.resize(max_embd_size);

View file

@ -5,6 +5,10 @@
#include <cmath> #include <cmath>
#include <ctime> #include <ctime>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
std::vector<float> softmax(const std::vector<float>& logits) { std::vector<float> softmax(const std::vector<float>& logits) {
std::vector<float> probs(logits.size()); std::vector<float> probs(logits.size());
float max_logit = logits[0]; float max_logit = logits[0];

View file

@ -19,6 +19,10 @@
#include <thread> #include <thread>
#include <mutex> #include <mutex>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
struct quantize_stats_params { struct quantize_stats_params {
std::string model = "models/7B/ggml-model-f16.bin"; std::string model = "models/7B/ggml-model-f16.bin";
bool verbose = false; bool verbose = false;

View file

@ -37,7 +37,7 @@ int main(int argc, char ** argv) {
// init // init
auto ctx = llama_init_from_file(params.model.c_str(), lparams); auto ctx = llama_init_from_file(params.model.c_str(), lparams);
auto tokens = std::vector<llama_token>(params.n_ctx); auto tokens = std::vector<llama_token>(params.n_ctx);
auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), tokens.size(), true); auto n_prompt_tokens = llama_tokenize(ctx, params.prompt.c_str(), tokens.data(), int(tokens.size()), true);
if (n_prompt_tokens < 1) { if (n_prompt_tokens < 1) {
fprintf(stderr, "%s : failed to tokenize prompt\n", __func__); fprintf(stderr, "%s : failed to tokenize prompt\n", __func__);

View file

@ -1,6 +1,10 @@
set(TARGET server) set(TARGET server)
option(LLAMA_SERVER_VERBOSE "Build verbose logging option for Server" ON)
include_directories(${CMAKE_CURRENT_SOURCE_DIR}) include_directories(${CMAKE_CURRENT_SOURCE_DIR})
add_executable(${TARGET} server.cpp json.hpp httplib.h) add_executable(${TARGET} server.cpp json.hpp httplib.h)
target_compile_definitions(${TARGET} PRIVATE
SERVER_VERBOSE=$<BOOL:${LLAMA_SERVER_VERBOSE}>
)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11) target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO) if(TARGET BUILD_INFO)

View file

@ -1,33 +1,74 @@
# llama.cpp/example/server # llama.cpp/example/server
This example allow you to have a llama.cpp http server to interact from a web page or consume the API. This example demonstrates a simple HTTP API server to interact with llama.cpp.
## Table of Contents Command line options:
1. [Quick Start](#quick-start) - `--threads N`, `-t N`: Set the number of threads to use during computation.
2. [Node JS Test](#node-js-test) - `-m FNAME`, `--model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`).
3. [API Endpoints](#api-endpoints) - `-m ALIAS`, `--alias ALIAS`: Set an alias for the model. The alias will be returned in API responses.
4. [More examples](#more-examples) - `-c N`, `--ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference.
5. [Common Options](#common-options) - `-ngl N`, `--n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
6. [Performance Tuning and Memory Options](#performance-tuning-and-memory-options) - `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
- `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS.
- `-b N`, `--batch-size N`: Set the batch size for prompt processing. Default: `512`.
- `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. Not recommended.
- `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped.
- `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed.
- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.
- `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`.
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.
- `--port`: Set the port to listen. Default: `8080`.
## Build
Build llama.cpp with server from repository root with either make or CMake.
- Using `make`:
```bash
LLAMA_BUILD_SERVER=1 make
```
- Using `CMake`:
```bash
mkdir build-server
cd build-server
cmake -DLLAMA_BUILD_SERVER=ON ..
cmake --build . --config Release
```
## Quick Start ## Quick Start
To get started right away, run the following command, making sure to use the correct path for the model you have: To get started right away, run the following command, making sure to use the correct path for the model you have:
#### Unix-based systems (Linux, macOS, etc.): ### Unix-based systems (Linux, macOS, etc.):
```bash ```bash
./server -m models/7B/ggml-model.bin --ctx_size 2048 ./server -m models/7B/ggml-model.bin -c 2048
``` ```
#### Windows: ### Windows:
```powershell ```powershell
server.exe -m models\7B\ggml-model.bin --ctx_size 2048 server.exe -m models\7B\ggml-model.bin -c 2048
``` ```
That will start a server that by default listens on `127.0.0.1:8080`. You can consume the endpoints with Postman or NodeJS with axios library. The above command will start a server that by default listens on `127.0.0.1:8080`.
You can consume the endpoints with Postman or NodeJS with axios library.
## Testing with CURL
Using [curl](https://curl.se/). On Windows `curl.exe` should be available in the base OS.
```sh
curl --request POST \
--url http://localhost:8080/completion \
--data '{"prompt": "Building a website can be done in 10 simple steps:","n_predict": 128}'
```
## Node JS Test ## Node JS Test
@ -50,7 +91,6 @@ const prompt = `Building a website can be done in 10 simple steps:`;
async function Test() { async function Test() {
let result = await axios.post("http://127.0.0.1:8080/completion", { let result = await axios.post("http://127.0.0.1:8080/completion", {
prompt, prompt,
batch_size: 128,
n_predict: 512, n_predict: 512,
}); });
@ -69,247 +109,75 @@ node .
## API Endpoints ## API Endpoints
You can interact with this API Endpoints. This implementations just support chat style interaction. - **POST** `/completion`: Given a prompt, it returns the predicted completion.
- **POST** `hostname:port/completion`: Setting up the Llama Context to begin the completions tasks. *Options:*
*Options:* `temperature`: Adjust the randomness of the generated text (default: 0.8).
`batch_size`: Set the batch size for prompt processing (default: 512). `top_k`: Limit the next token selection to the K most probable tokens (default: 40).
`temperature`: Adjust the randomness of the generated text (default: 0.8). `top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.9).
`top_k`: Limit the next token selection to the K most probable tokens (default: 40). `n_predict`: Set the number of tokens to predict when generating text. **Note:** May exceed the set limit slightly if the last token is a partial multibyte character. (default: 128, -1 = infinity).
`top_p`: Limit the next token selection to a subset of tokens with a cumulative probability above a threshold P (default: 0.9). `n_keep`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context.
By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt.
`n_predict`: Set the number of tokens to predict when generating text (default: 128, -1 = infinity). `stream`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`.
`threads`: Set the number of threads to use during computation. `prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate.
`n_keep`: Specify the number of tokens from the initial prompt to retain when the model resets its internal context. By default, this value is set to 0 (meaning no tokens are kept). Use `-1` to retain all tokens from the initial prompt. `stop`: Specify a JSON array of stopping strings.
These words will not be included in the completion, so make sure to add them to the prompt for the next iteration (default: []).
`as_loop`: It allows receiving each predicted token in real-time instead of waiting for the completion to finish. To enable this, set to `true`. `tfs_z`: Enable tail free sampling with parameter z (default: 1.0, 1.0 = disabled).
`interactive`: It allows interacting with the completion, and the completion stops as soon as it encounters a `stop word`. To enable this, set to `true`. `typical_p`: Enable locally typical sampling with parameter p (default: 1.0, 1.0 = disabled).
`prompt`: Provide a prompt. Internally, the prompt is compared, and it detects if a part has already been evaluated, and the remaining part will be evaluate. `repeat_penalty`: Control the repetition of token sequences in the generated text (default: 1.1).
`stop`: Specify the words or characters that indicate a stop. These words will not be included in the completion, so make sure to add them to the prompt for the next iteration. `repeat_last_n`: Last n tokens to consider for penalizing repetition (default: 64, 0 = disabled, -1 = ctx-size).
`exclude`: Specify the words or characters you do not want to appear in the completion. These words will not be included in the completion, so make sure to add them to the prompt for the next iteration. `penalize_nl`: Penalize newline tokens when applying the repeat penalty (default: true).
- **POST** `hostname:port/embedding`: Generate embedding of a given text `presence_penalty`: Repeat alpha presence penalty (default: 0.0, 0.0 = disabled).
*Options:* `frequency_penalty`: Repeat alpha frequency penalty (default: 0.0, 0.0 = disabled);
`content`: Set the text to get generate the embedding. `mirostat`: Enable Mirostat sampling, controlling perplexity during text generation (default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0).
`threads`: Set the number of threads to use during computation. `mirostat_tau`: Set the Mirostat target entropy, parameter tau (default: 5.0).
To use this endpoint, you need to start the server with the `--embedding` option added. `mirostat_eta`: Set the Mirostat learning rate, parameter eta (default: 0.1).
- **POST** `hostname:port/tokenize`: Tokenize a given text `seed`: Set the random number generator (RNG) seed (default: -1, < 0 = random seed).
*Options:* `ignore_eos`: Ignore end of stream token and continue generating (default: false).
`content`: Set the text to tokenize. `logit_bias`: Modify the likelihood of a token appearing in the generated text completion. For example, use `"logit_bias": [[15043,1.0]]` to increase the likelihood of the token 'Hello', or `"logit_bias": [[15043,-1.0]]` to decrease its likelihood. Setting the value to false, `"logit_bias": [[15043,false]]` ensures that the token `Hello` is never produced (default: []).
- **GET** `hostname:port/next-token`: Receive the next token predicted, execute this request in a loop. Make sure set `as_loop` as `true` in the completion request. - **POST** `/tokenize`: Tokenize a given text.
*Options:* *Options:*
`stop`: Set `hostname:port/next-token?stop=true` to stop the token generation. `content`: Set the text to tokenize.
## More examples ## More examples
### Interactive mode ### Interactive mode
This mode allows interacting in a chat-like manner. It is recommended for models designed as assistants such as `Vicuna`, `WizardLM`, `Koala`, among others. Make sure to add the correct stop word for the corresponding model. Check the sample in [chat.mjs](chat.mjs).
Run with NodeJS version 16 or later:
The prompt should be generated by you, according to the model's guidelines. You should keep adding the model's completions to the context as well. ```sh
node chat.mjs
This example works well for `Vicuna - version 1`.
```javascript
const axios = require("axios");
let prompt = `A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.
### Human: Hello, Assistant.
### Assistant: Hello. How may I help you today?
### Human: Please tell me the largest city in Europe.
### Assistant: Sure. The largest city in Europe is Moscow, the capital of Russia.`;
async function ChatCompletion(answer) {
// the user's next question to the prompt
prompt += `\n### Human: ${answer}\n`
result = await axios.post("http://127.0.0.1:8080/completion", {
prompt,
batch_size: 128,
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: -1,
n_predict: 2048,
stop: ["\n### Human:"], // when detect this, stop completion
exclude: ["### Assistant:"], // no show in the completion
threads: 8,
as_loop: true, // use this to request the completion token by token
interactive: true, // enable the detection of a stop word
});
// create a loop to receive every token predicted
// note: this operation is blocking, avoid use this in a ui thread
let message = "";
while (true) {
// you can stop the inference adding '?stop=true' like this http://127.0.0.1:8080/next-token?stop=true
result = await axios.get("http://127.0.0.1:8080/next-token");
process.stdout.write(result.data.content);
message += result.data.content;
// to avoid an infinite loop
if (result.data.stop) {
console.log("Completed");
// make sure to add the completion to the prompt.
prompt += `### Assistant: ${message}`;
break;
}
}
}
// This function should be called every time a question to the model is needed.
async function Test() {
// the server can't inference in paralell
await ChatCompletion("Write a long story about a time magician in a fantasy world");
await ChatCompletion("Summary the story");
}
Test();
``` ```
### Alpaca example Another sample in [chat.sh](chat.sh).
Requires [bash](https://www.gnu.org/software/bash/), [curl](https://curl.se) and [jq](https://jqlang.github.io/jq/).
Run with bash:
**Temporaly note:** no tested, if you have the model, please test it and report me some issue ```sh
bash chat.sh
```javascript
const axios = require("axios");
let prompt = `Below is an instruction that describes a task. Write a response that appropriately completes the request.
`;
async function DoInstruction(instruction) {
prompt += `\n\n### Instruction:\n\n${instruction}\n\n### Response:\n\n`;
result = await axios.post("http://127.0.0.1:8080/completion", {
prompt,
batch_size: 128,
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: -1,
n_predict: 2048,
stop: ["### Instruction:\n\n"], // when detect this, stop completion
exclude: [], // no show in the completion
threads: 8,
as_loop: true, // use this to request the completion token by token
interactive: true, // enable the detection of a stop word
});
// create a loop to receive every token predicted
// note: this operation is blocking, avoid use this in a ui thread
let message = "";
while (true) {
result = await axios.get("http://127.0.0.1:8080/next-token");
process.stdout.write(result.data.content);
message += result.data.content;
// to avoid an infinite loop
if (result.data.stop) {
console.log("Completed");
// make sure to add the completion and the user's next question to the prompt.
prompt += message;
break;
}
}
}
// This function should be called every time a instruction to the model is needed.
DoInstruction("Destroy the world"); // as joke
``` ```
### Embeddings
First, run the server with `--embedding` option:
```bash
server -m models/7B/ggml-model.bin --ctx_size 2048 --embedding
```
Run this code in NodeJS:
```javascript
const axios = require('axios');
async function Test() {
let result = await axios.post("http://127.0.0.1:8080/embedding", {
content: `Hello`,
threads: 5
});
// print the embedding array
console.log(result.data.embedding);
}
Test();
```
### Tokenize
Run this code in NodeJS:
```javascript
const axios = require('axios');
async function Test() {
let result = await axios.post("http://127.0.0.1:8080/tokenize", {
content: `Hello`
});
// print the embedding array
console.log(result.data.tokens);
}
Test();
```
## Common Options
- `-m FNAME, --model FNAME`: Specify the path to the LLaMA model file (e.g., `models/7B/ggml-model.bin`).
- `-c N, --ctx-size N`: Set the size of the prompt context. The default is 512, but LLaMA models were built with a context of 2048, which will provide better results for longer input/inference.
- `-ngl N, --n-gpu-layers N`: When compiled with appropriate support (currently CLBlast or cuBLAS), this option allows offloading some layers to the GPU for computation. Generally results in increased performance.
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
- `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS.
- `--embedding`: Enable the embedding mode. **Completion function doesn't work in this mode**.
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`;
- `--port`: Set the port to listen. Default: `8080`.
### RNG Seed
- `-s SEED, --seed SEED`: Set the random number generator (RNG) seed (default: -1, < 0 = random seed).
The RNG seed is used to initialize the random number generator that influences the text generation process. By setting a specific seed value, you can obtain consistent and reproducible results across multiple runs with the same input and settings. This can be helpful for testing, debugging, or comparing the effects of different options on the generated text to see when they diverge. If the seed is set to a value less than 0, a random seed will be used, which will result in different outputs on each run.
## Performance Tuning and Memory Options
### No Memory Mapping
- `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed. However, if the model is larger than your total amount of RAM or if your system is low on available memory, using mmap might increase the risk of pageouts, negatively impacting performance.
### Memory Float 32
- `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. This doubles the context memory requirement but does not appear to increase generation quality in a measurable way. Not recommended.
## Limitations:
- The actual implementation of llama.cpp need a `llama-state` for handle multiple contexts and clients, but this could require more powerful hardware.

89
examples/server/chat.mjs Normal file
View file

@ -0,0 +1,89 @@
import * as readline from 'node:readline'
import { stdin, stdout } from 'node:process'
const API_URL = 'http://127.0.0.1:8080'
const chat = [
{
human: "Hello, Assistant.",
assistant: "Hello. How may I help you today?"
},
{
human: "Please tell me the largest city in Europe.",
assistant: "Sure. The largest city in Europe is Moscow, the capital of Russia."
},
]
const instruction = `A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.`
function format_prompt(question) {
return `${instruction}\n${
chat.map(m =>`### Human: ${m.human}\n### Assistant: ${m.assistant}`).join("\n")
}\n### Human: ${question}\n### Assistant:`
}
async function tokenize(content) {
const result = await fetch(`${API_URL}/tokenize`, {
method: 'POST',
body: JSON.stringify({ content })
})
if (!result.ok) {
return []
}
return await result.json().tokens
}
const n_keep = await tokenize(instruction).length
async function chat_completion(question) {
const result = await fetch(`${API_URL}/completion`, {
method: 'POST',
body: JSON.stringify({
prompt: format_prompt(question),
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: n_keep,
n_predict: 256,
stop: ["\n### Human:"], // stop completion after generating this
stream: true,
})
})
if (!result.ok) {
return
}
let answer = ''
for await (var chunk of result.body) {
const t = Buffer.from(chunk).toString('utf8')
if (t.startsWith('data: ')) {
const message = JSON.parse(t.substring(6))
answer += message.content
process.stdout.write(message.content)
if (message.stop) {
if (message.truncated) {
chat.shift()
}
break
}
}
}
process.stdout.write('\n')
chat.push({ human: question, assistant: answer.trimStart() })
}
const rl = readline.createInterface({ input: stdin, output: stdout });
const readlineQuestion = (rl, query, options) => new Promise((resolve, reject) => {
rl.question(query, options, resolve)
});
while(true) {
const question = await readlineQuestion(rl, '> ')
await chat_completion(question)
}

77
examples/server/chat.sh Normal file
View file

@ -0,0 +1,77 @@
#!/bin/bash
API_URL="${API_URL:-http://127.0.0.1:8080}"
CHAT=(
"Hello, Assistant."
"Hello. How may I help you today?"
"Please tell me the largest city in Europe."
"Sure. The largest city in Europe is Moscow, the capital of Russia."
)
INSTRUCTION="A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions."
trim() {
shopt -s extglob
set -- "${1##+([[:space:]])}"
printf "%s" "${1%%+([[:space:]])}"
}
trim_trailing() {
shopt -s extglob
printf "%s" "${1%%+([[:space:]])}"
}
format_prompt() {
echo -n "${INSTRUCTION}"
printf "\n### Human: %s\n### Assistant: %s" "${CHAT[@]}" "$1"
}
tokenize() {
curl \
--silent \
--request POST \
--url "${API_URL}/tokenize" \
--data-raw "$(jq -ns --arg content "$1" '{content:$content}')" \
| jq '.tokens[]'
}
N_KEEP=$(tokenize "${INSTRUCTION}" | wc -l)
chat_completion() {
PROMPT="$(trim_trailing "$(format_prompt "$1")")"
DATA="$(echo -n "$PROMPT" | jq -Rs --argjson n_keep $N_KEEP '{
prompt: .,
temperature: 0.2,
top_k: 40,
top_p: 0.9,
n_keep: $n_keep,
n_predict: 256,
stop: ["\n### Human:"],
stream: true
}')"
ANSWER=''
while IFS= read -r LINE; do
if [[ $LINE = data:* ]]; then
CONTENT="$(echo "${LINE:5}" | jq -r '.content')"
printf "%s" "${CONTENT}"
ANSWER+="${CONTENT}"
fi
done < <(curl \
--silent \
--no-buffer \
--request POST \
--url "${API_URL}/completion" \
--data-raw "${DATA}")
printf "\n"
CHAT+=("$1" "$(trim "$ANSWER")")
}
while true; do
read -r -e -p "> " QUESTION
chat_completion "${QUESTION}"
done

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,7 @@
set(TARGET simple)
add_executable(${TARGET} simple.cpp)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_11)
if(TARGET BUILD_INFO)
add_dependencies(${TARGET} BUILD_INFO)
endif()

177
examples/simple/simple.cpp Normal file
View file

@ -0,0 +1,177 @@
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "common.h"
#include "llama.h"
#include "build-info.h"
#include <cassert>
#include <cinttypes>
#include <cmath>
#include <cstdio>
#include <cstring>
#include <ctime>
#include <fstream>
#include <iostream>
#include <string>
#include <vector>
#if defined (__unix__) || (defined (__APPLE__) && defined (__MACH__))
#include <signal.h>
#include <unistd.h>
#elif defined (_WIN32)
#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <windows.h>
#include <signal.h>
#endif
int main(int argc, char ** argv)
{
gpt_params params;
//---------------------------------
// Print help :
//---------------------------------
if ( argc == 1 || argv[1][0] == '-' )
{
printf( "usage: %s MODEL_PATH [PROMPT]\n" , argv[0] );
return 1 ;
}
//---------------------------------
// Load parameters :
//---------------------------------
if ( argc >= 2 )
{
params.model = argv[1];
}
if ( argc >= 3 )
{
params.prompt = argv[2];
}
if ( params.prompt.empty() )
{
params.prompt = "Hello my name is";
}
//---------------------------------
// Init LLM :
//---------------------------------
llama_init_backend();
llama_context * ctx ;
ctx = llama_init_from_gpt_params( params );
if ( ctx == NULL )
{
fprintf( stderr , "%s: error: unable to load model\n" , __func__ );
return 1;
}
//---------------------------------
// Tokenize the prompt :
//---------------------------------
std::vector<llama_token> tokens_list;
tokens_list = ::llama_tokenize( ctx , params.prompt , true );
const int max_context_size = llama_n_ctx( ctx );
const int max_tokens_list_size = max_context_size - 4 ;
if ( (int)tokens_list.size() > max_tokens_list_size )
{
fprintf( stderr , "%s: error: prompt too long (%d tokens, max %d)\n" ,
__func__ , (int)tokens_list.size() , max_tokens_list_size );
return 1;
}
fprintf( stderr, "\n\n" );
// Print the tokens from the prompt :
for( auto id : tokens_list )
{
printf( "%s" , llama_token_to_str( ctx , id ) );
}
fflush(stdout);
//---------------------------------
// Main prediction loop :
//---------------------------------
// The LLM keeps a contextual cache memory of previous token evaluation.
// Usually, once this cache is full, it is required to recompute a compressed context based on previous
// tokens (see "infinite text generation via context swapping" in the main example), but in this minimalist
// example, we will just stop the loop once this cache is full or once an end of stream is detected.
while ( llama_get_kv_cache_token_count( ctx ) < max_context_size )
{
//---------------------------------
// Evaluate the tokens :
//---------------------------------
if ( llama_eval( ctx , tokens_list.data() , tokens_list.size() , llama_get_kv_cache_token_count( ctx ) , params.n_threads ) )
{
fprintf( stderr, "%s : failed to eval\n" , __func__ );
return 1;
}
tokens_list.clear();
//---------------------------------
// Select the best prediction :
//---------------------------------
llama_token new_token_id = 0;
auto logits = llama_get_logits( ctx );
auto n_vocab = llama_n_vocab( ctx ); // the size of the LLM vocabulary (in tokens)
std::vector<llama_token_data> candidates;
candidates.reserve( n_vocab );
for( llama_token token_id = 0 ; token_id < n_vocab ; token_id++ )
{
candidates.emplace_back( llama_token_data{ token_id , logits[ token_id ] , 0.0f } );
}
llama_token_data_array candidates_p = { candidates.data(), candidates.size(), false };
// Select it using the "Greedy sampling" method :
new_token_id = llama_sample_token_greedy( ctx , &candidates_p );
// is it an end of stream ?
if ( new_token_id == llama_token_eos() )
{
fprintf(stderr, " [end of text]\n");
break;
}
// Print the new token :
printf( "%s" , llama_token_to_str( ctx , new_token_id ) );
fflush( stdout );
// Push this new token for next evaluation :
tokens_list.push_back( new_token_id );
} // wend of main loop
llama_free( ctx );
return 0;
}
// EOF

View file

@ -4,7 +4,7 @@ Basic usage instructions:
```bash ```bash
# get training data # get training data
wget https://github.com/brunoklein99/deep-learning-notes/blob/master/shakespeare.txt wget https://raw.githubusercontent.com/brunoklein99/deep-learning-notes/master/shakespeare.txt
# train # train
./bin/train-text-from-scratch \ ./bin/train-text-from-scratch \

View file

@ -12,6 +12,9 @@
#include <algorithm> #include <algorithm>
#include <string> #include <string>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
struct random_normal_distribution { struct random_normal_distribution {
std::mt19937 gen; std::mt19937 gen;
@ -20,7 +23,6 @@ struct random_normal_distribution {
float max; float max;
}; };
struct random_uniform_distribution { struct random_uniform_distribution {
std::mt19937 gen; std::mt19937 gen;
std::uniform_real_distribution<float> rd; std::uniform_real_distribution<float> rd;
@ -2366,7 +2368,7 @@ void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
file->write_u32(0); file->write_u32(0);
file->write_u32(0); file->write_u32(0);
file->write_u32(GGML_TYPE_F32); file->write_u32(GGML_TYPE_F32);
file->seek(-file->tell() & 31, SEEK_CUR); file->seek(0-file->tell() & 31, SEEK_CUR);
return; return;
} }
const char * name = ggml_get_name(tensor); const char * name = ggml_get_name(tensor);
@ -2381,7 +2383,7 @@ void write_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
file->write_u32(tensor->type); file->write_u32(tensor->type);
file->write_raw(ne, sizeof(ne[0]) * nd); file->write_raw(ne, sizeof(ne[0]) * nd);
file->write_raw(name, name_len); file->write_raw(name, name_len);
file->seek(-file->tell() & 31, SEEK_CUR); file->seek(0-file->tell() & 31, SEEK_CUR);
file->write_raw(tensor->data, ggml_nbytes(tensor)); file->write_raw(tensor->data, ggml_nbytes(tensor));
} }
@ -2402,7 +2404,7 @@ void read_tensor(struct llama_file * file, struct ggml_tensor * tensor) {
std::string name = file->read_string(name_len); std::string name = file->read_string(name_len);
GGML_ASSERT(strncmp(ggml_get_name(tensor), name.c_str(), sizeof(tensor->name)-1) == 0); GGML_ASSERT(strncmp(ggml_get_name(tensor), name.c_str(), sizeof(tensor->name)-1) == 0);
file->seek(-file->tell() & 31, SEEK_CUR); file->seek(0-file->tell() & 31, SEEK_CUR);
file->read_raw(tensor->data, ggml_nbytes(tensor)); file->read_raw(tensor->data, ggml_nbytes(tensor));
} }
@ -2756,8 +2758,8 @@ struct train_params get_default_train_params() {
params.lbfgs_n_iter = 16; params.lbfgs_n_iter = 16;
params.adam_n_iter = 16; params.adam_n_iter = 16;
params.adam_alpha = 1e-3; params.adam_alpha = 1e-3f;
params.adam_decay = 1e-3; params.adam_decay = 1e-3f;
params.mem_model_gb = 2; params.mem_model_gb = 2;
params.mem_compute_gb = 24; params.mem_compute_gb = 24;
@ -3331,8 +3333,8 @@ int main(int argc, char ** argv) {
int n_gen = params.n_predict; int n_gen = params.n_predict;
int sample_ctx = n_tokens - n_tokens/8; int sample_ctx = n_tokens - n_tokens/8;
sampler.params.temp = 0.2; sampler.params.temp = 0.2f;
sampler.params.repeat_penalty = 1.1; sampler.params.repeat_penalty = 1.1f;
sampler.params.mirostat = 2; sampler.params.mirostat = 2;
init_sampler(&sampler, lctx); init_sampler(&sampler, lctx);

View file

@ -48,6 +48,19 @@
''; '';
meta.mainProgram = "llama"; meta.mainProgram = "llama";
}; };
apps.llama-server = {
type = "app";
program = "${self.packages.${system}.default}/bin/llama-server";
};
apps.llama-embedding = {
type = "app";
program = "${self.packages.${system}.default}/bin/embedding";
};
apps.llama = {
type = "app";
program = "${self.packages.${system}.default}/bin/llama";
};
apps.default = self.apps.${system}.llama;
devShells.default = pkgs.mkShell { devShells.default = pkgs.mkShell {
packages = with pkgs; [ packages = with pkgs; [
cmake cmake

View file

@ -80,7 +80,7 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
} \ } \
} while (0) } while (0)
#if CUDART_VERSION >= 12 #if CUDART_VERSION >= 12000
#define CUBLAS_CHECK(err) \ #define CUBLAS_CHECK(err) \
do { \ do { \
cublasStatus_t err_ = (err); \ cublasStatus_t err_ = (err); \
@ -222,6 +222,12 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define GGML_CUDA_DMMV_Y 1 #define GGML_CUDA_DMMV_Y 1
#endif #endif
#ifndef K_QUANTS_PER_ITERATION
#define K_QUANTS_PER_ITERATION 2
#else
static_assert(K_QUANTS_PER_ITERATION == 1 || K_QUANTS_PER_ITERATION == 2, "K_QUANTS_PER_ITERATION must be 1 or 2");
#endif
static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) { static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -381,37 +387,6 @@ static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
} }
static __device__ void vec_dot_q2_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
const block_q2_K * x = (const block_q2_K *) vx;
// if n is 0, we want to do the lower 128, else the upper 128,
// covering y[l+0], y[l+32], y[l+64], y[l+96] and
// y[l+16], y[l+48], y[l+80], y[l+112]
int n = iqs/128; // 0 or 1
int r = iqs - 128*n; // 0...120 in steps of 8
int l = r/8; // 0...15 in steps of 1
const float * y = yy + 128*n + l;
const uint8_t * q = x[ib].qs + 32*n + l;
const uint8_t * s = x[ib].scales + 8*n;
const float dall = x[ib].d;
const float dmin = x[ib].dmin;
float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
+ y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
+ y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
+ y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
+ y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
+ y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
+ y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
+ y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
result = sum;
}
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) { static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
int r = threadIdx.x/4; int r = threadIdx.x/4;
@ -443,51 +418,6 @@ static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
} }
static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
const block_q3_K * x = (const block_q3_K *) vx;
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
uint32_t aux[3];
uint32_t utmp[4];
// if n is 0, we want to do the lower 128, else the upper 128,
// covering y[l+0], y[l+32], y[l+64], y[l+96] and
// y[l+16], y[l+48], y[l+80], y[l+112]
int n = iqs/128; // 0 or 1
int r = iqs - 128*n; // 0...120 in steps of 8
int l = r/8; // 0...15 in steps of 1
const float * y = yy + 128*n + l;
const uint8_t * q = x[ib].qs + 32*n + l;
const uint8_t * hm = x[ib].hmask + l;
const int8_t * s = (const int8_t *)utmp + 8*n;
memcpy(aux, x[ib].scales, 12);
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
const float dall = x[ib].d;
const uint8_t m = 1 << (4*n);
float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
+ y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
+ y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
+ y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
+ y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
+ y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
+ y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
+ y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
result = sum * dall;
}
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) { static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
if (j < 4) { if (j < 4) {
d = q[j] & 63; m = q[j + 4] & 63; d = q[j] & 63; m = q[j + 4] & 63;
@ -534,38 +464,6 @@ static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
} }
} }
static __device__ void vec_dot_q4_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
const block_q4_K * x = (const block_q4_K *) vx;
// iqs is in 0...248 in steps of 8 =>
const int j = iqs / 64; // j is in 0...3
const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
const int is = 2*j; // is is in 0...6 in steps of 2
const float * y = yy + 64*j + ir;
const uint8_t * q = x[ib].qs + 32*j + ir;
const float dall = x[ib].d;
const float dmin = x[ib].dmin;
uint8_t sc, m;
get_scale_min_k4(is + 0, x[ib].scales, sc, m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[ib].scales, sc, m);
const float d2 = dall * sc;
const float m2 = dmin * m;
float sum = 0;
for (int k = 0; k < 4; ++k) {
sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
}
result = sum;
}
static __global__ void dequantize_block_q5_K(const void * vx, float * yy) { static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
const block_q5_K * x = (const block_q5_K *) vx; const block_q5_K * x = (const block_q5_K *) vx;
@ -599,43 +497,6 @@ static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2; y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
} }
static __device__ void vec_dot_q5_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
const block_q5_K * x = (const block_q5_K *) vx;
// iqs is in 0...248 in steps of 8 =>
const int j = iqs / 64; // j is in 0...3
const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
const int is = 2*j; // is is in 0...6 in steps of 2
const float * y = yy + 64*j + ir;
const uint8_t * ql = x[ib].qs + 32*j + ir;
const uint8_t * qh = x[ib].qh + ir;
const float dall = x[ib].d;
const float dmin = x[ib].dmin;
uint8_t sc, m;
get_scale_min_k4(is + 0, x[ib].scales, sc, m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[ib].scales, sc, m);
const float d2 = dall * sc;
const float m2 = dmin * m;
uint8_t hm = 1 << is;
float sum = 0;
for (int k = 0; k < 4; ++k) {
sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
}
hm <<= 1;
for (int k = 0; k < 4; ++k) {
sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2);
}
result = sum;
}
static __global__ void dequantize_block_q6_K(const void * vx, float * yy) { static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
const block_q6_K * x = (const block_q6_K *) vx; const block_q6_K * x = (const block_q6_K *) vx;
@ -661,31 +522,376 @@ static __global__ void dequantize_block_q6_K(const void * vx, float * yy) {
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32); y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
} }
static __device__ void vec_dot_q6_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) { static __global__ void dequantize_mul_mat_vec_q2_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
const block_q6_K * x = (const block_q6_K *) vx; static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
const int ip = iqs / 128; // 0 or 1 const int row = blockIdx.y*blockDim.y + threadIdx.y;
const int il = (iqs - 128*ip)/8; // 0...15 if (row > nrows) return;
const int is = 8*ip;
const float * y = yy + 128*ip + il; const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const float d = x[ib].d; const block_q2_K * x = (const block_q2_K *)vx + ib0;
const uint8_t * ql = x[ib].ql + 64*ip + il; const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31
const uint8_t * qh = x[ib].qh + 32*ip + il; const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0
const int8_t * sc = x[ib].scales + is;
result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32) const int step = 16/K_QUANTS_PER_ITERATION;
+ y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
+ y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
+ y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
+ y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
+ y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
+ y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
+ y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...7
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...14 in steps of 4
const int q_offset = 32*im + l0;
const int s_offset = 8*im;
const int y_offset = 128*im + l0;
float tmp = 0; // partial sum for thread in warp
uint32_t aux[4];
const uint8_t * d = (const uint8_t *)aux;
const uint8_t * m = (const uint8_t *)(aux + 2);
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;
const float dall = x[i].d;
const float dmin = x[i].dmin;
const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
aux[0] = a[0] & 0x0f0f0f0f;
aux[1] = a[1] & 0x0f0f0f0f;
aux[2] = (a[0] >> 4) & 0x0f0f0f0f;
aux[3] = (a[1] >> 4) & 0x0f0f0f0f;
float sum1 = 0, sum2 = 0;
for (int l = 0; l < K_QUANTS_PER_ITERATION; ++l) {
sum1 += y[l+ 0] * d[0] * ((q[l+ 0] >> 0) & 3)
+ y[l+32] * d[2] * ((q[l+ 0] >> 2) & 3)
+ y[l+64] * d[4] * ((q[l+ 0] >> 4) & 3)
+ y[l+96] * d[6] * ((q[l+ 0] >> 6) & 3)
+ y[l+16] * d[1] * ((q[l+16] >> 0) & 3)
+ y[l+48] * d[3] * ((q[l+16] >> 2) & 3)
+ y[l+80] * d[5] * ((q[l+16] >> 4) & 3)
+y[l+112] * d[7] * ((q[l+16] >> 6) & 3);
sum2 += y[l+ 0] * m[0] + y[l+32] * m[2] + y[l+64] * m[4] + y[ l+96] * m[6]
+ y[l+16] * m[1] + y[l+48] * m[3] + y[l+80] * m[5] + y[l+112] * m[7];
}
tmp += dall * sum1 - dmin * sum2;
}
// sum up partial sums and write back result
__syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (tid == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q3_k(const void * vx, const float * yy, float * dst, const int ncols) {
const uint16_t kmask1 = 0x0303;
const uint16_t kmask2 = 0x0f0f;
const int row = blockIdx.x;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q3_K * x = (const block_q3_K *)vx + ib0;
const int tid = threadIdx.x/2; // 0...15
const int ix = threadIdx.x%2; // 0, 1
const int n = 2; // iterations in the inner loop
const int im = tid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - 8*im; // 0...7
const uint8_t m = 1 << (4*im);
const int l0 = n*in; // 0...28 in steps of 4
const int q_offset = 32*im + l0;
const int y_offset = 128*im + l0;
uint16_t utmp[4];
const int8_t * s = (const int8_t *)utmp;
const uint16_t s_shift = 4*im;
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += 2) {
const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset;
const uint8_t * h = x[i].hmask + l0;
const uint16_t * a = (const uint16_t *)x[i].scales;
utmp[0] = ((a[0] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 0)) & kmask1) << 4);
utmp[1] = ((a[1] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 0)) & kmask1) << 4);
utmp[2] = ((a[2] >> s_shift) & kmask2) | (((a[4] >> (s_shift + 2)) & kmask1) << 4);
utmp[3] = ((a[3] >> s_shift) & kmask2) | (((a[5] >> (s_shift + 2)) & kmask1) << 4);
const float d = x[i].d;
float sum = 0;
for (int l = 0; l < n; ++l) {
sum += y[l+ 0] * (s[0] - 32) * (((q[l] >> 0) & 3) - (h[l] & (m << 0) ? 0 : 4))
+ y[l+32] * (s[2] - 32) * (((q[l] >> 2) & 3) - (h[l] & (m << 1) ? 0 : 4))
+ y[l+64] * (s[4] - 32) * (((q[l] >> 4) & 3) - (h[l] & (m << 2) ? 0 : 4))
+ y[l+96] * (s[6] - 32) * (((q[l] >> 6) & 3) - (h[l] & (m << 3) ? 0 : 4));
sum += y[l+16] * (s[1] - 32) * (((q[l+16] >> 0) & 3) - (h[l+16] & (m << 0) ? 0 : 4))
+ y[l+48] * (s[3] - 32) * (((q[l+16] >> 2) & 3) - (h[l+16] & (m << 1) ? 0 : 4))
+ y[l+80] * (s[5] - 32) * (((q[l+16] >> 4) & 3) - (h[l+16] & (m << 2) ? 0 : 4))
+ y[l+112] * (s[7] - 32) * (((q[l+16] >> 6) & 3) - (h[l+16] & (m << 3) ? 0 : 4));
}
tmp += d * sum;
}
// sum up partial sums and write back result
__syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (tid == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q4_k(const void * vx, const float * yy, float * dst, const int ncols) {
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
const int row = blockIdx.x;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const int tid = threadIdx.x/2; // 0...15
const int ix = threadIdx.x%2;
const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3
const int n = 4;
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;
const int l0 = n*(2*ir + in);
const int q_offset = 32*im + l0;
const int y_offset = 64*im + l0;
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
const block_q4_K * x = (const block_q4_K *)vx + ib0;
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += 2) {
const uint8_t * q1 = x[i].qs + q_offset;
const uint8_t * q2 = q1 + 64;
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
const float dall = x[i].d;
const float dmin = x[i].dmin;
const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1;
aux[1] = a[im+2] & kmask1;
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
float4 s = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < n; ++l) {
s.x += y1[l] * (q1[l] & 0xF); s.y += y1[l+32] * (q1[l] >> 4);
s.z += y2[l] * (q2[l] & 0xF); s.w += y2[l+32] * (q2[l] >> 4);
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
tmp += dall * (s.x * sc[0] + s.y * sc[1] + s.z * sc[4] + s.w * sc[5]) - dmin * smin;
}
// sum up partial sums and write back result
__syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (tid == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q5_k(const void * vx, const float * yy, float * dst, const int ncols) {
const uint16_t kmask1 = 0x3f3f;
const uint16_t kmask2 = 0x0f0f;
const uint16_t kmask3 = 0xc0c0;
//const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int row = blockIdx.x;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const int tid = threadIdx.x/2; // 0...15
const int ix = threadIdx.x%2;
const int il = tid/4; // 0...3
const int ir = tid - 4*il;// 0...3
const int n = 4;
const int im = il/2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const int in = il%2;
const int l0 = n*(2*ir + in);
const int q_offset = 32*im + l0;
const int y_offset = 64*im + l0;
const uint8_t hm1 = 1 << (2*im);
const uint8_t hm2 = hm1 << 4;
uint16_t aux[4];
const uint8_t * sc = (const uint8_t *)aux;
const block_q5_K * x = (const block_q5_K *)vx + ib0;
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += 2) {
const uint8_t * ql1 = x[i].qs + q_offset;
const uint8_t * ql2 = ql1 + 64;
const uint8_t * qh = x[i].qh + l0;
const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128;
const float dall = x[i].d;
const float dmin = x[i].dmin;
const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1;
aux[1] = a[im+2] & kmask1;
aux[2] = ((a[im+4] >> 0) & kmask2) | ((a[im+0] & kmask3) >> 2);
aux[3] = ((a[im+4] >> 4) & kmask2) | ((a[im+2] & kmask3) >> 2);
float4 sum = {0.f, 0.f, 0.f, 0.f};
float smin = 0;
for (int l = 0; l < n; ++l) {
sum.x += y1[l+ 0] * ((ql1[l] & 0xF) + (qh[l] & (hm1 << 0) ? 16 : 0));
sum.y += y1[l+32] * ((ql1[l] >> 4) + (qh[l] & (hm1 << 1) ? 16 : 0));
sum.z += y2[l+ 0] * ((ql2[l] & 0xF) + (qh[l] & (hm2 << 0) ? 16 : 0));
sum.w += y2[l+32] * ((ql2[l] >> 4) + (qh[l] & (hm2 << 1) ? 16 : 0));
smin += y1[l] * sc[2] + y1[l+32] * sc[3] + y2[l] * sc[6] + y2[l+32] * sc[7];
}
tmp += dall * (sum.x * sc[0] + sum.y * sc[1] + sum.z * sc[4] + sum.w * sc[5]) - dmin * smin;
}
// sum up partial sums and write back result
__syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (tid == 0) {
dst[row] = tmp;
}
}
static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const float * yy, float * dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q6_K * x = (const block_q6_K *)vx + ib0;
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7
#if K_QUANTS_PER_ITERATION == 1
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
const int is = 0;
#else
const int l0 = 4 * in; // 0, 4, 8, ..., 28
const int is = in / 4;
#endif
const int ql_offset = 64*im + l0;
const int qh_offset = 32*im + l0;
const int s_offset = 8*im + is;
const int y_offset = 128*im + l0;
float tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const float * y = yy + i * QK_K + y_offset;
const uint8_t * ql = x[i].ql + ql_offset;
const uint8_t * qh = x[i].qh + qh_offset;
const int8_t * s = x[i].scales + s_offset;
const float d = x[i].d;
#if K_QUANTS_PER_ITERATION == 1
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
tmp += sum;
#else
float sum = 0;
for (int l = 0; l < 4; ++l) {
sum += y[l+ 0] * s[0] * d * ((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+ y[l+32] * s[2] * d * ((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+ y[l+64] * s[4] * d * ((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+ y[l+96] * s[6] * d * ((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
}
tmp += sum;
#endif
}
// sum up partial sums and write back result
__syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (tid == 0) {
dst[row] = tmp;
}
} }
static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){ static __device__ void convert_f16(const void * vx, const int ib, const int iqs, float & v0, float & v1){
@ -767,46 +973,6 @@ static __global__ void dequantize_mul_mat_vec(const void * vx, const float * y,
} }
} }
template <int n_thread, dot_kernel_k_t dot_kernel>
static __global__ void dequantize_mul_mat_vec_k(const void * vx, const float * y, float * dst, const int ncols, const int nrows) {
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row >= nrows) {
return;
}
const int tid = threadIdx.x;
const int iter_stride = QK_K;
const int vals_per_iter = iter_stride / n_thread;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
float tmp = 0; // partial sum for thread in warp
for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid;
const int ib = ib0 + col/QK_K; // x block index
const int iqs = col%QK_K; // x quant index
const int iybs = col - col%QK_K; // y block start index
float v;
dot_kernel(vx, ib, iqs, y + iybs, v);
tmp += v;
}
// sum up partial sums and write back result
__syncthreads();
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (tid == 0) {
dst[row] = tmp;
}
}
static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) { static __global__ void mul_mat_p021_f16_f32(const void * vx, const float * y, float * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
const half * x = (half *) vx; const half * x = (half *) vx;
@ -1149,43 +1315,34 @@ static void dequantize_mul_mat_vec_q2_K_cuda(const void * vx, const float * y, f
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1); const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_k<32, vec_dot_q2_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); dequantize_mul_mat_vec_q2_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q3_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2; const dim3 block_dims(32, 1, 1);
const int block_num_y = (nrows + ny - 1) / ny; dequantize_mul_mat_vec_q3_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_k<32, vec_dot_q3_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q4_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2; const dim3 block_dims(32, 1, 1);
const int block_num_y = (nrows + ny - 1) / ny; dequantize_mul_mat_vec_q4_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_k<32, vec_dot_q4_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q5_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2; const dim3 block_dims(32, 1, 1);
const int block_num_y = (nrows + ny - 1) / ny; dequantize_mul_mat_vec_q5_k<<<nrows, block_dims, 0, stream>>>(vx, y, dst, ncols);
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_k<32, vec_dot_q5_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0); GGML_ASSERT(ncols % QK_K == 0);
const int ny = 2; const int ny = 2 / K_QUANTS_PER_ITERATION;
const int block_num_y = (nrows + ny - 1) / ny; const int block_num_y = (nrows + ny - 1) / ny;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(32, ny, 1); const dim3 block_dims(32, ny, 1);
dequantize_mul_mat_vec_k<32, vec_dot_q6_K><<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows); dequantize_mul_mat_vec_q6_k<<<block_nums, block_dims, 0, stream>>>(vx, y, dst, ncols, nrows);
} }
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
@ -2421,7 +2578,7 @@ void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor) {
} }
void ggml_cuda_set_main_device(int main_device) { void ggml_cuda_set_main_device(int main_device) {
if (main_device > g_device_count) { if (main_device >= g_device_count) {
fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n", fprintf(stderr, "warning: cannot set main_device=%d because there are only %d devices. Using device %d instead.\n",
main_device, g_device_count, g_main_device); main_device, g_device_count, g_main_device);
return; return;

View file

@ -55,6 +55,7 @@ void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor *
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t); void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
// same as ggml_graph_compute but uses Metal // same as ggml_graph_compute but uses Metal
// creates gf->n_threads command buffers in parallel
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf); void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
#ifdef __cplusplus #ifdef __cplusplus

View file

@ -284,528 +284,551 @@ void ggml_metal_get_tensor(
void ggml_metal_graph_compute( void ggml_metal_graph_compute(
struct ggml_metal_context * ctx, struct ggml_metal_context * ctx,
struct ggml_cgraph * gf) { struct ggml_cgraph * gf) {
metal_printf("%s: evaluating graph\n", __func__); metal_printf("%s: evaluating graph\n", __func__);
size_t offs_src0 = 0; // create multiple command buffers and enqueue them
size_t offs_src1 = 0; // then, we encode the graph into the command buffers in parallel
size_t offs_dst = 0;
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBuffer]; const int n_cb = gf->n_threads;
id<MTLComputeCommandEncoder> encoder = nil;
for (int i = 0; i < gf->n_nodes; ++i) { NSMutableArray * command_buffers = [NSMutableArray arrayWithCapacity:n_cb];
//metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
struct ggml_tensor * src0 = gf->nodes[i]->src0; for (int i = 0; i < n_cb; ++i) {
struct ggml_tensor * src1 = gf->nodes[i]->src1; command_buffers[i] = [ctx->queue commandBuffer];
struct ggml_tensor * dst = gf->nodes[i];
const int64_t ne00 = src0 ? src0->ne[0] : 0; // enqueue the command buffers in order to specify their execution order
const int64_t ne01 = src0 ? src0->ne[1] : 0; [command_buffers[i] enqueue];
const int64_t ne02 = src0 ? src0->ne[2] : 0; }
const int64_t ne03 = src0 ? src0->ne[3] : 0;
const uint64_t nb00 = src0 ? src0->nb[0] : 0; // TODO: is this the best way to start threads?
const uint64_t nb01 = src0 ? src0->nb[1] : 0; dispatch_queue_t queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
const uint64_t nb02 = src0 ? src0->nb[2] : 0;
const uint64_t nb03 = src0 ? src0->nb[3] : 0;
const int64_t ne10 = src1 ? src1->ne[0] : 0; for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
const int64_t ne11 = src1 ? src1->ne[1] : 0; const int n_nodes_per_cb = (gf->n_nodes + n_cb - 1) / n_cb;
const int64_t ne12 = src1 ? src1->ne[2] : 0;
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
const uint64_t nb10 = src1 ? src1->nb[0] : 0; dispatch_async(queue, ^{
const uint64_t nb11 = src1 ? src1->nb[1] : 0; size_t offs_src0 = 0;
const uint64_t nb12 = src1 ? src1->nb[2] : 0; size_t offs_src1 = 0;
const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13); size_t offs_dst = 0;
const int64_t ne0 = dst ? dst->ne[0] : 0; id<MTLCommandBuffer> command_buffer = command_buffers[cb_idx];
const int64_t ne1 = dst ? dst->ne[1] : 0;
const int64_t ne2 = dst ? dst->ne[2] : 0;
const int64_t ne3 = dst ? dst->ne[3] : 0;
const uint64_t nb0 = dst ? dst->nb[0] : 0; id<MTLComputeCommandEncoder> encoder = nil;
const uint64_t nb1 = dst ? dst->nb[1] : 0;
const uint64_t nb2 = dst ? dst->nb[2] : 0;
const uint64_t nb3 = dst ? dst->nb[3] : 0;
const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT; const int node_start = (cb_idx + 0) * n_nodes_per_cb;
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT; const int node_end = (cb_idx == n_cb - 1) ? gf->n_nodes : (cb_idx + 1) * n_nodes_per_cb;
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil; for (int i = node_start; i < node_end; ++i) {
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil; metal_printf("%s: encoding node %3d, op = %8s\n", __func__, i, ggml_op_name(gf->nodes[i]->op));
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
//metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op)); struct ggml_tensor * src0 = gf->nodes[i]->src0;
//if (src0) { struct ggml_tensor * src1 = gf->nodes[i]->src1;
// metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02, struct ggml_tensor * dst = gf->nodes[i];
// ggml_is_contiguous(src0), src0->name);
//}
//if (src1) {
// metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
// ggml_is_contiguous(src1), src1->name);
//}
//if (dst) {
// metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
// dst->name);
//}
switch (dst->op) { const int64_t ne00 = src0 ? src0->ne[0] : 0;
case GGML_OP_RESHAPE: const int64_t ne01 = src0 ? src0->ne[1] : 0;
case GGML_OP_VIEW: const int64_t ne02 = src0 ? src0->ne[2] : 0;
case GGML_OP_TRANSPOSE: const int64_t ne03 = src0 ? src0->ne[3] : 0;
case GGML_OP_PERMUTE:
{
// noop
} break;
case GGML_OP_ADD:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder setComputePipelineState:ctx->pipeline_add]; const uint64_t nb00 = src0 ? src0->nb[0] : 0;
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; const uint64_t nb01 = src0 ? src0->nb[1] : 0;
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; const uint64_t nb02 = src0 ? src0->nb[2] : 0;
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; const uint64_t nb03 = src0 ? src0->nb[3] : 0;
const int64_t n = ggml_nelements(dst); const int64_t ne10 = src1 ? src1->ne[0] : 0;
const int64_t ne11 = src1 ? src1->ne[1] : 0;
const int64_t ne12 = src1 ? src1->ne[2] : 0;
const int64_t ne13 = src1 ? src1->ne[3] : 0; UNUSED(ne13);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; const uint64_t nb10 = src1 ? src1->nb[0] : 0;
} break; const uint64_t nb11 = src1 ? src1->nb[1] : 0;
case GGML_OP_MUL: const uint64_t nb12 = src1 ? src1->nb[2] : 0;
{ const uint64_t nb13 = src1 ? src1->nb[3] : 0; UNUSED(nb13);
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
if (ggml_nelements(src1) == ne10) { const int64_t ne0 = dst ? dst->ne[0] : 0;
// src1 is a row const int64_t ne1 = dst ? dst->ne[1] : 0;
[encoder setComputePipelineState:ctx->pipeline_mul_row]; const int64_t ne2 = dst ? dst->ne[2] : 0;
} else { const int64_t ne3 = dst ? dst->ne[3] : 0;
[encoder setComputePipelineState:ctx->pipeline_mul];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
const int64_t n = ggml_nelements(dst); const uint64_t nb0 = dst ? dst->nb[0] : 0;
const uint64_t nb1 = dst ? dst->nb[1] : 0;
const uint64_t nb2 = dst ? dst->nb[2] : 0;
const uint64_t nb3 = dst ? dst->nb[3] : 0;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; const enum ggml_type src0t = src0 ? src0->type : GGML_TYPE_COUNT;
} break; const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
case GGML_OP_SCALE: const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const float scale = *(const float *) src1->data; id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
[encoder setComputePipelineState:ctx->pipeline_scale]; //metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; //if (src0) {
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; // metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
[encoder setBytes:&scale length:sizeof(scale) atIndex:2]; // ggml_is_contiguous(src0), src0->name);
//}
//if (src1) {
// metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
// ggml_is_contiguous(src1), src1->name);
//}
//if (dst) {
// metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
// dst->name);
//}
const int64_t n = ggml_nelements(dst); switch (dst->op) {
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
{
// noop
} break;
case GGML_OP_ADD:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder setComputePipelineState:ctx->pipeline_add];
} break; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
case GGML_OP_SILU: [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
{ [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder setComputePipelineState:ctx->pipeline_silu]; const int64_t n = ggml_nelements(dst);
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst); [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_MUL:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; if (ggml_nelements(src1) == ne10) {
} break; // src1 is a row
case GGML_OP_RELU: [encoder setComputePipelineState:ctx->pipeline_mul_row];
{ } else {
if (encoder == nil) { [encoder setComputePipelineState:ctx->pipeline_mul];
encoder = [command_buffer computeCommandEncoder]; }
} [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setComputePipelineState:ctx->pipeline_relu]; const int64_t n = ggml_nelements(dst);
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst); [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_SCALE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; const float scale = *(const float *) src1->data;
} break;
case GGML_OP_GELU:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
[encoder setComputePipelineState:ctx->pipeline_gelu]; [encoder setComputePipelineState:ctx->pipeline_scale];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2];
const int64_t n = ggml_nelements(dst); const int64_t n = ggml_nelements(dst);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_SOFT_MAX: case GGML_OP_SILU:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];
} }
const int nth = 32; [encoder setComputePipelineState:ctx->pipeline_silu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setComputePipelineState:ctx->pipeline_soft_max]; const int64_t n = ggml_nelements(dst);
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_DIAG_MASK_INF: case GGML_OP_RELU:
{ {
if (encoder == nil) { if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder]; encoder = [command_buffer computeCommandEncoder];
} }
const int n_past = ((int32_t *)(src1->data))[0]; [encoder setComputePipelineState:ctx->pipeline_relu];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; const int64_t n = ggml_nelements(dst);
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
case GGML_OP_MUL_MAT: case GGML_OP_GELU:
{ {
// TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224 if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
GGML_ASSERT(ne00 == ne10); [encoder setComputePipelineState:ctx->pipeline_gelu];
GGML_ASSERT(ne02 == ne12); [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
if (ggml_is_contiguous(src0) && const int64_t n = ggml_nelements(dst);
ggml_is_contiguous(src1) &&
(src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
if (encoder != nil) { [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
[encoder endEncoding]; } break;
encoder = nil; case GGML_OP_SOFT_MAX:
} {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16; const int nth = 32;
MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
// for F32 x F32 we use MPS [encoder setComputePipelineState:ctx->pipeline_soft_max];
MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt]; } break;
case GGML_OP_DIAG_MASK_INF:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
MPSMatrixDescriptor * desc = [MPSMatrixDescriptor const int n_past = ((int32_t *)(src1->data))[0];
matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc] [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
initWithDevice:ctx->device transposeLeft:false transposeRight:true [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&n_past length:sizeof(int) atIndex:4];
// we need to do ne02 multiplications [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
// TODO: is there a way to do this in parallel - currently very slow .. } break;
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS case GGML_OP_MUL_MAT:
for (int64_t i02 = 0; i02 < ne02; ++i02) { {
size_t offs_src0_cur = offs_src0 + i02*nb02; // TODO: needs to be updated after PR: https://github.com/ggerganov/ggml/pull/224
size_t offs_src1_cur = offs_src1 + i02*nb12;
size_t offs_dst_cur = offs_dst + i02*nb2;
MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0]; GGML_ASSERT(ne00 == ne10);
MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1]; GGML_ASSERT(ne02 == ne12);
MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
[mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst]; if (ggml_is_contiguous(src0) &&
} ggml_is_contiguous(src1) &&
} else { (src0t == GGML_TYPE_F32 || src0t == GGML_TYPE_F16) && ne11 > 1) {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
int nth0 = 32; if (encoder != nil) {
int nth1 = 1; [encoder endEncoding];
encoder = nil;
// use custom matrix x vector kernel
switch (src0t) {
case GGML_TYPE_F16:
{
GGML_ASSERT(ne02 == ne12);
nth0 = 64;
nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
} break;
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
} break;
case GGML_TYPE_Q4_1:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 8;
nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
} break;
case GGML_TYPE_Q2_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
} break;
case GGML_TYPE_Q3_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
} break;
case GGML_TYPE_Q4_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
} break;
case GGML_TYPE_Q5_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
} break;
case GGML_TYPE_Q6_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
} break;
default:
{
fprintf(stderr, "Asserting on type %d\n",(int)src0t);
GGML_ASSERT(false && "not implemented");
} }
};
MPSDataType src0dt = src0t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
MPSDataType src1dt = src1t == GGML_TYPE_F32 ? MPSDataTypeFloat32 : MPSDataTypeFloat16;
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; // for F32 x F32 we use MPS
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; MPSMatrixDescriptor * desc0 = [MPSMatrixDescriptor
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; matrixDescriptorWithRows:ne01 columns:ne00 rowBytes:src0->nb[1] dataType:src0dt];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
[encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
[encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
[encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
[encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
[encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
[encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
[encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) { MPSMatrixDescriptor * desc1 = [MPSMatrixDescriptor
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; matrixDescriptorWithRows:ne11 columns:ne10 rowBytes:src1->nb[1] dataType:src1dt];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
else if (src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_Q3_K ||
src0t == GGML_TYPE_Q4_K ||
src0t == GGML_TYPE_Q5_K ||
src0t == GGML_TYPE_Q6_K) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
} break;
case GGML_OP_GET_ROWS:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
switch (src0->type) { MPSMatrixDescriptor * desc = [MPSMatrixDescriptor
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break; matrixDescriptorWithRows:ne1 columns:ne0 rowBytes:dst->nb[1] dataType:MPSDataTypeFloat32];
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
default: GGML_ASSERT(false && "not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; MPSMatrixMultiplication * mul = [[MPSMatrixMultiplication alloc]
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; initWithDevice:ctx->device transposeLeft:false transposeRight:true
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; resultRows:ne11 resultColumns:ne01 interiorColumns:ne00 alpha:1.0 beta:0.0];
[encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
const int64_t n = ggml_nelements(src1); // we need to do ne02 multiplications
// TODO: is there a way to do this in parallel - currently very slow ..
// TODO: might be possible to offload part of the computation to ANE using Accelerate's CBLAS
for (int64_t i02 = 0; i02 < ne02; ++i02) {
size_t offs_src0_cur = offs_src0 + i02*nb02;
size_t offs_src1_cur = offs_src1 + i02*nb12;
size_t offs_dst_cur = offs_dst + i02*nb2;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; MPSMatrix * mat_src0 = [[MPSMatrix alloc] initWithBuffer:id_src0 offset:offs_src0_cur descriptor:desc0];
} break; MPSMatrix * mat_src1 = [[MPSMatrix alloc] initWithBuffer:id_src1 offset:offs_src1_cur descriptor:desc1];
case GGML_OP_RMS_NORM: MPSMatrix * mat_dst = [[MPSMatrix alloc] initWithBuffer:id_dst offset:offs_dst_cur descriptor:desc ];
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const float eps = 1e-6f; [mul encodeToCommandBuffer:command_buffer leftMatrix:mat_src1 rightMatrix:mat_src0 resultMatrix:mat_dst];
}
} else {
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const int nth = 256; int nth0 = 32;
int nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_rms_norm]; // use custom matrix x vector kernel
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; switch (src0t) {
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; case GGML_TYPE_F16:
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; {
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3]; GGML_ASSERT(ne02 == ne12);
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0); nth0 = 64;
nth1 = 1;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
} break;
case GGML_TYPE_Q4_0:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; nth0 = 8;
} break; nth1 = 8;
case GGML_OP_ROPE: [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_0_f32];
{ } break;
if (encoder == nil) { case GGML_TYPE_Q4_1:
encoder = [command_buffer computeCommandEncoder]; {
} GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
const int n_dims = ((int32_t *) src1->data)[1]; nth0 = 8;
const int mode = ((int32_t *) src1->data)[2]; nth1 = 8;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_1_f32];
} break;
case GGML_TYPE_Q2_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
const int n_past = ((int32_t *)(src1->data))[0]; nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q2_k_f32];
} break;
case GGML_TYPE_Q3_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
[encoder setComputePipelineState:ctx->pipeline_rope]; nth0 = 4;
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; nth1 = 16;
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setComputePipelineState:ctx->pipeline_mul_mat_q3_k_f32];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; } break;
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; case GGML_TYPE_Q4_K:
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; {
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; GGML_ASSERT(ne02 == 1);
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; GGML_ASSERT(ne12 == 1);
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; nth0 = 4;
} break; nth1 = 16;
case GGML_OP_CPY: [encoder setComputePipelineState:ctx->pipeline_mul_mat_q4_k_f32];
{ } break;
if (encoder == nil) { case GGML_TYPE_Q5_K:
encoder = [command_buffer computeCommandEncoder]; {
} GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
const int nth = 32; nth0 = 4;
nth1 = 16;
[encoder setComputePipelineState:ctx->pipeline_mul_mat_q5_k_f32];
} break;
case GGML_TYPE_Q6_K:
{
GGML_ASSERT(ne02 == 1);
GGML_ASSERT(ne12 == 1);
switch (src0t) { nth0 = 4;
case GGML_TYPE_F32: nth1 = 16;
{ [encoder setComputePipelineState:ctx->pipeline_mul_mat_q6_k_f32];
switch (dstt) { } break;
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break; default:
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break; {
default: GGML_ASSERT(false && "not implemented"); fprintf(stderr, "Asserting on type %d\n",(int)src0t);
GGML_ASSERT(false && "not implemented");
}
}; };
} break;
default: GGML_ASSERT(false && "not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5]; [encoder setBytes:&nb00 length:sizeof(nb00) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6]; [encoder setBytes:&nb01 length:sizeof(nb01) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7]; [encoder setBytes:&nb02 length:sizeof(nb02) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8]; [encoder setBytes:&ne10 length:sizeof(ne10) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9]; [encoder setBytes:&ne11 length:sizeof(ne11) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10]; [encoder setBytes:&nb10 length:sizeof(nb10) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11]; [encoder setBytes:&nb11 length:sizeof(nb11) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12]; [encoder setBytes:&nb12 length:sizeof(nb12) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13]; [encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
} break; [encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
default: [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op)); }
GGML_ASSERT(false); else if (src0t == GGML_TYPE_Q2_K ||
} src0t == GGML_TYPE_Q3_K ||
src0t == GGML_TYPE_Q4_K ||
src0t == GGML_TYPE_Q5_K ||
src0t == GGML_TYPE_Q6_K) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else {
[encoder setThreadgroupMemoryLength:nth0*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
}
} break;
case GGML_OP_GET_ROWS:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
switch (src0->type) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_get_rows_f16]; break;
case GGML_TYPE_Q4_0: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_0]; break;
case GGML_TYPE_Q4_1: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_1]; break;
case GGML_TYPE_Q2_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q2_k]; break;
case GGML_TYPE_Q3_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q3_k]; break;
case GGML_TYPE_Q4_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q4_k]; break;
case GGML_TYPE_Q5_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q5_k]; break;
case GGML_TYPE_Q6_K: [encoder setComputePipelineState:ctx->pipeline_get_rows_q6_k]; break;
default: GGML_ASSERT(false && "not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&(src0->ne[0]) length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&(src0->nb[1]) length:sizeof(uint64_t) atIndex:4];
[encoder setBytes:&(dst->nb[1]) length:sizeof(uint64_t) atIndex:5];
const int64_t n = ggml_nelements(src1);
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_RMS_NORM:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const float eps = 1e-6f;
const int nth = 256;
[encoder setComputePipelineState:ctx->pipeline_rms_norm];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:3];
[encoder setBytes:&eps length:sizeof( float) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
const int64_t nrows = ggml_nrows(src0);
[encoder dispatchThreadgroups:MTLSizeMake(nrows, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
case GGML_OP_ROPE:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_past = ((int32_t *)(src1->data))[0];
[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder setBytes:&n_past length:sizeof( int) atIndex:18];
[encoder setBytes:&n_dims length:sizeof( int) atIndex:19];
[encoder setBytes:&mode length:sizeof( int) atIndex:20];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break;
case GGML_OP_CPY:
{
if (encoder == nil) {
encoder = [command_buffer computeCommandEncoder];
}
const int nth = 32;
switch (src0t) {
case GGML_TYPE_F32:
{
switch (dstt) {
case GGML_TYPE_F16: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f16]; break;
case GGML_TYPE_F32: [encoder setComputePipelineState:ctx->pipeline_cpy_f32_f32]; break;
default: GGML_ASSERT(false && "not implemented");
};
} break;
default: GGML_ASSERT(false && "not implemented");
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof( int64_t) atIndex:2];
[encoder setBytes:&ne01 length:sizeof( int64_t) atIndex:3];
[encoder setBytes:&ne02 length:sizeof( int64_t) atIndex:4];
[encoder setBytes:&ne03 length:sizeof( int64_t) atIndex:5];
[encoder setBytes:&nb00 length:sizeof(uint64_t) atIndex:6];
[encoder setBytes:&nb01 length:sizeof(uint64_t) atIndex:7];
[encoder setBytes:&nb02 length:sizeof(uint64_t) atIndex:8];
[encoder setBytes:&nb03 length:sizeof(uint64_t) atIndex:9];
[encoder setBytes:&ne0 length:sizeof( int64_t) atIndex:10];
[encoder setBytes:&ne1 length:sizeof( int64_t) atIndex:11];
[encoder setBytes:&ne2 length:sizeof( int64_t) atIndex:12];
[encoder setBytes:&ne3 length:sizeof( int64_t) atIndex:13];
[encoder setBytes:&nb0 length:sizeof(uint64_t) atIndex:14];
[encoder setBytes:&nb1 length:sizeof(uint64_t) atIndex:15];
[encoder setBytes:&nb2 length:sizeof(uint64_t) atIndex:16];
[encoder setBytes:&nb3 length:sizeof(uint64_t) atIndex:17];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break;
default:
fprintf(stderr, "%s: node %3d, op = %8s not implemented\n", __func__, i, ggml_op_name(dst->op));
GGML_ASSERT(false);
}
}
if (encoder != nil) {
[encoder endEncoding];
encoder = nil;
}
[command_buffer commit];
});
} }
if (encoder != nil) { // wait for all threads to finish
[encoder endEncoding]; dispatch_barrier_sync(queue, ^{});
encoder = nil;
}
[command_buffer commit]; [command_buffers[n_cb - 1] waitUntilCompleted];
[command_buffer waitUntilCompleted];
{
const double time_elapsed = [command_buffer GPUEndTime] - [command_buffer GPUStartTime];
UNUSED(time_elapsed);
metal_printf("%s: time elapsed = %f ms\n", __func__, time_elapsed * 1000.0);
}
} }

View file

@ -15,7 +15,7 @@
#include "ggml.h" #include "ggml.h"
#define CL_DMMV_BLOCK_SIZE 32; #define CL_DMMV_BLOCK_SIZE 32
#define MULTILINE_QUOTE(...) #__VA_ARGS__ #define MULTILINE_QUOTE(...) #__VA_ARGS__
static std::string program_source = MULTILINE_QUOTE( static std::string program_source = MULTILINE_QUOTE(
@ -59,6 +59,46 @@ struct __attribute__ ((packed)) block_q8_0
int8_t qs[QK8_0]; int8_t qs[QK8_0];
}; };
struct __attribute__((packed)) block_q2_K
{
uint8_t scales[16];
uint8_t qs[64];
half d;
half dmin;
};
struct __attribute__((packed)) block_q3_K
{
uint8_t hmask[32];
uint8_t qs[64];
uint8_t scales[12];
half d;
};
struct __attribute__((packed)) block_q4_K
{
half d;
half dmin;
uint8_t scales[12];
uint8_t qs[128];
};
struct __attribute__((packed)) block_q5_K
{
half d;
half dmin;
uint8_t scales[12];
uint8_t qh[32];
uint8_t qs[128];
};
struct __attribute__((packed)) block_q6_K
{
uint8_t ql[128];
uint8_t qh[64];
int8_t scales[16];
half d;
};
__kernel void convert_fp16_to_fp32(__global half* x, __global float* y) { __kernel void convert_fp16_to_fp32(__global half* x, __global float* y) {
const uint i = get_global_id(0); const uint i = get_global_id(0);
@ -131,8 +171,314 @@ void convert_f16(__global half* x, const int ib, const int iqs, float* v0, float
*v0 = vload_half(0, &x[ib + 0]); *v0 = vload_half(0, &x[ib + 0]);
*v1 = vload_half(0, &x[ib + 1]); *v1 = vload_half(0, &x[ib + 1]);
} }
inline void get_scale_min_k4(int j, const __global uint8_t *q, uint8_t *d, uint8_t *m)
{
if (j < 4)
{
*d = q[j] & 63;
*m = q[j + 4] & 63;
}
else
{
*d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4);
*m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4);
}
}
__kernel void dequantize_block_q2_K(__global const struct block_q2_K *x, __global float *yy)
{
const int i = get_group_id(0);
const int tid = get_local_id(0);
const int n = tid / 32;
const int l = tid - 32 * n;
const int is = 8 * n + l / 16;
const uint8_t q = x[i].qs[32 * n + l];
__global float *y = yy + i * 256 + 128 * n;
const float dall = vload_half(0, &x[i].d);
const float dmin = vload_half(0, &x[i].dmin);
y[l + 0] = dall * (x[i].scales[is + 0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is + 0] >> 4);
y[l + 32] = dall * (x[i].scales[is + 2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is + 2] >> 4);
y[l + 64] = dall * (x[i].scales[is + 4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is + 4] >> 4);
y[l + 96] = dall * (x[i].scales[is + 6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is + 6] >> 4);
}
__kernel void dequantize_block_q3_K(__global const struct block_q3_K *x, __global float *yy)
{
int r = get_local_id(0) / 4;
int i = get_group_id(0);
int tid = r / 2;
int is0 = r % 2;
int l0 = 16 * is0 + 4 * (get_local_id(0) % 4);
int n = tid / 4;
int j = tid - 4 * n;
uint8_t m = 1 << (4 * n + j);
int is = 8 * n + 2 * j + is0;
int shift = 2 * j;
int8_t us = is < 4 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 8] >> 0) & 3) << 4)
: is < 8 ? (x[i].scales[is - 0] & 0xF) | (((x[i].scales[is + 4] >> 2) & 3) << 4)
: is < 12 ? (x[i].scales[is - 8] >> 4) | (((x[i].scales[is + 0] >> 4) & 3) << 4)
: (x[i].scales[is - 8] >> 4) | (((x[i].scales[is - 4] >> 6) & 3) << 4);
float d_all = vload_half(0, &x[i].d);
float dl = d_all * (us - 32);
__global float *y = yy + i * 256 + 128 * n + 32 * j;
const __global uint8_t *q = x[i].qs + 32 * n;
const __global uint8_t *hm = x[i].hmask;
for (int l = l0; l < l0 + 4; ++l)
y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
}
__kernel void dequantize_block_q4_K(__global const struct block_q4_K *x, __global float *yy)
{
const int i = get_group_id(0);
const int tid = get_local_id(0);
const int il = tid / 8;
const int ir = tid % 8;
const int is = 2 * il;
const int n = 4;
__global float *y = yy + i * 256 + 64 * il + n * ir;
const float dall = vload_half(0, &x[i].d);
const float dmin = vload_half(0, &x[i].dmin);
__global const uint8_t *q = x[i].qs + 32 * il + n * ir;
uint8_t sc, m;
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
float d1 = dall * sc;
float m1 = dmin * m;
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
float d2 = dall * sc;
float m2 = dmin * m;
for (int l = 0; l < n; ++l)
{
y[l + 0] = d1 * (q[l] & 0xF) - m1;
y[l + 32] = d2 * (q[l] >> 4) - m2;
}
}
__kernel void dequantize_block_q5_K(__global const struct block_q5_K *x, __global float *yy)
{
const int i = get_group_id(0);
const int tid = get_local_id(0);
const int il = tid / 16;
const int ir = tid % 16;
const int is = 2 * il;
__global float *y = yy + i * 256 + 64 * il + 2 * ir;
const float dall = vload_half(0, &x[i].d);
const float dmin = vload_half(0, &x[i].dmin);
__global const uint8_t *ql = x[i].qs + 32 * il + 2 * ir;
__global const uint8_t *qh = x[i].qh + 2 * ir;
uint8_t sc, m;
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
const float d2 = dall * sc;
const float m2 = dmin * m;
uint8_t hm = 1 << (2 * il);
y[0] = d1 * ((ql[0] & 0xF) + (qh[0] & hm ? 16 : 0)) - m1;
y[1] = d1 * ((ql[1] & 0xF) + (qh[1] & hm ? 16 : 0)) - m1;
hm <<= 1;
y[32] = d2 * ((ql[0] >> 4) + (qh[0] & hm ? 16 : 0)) - m2;
y[33] = d2 * ((ql[1] >> 4) + (qh[1] & hm ? 16 : 0)) - m2;
}
__kernel void dequantize_block_q6_K(__global const struct block_q6_K *x, __global float *yy)
{
const int i = get_group_id(0);
const int tid = get_local_id(0);
const int ip = tid / 32;
const int il = tid - 32 * ip;
const int is = 8 * ip + il / 16;
__global float *y = yy + i * 256 + 128 * ip + il;
const float d = vload_half(0, &x[i].d);
__global const uint8_t *ql = x[i].ql + 64 * ip + il;
const uint8_t qh = x[i].qh[32 * ip + il];
__global const int8_t *sc = x[i].scales + is;
y[0] = d * sc[0] * ((int8_t)((ql[0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
y[64] = d * sc[4] * ((int8_t)((ql[0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
}
void vec_dot_q2_K(__global const struct block_q2_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
int n = iqs / 128;
int r = iqs - 128 * n;
int l = r / 8;
__global const float *y = yy + 128 * n + l;
__global const uint8_t *q = x[ib].qs + 32 * n + l;
__global const uint8_t *s = x[ib].scales + 8 * n;
const float dall = vload_half(0, &x[ib].d);
const float dmin = vload_half(0, &x[ib].dmin);
float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
+ y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
+ y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
+ y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
+ y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
+ y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
+ y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
+ y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
*result = sum;
}
void vec_dot_q3_K(__global const struct block_q3_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
uint32_t aux[3];
uint32_t utmp[4];
int n = iqs/128;
int r = iqs - 128*n;
int l = r/8;
__global const float * y = yy + 128*n + l;
__global const uint8_t * q = x[ib].qs + 32*n + l;
__global const uint8_t * hm = x[ib].hmask + l;
const int8_t * s = (const int8_t *)utmp + 8*n;
aux[0] = x[ib].scales[0] | x[ib].scales[1] << 8 | x[ib].scales[2] << 16 | x[ib].scales[3] << 24;
aux[1] = x[ib].scales[4] | x[ib].scales[5] << 8 | x[ib].scales[6] << 16 | x[ib].scales[7] << 24;
aux[2] = x[ib].scales[8] | x[ib].scales[9] << 8 | x[ib].scales[10] << 16 | x[ib].scales[11] << 24;
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
const float dall = vload_half(0, &x[ib].d);
const uint8_t m = 1 << (4*n);
float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
+ y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
+ y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
+ y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
+ y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
+ y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
+ y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
+ y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
*result = sum * dall;
}
void vec_dot_q4_K(__global const struct block_q4_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
const int j = iqs / 64; // j is in 0...3
const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
const int is = 2*j; // is is in 0...6 in steps of 2
__global const float * y = yy + 64*j + ir;
__global const uint8_t * q = x[ib].qs + 32*j + ir;
const float dall = vload_half(0, &x[ib].d);
const float dmin = vload_half(0, &x[ib].dmin);
uint8_t sc, m;
get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
const float d2 = dall * sc;
const float m2 = dmin * m;
float sum = 0;
for (int k = 0; k < 4; ++k) {
sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
}
*result = sum;
}
void vec_dot_q5_K(__global const struct block_q5_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
const int j = iqs / 64;
const int ir = (iqs - 64*j)/2;
const int is = 2*j;
__global const float * y = yy + 64*j + ir;
__global const uint8_t * ql = x[ib].qs + 32*j + ir;
__global const uint8_t * qh = x[ib].qh + ir;
const float dall = vload_half(0, &x[ib].d);
const float dmin = vload_half(0, &x[ib].dmin);
uint8_t sc, m;
get_scale_min_k4(is + 0, x[ib].scales, &sc, &m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[ib].scales, &sc, &m);
const float d2 = dall * sc;
const float m2 = dmin * m;
uint8_t hm = 1 << is;
float sum = 0;
for (int k = 0; k < 4; ++k) {
sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
}
hm <<= 1;
for (int k = 0; k < 4; ++k) {
sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2);
}
*result = sum;
}
void vec_dot_q6_K(__global const struct block_q6_K* x, const int ib, const int iqs, const __global float *yy, float *result) {
const int ip = iqs / 128; // 0 or 1
const int il = (iqs - 128*ip)/8; // 0...15
const int is = 8*ip;
__global const float * y = yy + 128*ip + il;
const float d = vload_half(0, &x[ib].d);
__global const uint8_t * ql = x[ib].ql + 64*ip + il;
__global const uint8_t * qh = x[ib].qh + 32*ip + il;
__global const int8_t * sc = x[ib].scales + is;
*result = y[ 0] * d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh[ 0] >> 0) & 3) << 4)) - 32)
+ y[ 32] * d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh[ 0] >> 2) & 3) << 4)) - 32)
+ y[ 64] * d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh[ 0] >> 4) & 3) << 4)) - 32)
+ y[ 96] * d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh[ 0] >> 6) & 3) << 4)) - 32)
+ y[ 16] * d * sc[1] * ((int8_t)((ql[16] & 0xF) | (((qh[16] >> 0) & 3) << 4)) - 32)
+ y[ 48] * d * sc[3] * ((int8_t)((ql[48] & 0xF) | (((qh[16] >> 2) & 3) << 4)) - 32)
+ y[ 80] * d * sc[5] * ((int8_t)((ql[16] >> 4) | (((qh[16] >> 4) & 3) << 4)) - 32)
+ y[112] * d * sc[7] * ((int8_t)((ql[48] >> 4) | (((qh[16] >> 6) & 3) << 4)) - 32);
}
); );
std::string dequant_template = MULTILINE_QUOTE( std::string dequant_template = MULTILINE_QUOTE(
__kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) { __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
const int i = get_group_id(0)*get_local_size(0) + get_local_id(0)*2; const int i = get_group_id(0)*get_local_size(0) + get_local_id(0)*2;
@ -160,7 +506,7 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __global float* y) {
std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE( std::string dequant_mul_mat_vec_template = MULTILINE_QUOTE(
__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) { __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
const int block_size = get_local_size(0); const int block_size = get_local_size(0);
const int row = get_global_id(0) / block_size; const int row = get_group_id(0);
const int tid = get_local_id(0); const int tid = get_local_id(0);
const uint qk = QUANT_K; const uint qk = QUANT_K;
@ -199,6 +545,45 @@ __kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float
} }
); );
std::string dequant_mul_mat_vec_k_template = MULTILINE_QUOTE(
__kernel void KERNEL_NAME(__global X_TYPE* x, __local float* tmp, __global float* y, __global float* dst, const int ncols) {
const int block_size = get_local_size(0);
const int row = get_group_id(0);
const int tid = get_local_id(0);
const int iter_stride = 256;
const int vals_per_iter = iter_stride / block_size;
const int num_blocks_per_row = ncols / 256;
const int ib0 = row*num_blocks_per_row;
tmp[tid] = 0;
for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid;
const int ib = ib0 + col/256; // x block index
const int iqs = col%256; // x quant index
const int iybs = col - col%256; // y block start index
// dequantize
float v;
DOT_KERNEL(x, ib, iqs, y + iybs, &v);
tmp[tid] += v;
}
// sum up partial sums and write back result
barrier(CLK_LOCAL_MEM_FENCE);
for (int s=block_size/2; s>0; s>>=1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
barrier(CLK_LOCAL_MEM_FENCE);
}
if (tid == 0) {
dst[row] = tmp[0];
}
}
);
std::string mul_template = MULTILINE_QUOTE( std::string mul_template = MULTILINE_QUOTE(
__kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) { __kernel void KERNEL_NAME(__global TYPE* x, const int x_offset, __global TYPE* y, const int y_offset, __global TYPE* dst, const int dst_offset, const int ky) {
const int i = get_group_id(0)*get_local_size(0) + get_local_id(0); const int i = get_group_id(0)*get_local_size(0) + get_local_id(0);
@ -260,6 +645,18 @@ std::array<std::string, 2> mul_str_values = {
"mul_f32", "float" "mul_f32", "float"
}; };
std::array<std::string, 3> dmmv_k_str_keys = {
"KERNEL_NAME", "X_TYPE", "DOT_KERNEL"
};
std::array<std::string, 15> dmmv_k_str_values = {
"dequantize_mul_mat_vec_q2_K", "struct block_q2_K", "vec_dot_q2_K",
"dequantize_mul_mat_vec_q3_K", "struct block_q3_K", "vec_dot_q3_K",
"dequantize_mul_mat_vec_q4_K", "struct block_q4_K", "vec_dot_q4_K",
"dequantize_mul_mat_vec_q5_K", "struct block_q5_K", "vec_dot_q5_K",
"dequantize_mul_mat_vec_q6_K", "struct block_q6_K", "vec_dot_q6_K",
};
std::string& replace(std::string& s, const std::string& from, const std::string& to) { std::string& replace(std::string& s, const std::string& from, const std::string& to) {
size_t pos = 0; size_t pos = 0;
while ((pos = s.find(from, pos)) != std::string::npos) { while ((pos = s.find(from, pos)) != std::string::npos) {
@ -289,6 +686,14 @@ std::string generate_kernels() {
} }
src << mul_kernel << '\n'; src << mul_kernel << '\n';
} }
for (size_t i = 0; i < dmmv_k_str_values.size(); i += dmmv_k_str_keys.size()) {
std::string dmmv_k_kernel = dequant_mul_mat_vec_k_template;
for (size_t j = 0; j < dmmv_k_str_keys.size(); j++) {
replace(dmmv_k_kernel, dmmv_k_str_keys[j], dmmv_k_str_values[i + j]);
}
src << dmmv_k_kernel << '\n';
}
return src.str(); return src.str();
} }
@ -300,6 +705,8 @@ static cl_program program;
static cl_kernel convert_row_f16_cl; static cl_kernel convert_row_f16_cl;
static cl_kernel dequantize_row_q4_0_cl, dequantize_row_q4_1_cl, dequantize_row_q5_0_cl, dequantize_row_q5_1_cl, dequantize_row_q8_0_cl; static cl_kernel dequantize_row_q4_0_cl, dequantize_row_q4_1_cl, dequantize_row_q5_0_cl, dequantize_row_q5_1_cl, dequantize_row_q8_0_cl;
static cl_kernel dequantize_mul_mat_vec_q4_0_cl, dequantize_mul_mat_vec_q4_1_cl, dequantize_mul_mat_vec_q5_0_cl, dequantize_mul_mat_vec_q5_1_cl, dequantize_mul_mat_vec_q8_0_cl, convert_mul_mat_vec_f16_cl; static cl_kernel dequantize_mul_mat_vec_q4_0_cl, dequantize_mul_mat_vec_q4_1_cl, dequantize_mul_mat_vec_q5_0_cl, dequantize_mul_mat_vec_q5_1_cl, dequantize_mul_mat_vec_q8_0_cl, convert_mul_mat_vec_f16_cl;
static cl_kernel dequantize_block_q2_k_cl, dequantize_block_q3_k_cl, dequantize_block_q4_k_cl, dequantize_block_q5_k_cl, dequantize_block_q6_k_cl;
static cl_kernel dequantize_mul_mat_vec_q2_K_cl, dequantize_mul_mat_vec_q3_K_cl, dequantize_mul_mat_vec_q4_K_cl, dequantize_mul_mat_vec_q5_K_cl, dequantize_mul_mat_vec_q6_K_cl;
static cl_kernel mul_f32_cl; static cl_kernel mul_f32_cl;
static bool fp16_support; static bool fp16_support;
@ -529,6 +936,12 @@ void ggml_cl_init(void) {
CL_CHECK((dequantize_row_q5_0_cl = clCreateKernel(program, "dequantize_row_q5_0", &err), err)); CL_CHECK((dequantize_row_q5_0_cl = clCreateKernel(program, "dequantize_row_q5_0", &err), err));
CL_CHECK((dequantize_row_q5_1_cl = clCreateKernel(program, "dequantize_row_q5_1", &err), err)); CL_CHECK((dequantize_row_q5_1_cl = clCreateKernel(program, "dequantize_row_q5_1", &err), err));
CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err)); CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
CL_CHECK((dequantize_row_q8_0_cl = clCreateKernel(program, "dequantize_row_q8_0", &err), err));
CL_CHECK((dequantize_block_q2_k_cl = clCreateKernel(program, "dequantize_block_q2_K", &err), err));
CL_CHECK((dequantize_block_q3_k_cl = clCreateKernel(program, "dequantize_block_q3_K", &err), err));
CL_CHECK((dequantize_block_q4_k_cl = clCreateKernel(program, "dequantize_block_q4_K", &err), err));
CL_CHECK((dequantize_block_q5_k_cl = clCreateKernel(program, "dequantize_block_q5_K", &err), err));
CL_CHECK((dequantize_block_q6_k_cl = clCreateKernel(program, "dequantize_block_q6_K", &err), err));
// dequant mul mat kernel // dequant mul mat kernel
CL_CHECK((dequantize_mul_mat_vec_q4_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_0", &err), err)); CL_CHECK((dequantize_mul_mat_vec_q4_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_0", &err), err));
@ -537,6 +950,11 @@ void ggml_cl_init(void) {
CL_CHECK((dequantize_mul_mat_vec_q5_1_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_1", &err), err)); CL_CHECK((dequantize_mul_mat_vec_q5_1_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_1", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q8_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q8_0", &err), err)); CL_CHECK((dequantize_mul_mat_vec_q8_0_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q8_0", &err), err));
CL_CHECK((convert_mul_mat_vec_f16_cl = clCreateKernel(program, "convert_mul_mat_vec_f16", &err), err)); CL_CHECK((convert_mul_mat_vec_f16_cl = clCreateKernel(program, "convert_mul_mat_vec_f16", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q2_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q2_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q3_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q3_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q4_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q4_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q5_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q5_K", &err), err));
CL_CHECK((dequantize_mul_mat_vec_q6_K_cl = clCreateKernel(program, "dequantize_mul_mat_vec_q6_K", &err), err));
// mul kernel // mul kernel
CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err)); CL_CHECK((mul_f32_cl = clCreateKernel(program, "mul_f32", &err), err));
@ -554,6 +972,16 @@ static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) {
return &dequantize_row_q5_1_cl; return &dequantize_row_q5_1_cl;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
return &dequantize_row_q8_0_cl; return &dequantize_row_q8_0_cl;
case GGML_TYPE_Q2_K:
return &dequantize_block_q2_k_cl;
case GGML_TYPE_Q3_K:
return &dequantize_block_q3_k_cl;
case GGML_TYPE_Q4_K:
return &dequantize_block_q4_k_cl;
case GGML_TYPE_Q5_K:
return &dequantize_block_q5_k_cl;
case GGML_TYPE_Q6_K:
return &dequantize_block_q6_k_cl;
case GGML_TYPE_F16: case GGML_TYPE_F16:
return &convert_row_f16_cl; return &convert_row_f16_cl;
default: default:
@ -561,6 +989,50 @@ static cl_kernel* ggml_get_to_fp32_cl(ggml_type type) {
} }
} }
static size_t ggml_cl_global_denom(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
return 1;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
return 4;
case GGML_TYPE_Q4_K:
return 8;
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
return 4;
case GGML_TYPE_F16:
default:
return 1;
}
}
static size_t ggml_cl_local_size(ggml_type type) {
switch (type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
return 0;
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
return 64;
case GGML_TYPE_Q4_K:
return 32;
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
return 64;
case GGML_TYPE_F16:
default:
return 0;
}
}
static cl_kernel* ggml_get_dequantize_mul_mat_vec_cl(ggml_type type) { static cl_kernel* ggml_get_dequantize_mul_mat_vec_cl(ggml_type type) {
switch (type) { switch (type) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
@ -575,6 +1047,16 @@ static cl_kernel* ggml_get_dequantize_mul_mat_vec_cl(ggml_type type) {
return &dequantize_mul_mat_vec_q8_0_cl; return &dequantize_mul_mat_vec_q8_0_cl;
case GGML_TYPE_F16: case GGML_TYPE_F16:
return &convert_mul_mat_vec_f16_cl; return &convert_mul_mat_vec_f16_cl;
case GGML_TYPE_Q2_K:
return &dequantize_mul_mat_vec_q2_K_cl;
case GGML_TYPE_Q3_K:
return &dequantize_mul_mat_vec_q3_K_cl;
case GGML_TYPE_Q4_K:
return &dequantize_mul_mat_vec_q4_K_cl;
case GGML_TYPE_Q5_K:
return &dequantize_mul_mat_vec_q5_K_cl;
case GGML_TYPE_Q6_K:
return &dequantize_mul_mat_vec_q6_K_cl;
default: default:
return nullptr; return nullptr;
} }
@ -1017,6 +1499,9 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
cl_kernel* dmmv = ggml_get_dequantize_mul_mat_vec_cl(type); cl_kernel* dmmv = ggml_get_dequantize_mul_mat_vec_cl(type);
GGML_ASSERT(to_fp32_cl != nullptr); GGML_ASSERT(to_fp32_cl != nullptr);
const size_t global_denom = ggml_cl_global_denom(type);
const size_t local = ggml_cl_local_size(type);
size_t ev_idx = 0; size_t ev_idx = 0;
std::vector<cl_event> events; std::vector<cl_event> events;
@ -1049,10 +1534,10 @@ static void ggml_cl_mul_mat_q_f32(const ggml_tensor * src0, const ggml_tensor *
CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++)); CL_CHECK(clEnqueueNDRangeKernel(queue, *dmmv, 1, NULL, &global, &local, events.size() - 1, events.data(), events.data() + ev_idx++));
} else { // general dequantization kernel + CLBlast matrix matrix multiplication } else { // general dequantization kernel + CLBlast matrix matrix multiplication
// convert src0 to fp32 on device // convert src0 to fp32 on device
const size_t global = x_ne; const size_t global = x_ne / global_denom;
CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q)); CL_CHECK(clSetKernelArg(*to_fp32_cl, 0, sizeof(cl_mem), &d_Q));
CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X)); CL_CHECK(clSetKernelArg(*to_fp32_cl, 1, sizeof(cl_mem), &d_X));
CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, NULL, events.size(), !events.empty() ? events.data() : NULL, NULL)); CL_CHECK(clEnqueueNDRangeKernel(queue, *to_fp32_cl, 1, NULL, &global, local > 0 ? &local : NULL, events.size(), !events.empty() ? events.data() : NULL, NULL));
// copy src1 to device // copy src1 to device
CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL)); CL_CHECK(ggml_cl_h2d_tensor_2d(queue, d_Y, 0, src1, i03, i02, NULL));

6
ggml.c
View file

@ -35,6 +35,12 @@
#define static_assert(cond, msg) struct global_scope_noop_trick #define static_assert(cond, msg) struct global_scope_noop_trick
#endif #endif
#if defined(_MSC_VER)
// disable "possible loss of data" to avoid hundreds of casts
// we should just be careful :)
#pragma warning(disable: 4244 4267)
#endif
#if defined(_WIN32) #if defined(_WIN32)
#include <windows.h> #include <windows.h>

View file

@ -40,6 +40,10 @@
#include <sstream> #include <sstream>
#include <numeric> #include <numeric>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
#define LLAMA_USE_SCRATCH #define LLAMA_USE_SCRATCH
#define LLAMA_MAX_SCRATCH_BUFFERS 16 #define LLAMA_MAX_SCRATCH_BUFFERS 16
@ -1654,7 +1658,7 @@ static bool llama_eval_internal(
// cur = cur*norm(broadcasted) // cur = cur*norm(broadcasted)
cur = ggml_mul(ctx0, cur, model.norm); cur = ggml_mul(ctx0, cur, model.norm);
offload_func_nr(cur); // offload_func_nr(cur); // TODO CPU + GPU mirrored backend
ggml_set_name(cur, "result_norm"); ggml_set_name(cur, "result_norm");
embeddings = cur; embeddings = cur;

View file

@ -244,9 +244,9 @@ extern "C" {
LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token); LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token);
// Special tokens // Special tokens
LLAMA_API llama_token llama_token_bos(); LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence
LLAMA_API llama_token llama_token_eos(); LLAMA_API llama_token llama_token_eos(); // end-of-sentence
LLAMA_API llama_token llama_token_nl(); LLAMA_API llama_token llama_token_nl(); // next-line
// Sampling functions // Sampling functions

View file

@ -10,6 +10,10 @@
#include <ggml.h> #include <ggml.h>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
constexpr int kVecSize = 1 << 18; constexpr int kVecSize = 1 << 18;
float drawFromGaussianPdf(std::mt19937& rndm) { float drawFromGaussianPdf(std::mt19937& rndm) {

View file

@ -1,9 +1,10 @@
import os import os
import hashlib import hashlib
def sha256sum(file): def sha256sum(file):
block_size = 16 * 1024 * 1024 # 16 MB block size block_size = 16 * 1024 * 1024 # 16 MB block size
b = bytearray(block_size) b = bytearray(block_size)
file_hash = hashlib.sha256() file_hash = hashlib.sha256()
mv = memoryview(b) mv = memoryview(b)
with open(file, 'rb', buffering=0) as f: with open(file, 'rb', buffering=0) as f:
@ -15,6 +16,7 @@ def sha256sum(file):
return file_hash.hexdigest() return file_hash.hexdigest()
# Define the path to the llama directory (parent folder of script directory) # Define the path to the llama directory (parent folder of script directory)
llama_path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir)) llama_path = os.path.abspath(os.path.join(os.path.dirname(__file__), os.pardir))

1
spm-headers/ggml.h Symbolic link
View file

@ -0,0 +1 @@
../ggml.h

View file

@ -9,12 +9,15 @@
#include <string> #include <string>
#include <vector> #include <vector>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
const float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001; const float MAX_QUANTIZATION_REFERENCE_ERROR = 0.0001f;
const float MAX_QUANTIZATION_TOTAL_ERROR = 0.002; const float MAX_QUANTIZATION_TOTAL_ERROR = 0.002f;
const float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075; const float MAX_QUANTIZATION_TOTAL_ERROR_2BITS = 0.0075f;
const float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040; const float MAX_QUANTIZATION_TOTAL_ERROR_3BITS = 0.0040f;
const float MAX_DOT_PRODUCT_ERROR = 0.02; const float MAX_DOT_PRODUCT_ERROR = 0.02f;
const char* RESULT_STR[] = {"ok", "FAILED"}; const char* RESULT_STR[] = {"ok", "FAILED"};

View file

@ -13,6 +13,10 @@
#include <string> #include <string>
#include <vector> #include <vector>
#if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data
#endif
#define MAX_ALIGNMENT 64 #define MAX_ALIGNMENT 64
#define QK 32 #define QK 32
#define WARMUP 5 #define WARMUP 5

View file

@ -176,27 +176,27 @@ void test_frequency_presence_penalty(
int main(void) { int main(void) {
ggml_time_init(); ggml_time_init();
test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4}, 1); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 1);
test_top_k({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2}, 3); test_top_k({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f}, 3);
test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4}, 0); test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f}, 0);
test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3}, 0.7); test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f}, 0.7f);
test_top_p({0.1, 0.2, 0.3, 0.4}, {0.4, 0.3, 0.2, 0.1}, 1); test_top_p({0.1f, 0.2f, 0.3f, 0.4f}, {0.4f, 0.3f, 0.2f, 0.1f}, 1);
test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3}, 0.25); test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f}, 0.25f);
test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3, 0.25}, 0.75); test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.75f);
test_tfs({0.1, 0.15, 0.2, 0.25, 0.3}, {0.3, 0.25}, 0.99); test_tfs({0.1f, 0.15f, 0.2f, 0.25f, 0.3f}, {0.3f, 0.25f}, 0.99f);
test_typical({0.97, 0.01, 0.01, 0.01}, {0.97}, 0.5); test_typical({0.97f, 0.01f, 0.01f, 0.01f}, {0.97f}, 0.5f);
test_typical({0.4, 0.2, 0.2, 0.2}, {0.2, 0.2, 0.2}, 0.5); test_typical({0.4f, 0.2f, 0.2f, 0.2f}, {0.2f, 0.2f, 0.2f}, 0.5f);
test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.25, 0.25, 0.25, 0.25, 0}, 50.0); test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.25f, 0.25f, 0.25f, 0.25f, 0}, 50.0f);
test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.5, 0.5, 0, 0, 0}, 50.0); test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.5f, 0.5f, 0, 0, 0}, 50.0f);
test_repetition_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.5, 0.5, 0, 0, 0}, 50.0); test_repetition_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.5f, 0.5f, 0, 0, 0}, 50.0f);
test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0}, {0.249997, 0.249997, 0.249997, 0.249997, 0.000011}, 5.0, 5.0); test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0}, {0.249997f, 0.249997f, 0.249997f, 0.249997f, 0.000011f}, 5.0f, 5.0f);
test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2}, {0.499966, 0.499966, 0.000023, 0.000023, 0.000023}, 5.0, 5.0); test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2}, {0.499966f, 0.499966f, 0.000023f, 0.000023f, 0.000023f}, 5.0f, 5.0f);
test_frequency_presence_penalty({0.2, 0.2, 0.2, 0.2, 0.2}, {0, 1, 2, 0, 0}, {0.499977, 0.499977, 0.000023, 0.000023, 0.000000}, 5.0, 5.0); test_frequency_presence_penalty({0.2f, 0.2f, 0.2f, 0.2f, 0.2f}, {0, 1, 2, 0, 0}, {0.499977f, 0.499977f, 0.000023f, 0.000023f, 0.000000f}, 5.0f, 5.0f);
printf("OK\n"); printf("OK\n");
} }

View file

@ -53,7 +53,7 @@ int main(int argc, char **argv) {
for (const auto & test_kv : k_tests()) { for (const auto & test_kv : k_tests()) {
std::vector<llama_token> res(test_kv.first.size()); std::vector<llama_token> res(test_kv.first.size());
const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), res.size(), true); const int n = llama_tokenize(ctx, test_kv.first.c_str(), res.data(), int(res.size()), true);
res.resize(n); res.resize(n);
bool correct = res.size() == test_kv.second.size(); bool correct = res.size() == test_kv.second.size();