Compare commits

...
Sign in to create a new pull request.

18 commits

Author SHA1 Message Date
Georgi Gerganov
d45c1631bc
metal : rewrite to fit new backend interface correctly (WIP) 2023-07-20 22:51:19 +03:00
Georgi Gerganov
cb82adadb8
metal : first working version of the inference without prompt processing
Bonus: supports partial inference on the CPU
2023-07-20 14:56:29 +03:00
Georgi Gerganov
290cb700bf
metal : map the CPU buffers to Metal buffers (WIP) 2023-07-20 14:30:34 +03:00
Georgi Gerganov
f38433ef5d
Merge remote-tracking branch 'origin/ggml-backends' into ggml-backends-metal 2023-07-19 17:45:45 +03:00
Georgi Gerganov
70c55c17c7
metal : create backend, mostly reuse CPU backend interface 2023-07-19 16:47:43 +03:00
slaren
295f85654a allocators wip
renamed ggml_backend functions
changed ggml_buffer and ggml_backend to always be used as pointers
rename ggml_tensor::params -> op_params
2023-07-19 02:43:44 +02:00
Georgi Gerganov
ed960fa1ab
llama : separate compute buffer for metal 2023-07-18 19:19:59 +03:00
Georgi Gerganov
652c849643
ggml : add is_ram_shared to ggml_backend
Metal can share the RAM memory and can utilize mmap without temp buffer
2023-07-18 18:51:02 +03:00
Georgi Gerganov
90503f150d
llama : init metal backend as CPU backend for now 2023-07-18 17:54:16 +03:00
Georgi Gerganov
0a3861c47b
metal : adapting to ggml_backend (WIP) 2023-07-18 16:54:41 +03:00
slaren
1102ff56db fix double-free with --no-mmap 2023-07-17 12:00:17 +02:00
slaren
4e94af3060 improve layer backend printing with ranges 2023-07-17 11:53:01 +02:00
slaren
c2beeb8e3a only allocate as much memory as is required in each backend for the model 2023-07-17 11:21:32 +02:00
slaren
9c72e7e916 rebase to master (except ggml-cuda) 2023-07-16 15:10:46 +02:00
slaren
33ab185dd1 fix NVCC version on Makefile, __halves2half2 -> make_half2 2023-07-16 14:56:52 +02:00
slaren
24cc6f008f minor fixes 2023-07-16 14:56:52 +02:00
slaren
5765d7a587 restore simple.cpp for now 2023-07-16 14:56:52 +02:00
slaren
0d2b66c638 ggml backend interface wip
refactor ggml-cuda
2023-07-16 14:56:46 +02:00
18 changed files with 4996 additions and 4921 deletions

View file

@ -308,13 +308,13 @@ jobs:
path: |
llama-${{ env.BRANCH_NAME }}-${{ steps.commit.outputs.short }}-bin-win-${{ matrix.build }}-x64.zip
windows-latest-cmake-cublas:
windows-latest-cmake-cuda:
runs-on: windows-latest
strategy:
matrix:
cuda: ['12.1.0', '11.7.1']
build: ['cublas']
build: ['cuda']
steps:
- name: Clone
@ -333,7 +333,7 @@ jobs:
run: |
mkdir build
cd build
cmake .. -DLLAMA_BUILD_SERVER=ON -DLLAMA_CUBLAS=ON
cmake .. -DLLAMA_BUILD_SERVER=ON -DLLAMA_CUDA=ON
cmake --build . --config Release
- name: Get commit hash
@ -395,7 +395,7 @@ jobs:
- macOS-latest-make
- macOS-latest-cmake
- windows-latest-cmake
- windows-latest-cmake-cublas
- windows-latest-cmake-cuda
steps:
- name: Download artifacts

View file

@ -67,7 +67,7 @@ endif()
option(LLAMA_ACCELERATE "llama: enable Accelerate framework" ON)
option(LLAMA_BLAS "llama: use BLAS" OFF)
set(LLAMA_BLAS_VENDOR "Generic" CACHE STRING "llama: BLAS library vendor")
option(LLAMA_CUBLAS "llama: use cuBLAS" OFF)
option(LLAMA_CUDA "llama: use CUDA" OFF)
option(LLAMA_CUDA_FORCE_DMMV "llama: use dmmv instead of mmvq CUDA kernels" OFF)
set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kernels")
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
@ -239,18 +239,18 @@ if (LLAMA_K_QUANTS)
endif()
endif()
if (LLAMA_CUBLAS)
if (LLAMA_CUDA)
cmake_minimum_required(VERSION 3.17)
find_package(CUDAToolkit)
if (CUDAToolkit_FOUND)
message(STATUS "cuBLAS found")
message(STATUS "CUDA found")
enable_language(CUDA)
set(GGML_SOURCES_CUDA ggml-cuda.cu ggml-cuda.h)
add_compile_definitions(GGML_USE_CUBLAS)
add_compile_definitions(GGML_USE_CUDA)
if (LLAMA_CUDA_FORCE_DMMV)
add_compile_definitions(GGML_CUDA_FORCE_DMMV)
endif()
@ -280,7 +280,7 @@ if (LLAMA_CUBLAS)
message(STATUS "Using CUDA architectures: ${CMAKE_CUDA_ARCHITECTURES}")
else()
message(WARNING "cuBLAS not found")
message(WARNING "CUDA not found")
endif()
endif()

View file

@ -55,6 +55,12 @@ else
CXXFLAGS += -DNDEBUG
endif
ifdef LLAMA_SANITIZE
CFLAGS += -g -fsanitize=$(LLAMA_SANITIZE) -fno-omit-frame-pointer
CXXFLAGS += -g -fsanitize=$(LLAMA_SANITIZE) -fno-omit-frame-pointer
LDFLAGS += -g -fsanitize=$(LLAMA_SANITIZE)
endif
ifdef LLAMA_SERVER_VERBOSE
CXXFLAGS += -DSERVER_VERBOSE=$(LLAMA_SERVER_VERBOSE)
endif
@ -163,13 +169,17 @@ ifdef LLAMA_BLIS
LDFLAGS += -lblis -L/usr/local/lib
endif # LLAMA_BLIS
ifdef LLAMA_CUBLAS
CFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
CXXFLAGS += -DGGML_USE_CUBLAS -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
ifdef LLAMA_CUDA
CFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
CXXFLAGS += -DGGML_USE_CUDA -I/usr/local/cuda/include -I/opt/cuda/include -I$(CUDA_PATH)/targets/x86_64-linux/include
LDFLAGS += -lcublas -lculibos -lcudart -lcublasLt -lpthread -ldl -lrt -L/usr/local/cuda/lib64 -L/opt/cuda/lib64 -L$(CUDA_PATH)/targets/x86_64-linux/lib
OBJS += ggml-cuda.o
NVCC = nvcc
NVCCFLAGS = --forward-unknown-to-host-compiler
NVCCV := $(shell $(NVCC) --version | tail -n 1)
ifdef LLAMA_DEBUG
NVCCFLAGS += -lineinfo
endif # LLAMA_DEBUG
ifdef CUDA_DOCKER_ARCH
NVCCFLAGS += -Wno-deprecated-gpu-targets -arch=$(CUDA_DOCKER_ARCH)
else
@ -198,10 +208,9 @@ ifdef LLAMA_CUDA_KQUANTS_ITER
else
NVCCFLAGS += -DK_QUANTS_PER_ITERATION=2
endif
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h ggml-cuda-kern.h ggml-cuda-quant.h
$(NVCC) $(NVCCFLAGS) $(CXXFLAGS) -Wno-pedantic -c $< -o $@
endif # LLAMA_CUBLAS
endif # LLAMA_CUDA
ifdef LLAMA_CLBLAST
CFLAGS += -DGGML_USE_CLBLAST
@ -275,6 +284,9 @@ $(info I CXXFLAGS: $(CXXFLAGS))
$(info I LDFLAGS: $(LDFLAGS))
$(info I CC: $(CCV))
$(info I CXX: $(CXXV))
ifdef LLAMA_CUDA
$(info I NVCC: $(NVCCV))
endif # LLAMA_CUDA
$(info )
#
@ -284,6 +296,12 @@ $(info )
ggml.o: ggml.c ggml.h ggml-cuda.h
$(CC) $(CFLAGS) -c $< -o $@
# temporary, probably will be added to ggml.c
ggml-backend.o: ggml-backend.c ggml-backend.h ggml.h
$(CC) $(CFLAGS) -c $< -o $@
OBJS += ggml-backend.o
llama.o: llama.cpp ggml.h ggml-cuda.h ggml-metal.h llama.h llama-util.h
$(CXX) $(CXXFLAGS) -c $< -o $@

View file

@ -327,24 +327,24 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
params.n_gpu_layers = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: not compiled with GPU offload support, --n-gpu-layers option will be ignored\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU BLAS support\n");
fprintf(stderr, "warning: see main README.md for information on enabling GPU support\n");
#endif
} else if (arg == "--main-gpu" || arg == "-mg") {
if (++i >= argc) {
invalid_param = true;
break;
}
#ifdef GGML_USE_CUBLAS
#ifdef GGML_USE_CUDA
params.main_gpu = std::stoi(argv[i]);
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a main GPU.\n");
fprintf(stderr, "warning: llama.cpp was compiled without CUDA. It is not possible to set a main GPU.\n");
#endif
} else if (arg == "--tensor-split" || arg == "-ts") {
if (++i >= argc) {
invalid_param = true;
break;
}
#ifdef GGML_USE_CUBLAS
#ifdef GGML_USE_CUDA
std::string arg_next = argv[i];
// split string by , and /
@ -361,14 +361,14 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
}
}
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set a tensor split.\n");
#endif // GGML_USE_CUBLAS
fprintf(stderr, "warning: llama.cpp was compiled without CUDA. It is not possible to set a tensor split.\n");
#endif // GGML_USE_CUDA
} else if (arg == "--low-vram" || arg == "-lv") {
#ifdef GGML_USE_CUBLAS
#ifdef GGML_USE_CUDA
params.low_vram = true;
#else
fprintf(stderr, "warning: llama.cpp was compiled without cuBLAS. It is not possible to set lower vram usage.\n");
#endif // GGML_USE_CUBLAS
fprintf(stderr, "warning: llama.cpp was compiled without CUDA. It is not possible to set lower vram usage.\n");
#endif // GGML_USE_CUDA
} else if (arg == "--no-mmap") {
params.use_mmap = false;
} else if (arg == "--mtest") {

View file

@ -175,6 +175,8 @@ int main(int argc, char ** argv)
llama_backend_free();
llama_backend_free();
return 0;
}

680
ggml-backend.c Normal file
View file

@ -0,0 +1,680 @@
#include "ggml-backend.h"
#include <assert.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#define UNUSED(x) (void)(x)
// allocator
static size_t aligned_offset(const void * buffer, size_t offset, size_t alignment) {
assert(alignment && !(alignment & (alignment - 1))); // power of 2
size_t align = (alignment - (((uintptr_t)buffer + offset) % alignment)) % alignment;
return offset + align;
}
static inline size_t ggml_backend_buffer_get_alloc_size(struct ggml_backend_buffer * alloc, struct ggml_tensor * tensor) { return alloc->interface.get_alloc_size(alloc, tensor); }
static inline void ggml_backend_buffer_init_tensor(struct ggml_backend_buffer * alloc, struct ggml_tensor * tensor) { alloc->interface.init_tensor(alloc, tensor); }
void ggml_backend_buffer_free(struct ggml_backend_buffer * alloc) {
alloc->interface.free_buffer(alloc);
free(alloc);
}
// backend buffer allocator - simple
struct ggml_allocator_simple_context {
void * data;
size_t size;
size_t offset;
size_t alignment;
};
static void ggml_allocator_simple_free_buffer(struct ggml_backend_buffer * alloc) {
struct ggml_allocator_simple_context * context = (struct ggml_allocator_simple_context *)alloc->context;
free(context);
}
static void ggml_allocator_simple_alloc_tensor(struct ggml_backend_buffer * alloc, struct ggml_tensor * tensor) {
struct ggml_allocator_simple_context * context = (struct ggml_allocator_simple_context *)alloc->context;
size_t size = ggml_backend_buffer_get_alloc_size(alloc, tensor);
if (context->offset + size > context->size) {
fprintf(stderr, "%s: not enough space in the buffer (needed %zu, available %zu)\n",
__func__, size, context->size - context->offset);
GGML_ASSERT(!"not enough space in the buffer");
return;
}
void * ptr = (char*)context->data + context->offset;
context->offset = aligned_offset(context->data, context->offset + size, context->alignment);
tensor->data = ptr;
if (alloc->interface.init_tensor) {
alloc->interface.init_tensor(alloc, tensor);
}
}
static void ggml_allocator_simple_free_tensor(struct ggml_backend_buffer * alloc, struct ggml_tensor * tensor) {
GGML_ASSERT(!"ggml_simple_allocator cannot free individual tensors");
UNUSED(alloc);
UNUSED(tensor);
}
static void ggml_allocator_simple_reset(struct ggml_backend_buffer * alloc) {
struct ggml_allocator_simple_context * context = (struct ggml_allocator_simple_context *)alloc->context;
context->offset = aligned_offset(context->data, 0, context->alignment);
}
size_t ggml_allocator_simple_get_alloc_size(struct ggml_backend_buffer * alloc, struct ggml_tensor * tensor) {
return ggml_nbytes(tensor);
UNUSED(alloc);
}
static const struct ggml_backend_buffer_interface ggml_allocator_simple_interface = {
/* .free_buffer = */ ggml_allocator_simple_free_buffer,
/* .alloc_tensor = */ ggml_allocator_simple_alloc_tensor,
/* .free_tensor = */ ggml_allocator_simple_free_tensor,
/* .reset = */ ggml_allocator_simple_reset,
/* .get_alloc_size = */ ggml_allocator_simple_get_alloc_size,
/* .init_tensor = */ NULL,
/* .free_data = */ NULL,
};
struct ggml_backend_buffer * ggml_allocator_simple_init(void * data, size_t size, size_t alignment) {
struct ggml_allocator_simple_context * ctx = malloc(sizeof(struct ggml_allocator_simple_context));
ctx->data = data;
ctx->size = size;
ctx->offset = aligned_offset(data, 0, alignment);
ctx->alignment = alignment;
struct ggml_backend_buffer * allocator = malloc(sizeof(struct ggml_backend_buffer));
*allocator = (struct ggml_backend_buffer){
/* .interface = */ ggml_allocator_simple_interface,
/* .context = */ ctx,
/* .backend_data = */ NULL,
};
return allocator;
}
// buffer
struct ggml_buffer * ggml_buffer_alloc(struct ggml_backend * backend, size_t size, size_t max_tensors) {
struct ggml_buffer * buffer = malloc(sizeof(struct ggml_buffer));
buffer->mem_size = ggml_tensor_overhead() * max_tensors;
buffer->mem_buffer = malloc(buffer->mem_size);
buffer->backend = backend;
size += 128 * max_tensors; // alignment overhead
buffer->backend_buffer = backend->interface.alloc_buffer(backend, size);
return buffer;
}
void ggml_buffer_free(struct ggml_buffer * buffer) {
ggml_backend_buffer_free(buffer->backend_buffer);
free(buffer->mem_buffer);
free(buffer);
}
// backend copy
static bool ggml_are_same_layout(const struct ggml_tensor * a, const struct ggml_tensor * b) {
if (a->type != b->type) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (a->ne[i] != b->ne[i]) {
return false;
}
if (a->nb[i] != b->nb[i]) {
return false;
}
}
return true;
}
void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst) {
//printf("src: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", src->name, (int)src->ne[0], (int)src->ne[1], (int)src->ne[2], (int)src->ne[3], (int)src->nb[0], (int)src->nb[1], (int)src->nb[2], (int)src->nb[3]);
//printf("dst: %s ne: [%d %d %d %d] nb: [%d %d %d %d]\n", dst->name, (int)dst->ne[0], (int)dst->ne[1], (int)dst->ne[2], (int)dst->ne[3], (int)dst->nb[0], (int)dst->nb[1], (int)dst->nb[2], (int)dst->nb[3]);
GGML_ASSERT(ggml_are_same_layout(src, dst) && "cannot copy tensors with different layouts");
// printf("cpy tensor %s from %s to %s (%lu bytes)\n", src->name, ggml_backend_name(src->backend), ggml_backend_name(dst->backend), ggml_nbytes(src));
if (src == dst) {
return;
}
//printf("src->data = %p, src->extra = %p\n", src->data, src->extra);
//printf("dst->data = %p, dst->extra = %p\n", dst->data, dst->extra);
if (dst->backend->interface.cpy_tensor_from != NULL) {
dst->backend->interface.cpy_tensor_from(dst->backend->context, src, dst);
} else if (src->backend->interface.cpy_tensor_to != NULL) {
src->backend->interface.cpy_tensor_to(src->backend->context, src, dst);
} else {
// not ideal, but shouldn't be hit when copying from/to CPU
// TODO: print a performance warning in debug builds
size_t nbytes = ggml_nbytes(src);
void * data = malloc(nbytes);
ggml_backend_tensor_get(src, data, 0, nbytes);
ggml_backend_tensor_set(dst, data, 0, nbytes);
free(data);
}
}
// backend CPU
struct ggml_backend_cpu_context {
int n_threads;
void * work_data;
size_t work_size;
};
static const char * ggml_backend_cpu_name(struct ggml_backend * backend) {
return "CPU";
UNUSED(backend);
}
static void ggml_backend_cpu_free(struct ggml_backend * backend) {
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
free(cpu_ctx->work_data);
free(cpu_ctx);
free(backend);
}
static const size_t TENSOR_ALIGNMENT = 64; // should be enough for AVX 512
static void ggml_backend_cpu_free_buffer(struct ggml_backend_buffer * alloc) {
free(alloc->backend_data);
}
static struct ggml_backend_buffer * ggml_backend_cpu_alloc_buffer(struct ggml_backend * backend, size_t size) {
void * data = malloc(size);
struct ggml_backend_buffer * buffer = ggml_allocator_simple_init(data, size, TENSOR_ALIGNMENT);
buffer->interface.free_data = ggml_backend_cpu_free_buffer;
buffer->backend_data = data;
return buffer;
UNUSED(backend);
}
static void ggml_backend_cpu_set_tensor_async(struct ggml_backend * backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
memcpy((char *)tensor->data + offset, data, size);
UNUSED(backend);
}
static void ggml_backend_cpu_get_tensor_async(struct ggml_backend * backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
memcpy(data, (const char *)tensor->data + offset, size);
UNUSED(backend);
}
static void ggml_backend_cpu_synchronize(struct ggml_backend * backend) {
UNUSED(backend);
}
static void ggml_backend_cpu_cpy_tensor_from(struct ggml_backend * backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
ggml_backend_tensor_get(src, dst->data, 0, ggml_nbytes(src));
UNUSED(backend);
}
static void ggml_backend_cpu_cpy_tensor_to(struct ggml_backend * backend, struct ggml_tensor * src, struct ggml_tensor * dst) {
// for a backend such as CUDA that can queue async calls, it is ok to do this asynchronously, but it may not be the case for other backends
ggml_backend_tensor_set_async(dst, src->data, 0, ggml_nbytes(src));
UNUSED(backend);
}
struct ggml_backend_cpu_plan {
struct ggml_cplan cplan;
struct ggml_cgraph cgraph;
};
static ggml_graph_plan_t ggml_backend_cpu_graph_plan_create(struct ggml_backend * backend, struct ggml_cgraph * cgraph) {
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
struct ggml_backend_cpu_plan * cpu_plan = malloc(sizeof(struct ggml_backend_cpu_plan));
cpu_plan->cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
cpu_plan->cgraph = *cgraph;
if (cpu_plan->cplan.work_size > 0) {
cpu_plan->cplan.work_data = malloc(cpu_plan->cplan.work_size);
}
return cpu_plan;
}
static void ggml_backend_cpu_graph_plan_free(struct ggml_backend * backend, ggml_graph_plan_t plan) {
struct ggml_backend_cpu_plan * cpu_plan = (struct ggml_backend_cpu_plan *)plan;
free(cpu_plan->cplan.work_data);
free(cpu_plan);
UNUSED(backend);
}
static void ggml_backend_cpu_graph_plan_compute(struct ggml_backend * backend, ggml_graph_plan_t plan) {
struct ggml_backend_cpu_plan * cpu_plan = (struct ggml_backend_cpu_plan *)plan;
ggml_graph_compute(&cpu_plan->cgraph, &cpu_plan->cplan);
UNUSED(backend);
}
static void ggml_backend_cpu_graph_compute(struct ggml_backend * backend, struct ggml_cgraph * cgraph) {
struct ggml_backend_cpu_context * cpu_ctx = (struct ggml_backend_cpu_context *)backend->context;
struct ggml_cplan cplan = ggml_graph_plan(cgraph, cpu_ctx->n_threads);
if (cpu_ctx->work_size < cplan.work_size) {
// TODO: may be faster to free and use malloc to avoid the copy
cpu_ctx->work_data = realloc(cpu_ctx->work_data, cplan.work_size);
cpu_ctx->work_size = cplan.work_size;
}
cplan.work_data = cpu_ctx->work_data;
ggml_graph_compute(cgraph, &cplan);
}
static struct ggml_backend_interface cpu_backend_interface = {
/* .get_name = */ ggml_backend_cpu_name,
/* .free = */ ggml_backend_cpu_free,
/* .alloc_buffer = */ ggml_backend_cpu_alloc_buffer,
/* .set_tensor_async = */ ggml_backend_cpu_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_cpu_get_tensor_async,
/* .synchronize = */ ggml_backend_cpu_synchronize,
/* .cpy_tensor_from = */ ggml_backend_cpu_cpy_tensor_from,
/* .cpy_tensor_to = */ ggml_backend_cpu_cpy_tensor_to,
/* .graph_plan_create = */ ggml_backend_cpu_graph_plan_create,
/* .graph_plan_free = */ ggml_backend_cpu_graph_plan_free,
/* .graph_plan_compute = */ ggml_backend_cpu_graph_plan_compute,
/* .graph_compute = */ ggml_backend_cpu_graph_compute
};
struct ggml_backend * ggml_backend_cpu_init(void) {
struct ggml_backend_cpu_context * ctx = malloc(sizeof(struct ggml_backend_cpu_context));
ctx->n_threads = GGML_DEFAULT_N_THREADS;
ctx->work_data = NULL;
ctx->work_size = 0;
struct ggml_backend * cpu_backend = malloc(sizeof(struct ggml_backend));
*cpu_backend = (struct ggml_backend) {
/* .interface = */ cpu_backend_interface,
/* .context = */ ctx,
/* .is_ram_shared = */ true,
};
return cpu_backend;
}
void ggml_backend_cpu_set_n_threads(struct ggml_backend * backend_cpu, int n_threads) {
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
ctx->n_threads = n_threads;
}
// splits
struct ggml_graph_splits ggml_graph_split_init(void) {
struct ggml_graph_splits splits = {0};
return splits;
}
// TODO: this can be removed after allocating the graphs in a ggml_context
void ggml_graph_splits_free(struct ggml_graph_splits * splits) {
for (int i = 0; i < splits->n_splits; i++) {
if (splits->splits[i].graph) {
free(splits->splits[i].graph);
}
}
}
void ggml_graph_splits_add_n_va(struct ggml_graph_splits * splits, struct ggml_tensor *** inputs, struct ggml_context * ctx, const char * fmt, va_list args) {
GGML_ASSERT(splits->n_splits < GGML_MAX_SPLITS);
struct ggml_graph_split * split = &splits->splits[splits->n_splits];
// check if the split is on the same backend as the previous one
// FIXME: need to check all the inputs
if ((*inputs[0])->backend == ggml_get_ctx_backend(ctx)) {
if (splits->n_splits == 0) {
// always add the first split
int i = 0;
while (inputs[i] != NULL) {
GGML_ASSERT(i < GGML_MAX_SPLIT_INPUTS);
split->src_inputs[i] = *inputs[i];
split->dst_inputs[i] = *inputs[i];
i++;
}
split->src_inputs[i] = NULL;
split->dst_inputs[i] = NULL;
} else {
// add to the previous split
char name[GGML_MAX_NAME - 2];
int n = vsnprintf(name, sizeof(name), fmt, args);
char new_name[GGML_MAX_NAME];
snprintf(new_name, sizeof(new_name), "%.*s,%s", GGML_MAX_NAME - n - 2, splits->splits[splits->n_splits - 1].name, name);
strcpy(splits->splits[splits->n_splits - 1].name, new_name);
return;
}
} else {
// add a new split
int i = 0;
while (inputs[i] != NULL) {
GGML_ASSERT(i < GGML_MAX_SPLIT_INPUTS);
split->src_inputs[i] = *inputs[i];
split->dst_inputs[i] = ggml_dup_tensor(ctx, *inputs[i]);
// TODO: maybe support different layings in ggml_backend_cpy_tensor instead
for (int j = 0; j < GGML_MAX_DIMS; j++) {
split->dst_inputs[i]->nb[j] = split->src_inputs[i]->nb[j];
}
ggml_set_name(split->dst_inputs[i], ggml_get_name(*inputs[i]));
*inputs[i] = split->dst_inputs[i];
i++;
}
split->src_inputs[i] = NULL;
split->dst_inputs[i] = NULL;
}
vsnprintf(split->name, GGML_MAX_NAME, fmt, args);
split->graph = NULL;
splits->n_splits++;
}
void ggml_graph_splits_add_n(struct ggml_graph_splits * splits, struct ggml_tensor *** input, struct ggml_context * ctx, const char * fmt, ...) {
va_list args;
va_start(args, fmt);
ggml_graph_splits_add_n_va(splits, input, ctx, fmt, args);
va_end(args);
}
void ggml_graph_splits_add(struct ggml_graph_splits * splits, struct ggml_tensor ** input, struct ggml_context * ctx, const char * fmt, ...) {
va_list args;
va_start(args, fmt);
ggml_graph_splits_add_n_va(splits, (struct ggml_tensor**[2]){ input, NULL }, ctx, fmt, args);
va_end(args);
}
void ggml_graph_splits_build_forward(struct ggml_graph_splits * splits, struct ggml_tensor * output) {
struct ggml_tensor *last_outputs[2] = { output, NULL };
struct ggml_tensor ** outputs;
for (int i = 0; i < splits->n_splits; i++) {
struct ggml_graph_split * split = &splits->splits[i];
if (i < splits->n_splits - 1) {
outputs = splits->splits[i + 1].src_inputs;
} else {
outputs = last_outputs;
}
// build the graph
// TODO: allocate graphs in context
split->graph = (struct ggml_cgraph *) malloc(sizeof(struct ggml_cgraph));
memset(split->graph, 0, sizeof(struct ggml_cgraph));
for (int j = 0; outputs[j] != NULL; j++) {
ggml_build_forward_expand(split->graph, outputs[j]);
}
for (int j = 1; j < split->graph->n_nodes; j++) {
if (split->graph->nodes[j]->backend != split->graph->nodes[0]->backend) {
fprintf(stderr, "split %s: node %s has different backend (%s) than the first node (%s)\n",
split->name, split->graph->nodes[j]->name,
ggml_backend_name(split->graph->nodes[j]->backend),
ggml_backend_name(split->graph->nodes[0]->backend));
}
}
for (int j = 1; j < split->graph->n_leafs; j++) {
if (split->graph->leafs[j]->backend != split->graph->leafs[0]->backend) {
fprintf(stderr, "split %s: leaf %s has different backend (%s) than the first leaf (%s)\n",
split->name, split->graph->leafs[j]->name,
ggml_backend_name(split->graph->leafs[j]->backend),
ggml_backend_name(split->graph->leafs[0]->backend));
}
}
}
// close graphs
for (int i = 0; i < splits->n_splits; i++) {
struct ggml_graph_split * split = &splits->splits[i];
ggml_graph_close(split->graph);
}
}
void ggml_graph_splits_compute(struct ggml_graph_splits * splits) {
uint64_t copy_us = 0;
uint64_t compute_cpu_us = 0;
uint64_t compute_gpu_us = 0;
int n_nodes = 0;
for (int i = 0; i < splits->n_splits; i++) {
struct ggml_graph_split * split = &splits->splits[i];
//printf("computing split %i (%s) on backend %s (%i nodes)\n", i, split->name, ggml_backend_name(split->dst_inputs[0]->backend), split->graph->n_nodes);
// copy the input tensor to the backend
uint64_t copy_start_us = ggml_time_us();
for (int j = 0; split->src_inputs[j] != NULL; j++) {
//printf("\tcopying tensor %d (%s) (%lu bytes)\n", j, split->src_inputs[j]->name, ggml_nbytes(split->src_inputs[j]));
ggml_backend_tensor_copy(split->src_inputs[j], split->dst_inputs[j]);
}
// ggml_backend_synchronize(split->dst_inputs[0]->backend);
copy_us += ggml_time_us() - copy_start_us;
#if 0
char split_filename[GGML_MAX_NAME];
snprintf(split_filename, GGML_MAX_NAME, "split_%i.dot", i);
ggml_graph_dump_dot(split->graph, NULL, split_filename);
#endif
uint64_t start = ggml_time_us();
ggml_backend_graph_compute(split->dst_inputs[0]->backend, split->graph);
//ggml_backend_synchronize(split->dst_inputs[0]->backend);
uint64_t end = ggml_time_us();
if (strcmp(ggml_backend_name(split->dst_inputs[0]->backend), "CPU") == 0) {
compute_cpu_us += end - start;
} else {
compute_gpu_us += end - start;
}
n_nodes += split->graph->n_nodes;
}
//printf("splits: %d, nodes: %d, copy: %.2fms, compute_cpu: %.2fms, compute_gpu: %.2fms\n", splits->n_splits, n_nodes, copy_us / 1000.0, compute_cpu_us / 1000.0, compute_gpu_us / 1000.0);
//exit(0);
}
#if 0
// default allocator
struct free_block {
void * addr;
size_t size;
};
struct ggml_backend_default_allocator_context {
void * data;
size_t alignment;
int n_free_blocks;
struct free_block free_blocks[];
};
void ggml_backend_default_allocator_free_context(ggml_allocator_context_t ctx) {
struct ggml_backend_default_allocator_context * allocator_ctx = ctx;
free(allocator_ctx);
}
ggml_allocator_context_t ggml_backend_default_allocator_context(void * data, size_t size, size_t alignment, int n_free_blocks) {
struct ggml_backend_default_allocator_context * ctx = malloc(sizeof(struct ggml_backend_default_allocator_context) + n_free_blocks * sizeof(struct free_block));
ctx->data = data;
ctx->alignment = alignment;
ctx->n_free_blocks = 1;
size_t align_offset = align_offset(data, alignment);
ctx->free_blocks[0].addr = (char *)data + align_offset;
ctx->free_blocks[0].size = size - align_offset;
return ctx;
}
void * ggml_backend_default_allocator_alloc(ggml_allocator_context_t ctx, size_t size) {
struct ggml_backend_default_allocator_context * allocator_ctx = ctx;
size = align_size(size, allocator_ctx->alignment);
// find a free block
for (int i = 0; i < allocator_ctx->n_free_blocks; i++) {
struct free_block * block = &allocator_ctx->free_blocks[i];
if (block->size >= size) {
void * addr = block->addr;
block->addr += size;
block->size -= size;
if (block->size == 0) {
// remove block if empty
allocator_ctx->n_free_blocks--;
for (int j = i; j < allocator_ctx->n_free_blocks; j++) {
allocator_ctx->free_blocks[j] = allocator_ctx->free_blocks[j+1];
}
}
return addr;
}
}
return NULL;
}
// this is a very naive implementation, but for our case the number of free blocks should be very small
void ggml_backend_default_allocator_free(ggml_allocator_context_t ctx, void * ptr, size_t size) {
struct ggml_backend_default_allocator_context * allocator_ctx = ctx;
size = align_size(size, allocator_ctx->alignment);
// see if we can merge with an existing block
for (int i = 0; i < allocator_ctx->n_free_blocks; i++) {
struct free_block * block = &allocator_ctx->free_blocks[i];
// check if ptr is at the end of the block
if (block->addr + block->size == ptr) {
block->size += size;
// check if we can merge with the next block
if (i < allocator_ctx->n_free_blocks - 1 && block->addr + block->size == allocator_ctx->free_blocks[i+1].addr) {
block->size += allocator_ctx->free_blocks[i+1].size;
allocator_ctx->n_free_blocks--;
for (int j = i+1; j < allocator_ctx->n_free_blocks; j++) {
allocator_ctx->free_blocks[j] = allocator_ctx->free_blocks[j+1];
}
}
return;
}
// check if ptr is at the beginning of the block
if (ptr + size == block->addr) {
block->addr = ptr;
block->size += size;
// check if we can merge with the previous block
if (i > 0 && allocator_ctx->free_blocks[i-1].addr + allocator_ctx->free_blocks[i-1].size == block->addr) {
allocator_ctx->free_blocks[i-1].size += block->size;
allocator_ctx->n_free_blocks--;
for (int j = i; j < allocator_ctx->n_free_blocks; j++) {
allocator_ctx->free_blocks[j] = allocator_ctx->free_blocks[j+1];
}
}
return;
}
}
// otherwise, add a new block
if (allocator_ctx->n_free_blocks < MAX_FREE_BLOCKS) {
// insert the new block in the correct position to keep the array sorted
int insert_pos = 0;
while (insert_pos < allocator_ctx->n_free_blocks && allocator_ctx->free_blocks[insert_pos].addr < ptr) {
insert_pos++;
}
// shift all blocks from insert_pos onward to make room for the new block
for (int i = allocator_ctx->n_free_blocks; i > insert_pos; i--) {
allocator_ctx->free_blocks[i] = allocator_ctx->free_blocks[i-1];
}
// insert the new block
allocator_ctx->free_blocks[insert_pos].addr = ptr;
allocator_ctx->free_blocks[insert_pos].size = size;
allocator_ctx->n_free_blocks++;
}
else {
GGML_ASSERT(!"out of free blocks");
}
}
static bool ggml_is_view(struct ggml_tensor * t) {
return t->op == GGML_OP_RESHAPE || t->op == GGML_OP_VIEW || t->op == GGML_OP_TRANSPOSE ||
t->op == GGML_OP_PERMUTE || t->op == GGML_OP_NONE;
}
NOTE: id can be n_leaf OR n_node instead, we can determine the type by checking if the node is a leaf or not
void allocate_graph(struct ggml_cgraph * gf, struct ggml_buffer * buffer) {
int node_children_count[GGML_MAX_NODES*2];
int node_view_count[GGML_MAX_NODES*2];
memset(node_children_count, 0, sizeof(int) * (gf->n_nodes + gf->n_leafs));
memset(node_view_count, 0, sizeof(int) * (gf->n_nodes + gf->n_leafs));
// count number of children and views
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_tensor * node = gf->nodes[i];
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
// todo: ....
node_children_count[parent->id] += 1;
if (ggml_is_view(parent)) {
struct ggml_tensor * ancestor = parent;
do {
node_view_count[ancestor->id] += 1;
ancestor = ancestor->src[0];
} while (ggml_is_view(ancestor));
}
}
}
// allocate tensors
for (int i = 0; i < gf->n_nodes; i++) {
struct ggml_tensor * node = gf->nodes[i];
bool is_view = ggml_is_view(node);
if (is_view) {
// allocate view accordingly to the OP
node->data = node->src[0]->data; // + offset
struct ggml_tensor * ancestor = node->src[0];
while (ggml_is_view(ancestor)) {
ancestor = ancestor->src[0];
}
node_view_count[ancestor->id] -= 1;
} else {
if (node->data == NULL) {
// allocate tensor
// TODO: if last children and size == parent.size, then reuse parent tensor (auto in-place)
// may need a list of ops that can be in-place
ggml_backend_alloc_tensor(buffer, node);
}
}
// update parents
for (int j = 0; j < GGML_MAX_SRC; j++) {
struct ggml_tensor * parent = node->src[j];
if (parent == NULL) {
break;
}
if (is_view) {
node_view_count[parent->id] -= 1;
}
node_children_count[parent->id] -= 1;
if (node_children_count[parent->id] == 0 && node_view_count[parent->id] == 0) {
// free parent
ggml_backend_free_tensor(buffer, parent);
}
}
}
}
#endif

159
ggml-backend.h Normal file
View file

@ -0,0 +1,159 @@
#pragma once
#include "ggml.h"
#ifdef __cplusplus
extern "C" {
#endif
struct ggml_backend;
// backend buffers
typedef void * ggml_buffer_context_t;
struct ggml_backend_buffer;
struct ggml_backend_buffer_interface {
// allocator functions
void (*free_buffer) (struct ggml_backend_buffer * alloc);
void (*alloc_tensor) (struct ggml_backend_buffer * alloc, struct ggml_tensor * tensor);
void (*free_tensor) (struct ggml_backend_buffer * alloc, struct ggml_tensor * tensor);
void (*reset) (struct ggml_backend_buffer * alloc);
// functions overriden by the backend
size_t (*get_alloc_size)(struct ggml_backend_buffer * alloc, struct ggml_tensor * tensor); // pre-allocation callback
void (*init_tensor) (struct ggml_backend_buffer * alloc, struct ggml_tensor * tensor); // post-allocation callback
void (*free_data) (struct ggml_backend_buffer * alloc); // free backend-specific data // TODO: better name
};
struct ggml_backend_buffer {
struct ggml_backend_buffer_interface interface;
ggml_buffer_context_t context;
void * backend_data;
};
// backend buffer helper functions
GGML_API void ggml_backend_buffer_free(struct ggml_backend_buffer * alloc);
static inline void ggml_backend_buffer_tensor_alloc(struct ggml_backend_buffer * alloc, struct ggml_tensor * tensor) { alloc->interface.alloc_tensor(alloc, tensor); }
static inline void ggml_backend_buffer_free_tensor(struct ggml_backend_buffer * alloc, struct ggml_tensor * tensor) { alloc->interface.free_tensor(alloc, tensor); }
static inline void ggml_backend_buffer_reset(struct ggml_backend_buffer * alloc) { alloc->interface.reset(alloc); }
// default buffer allocators
// simple buffer allocator: cannot free tensors, good for weights and small contexts
// default buffer allocator: can free tensors, good for compute contexts
GGML_API struct ggml_backend_buffer * ggml_allocator_simple_init(void * data, size_t size, size_t alignment);
GGML_API struct ggml_backend_buffer * ggml_allocator_default_init(void * data, size_t size, size_t alignment, int max_free_blocks);
// buffer
// buffers have space for the tensor structs in host memory, and tensor data in backend-specific memory
struct ggml_buffer {
// host memory
size_t mem_size;
void * mem_buffer;
// tensor data
struct ggml_backend * backend;
struct ggml_backend_buffer * backend_buffer;
};
GGML_API struct ggml_buffer * ggml_buffer_alloc(struct ggml_backend * backend, size_t size, size_t max_tensors);
GGML_API void ggml_buffer_free(struct ggml_buffer * buffer);
// backend
typedef void * ggml_backend_context_t;
typedef void * ggml_graph_plan_t;
struct ggml_backend_interface {
const char * (*get_name)(struct ggml_backend * backend);
void (*free)(struct ggml_backend * backend);
// buffer allocation
struct ggml_backend_buffer * (*alloc_buffer)(struct ggml_backend * backend, size_t size);
// tensor data access
// these functions can be asynchronous. helper functions are provided for synchronous access that automatically call synchronize
void (*set_tensor_async)(struct ggml_backend * backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size);
void (*get_tensor_async)(struct ggml_backend * backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size);
void (*synchronize) (struct ggml_backend * backend);
// (optional) copy tensor between different backends, allow for single-copy tranfers
void (*cpy_tensor_from)(struct ggml_backend * backend, struct ggml_tensor * src, struct ggml_tensor * dst);
void (*cpy_tensor_to) (struct ggml_backend * backend, struct ggml_tensor * src, struct ggml_tensor * dst);
// compute graph with a plan
ggml_graph_plan_t (*graph_plan_create) (struct ggml_backend * backend, struct ggml_cgraph * cgraph);
void (*graph_plan_free) (struct ggml_backend * backend, ggml_graph_plan_t plan);
void (*graph_plan_compute)(struct ggml_backend * backend, ggml_graph_plan_t plan);
// compute graph without a plan
void (*graph_compute) (struct ggml_backend * backend, struct ggml_cgraph * cgraph);
// check if a backend supports a given operation
// this could be used to fallback automatically to the CPU backend if a backend doesn't support an operation
// bool (*supports_op)(struct ggml_backend * backend, struct ggml_tensor * op);
};
struct ggml_backend {
struct ggml_backend_interface interface;
ggml_backend_context_t context;
bool is_ram_shared;
};
// backend helper functions
static inline const char * ggml_backend_name(struct ggml_backend * backend) { return backend->interface.get_name(backend); }
static inline void ggml_backend_free(struct ggml_backend * backend) { backend->interface.free(backend); }
static inline void ggml_backend_tensor_set_async(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { tensor->backend->interface.set_tensor_async(tensor->backend, tensor, data, offset, size); }
static inline void ggml_backend_tensor_get_async(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { tensor->backend->interface.get_tensor_async(tensor->backend, tensor, data, offset, size); }
static inline void ggml_backend_tensor_set(struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { tensor->backend->interface.set_tensor_async(tensor->backend, tensor, data, offset, size); tensor->backend->interface.synchronize(tensor->backend); }
static inline void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { tensor->backend->interface.get_tensor_async(tensor->backend, tensor, data, offset, size); tensor->backend->interface.synchronize(tensor->backend); }
static inline void ggml_backend_synchronize(struct ggml_backend * backend) { backend->interface.synchronize(backend); }
static inline ggml_graph_plan_t ggml_backend_graph_plan_create(struct ggml_backend * backend, struct ggml_cgraph * cgraph) { return backend->interface.graph_plan_create(backend, cgraph); }
static inline void ggml_backend_graph_plan_free(struct ggml_backend * backend, ggml_graph_plan_t plan) { backend->interface.graph_plan_free(backend, plan); }
static inline void ggml_backend_graph_plan_compute(struct ggml_backend * backend, ggml_graph_plan_t plan) { backend->interface.graph_plan_compute(backend, plan); }
static inline void ggml_backend_graph_compute(struct ggml_backend * backend, struct ggml_cgraph * cgraph) { backend->interface.graph_compute(backend, cgraph); }
// tensor copy between different backends
GGML_API void ggml_backend_tensor_copy(struct ggml_tensor * src, struct ggml_tensor * dst);
// CPU backend
GGML_API struct ggml_backend * ggml_backend_cpu_init(void);
GGML_API void ggml_backend_cpu_set_n_threads(struct ggml_backend * backend_cpu, int n_threads);
///////////////////////////
// graph splitting
#define GGML_MAX_SPLITS 200
#define GGML_MAX_SPLIT_INPUTS 4
struct ggml_graph_split {
char name[GGML_MAX_NAME];
struct ggml_tensor * src_inputs[GGML_MAX_SPLIT_INPUTS + 1];
struct ggml_tensor * dst_inputs[GGML_MAX_SPLIT_INPUTS + 1];
struct ggml_cgraph * graph;
};
// TODO: this shouldn't be fixed size, allocate from ggml_context
struct ggml_graph_splits {
int n_splits;
struct ggml_graph_split splits[GGML_MAX_SPLITS];
};
// TODO: allocate in ggml_context
struct ggml_graph_splits ggml_graph_split_init(void);
// this won't be needed once we can allocate graphs from a ggml_context
GGML_API void ggml_graph_splits_free(struct ggml_graph_splits * splits);
// add a split to the graph - single and multiple inputs versions
GGML_API void ggml_graph_splits_add(struct ggml_graph_splits * splits, struct ggml_tensor ** input, struct ggml_context * ctx, const char * fmt, ...);
GGML_API void ggml_graph_splits_add_n(struct ggml_graph_splits * splits, struct ggml_tensor *** inputs, struct ggml_context * ctx, const char * fmt, ...);
// build graphs for all splits
GGML_API void ggml_graph_splits_build_forward(struct ggml_graph_splits * splits, struct ggml_tensor * output);
// compute
GGML_API void ggml_graph_splits_compute(struct ggml_graph_splits * splits);
#ifdef __cplusplus
}
#endif

468
ggml-cuda-kern.h Normal file
View file

@ -0,0 +1,468 @@
// kernels for ggml-cuda
#include <cuda.h>
#include <cuda_fp16.h>
template<typename dst_t>
using to_t_cuda_t = void (*)(const void * x, dst_t * y, int k, cudaStream_t stream);
// support for vector types in generic code
template<typename T> struct vec2_t_impl;
template<> struct vec2_t_impl<half> { typedef half2 type; };
template<> struct vec2_t_impl<float> { typedef float2 type; };
template<typename T> using vec2_t = typename vec2_t_impl<T>::type;
template<typename T> inline __host__ __device__ vec2_t<T> make_vec2_t(const T & x, const T & y);
template<> inline __host__ __device__ vec2_t<half> make_vec2_t(const half & x, const half & y) { return make_half2 (x, y); }
template<> inline __host__ __device__ vec2_t<float> make_vec2_t(const float & x, const float & y) { return make_float2(x, y); }
// the cuda headers define operators for half2, but not for float2
// they are defined here to simplify generic code
inline __host__ __device__ float2 operator+(const float2 & a, const float2 & b) { return make_float2(a.x + b.x, a.y + b.y); }
inline __host__ __device__ float2 operator-(const float2 & a, const float2 & b) { return make_float2(a.x - b.x, a.y - b.y); }
inline __host__ __device__ float2 operator*(const float2 & a, const float2 & b) { return make_float2(a.x * b.x, a.y * b.y); }
inline __host__ __device__ float2 operator/(const float2 & a, const float2 & b) { return make_float2(a.x / b.x, a.y / b.y); }
inline __host__ __device__ float2 & operator+=( float2 & a, const float2 & b) { a.x += b.x; a.y += b.y; return a; }
inline __host__ __device__ float2 & operator-=( float2 & a, const float2 & b) { a.x -= b.x; a.y -= b.y; return a; }
inline __host__ __device__ float2 & operator*=( float2 & a, const float2 & b) { a.x *= b.x; a.y *= b.y; return a; }
inline __host__ __device__ float2 & operator/=( float2 & a, const float2 & b) { a.x /= b.x; a.y /= b.y; return a; }
template<typename dst_t>
using dequantize_kernel_t = void (*)(const void * vx, const int ib, const int iqs, vec2_t<dst_t> & v);
__device__ half sqrt(const half x) { return hsqrt(x); }
__device__ half exp(const half x) { return hexp(x); }
__device__ half2 exp(const half2 x) { return h2exp(x); }
__device__ half cos(const half x) { return hcos(x); }
__device__ half sin(const half x) { return hsin(x); }
__device__ half max(const half x, const half y) { return __hmax(x, y); }
__device__ half2 max(const half2 x, const half2 y) { return __hmax2(x, y); }
template<typename T> struct op_max { __device__ T operator()(T a, T b) const { return max(a, b); } };
template<typename T> struct op_sum { __device__ T operator()(T a, T b) const { return a + b; } };
template<template<typename> class op_t, typename T>
static inline __device__ T warp_reduce_all(T val) {
op_t<T> op;
#pragma unroll
for (int mask = warpSize/2; mask > 0; mask /= 2) {
val = op(val, __shfl_xor_sync(0xffffffff, val, mask, 32));
}
return val;
}
template<typename T>
static __device__ T zero_init() { return T(0); }
template<>
__device__ half2 zero_init() { return half2(0.0f, 0.0f); }
template<template<typename> class op_t, typename T>
static __device__ T block_reduce_all(const T val, const T init = zero_init<T>()) {
const int warp_id = threadIdx.x / warpSize; // warp id within the block
const int lane_id = threadIdx.x % warpSize; // lane id within the warp
const int num_warps = blockDim.x / warpSize; // number of warps in the block
__shared__ T lane_result[32]; // max 32 warps per block
// reduce warps
T warp_reduction = warp_reduce_all<op_t>(val);
__syncthreads();
// first thread within a warp writes reduction to shared memory
if (lane_id == 0) {
lane_result[warp_id] = warp_reduction;
}
// wait for all warps to finish writing their reductions
__syncthreads();
// reduce the results of all warps
T block_reduction = init;
if (lane_id < num_warps) {
block_reduction = lane_result[lane_id];
}
block_reduction = warp_reduce_all<op_t>(block_reduction);
return block_reduction;
}
template<typename dst_t>
static __device__ void convert_fp16(const void * vx, const int ib, const int iqs, vec2_t<dst_t> & v) {
const half * x = (const half *) vx;
v.x = (dst_t)(x[ib + iqs + 0]);
v.y = (dst_t)(x[ib + iqs + 1]);
}
template<typename dst_t>
static __device__ void convert_fp32(const void * vx, const int ib, const int iqs, vec2_t<dst_t> & v) {
const float * x = (const float *) vx;
v.x = (dst_t)(x[ib + iqs + 0]);
v.y = (dst_t)(x[ib + iqs + 1]);
}
template<typename src0_t, typename src1_t, typename dst_t>
static __global__ void k_mul_mat_p021(const src0_t * vx, const src1_t * y, dst_t * dst, const int ncols_x, const int nrows_x, const int nchannels_x) {
const src0_t * x = vx;
// const int col_x = blockDim.x*blockIdx.x + threadIdx.x;
// const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
const int row_dst = row_x;
dst_t tmp = 0;
for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
const int col_x = col_x0 + threadIdx.x;
if (col_x >= ncols_x) {
break;
}
// x is transposed and permuted
const int ix = row_x*nchannels_x*ncols_x + channel*ncols_x + col_x;
const dst_t xi = (dst_t)(x[ix]);
const int row_y = col_x;
// y is not transposed but permuted
const int iy = channel*nrows_y + row_y;
tmp += xi * y[iy];
}
// dst is not transposed and not permuted
const int idst = channel*nrows_dst + row_dst;
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (threadIdx.x == 0) {
dst[idst] = tmp;
}
}
template<typename src0_t, typename src1_t, typename dst_t>
static __global__ void k_mul_mat_vec_nc(
const src0_t * vx, const src1_t * y, dst_t * dst, const int ncols_x, const int nrows_x,
const int row_stride_x, const int nchannels_x, const int channel_stride_x) {
const src0_t * x = vx;
const int row_x = blockDim.y*blockIdx.y + threadIdx.y;
const int channel = blockDim.z*blockIdx.z + threadIdx.z;
const int nrows_y = ncols_x;
const int nrows_dst = nrows_x;
const int row_dst = row_x;
const int idst = channel*nrows_dst + row_dst;
dst_t tmp = 0;
for (int col_x0 = 0; col_x0 < ncols_x; col_x0 += blockDim.x) {
const int col_x = col_x0 + threadIdx.x;
if (col_x >= ncols_x) {
break;
}
const int ix = channel*channel_stride_x + row_x*row_stride_x + col_x;
const dst_t xi = (dst_t)(x[ix]);
const int row_y = col_x;
const int iy = channel*nrows_y + row_y;
tmp += xi * y[iy];
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (threadIdx.x == 0) {
dst[idst] = tmp;
}
}
template <typename src_t, typename dst_t>
static __global__ void k_cpy(const char * cx, char * cdst, const int ne,
const int ne00, const int ne01, const int nb00, const int nb01, const int nb02,
const int ne10, const int ne11, const int nb10, const int nb11, const int nb12) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) {
return;
}
const int i02 = i / (ne00*ne01);
const int i01 = (i - i02*ne01*ne00) / ne00;
const int i00 = i - i02*ne01*ne00 - i01*ne00;
const int x_offset = i00*nb00 + i01*nb01 + i02*nb02;
const int i12 = i / (ne10*ne11);
const int i11 = (i - i12*ne10*ne11) / ne10;
const int i10 = i - i12*ne10*ne11 - i11*ne10;
const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12;
*(dst_t *)(cdst + dst_offset) = *(const src_t *)(cx + x_offset);
}
template<typename src0_t, typename src1_t, typename dst_t>
static __global__ void k_add(const src0_t * x, const src1_t * y, dst_t * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = (dst_t)x[i] + (dst_t)y[i];
}
template<typename src0_t, typename src1_t, typename dst_t>
static __global__ void k_mul(const src0_t * x, const src1_t * y, dst_t * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= kx) {
return;
}
dst[i] = (dst_t)x[i] * (dst_t)y[i%ky];
}
template<typename src0_t, typename dst_t>
static __global__ void k_silu(const src0_t * x, dst_t * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = x[i] / (src0_t(1) + exp(-x[i]));
}
// TODO: unstable with f16 compute, using f32 compute for now
template<typename src0_t, typename dst_t>
static __global__ void k_rms_norm(const src0_t * x, dst_t * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const float eps = 1e-6;
float tmp = 0; // partial sum for thread in warp
for (int col = tid; col < ncols; col += WARP_SIZE) {
const float xi = x[row*ncols + col];
tmp += xi * xi;
}
// sum up partial sums
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
const float mean = tmp / (float)ncols;
const float scale = 1.0f / sqrtf(mean + eps);
for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[row*ncols + col] = scale * (float)x[row*ncols + col];
}
}
template<typename src0_t, typename dst_t>
static __global__ void k_rope(const src0_t * x, dst_t * dst, const int ncols, const float p, const float theta_scale) {
const int col = 2*(blockDim.x*blockIdx.x + threadIdx.x);
if (col >= ncols) {
return;
}
const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int i = row*ncols + col;
const dst_t theta = p * powf(theta_scale, col/2);
const dst_t sin_theta = sin(theta);
const dst_t cos_theta = cos(theta);
const dst_t x0 = x[i + 0];
const dst_t x1 = x[i + 1];
dst[i + 0] = (dst_t)x0*cos_theta - (dst_t)x1*sin_theta;
dst[i + 1] = (dst_t)x0*sin_theta + (dst_t)x1*cos_theta;
}
template<typename src0_t, typename dst_t>
static __global__ void k_diag_mask_inf(const src0_t * x, dst_t * dst, const int ncols, const int rows_per_channel, const int n_past) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int row = blockDim.y*blockIdx.y + threadIdx.y;
if (col >= ncols) {
return;
}
const int i = row*ncols + col;
//dst[i] = col > (n_past + row % rows_per_channel) ? (dst_t)-INFINITY : (dst_t)x[i];
dst[i] = (dst_t)x[i] - (dst_t)((col > n_past + row % rows_per_channel) * INT_MAX); // equivalent within rounding error but slightly faster on GPU
}
// TODO: numerically stable version - low prio since the softmax is computed in the fused attention kernel
// check: https://arxiv.org/pdf/2001.04438.pdf
template<typename src0_t, typename dst_t>
static __global__ void k_soft_max_orig(const src0_t * x, dst_t * dst, const int ncols) {
const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int block_size = blockDim.x;
const int tid = threadIdx.x;
float tmp = 0;
for (int block_start = 0; block_start < ncols; block_start += block_size) {
const int col = block_start + tid;
if (col >= ncols) {
break;
}
const int i = row*ncols + col;
const float val = expf(x[i]);
tmp += val;
dst[i] = val;
}
// sum up partial sums
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
for (int block_start = 0; block_start < ncols; block_start += block_size) {
const int col = block_start + tid;
if (col >= ncols) {
break;
}
const int i = row*ncols + col;
dst[i] /= tmp;
}
}
template<typename src_t, typename dst_t, int pack_size, int block_size>
static __global__ void k_soft_max(const src_t * x, dst_t * dst, const int64_t nrows, const int64_t ncols) {
//assert(ncols % pack_size == 0);
const int tid = threadIdx.x;
const int num_packs = ncols / pack_size;
for (int row = blockIdx.x; row < nrows; row += gridDim.x) {
src_t th_max = -INFINITY;
// row max thread
#pragma unroll
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
// load pack
src_t pack[pack_size];
#pragma unroll
for (int i = 0; i < pack_size; i++) {
pack[i] = x[row * ncols + pack_id * pack_size + i];
}
// reduce max pack
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
th_max = max(th_max, pack[i]);
}
}
// reduce max row warp threads
src_t row_max = block_reduce_all<op_max>(th_max, (src_t)-INFINITY);
// row exp sum thread
src_t th_sum = 0;
#pragma unroll
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
// load pack
src_t pack[pack_size];
#pragma unroll
for (int i = 0; i < pack_size; i++) {
pack[i] = x[row * ncols + pack_id * pack_size + i];
}
// reduce pack
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
th_sum += exp(pack[i] - row_max);
}
}
// reduce row exp sum all threads
src_t row_sum = block_reduce_all<op_sum>(th_sum);
// store (row - row_max) / row exp sum
#pragma unroll
for (int pack_id = tid; pack_id < num_packs; pack_id += block_size) {
// load pack
src_t pack[pack_size];
#pragma unroll
for (int i = 0; i < pack_size; i++) {
pack[i] = x[row * ncols + pack_id * pack_size + i];
}
// reduce pack
#pragma unroll
for (int i = 0; i < pack_size; ++i) {
pack[i] = exp(pack[i] - row_max) / row_sum;
}
// store pack
#pragma unroll
for (int i = 0; i < pack_size; i++) {
dst[row * ncols + pack_id * pack_size + i] = pack[i];
}
}
}
}
template<typename src0_t, typename src1_t, typename dst_t>
static __global__ void k_scale(const src0_t * x, dst_t * dst, const src1_t * scale, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
dst[i] = (dst_t)(*scale) * (dst_t)x[i];
}
template<typename dst_t, int qk, int qr, dequantize_kernel_t<dst_t> dequantize_kernel>
static __global__ void k_get_rows(const void * x, const int * y, dst_t * dst, const int ncols) {
const int col = (blockIdx.x*blockDim.x + threadIdx.x)*2;
const int row = blockDim.y*blockIdx.y + threadIdx.y;
if (col >= ncols) {
return;
}
const int r = y[row];
// copy x[r*ncols + col] to dst[row*ncols + col]
const int xi = r*ncols + col;
const int di = row*ncols + col;
const int ib = xi/qk; // block index
const int iqs = (xi%qk)/qr; // quant index
const int iybs = di - di%qk; // y block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
vec2_t<dst_t> v;
dequantize_kernel(x, ib, iqs, v);
dst[iybs + iqs + 0] = v.x;
dst[iybs + iqs + y_offset] = v.y;
}

920
ggml-cuda-quant.h Normal file
View file

@ -0,0 +1,920 @@
// quants kernels for ggml-cuda
// QK = number of values after dequantization
// QR = QK / number of values before dequantization
// QI = number of 32 bit integers before dequantization
#define QK4_0 32
#define QR4_0 2
#define QI4_0 4
typedef struct {
half d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants
} block_q4_0;
static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0 block size/padding");
#define QK4_1 32
#define QR4_1 2
#define QI4_1 4
typedef struct {
half d; // delta
half m; // min
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong q4_1 block size/padding");
#define QK5_0 32
#define QR5_0 2
#define QI5_0 4
typedef struct {
half d; // delta
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_0 / 2]; // nibbles / quants
} block_q5_0;
static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_0 / 2, "wrong q5_0 block size/padding");
#define QK5_1 32
#define QR5_1 2
#define QI5_1 4
typedef struct {
half d; // delta
half m; // min
uint8_t qh[4]; // 5-th bit of quants
uint8_t qs[QK5_1 / 2]; // nibbles / quants
} block_q5_1;
static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5_1 / 2, "wrong q5_1 block size/padding");
#define QK8_0 32
#define QR8_0 1
#define QI8_0 8
typedef struct {
half d; // delta
int8_t qs[QK8_0]; // quants
} block_q8_0;
static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 block size/padding");
#define QK8_1 32
#define QR8_1 1
#define QI8_1 8
typedef struct {
half d; // delta
half s; // unquantized sum
int8_t qs[QK8_0]; // quants
} block_q8_1;
static_assert(sizeof(block_q8_1) == 2*sizeof(ggml_fp16_t) + QK8_0, "wrong q8_1 block size/padding");
//================================= k-quants
#define QK_K 256
typedef struct {
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
uint8_t qs[QK_K/4]; // quants
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
} block_q2_K;
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
typedef struct {
uint8_t hmask[QK_K/8];
uint8_t qs[QK_K/4]; // nibbles / quants
uint8_t scales[3*QK_K/64];
half d;
} block_q3_K;
static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + 11 * QK_K / 64, "wrong q3_K block size/padding");
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
uint8_t qs[QK_K/2]; // 4--bit quants
} block_q4_K;
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uint8_t scales[3*QK_K/64]; // scales, quantized with 6 bits
uint8_t qh[QK_K/8]; // quants, high bit
uint8_t qs[QK_K/2]; // quants, low 4 bits
} block_q5_K;
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits
int8_t scales[QK_K/16]; // scales
half d; // delta
} block_q6_K;
static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_K block size/padding");
template<typename src1_t, typename dst_t>
using dot_kernel_k_t = void (*)(const void * vx, const int ib, const int iqs, const src1_t * y, dst_t & v);
template<typename dst_t>
using vec_dot_q_cuda_t = dst_t (*)(const void * vbq, const block_q8_1 * bq8_1, const int iqs);
// TODO: f16
template<typename src_t>
static __global__ void quantize_q8_1(const src_t * x, void * vy, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
block_q8_1 * y = (block_q8_1 *) vy;
const int ib = i / QK8_0; // block index
const int iqs = i % QK8_0; // quant index
const float xi = x[i];
float amax = fabsf(xi);
float sum = xi;
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32));
sum += __shfl_xor_sync(0xffffffff, sum, mask, 32);
}
const float d = amax / 127;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
y[ib].qs[iqs] = q;
if (iqs > 0) {
return;
}
y[ib].d = d;
y[ib].s = sum;
}
template<typename dst_t>
static __device__ void dequantize_q4_0(const void * vx, const int ib, const int iqs, vec2_t<dst_t> & v){
const block_q4_0 * x = (const block_q4_0 *) vx;
const dst_t d = x[ib].d;
const uint8_t vui = x[ib].qs[iqs];
v.x = vui & 0xF;
v.y = vui >> 4;
const vec2_t<dst_t> off2 = make_vec2_t<dst_t>(8, 8);
const vec2_t<dst_t> d2 = make_vec2_t<dst_t>(d, d);
v = (v - off2) * d2;
}
template<typename dst_t>
static __device__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, vec2_t<dst_t> & v){
const block_q4_1 * x = (const block_q4_1 *) vx;
const dst_t d = x[ib].d;
const dst_t m = x[ib].m;
const uint8_t vui = x[ib].qs[iqs];
v.x = vui & 0xF;
v.y = vui >> 4;
const vec2_t<dst_t> d2 = make_vec2_t<dst_t>(d, d);
const vec2_t<dst_t> m2 = make_vec2_t<dst_t>(m, m);
v = v * d2 + m2;
}
template<typename dst_t>
static __device__ void dequantize_q5_0(const void * vx, const int ib, const int iqs, vec2_t<dst_t> & v){
const block_q5_0 * x = (const block_q5_0 *) vx;
const dst_t d = x[ib].d;
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
const vec2_t<dst_t> off2 = make_vec2_t<dst_t>(16, 16);
const vec2_t<dst_t> d2 = make_vec2_t<dst_t>(d, d);
v = (v - off2) * d2;
}
template<typename dst_t>
static __device__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, vec2_t<dst_t> & v){
const block_q5_1 * x = (const block_q5_1 *) vx;
const dst_t d = x[ib].d;
const dst_t m = x[ib].m;
uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh));
const uint8_t xh_0 = ((qh >> (iqs + 0)) << 4) & 0x10;
const uint8_t xh_1 = ((qh >> (iqs + 12)) ) & 0x10;
v.x = ((x[ib].qs[iqs] & 0xf) | xh_0);
v.y = ((x[ib].qs[iqs] >> 4) | xh_1);
const vec2_t<dst_t> d2 = make_vec2_t<dst_t>(d, d);
const vec2_t<dst_t> m2 = make_vec2_t<dst_t>(m, m);
v = v * d2 + m2;
}
template<typename dst_t>
static __device__ void dequantize_q8_0(const void * vx, const int ib, const int iqs, vec2_t<dst_t> & v){
const block_q8_0 * x = (const block_q8_0 *) vx;
const dst_t d = x[ib].d;
v.x = x[ib].qs[iqs + 0];
v.y = x[ib].qs[iqs + 1];
const vec2_t<dst_t> d2 = make_vec2_t<dst_t>(d, d);
v = v * d2;
}
//================================== k-quants
static __global__ void dequantize_block_q2_K(const void * vx, float * yy) {
const int i = blockIdx.x;
const int tid = threadIdx.x;
const int n = tid/32;
const int l = tid - 32*n;
const int is = 8*n + l/16;
const block_q2_K * x = (const block_q2_K *) vx;
const uint8_t q = x[i].qs[32*n + l];
float * y = yy + i*QK_K + 128*n;
float dall = x[i].d;
float dmin = x[i].dmin;
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
y[l+96] = dall * (x[i].scales[is+6] & 0xF) * ((q >> 6) & 3) - dmin * (x[i].scales[is+6] >> 4);
}
static __device__ void vec_dot_q2_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
const block_q2_K * x = (const block_q2_K *) vx;
// if n is 0, we want to do the lower 128, else the upper 128,
// covering y[l+0], y[l+32], y[l+64], y[l+96] and
// y[l+16], y[l+48], y[l+80], y[l+112]
int n = iqs/128; // 0 or 1
int r = iqs - 128*n; // 0...120 in steps of 8
int l = r/8; // 0...15 in steps of 1
const float * y = yy + 128*n + l;
const uint8_t * q = x[ib].qs + 32*n + l;
const uint8_t * s = x[ib].scales + 8*n;
const float dall = x[ib].d;
const float dmin = x[ib].dmin;
float sum = y[ 0] * (dall * ((s[0] & 0xF) * ((q[ 0] >> 0) & 3)) - dmin * (s[0] >> 4))
+ y[ 32] * (dall * ((s[2] & 0xF) * ((q[ 0] >> 2) & 3)) - dmin * (s[2] >> 4))
+ y[ 64] * (dall * ((s[4] & 0xF) * ((q[ 0] >> 4) & 3)) - dmin * (s[4] >> 4))
+ y[ 96] * (dall * ((s[6] & 0xF) * ((q[ 0] >> 6) & 3)) - dmin * (s[6] >> 4))
+ y[ 16] * (dall * ((s[1] & 0xF) * ((q[16] >> 0) & 3)) - dmin * (s[1] >> 4))
+ y[ 48] * (dall * ((s[3] & 0xF) * ((q[16] >> 2) & 3)) - dmin * (s[3] >> 4))
+ y[ 80] * (dall * ((s[5] & 0xF) * ((q[16] >> 4) & 3)) - dmin * (s[5] >> 4))
+ y[112] * (dall * ((s[7] & 0xF) * ((q[16] >> 6) & 3)) - dmin * (s[7] >> 4));
result = sum;
}
static __global__ void dequantize_block_q3_K(const void * vx, float * yy) {
int r = threadIdx.x/4;
int i = blockIdx.x;
int tid = r/2;
int is0 = r%2;
int l0 = 16*is0 + 4*(threadIdx.x%4);
int n = tid / 4;
int j = tid - 4*n;
const block_q3_K * x = (const block_q3_K *) vx;
uint8_t m = 1 << (4*n + j);
int is = 8*n + 2*j + is0;
int shift = 2*j;
int8_t us = is < 4 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+8] >> 0) & 3) << 4) :
is < 8 ? (x[i].scales[is-0] & 0xF) | (((x[i].scales[is+4] >> 2) & 3) << 4) :
is < 12 ? (x[i].scales[is-8] >> 4) | (((x[i].scales[is+0] >> 4) & 3) << 4) :
(x[i].scales[is-8] >> 4) | (((x[i].scales[is-4] >> 6) & 3) << 4);
float d_all = x[i].d;
float dl = d_all * (us - 32);
float * y = yy + i*QK_K + 128*n + 32*j;
const uint8_t * q = x[i].qs + 32*n;
const uint8_t * hm = x[i].hmask;
for (int l = l0; l < l0+4; ++l) y[l] = dl * ((int8_t)((q[l] >> shift) & 3) - ((hm[l] & m) ? 0 : 4));
}
static __device__ void vec_dot_q3_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
const block_q3_K * x = (const block_q3_K *) vx;
const uint32_t kmask1 = 0x03030303;
const uint32_t kmask2 = 0x0f0f0f0f;
uint32_t aux[3];
uint32_t utmp[4];
// if n is 0, we want to do the lower 128, else the upper 128,
// covering y[l+0], y[l+32], y[l+64], y[l+96] and
// y[l+16], y[l+48], y[l+80], y[l+112]
int n = iqs/128; // 0 or 1
int r = iqs - 128*n; // 0...120 in steps of 8
int l = r/8; // 0...15 in steps of 1
const float * y = yy + 128*n + l;
const uint8_t * q = x[ib].qs + 32*n + l;
const uint8_t * hm = x[ib].hmask + l;
const int8_t * s = (const int8_t *)utmp + 8*n;
memcpy(aux, x[ib].scales, 12);
utmp[3] = ((aux[1] >> 4) & kmask2) | (((aux[2] >> 6) & kmask1) << 4);
utmp[2] = ((aux[0] >> 4) & kmask2) | (((aux[2] >> 4) & kmask1) << 4);
utmp[1] = (aux[1] & kmask2) | (((aux[2] >> 2) & kmask1) << 4);
utmp[0] = (aux[0] & kmask2) | (((aux[2] >> 0) & kmask1) << 4);
const float dall = x[ib].d;
const uint8_t m = 1 << (4*n);
float sum = y[ 0] * (s[0] - 32) * (((q[ 0] >> 0) & 3) - (hm[ 0] & (m << 0) ? 0 : 4))
+ y[ 32] * (s[2] - 32) * (((q[ 0] >> 2) & 3) - (hm[ 0] & (m << 1) ? 0 : 4))
+ y[ 64] * (s[4] - 32) * (((q[ 0] >> 4) & 3) - (hm[ 0] & (m << 2) ? 0 : 4))
+ y[ 96] * (s[6] - 32) * (((q[ 0] >> 6) & 3) - (hm[ 0] & (m << 3) ? 0 : 4))
+ y[ 16] * (s[1] - 32) * (((q[16] >> 0) & 3) - (hm[16] & (m << 0) ? 0 : 4))
+ y[ 48] * (s[3] - 32) * (((q[16] >> 2) & 3) - (hm[16] & (m << 1) ? 0 : 4))
+ y[ 80] * (s[5] - 32) * (((q[16] >> 4) & 3) - (hm[16] & (m << 2) ? 0 : 4))
+ y[112] * (s[7] - 32) * (((q[16] >> 6) & 3) - (hm[16] & (m << 3) ? 0 : 4));
result = sum * dall;
}
static inline __device__ void get_scale_min_k4(int j, const uint8_t * q, uint8_t & d, uint8_t & m) {
if (j < 4) {
d = q[j] & 63; m = q[j + 4] & 63;
} else {
d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
}
}
static __global__ void dequantize_block_q4_K(const void * vx, float * yy) {
const block_q4_K * x = (const block_q4_K *) vx;
const int i = blockIdx.x;
//// assume 64 threads - this is very slightly better than the one below
//const int tid = threadIdx.x;
//const int il = tid/16;
//const int ir = tid%16;
//const int is = 2*il;
//const int n = 2;
// assume 32 threads
const int tid = threadIdx.x;
const int il = tid/8;
const int ir = tid%8;
const int is = 2*il;
const int n = 4;
float * y = yy + i*QK_K + 64*il + n*ir;
const float dall = x[i].d;
const float dmin = x[i].dmin;
const uint8_t * q = x[i].qs + 32*il + n*ir;
uint8_t sc, m;
get_scale_min_k4(is + 0, x[i].scales, sc, m);
const float d1 = dall * sc; const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[i].scales, sc, m);
const float d2 = dall * sc; const float m2 = dmin * m;
for (int l = 0; l < n; ++l) {
y[l + 0] = d1 * (q[l] & 0xF) - m1;
y[l +32] = d2 * (q[l] >> 4) - m2;
}
}
static __device__ void vec_dot_q4_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
const block_q4_K * x = (const block_q4_K *) vx;
// iqs is in 0...248 in steps of 8 =>
const int j = iqs / 64; // j is in 0...3
const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
const int is = 2*j; // is is in 0...6 in steps of 2
const float * y = yy + 64*j + ir;
const uint8_t * q = x[ib].qs + 32*j + ir;
const float dall = x[ib].d;
const float dmin = x[ib].dmin;
uint8_t sc, m;
get_scale_min_k4(is + 0, x[ib].scales, sc, m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[ib].scales, sc, m);
const float d2 = dall * sc;
const float m2 = dmin * m;
float sum = 0;
for (int k = 0; k < 4; ++k) {
sum += y[k + 0] * (d1 * (q[k] & 0xF) - m1);
sum += y[k + 32] * (d2 * (q[k] >> 4) - m2);
}
result = sum;
}
static __global__ void dequantize_block_q5_K(const void * vx, float * yy) {
const block_q5_K * x = (const block_q5_K *) vx;
const int i = blockIdx.x;
// assume 64 threads - this is very slightly better than the one below
const int tid = threadIdx.x;
const int il = tid/16; // il is in 0...3
const int ir = tid%16; // ir is in 0...15
const int is = 2*il; // is is in 0...6
float * y = yy + i*QK_K + 64*il + 2*ir;
const float dall = x[i].d;
const float dmin = x[i].dmin;
const uint8_t * ql = x[i].qs + 32*il + 2*ir;
const uint8_t * qh = x[i].qh + 2*ir;
uint8_t sc, m;
get_scale_min_k4(is + 0, x[i].scales, sc, m);
const float d1 = dall * sc; const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[i].scales, sc, m);
const float d2 = dall * sc; const float m2 = dmin * m;
uint8_t hm = 1 << (2*il);
y[ 0] = d1 * ((ql[ 0] & 0xF) + (qh[ 0] & hm ? 16 : 0)) - m1;
y[ 1] = d1 * ((ql[ 1] & 0xF) + (qh[ 1] & hm ? 16 : 0)) - m1;
hm <<= 1;
y[32] = d2 * ((ql[ 0] >> 4) + (qh[ 0] & hm ? 16 : 0)) - m2;
y[33] = d2 * ((ql[ 1] >> 4) + (qh[ 1] & hm ? 16 : 0)) - m2;
}
static __device__ void vec_dot_q5_K(const void * vx, const int ib, const int iqs, const float * yy, float & result) {
const block_q5_K * x = (const block_q5_K *) vx;
// iqs is in 0...248 in steps of 8 =>
const int j = iqs / 64; // j is in 0...3
const int ir = (iqs - 64*j)/2; // ir is in 0...28 in steps of 4
const int is = 2*j; // is is in 0...6 in steps of 2
const float * y = yy + 64*j + ir;
const uint8_t * ql = x[ib].qs + 32*j + ir;
const uint8_t * qh = x[ib].qh + ir;
const float dall = x[ib].d;
const float dmin = x[ib].dmin;
uint8_t sc, m;
get_scale_min_k4(is + 0, x[ib].scales, sc, m);
const float d1 = dall * sc;
const float m1 = dmin * m;
get_scale_min_k4(is + 1, x[ib].scales, sc, m);
const float d2 = dall * sc;
const float m2 = dmin * m;
uint8_t hm = 1 << is;
float sum = 0;
for (int k = 0; k < 4; ++k) {
sum += y[k + 0] * (d1 * ((ql[k] & 0xF) + (qh[k] & hm ? 16 : 0)) - m1);
}
hm <<= 1;
for (int k = 0; k < 4; ++k) {
sum += y[k + 32] * (d2 * ((ql[k] >> 4) + (qh[k] & hm ? 16 : 0)) - m2);
}
result = sum;
}
template<typename dst_t>
static __global__ void dequantize_block_q6_K(const void * vx, dst_t * yy) {
const block_q6_K * x = (const block_q6_K *) vx;
const int i = blockIdx.x;
// assume 64 threads - this is very slightly better than the one below
const int tid = threadIdx.x;
const int ip = tid/32; // ip is 0 or 1
const int il = tid - 32*ip; // 0...32
const int is = 8*ip + il/16;
// TODO: fp16 compute
dst_t * y = yy + i*QK_K + 128*ip + il;
const float d = x[i].d;
const uint8_t * ql = x[i].ql + 64*ip + il;
const uint8_t qh = x[i].qh[32*ip + il];
const int8_t * sc = x[i].scales + is;
y[ 0] = d * sc[0] * ((int8_t)((ql[ 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32);
y[32] = d * sc[2] * ((int8_t)((ql[32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32);
y[64] = d * sc[4] * ((int8_t)((ql[ 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32);
y[96] = d * sc[6] * ((int8_t)((ql[32] >> 4) | (((qh >> 6) & 3) << 4)) - 32);
}
template<typename src1_t, typename dst_t>
static __global__ void dequantize_mul_mat_vec_q6_k(const void * vx, const src1_t * yy, dst_t * dst, const int ncols, int nrows) {
static_assert(16%K_QUANTS_PER_ITERATION == 0, "16 must be divisible by K_QUANTS_PER_ITERATION");
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row > nrows) return;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
const block_q6_K * x = (const block_q6_K *)vx + ib0;
const int tid = threadIdx.x/K_QUANTS_PER_ITERATION; // 0...31 or 0...16
const int ix = threadIdx.x%K_QUANTS_PER_ITERATION; // 0 or 0, 1
const int step = 16/K_QUANTS_PER_ITERATION; // 16 or 8
const int im = tid/step; // 0 or 1. 0 computes 0..., 1 computes 128...
const int in = tid - step*im; // 0...15 or 0...7
#if K_QUANTS_PER_ITERATION == 1
const int l0 = K_QUANTS_PER_ITERATION*in; // 0...15
const int is = 0;
#else
const int l0 = 4 * in; // 0, 4, 8, ..., 28
const int is = in / 4;
#endif
const int ql_offset = 64*im + l0;
const int qh_offset = 32*im + l0;
const int s_offset = 8*im + is;
const int y_offset = 128*im + l0;
dst_t tmp = 0; // partial sum for thread in warp
for (int i = ix; i < num_blocks_per_row; i += K_QUANTS_PER_ITERATION) {
const src1_t * y = yy + i * QK_K + y_offset;
const uint8_t * ql = x[i].ql + ql_offset;
const uint8_t * qh = x[i].qh + qh_offset;
const int8_t * s = x[i].scales + s_offset;
const dst_t d = x[i].d;
#if K_QUANTS_PER_ITERATION == 1
float sum = y[ 0] * s[0] * d * ((int8_t)((ql[ 0] & 0xF) | ((qh[ 0] & 0x03) << 4)) - 32)
+ y[16] * s[1] * d * ((int8_t)((ql[16] & 0xF) | ((qh[16] & 0x03) << 4)) - 32)
+ y[32] * s[2] * d * ((int8_t)((ql[32] & 0xF) | ((qh[ 0] & 0x0c) << 2)) - 32)
+ y[48] * s[3] * d * ((int8_t)((ql[48] & 0xF) | ((qh[16] & 0x0c) << 2)) - 32)
+ y[64] * s[4] * d * ((int8_t)((ql[ 0] >> 4) | ((qh[ 0] & 0x30) >> 0)) - 32)
+ y[80] * s[5] * d * ((int8_t)((ql[16] >> 4) | ((qh[16] & 0x30) >> 0)) - 32)
+ y[96] * s[6] * d * ((int8_t)((ql[32] >> 4) | ((qh[ 0] & 0xc0) >> 2)) - 32)
+y[112] * s[7] * d * ((int8_t)((ql[48] >> 4) | ((qh[16] & 0xc0) >> 2)) - 32);
tmp += sum;
#else
dst_t sum = 0;
for (int l = 0; l < 4; ++l) {
sum += (dst_t)y[l+ 0] * (dst_t)s[0] * d * (dst_t)((int8_t)((ql[l+ 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32)
+ (dst_t)y[l+32] * (dst_t)s[2] * d * (dst_t)((int8_t)((ql[l+32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32)
+ (dst_t)y[l+64] * (dst_t)s[4] * d * (dst_t)((int8_t)((ql[l+ 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32)
+ (dst_t)y[l+96] * (dst_t)s[6] * d * (dst_t)((int8_t)((ql[l+32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32);
}
tmp += sum;
#endif
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (tid == 0) {
dst[row] = tmp;
}
}
template <typename dst_t, int qk, int qr, dequantize_kernel_t<dst_t> dequantize_kernel>
static __global__ void dequantize_block(const void * vx, dst_t * y, const int k) {
const int i = blockDim.x*blockIdx.x + 2*threadIdx.x;
if (i >= k) {
return;
}
const int ib = i/qk; // block index
const int iqs = (i%qk)/qr; // quant index
const int iybs = i - i%qk; // y block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
vec2_t<dst_t> v;
dequantize_kernel(vx, ib, iqs, v);
y[iybs + iqs + 0] = v.x;
y[iybs + iqs + y_offset] = v.y;
}
template<typename dst_t>
static __device__ __forceinline__ dst_t vec_dot_q4_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
int vi;
memcpy(&vi, &bq4_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_0)]);
const float d = __half2float(bq4_0->d) * __half2float(bq8_1->d);
// subtract 8 from each quantized value
const int vi0 = __vsub4((vi >> 0) & 0x0F0F0F0F, 0x08080808);
const int vi1 = __vsub4((vi >> 4) & 0x0F0F0F0F, 0x08080808);
// SIMD dot product of quantized values
int sumi = __dp4a(vi0, ui0, 0);
sumi = __dp4a(vi1, ui1, sumi);
return sumi*d;
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= 600
}
template<typename dst_t>
static __device__ __forceinline__ dst_t vec_dot_q4_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI4_1)]);
const float d = __half2float(bq4_1->d) * __half2float(bq8_1->d);
const float m = bq4_1->m;
const float s = bq8_1->s;
const int vi0 = (vi >> 0) & 0x0F0F0F0F;
const int vi1 = (vi >> 4) & 0x0F0F0F0F;
// SIMD dot product of quantized values
int sumi = __dp4a(vi0, ui0, 0);
sumi = __dp4a(vi1, ui1, sumi);
return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= 600
}
template<typename dst_t>
static __device__ __forceinline__ dst_t vec_dot_q5_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
int qs;
memcpy(&qs, &bq5_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
const int qh0 = bq5_0->qh[iqs/2 + 0] >> 4*(iqs%2);
const int qh1 = bq5_0->qh[iqs/2 + 2] >> 4*(iqs%2);
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_0)]);
const float d = __half2float(bq5_0->d) * __half2float(bq8_1->d);
int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5
vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13
vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21
vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29
vi0 = __vsub4(vi0, 0x10101010); // subtract 16 from quantized values
int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5
vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13
vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21
vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29
vi1 = __vsub4(vi1, 0x10101010); // subtract 16 from quantized values
sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
return sumi*d;
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= 600
}
template<typename dst_t>
static __device__ __forceinline__ dst_t vec_dot_q5_1_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]);
const int qh0 = bq5_1->qh[iqs/2 + 0] >> 4*(iqs%2);
const int qh1 = bq5_1->qh[iqs/2 + 2] >> 4*(iqs%2);
const int ui0 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
const int ui1 = *((int *) &bq8_1->qs[sizeof(int) * (iqs + QI5_1)]);
const float d = __half2float(bq5_1->d) * __half2float(bq8_1->d);
const float m = bq5_1->m;
const float s = bq8_1->s;
int vi0 = (qs >> 0) & 0x0F0F0F0F; // lower 4 qs bits, still need qh0 as 5th bits
vi0 |= (qh0 << 4) & 0x00000010; // 1 -> 5
vi0 |= (qh0 << 11) & 0x00001000; // 2 -> 13
vi0 |= (qh0 << 18) & 0x00100000; // 3 -> 21
vi0 |= (qh0 << 25) & 0x10000000; // 4 -> 29
int sumi = __dp4a(vi0, ui0, 0); // SIMD dot product of quantized values
int vi1 = (qs >> 4) & 0x0F0F0F0F; // upper 4 qs bits, still need qh1 as 5th bits
vi1 |= (qh1 << 4) & 0x00000010; // 1 -> 5
vi1 |= (qh1 << 11) & 0x00001000; // 2 -> 13
vi1 |= (qh1 << 18) & 0x00100000; // 3 -> 21
vi1 |= (qh1 << 25) & 0x10000000; // 4 -> 29
sumi = __dp4a(vi1, ui1, sumi); // SIMD dot product of quantized values
return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= 600
}
template<typename dst_t>
static __device__ __forceinline__ dst_t vec_dot_q8_0_q8_1(const void * vbq, const block_q8_1 * bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
int vi;
memcpy(&vi, &bq8_0->qs[sizeof(int) * (iqs + 0)], sizeof(int));
const int ui = *((int *) &bq8_1->qs[sizeof(int) * (iqs + 0)]);
const float d = __half2float(bq8_0->d) * __half2float(bq8_1->d);
// SIMD dot product of quantized values
int sumi = __dp4a(vi, ui, 0);
return sumi*d;
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= 600
}
template <typename dst_t, int qk, int qi, typename block_q_t, vec_dot_q_cuda_t<dst_t> vec_dot_q_cuda>
static __global__ void mul_mat_vec_q(const void * vx, const void * vy, dst_t * dst, const int ncols, const int nrows) {
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row >= nrows) {
return;
}
const int blocks_per_row = ncols / qk;
const int blocks_per_warp = WARP_SIZE / qi;
// partial sum for each thread
float tmp = 0.0f;
const block_q_t * x = (const block_q_t *) vx;
const block_q8_1 * y = (const block_q8_1 *) vy;
for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index
const int iby = i + threadIdx.x / qi; // y block index
const int iqs = threadIdx.x % qi; // x block quant index when casting the quants to int
tmp += (float)vec_dot_q_cuda(&x[ibx], &y[iby], iqs);
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (threadIdx.x == 0) {
dst[row] = (dst_t)tmp;
}
}
template <typename src1_t, typename dst_t, int qk, int qr, dequantize_kernel_t<dst_t> dequantize_kernel>
static __global__ void dequantize_mul_mat_vec(const void * vx, const src1_t * y, dst_t * dst, const int ncols, const int nrows) {
// qk = quantized weights per x block
// qr = number of quantized weights per data value in x block
const int row = blockIdx.y*blockDim.y + threadIdx.y;
if (row >= nrows) {
return;
}
const int tid = threadIdx.x;
const int iter_stride = 2*GGML_CUDA_DMMV_X;
const int vals_per_iter = iter_stride / WARP_SIZE; // num quantized vals per thread and i iter
const int y_offset = qr == 1 ? 1 : qk/2;
vec2_t<dst_t> tmp2 = make_vec2_t<dst_t>(0, 0); // partial sum for thread in warp
for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid;
const int ib = (row*ncols + col)/qk; // x block index
const int iqs = (col%qk)/qr; // x quant index
const int iybs = col - col%qk; // y block start index
// processing >2 values per i iter is faster for fast GPUs
#pragma unroll
for (int j = 0; j < vals_per_iter; j += 2) {
// process 2 vals per j iter
// for qr = 2 the iqs needs to increase by 1 per j iter because 2 weights per data val
// dequantize
vec2_t<dst_t> xc;
dequantize_kernel(vx, ib, iqs + j/qr, xc);
// matrix multiplication
vec2_t<dst_t> yc = make_vec2_t<dst_t>(
y[iybs + iqs + j/qr + 0],
y[iybs + iqs + j/qr + y_offset]);
tmp2 += xc * yc;
}
}
// sum up partial sums and write back result
// TODO: reducing as half2 may be faster, but requires special handling for float2
dst_t tmp = tmp2.x + tmp2.y;
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (tid == 0) {
dst[row] = tmp;
}
}
template <typename src1_t, typename dst_t, int n_thread, dot_kernel_k_t<src1_t, dst_t> dot_kernel>
static __global__ void dequantize_mul_mat_vec_k(const void * vx, const src1_t * y, dst_t * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const int iter_stride = QK_K;
const int vals_per_iter = iter_stride / n_thread;
const int num_blocks_per_row = ncols / QK_K;
const int ib0 = row*num_blocks_per_row;
dst_t tmp = 0; // partial sum for thread in warp
for (int i = 0; i < ncols; i += iter_stride) {
const int col = i + vals_per_iter*tid;
const int ib = ib0 + col/QK_K; // x block index
const int iqs = col%QK_K; // x quant index
const int iybs = col - col%QK_K; // y block start index
dst_t v;
dot_kernel(vx, ib, iqs, y + iybs, v);
tmp += v;
}
// sum up partial sums and write back result
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32);
}
if (tid == 0) {
dst[row] = tmp;
}
}

File diff suppressed because it is too large Load diff

View file

@ -6,30 +6,15 @@
extern "C" {
#endif
#define GGML_CUDA_MAX_DEVICES 16
GGML_API void * ggml_cuda_host_malloc(size_t size);
GGML_API void ggml_cuda_host_free(void * ptr);
GGML_API void ggml_cuda_host_register(void * ptr, size_t size);
GGML_API void ggml_cuda_host_unregister(void * ptr);
void ggml_init_cublas(void);
void ggml_cuda_set_tensor_split(const float * tensor_split);
// backend API
void ggml_cuda_mul(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
size_t ggml_cuda_mul_mat_get_wsize(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
void ggml_cuda_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst, void * wdata, size_t wsize);
GGML_API struct ggml_backend * ggml_backend_cuda_init();
// TODO: export these with GGML_API
void * ggml_cuda_host_malloc(size_t size);
void ggml_cuda_host_free(void * ptr);
void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
void ggml_cuda_free_data(struct ggml_tensor * tensor);
void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
void ggml_cuda_assign_buffers_force_inplace(struct ggml_tensor * tensor);
void ggml_cuda_set_main_device(int main_device);
void ggml_cuda_set_scratch_size(size_t scratch_size);
void ggml_cuda_free_scratch(void);
bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_tensor * tensor);
#ifdef __cplusplus
}

View file

@ -22,48 +22,49 @@
#include <stddef.h>
#include <stdbool.h>
// max memory buffers that can be mapped to the device
#define GGML_METAL_MAX_BUFFERS 16
struct ggml_tensor;
struct ggml_cgraph;
//struct ggml_tensor;
//struct ggml_cgraph;
#ifdef __cplusplus
extern "C" {
#endif
struct ggml_metal_context;
struct ggml_backend;
// number of command buffers to use
struct ggml_metal_context * ggml_metal_init(int n_cb);
void ggml_metal_free(struct ggml_metal_context * ctx);
struct ggml_backend * ggml_backend_metal_init(void);
// set the number of command buffers to use
void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
// creates a mapping between a host memory buffer and a device memory buffer
// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
// - the mapping is used during computation to determine the arguments of the compute kernels
// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal
// - max_size specifies the maximum size of a tensor and is used to create shared views such
// that it is guaranteed that the tensor will fit in at least one of the views
//struct ggml_metal_context;
//
bool ggml_metal_add_buffer(
struct ggml_metal_context * ctx,
const char * name,
void * data,
size_t size,
size_t max_size);
// set data from host memory into the device
void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
// get data from the device into host memory
void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
// same as ggml_graph_compute but uses Metal
// creates gf->n_threads command buffers in parallel
void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
//// number of command buffers to use
//struct ggml_metal_context * ggml_metal_init(int n_cb);
//void ggml_metal_free(struct ggml_metal_context * ctx);
//
//// set the number of command buffers to use
//void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb);
//
//// creates a mapping between a host memory buffer and a device memory buffer
//// - make sure to map all buffers used in the graph before calling ggml_metal_graph_compute
//// - the mapping is used during computation to determine the arguments of the compute kernels
//// - you don't need to keep the host memory buffer allocated as it is never accessed by Metal
//// - max_size specifies the maximum size of a tensor and is used to create shared views such
//// that it is guaranteed that the tensor will fit in at least one of the views
////
//bool ggml_metal_add_buffer(
// struct ggml_metal_context * ctx,
// const char * name,
// void * data,
// size_t size,
// size_t max_size);
//
//// set data from host memory into the device
//void ggml_metal_set_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
//
//// get data from the device into host memory
//void ggml_metal_get_tensor(struct ggml_metal_context * ctx, struct ggml_tensor * t);
//
//// same as ggml_graph_compute but uses Metal
//// creates gf->n_threads command buffers in parallel
//void ggml_metal_graph_compute(struct ggml_metal_context * ctx, struct ggml_cgraph * gf);
#ifdef __cplusplus
}

View file

@ -12,18 +12,16 @@
#else
#define metal_printf(...) fprintf(stderr, __VA_ARGS__)
#endif
//#define metal_printf(...) fprintf(stderr, __VA_ARGS__)
#define UNUSED(x) (void)(x)
struct ggml_metal_buffer {
const char * name;
void * data;
size_t size;
id<MTLBuffer> metal;
struct ggml_metal_buffer_wrapper {
id<MTLBuffer> buffer;
};
static void * g_ptr_base = (void *)0x1000;
struct ggml_metal_context {
int n_cb;
@ -33,9 +31,6 @@ struct ggml_metal_context {
id<MTLCommandQueue> queue;
id<MTLLibrary> library;
int n_buffers;
struct ggml_metal_buffer buffers[GGML_METAL_MAX_BUFFERS];
// custom kernels
#define GGML_METAL_DECL_KERNEL(name) \
id<MTLFunction> function_##name; \
@ -96,7 +91,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
ctx->n_cb = n_cb;
ctx->device = MTLCreateSystemDefaultDevice();
ctx->queue = [ctx->device newCommandQueue];
ctx->n_buffers = 0;
// determine if we can use MPS
if (MPSSupportsMTLDevice(ctx->device)) {
@ -205,9 +199,6 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
void ggml_metal_free(struct ggml_metal_context * ctx) {
fprintf(stderr, "%s: deallocating\n", __func__);
for (int i = 0; i < ctx->n_buffers; ++i) {
[ctx->buffers[i].metal release];
}
free(ctx);
}
@ -215,142 +206,29 @@ void ggml_metal_set_n_cb(struct ggml_metal_context * ctx, int n_cb) {
ctx->n_cb = n_cb;
}
// finds the Metal buffer that contains the tensor data on the GPU device
// the assumption is that there is 1-to-1 mapping between the host and device memory buffers, so we can find the
// Metal buffer based on the host memory pointer
//
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_metal_context * ctx, struct ggml_tensor * t, size_t * offs) {
//fprintf(stderr, "%s: data tensor '%16s', offs_data = %8ld, offs_eval = %8ld, offs_cach = %8ld\n", __func__, t->name, offs_data, offs_eval, offs_cach);
const int64_t tsize = ggml_nbytes(t);
// find the view that contains the tensor fully
for (int i = 0; i < ctx->n_buffers; ++i) {
const int64_t ioffs = (int64_t) t->data - (int64_t) ctx->buffers[i].data;
if (ioffs >= 0 && ioffs + tsize <= (int64_t) ctx->buffers[i].size) {
*offs = (size_t) ioffs;
//fprintf(stderr, "%s: '%s' tensor '%16s', offs = %8ld\n", __func__, ctx->buffers[i].name, t->name, *offs);
return ctx->buffers[i].metal;
}
static id<MTLBuffer> ggml_metal_get_buffer(struct ggml_tensor * tensor, size_t * offs) {
if (tensor == nil) {
return nil;
}
fprintf(stderr, "%s: error: buffer is nil\n", __func__);
return nil;
}
bool ggml_metal_add_buffer(
struct ggml_metal_context * ctx,
const char * name,
void * data,
size_t size,
size_t max_size) {
if (ctx->n_buffers >= GGML_METAL_MAX_BUFFERS) {
fprintf(stderr, "%s: too many buffers\n", __func__);
return false;
}
if (data) {
// verify that the buffer does not overlap with any of the existing buffers
for (int i = 0; i < ctx->n_buffers; ++i) {
const int64_t ioffs = (int64_t) data - (int64_t) ctx->buffers[i].data;
if (ioffs >= 0 && ioffs < (int64_t) ctx->buffers[i].size) {
fprintf(stderr, "%s: error: buffer '%s' overlaps with '%s'\n", __func__, name, ctx->buffers[i].name);
return false;
}
}
const size_t size_page = getpagesize();
size_t size_aligned = size;
if ((size_aligned % size_page) != 0) {
size_aligned += (size_page - (size_aligned % size_page));
}
// the buffer fits into the max buffer size allowed by the device
if (size_aligned <= ctx->device.maxBufferLength) {
ctx->buffers[ctx->n_buffers].name = name;
ctx->buffers[ctx->n_buffers].data = data;
ctx->buffers[ctx->n_buffers].size = size;
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:data length:size_aligned options:MTLResourceStorageModeShared deallocator:nil];
if (ctx->buffers[ctx->n_buffers].metal == nil) {
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_aligned / 1024.0 / 1024.0);
return false;
}
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB", __func__, name, size_aligned / 1024.0 / 1024.0);
++ctx->n_buffers;
} else {
// this overlap between the views will guarantee that the tensor with the maximum size will fully fit into
// one of the views
const size_t size_ovlp = ((max_size + size_page - 1) / size_page + 1) * size_page; // round-up 2 pages just in case
const size_t size_step = ctx->device.maxBufferLength - size_ovlp;
const size_t size_view = ctx->device.maxBufferLength;
for (size_t i = 0; i < size; i += size_step) {
const size_t size_step_aligned = (i + size_view <= size) ? size_view : (size_aligned - i);
ctx->buffers[ctx->n_buffers].name = name;
ctx->buffers[ctx->n_buffers].data = (void *) ((uint8_t *) data + i);
ctx->buffers[ctx->n_buffers].size = size_step_aligned;
ctx->buffers[ctx->n_buffers].metal = [ctx->device newBufferWithBytesNoCopy:(void *) ((uint8_t *) data + i) length:size_step_aligned options:MTLResourceStorageModeShared deallocator:nil];
if (ctx->buffers[ctx->n_buffers].metal == nil) {
fprintf(stderr, "%s: failed to allocate '%-16s' buffer, size = %8.2f MB\n", __func__, name, size_step_aligned / 1024.0 / 1024.0);
return false;
switch (tensor->op) {
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
{
if (tensor->op == GGML_OP_VIEW) {
//printf("view offs = %zu\n", *(size_t *)tensor->op_params);
}
fprintf(stderr, "%s: allocated '%-16s' buffer, size = %8.2f MB, offs = %12ld", __func__, name, size_step_aligned / 1024.0 / 1024.0, i);
if (i + size_step < size) {
fprintf(stderr, "\n");
}
++ctx->n_buffers;
return ggml_metal_get_buffer(tensor->src[0], offs);
}
}
fprintf(stderr, ", (%8.2f / %8.2f)",
ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
if (ctx->device.currentAllocatedSize > ctx->device.recommendedMaxWorkingSetSize) {
fprintf(stderr, ", warning: current allocated size is greater than the recommended max working set size\n");
} else {
fprintf(stderr, "\n");
}
default: {}
}
return true;
}
void ggml_metal_set_tensor(
struct ggml_metal_context * ctx,
struct ggml_tensor * t) {
metal_printf("%s: set input for tensor '%s'\n", __func__, t->name);
size_t offs;
id<MTLBuffer> id_dst = ggml_metal_get_buffer(ctx, t, &offs);
memcpy((void *) ((uint8_t *) id_dst.contents + offs), t->data, ggml_nbytes(t));
}
void ggml_metal_get_tensor(
struct ggml_metal_context * ctx,
struct ggml_tensor * t) {
metal_printf("%s: extract results for tensor '%s'\n", __func__, t->name);
size_t offs;
id<MTLBuffer> id_src = ggml_metal_get_buffer(ctx, t, &offs);
memcpy(t->data, (void *) ((uint8_t *) id_src.contents + offs), ggml_nbytes(t));
*offs = (size_t) tensor->data - (size_t) g_ptr_base;
//printf("%s: offs = %zu, %p, op = %s\n", __func__, *offs, tensor->extra, ggml_op_name(tensor->op));
return ((struct ggml_metal_buffer_wrapper *) tensor->extra)->buffer;
}
void ggml_metal_graph_compute(
@ -431,23 +309,35 @@ void ggml_metal_graph_compute(
const enum ggml_type src1t = src1 ? src1->type : GGML_TYPE_COUNT;
const enum ggml_type dstt = dst ? dst->type : GGML_TYPE_COUNT;
id<MTLBuffer> id_src0 = src0 ? ggml_metal_get_buffer(ctx, src0, &offs_src0) : nil;
id<MTLBuffer> id_src1 = src1 ? ggml_metal_get_buffer(ctx, src1, &offs_src1) : nil;
id<MTLBuffer> id_dst = dst ? ggml_metal_get_buffer(ctx, dst, &offs_dst) : nil;
switch (dst->op) {
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
case GGML_OP_TRANSPOSE:
case GGML_OP_PERMUTE:
{
continue;
} break;
default: break;
}
//metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
//if (src0) {
// metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
// ggml_is_contiguous(src0), src0->name);
//}
//if (src1) {
// metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
// ggml_is_contiguous(src1), src1->name);
//}
//if (dst) {
// metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
// dst->name);
//}
id<MTLBuffer> id_src0 = ggml_metal_get_buffer(src0, &offs_src0);
id<MTLBuffer> id_src1 = ggml_metal_get_buffer(src1, &offs_src1);
id<MTLBuffer> id_dst = ggml_metal_get_buffer(dst, &offs_dst);
metal_printf("%s: op - %s\n", __func__, ggml_op_name(dst->op));
if (src0) {
metal_printf("%s: src0 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src0t), ne00, ne01, ne02,
ggml_is_contiguous(src0), src0->name);
}
if (src1) {
metal_printf("%s: src1 - %4s [%5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(src1t), ne10, ne11, ne12,
ggml_is_contiguous(src1), src1->name);
}
if (dst) {
metal_printf("%s: dst - %4s [%5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(dstt), ne0, ne1, ne2,
dst->name);
}
switch (dst->op) {
case GGML_OP_NONE:
@ -500,7 +390,9 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder];
}
const float scale = *(const float *) src1->data;
//const float scale = *(const float *) src1->data;
const float scale = ((float *)((char *)[((struct ggml_metal_buffer_wrapper *)(src1->extra))->buffer contents] + (size_t) src1->data - (size_t)g_ptr_base))[0];
//printf("scale: %f, src1->data: %p, src1->extra: %p, src1->extra->buffer: %p\n", scale, src1->data, src1->extra, ((struct ggml_metal_buffer_wrapper *)(src1->extra))->buffer);
[encoder setComputePipelineState:ctx->pipeline_scale];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -577,7 +469,8 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder];
}
const int n_past = ((int32_t *)(src1->data))[0];
//const int n_past = ((int32_t *)(src1->data))[0];
const int n_past = ((int32_t *)(dst->op_params))[0];
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -739,6 +632,10 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne0 length:sizeof(ne0) atIndex:13];
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
//printf("id_src0 %p, offs_src0 %zu\n", id_src0, offs_src0);
//printf("id_src1 %p, offs_src1 %zu\n", id_src1, offs_src1);
//printf("id_dst %p, offs_dst %zu\n", id_dst, offs_dst);
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
}
@ -876,15 +773,14 @@ void ggml_metal_graph_compute(
encoder = [command_buffer computeCommandEncoder];
}
const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2];
const int n_past = ((int32_t *)(src1->data))[0];
const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
const int mode = ((int32_t *) dst->op_params)[2];
float freq_base;
float freq_scale;
memcpy(&freq_base, (int32_t *) src1->data + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) src1->data + 5, sizeof(float));
memcpy(&freq_base, (int32_t *) dst->op_params + 4, sizeof(float));
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
[encoder setComputePipelineState:ctx->pipeline_rope];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
@ -992,3 +888,141 @@ void ggml_metal_graph_compute(
}
}
}
static const char * ggml_backend_metal_name(struct ggml_backend * ctx) {
return "Metal";
UNUSED(ctx);
}
static void ggml_backend_metal_free(struct ggml_backend * backend) {
struct ggml_metal_context * ctx_metal = (struct ggml_metal_context *)backend->context;
ggml_metal_free(ctx_metal);
free(backend);
}
static const size_t TENSOR_ALIGNMENT = 128;
static void ggml_backend_metal_init_tensor(struct ggml_backend_buffer * alloc, struct ggml_tensor * tensor) {
tensor->extra = alloc->backend_data;
}
static void ggml_backend_metal_free_data(struct ggml_backend_buffer * alloc) {
struct ggml_metal_buffer_wrapper * wrapper = (struct ggml_metal_buffer_wrapper *)alloc->backend_data;
[wrapper->buffer release];
free(wrapper);
}
static struct ggml_backend_buffer * ggml_backend_metal_alloc_buffer(struct ggml_backend * backend, size_t size) {
struct ggml_metal_context * ctx_metal = (struct ggml_metal_context *)backend->context;
struct ggml_metal_buffer_wrapper * wrapper = malloc(sizeof(struct ggml_metal_buffer_wrapper));
wrapper->buffer = [ctx_metal->device newBufferWithLength:size options:MTLResourceStorageModeShared];
if (wrapper->buffer == nil) {
fprintf(stderr, "%s: failed to allocate buffer of size %zu\n", __func__, size);
GGML_ASSERT(false);
}
//printf("XXXXXXXXXXXXXXX ALOC: %p %p %p size = %zu\n", (void * )wrapper, (void *)&wrapper->buffer, (void *)[wrapper->buffer contents], size);
struct ggml_backend_buffer * buffer = ggml_allocator_simple_init(g_ptr_base, size, TENSOR_ALIGNMENT);
buffer->interface.init_tensor = ggml_backend_metal_init_tensor;
buffer->interface.free_data = ggml_backend_metal_free_data;
buffer->backend_data = wrapper;
return buffer;
}
static void ggml_backend_metal_set_tensor_async(struct ggml_backend * backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
GGML_ASSERT(tensor->extra != nil && "tensor not allocated");
struct ggml_metal_buffer_wrapper * wrapper = (struct ggml_metal_buffer_wrapper *)tensor->extra;
char * contents = (char *)[wrapper->buffer contents];
const size_t t_data = (size_t) tensor->data - (size_t) g_ptr_base;
//printf("XXXXXXXXXXXXXXX SET : %p %p %p offset = %zu\n", (void *)(tensor->data), (void *)&wrapper->buffer, (void *)contents, offset);
memcpy((char *)contents + t_data + offset, data, size);
//memcpy((char *)tensor->data, data, size);
UNUSED(backend);
}
static void ggml_backend_metal_get_tensor_async(struct ggml_backend * backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
//printf("XXXXXXXXXXXXXXX GET : %d %p, backend = %s\n", (void *)(tensor->data), (void *)tensor->extra, tensor->backend->interface.get_name(tensor->backend));
GGML_ASSERT(tensor->extra != nil && "tensor not allocated");
struct ggml_metal_buffer_wrapper * wrapper = (struct ggml_metal_buffer_wrapper *)tensor->extra;
const char * contents = (const char *)[wrapper->buffer contents];
const size_t t_data = (size_t) tensor->data - (size_t) g_ptr_base;
//printf("XXXXXXXXXXXXXXX GET : %p %p %p offset = %zu\n", (void *)(tensor->data), (void *)&wrapper->buffer, (void *)contents, offset);
memcpy(data, (const char *)contents + t_data + offset, size);
UNUSED(backend);
}
static void ggml_backend_metal_synchronize(struct ggml_backend * backend) {
UNUSED(backend);
}
static ggml_graph_plan_t ggml_backend_metal_graph_plan_create(struct ggml_backend * backend, struct ggml_cgraph * cgraph) {
GGML_ASSERT(false);
return nil;
UNUSED(backend);
UNUSED(cgraph);
}
static void ggml_backend_metal_graph_plan_free(struct ggml_backend * backend, ggml_graph_plan_t plan) {
GGML_ASSERT(false);
UNUSED(backend);
UNUSED(plan);
}
static void ggml_backend_metal_graph_plan_compute(struct ggml_backend * backend, ggml_graph_plan_t plan) {
GGML_ASSERT(false);
UNUSED(backend);
UNUSED(plan);
}
static void ggml_backend_metal_graph_compute(struct ggml_backend * backend, struct ggml_cgraph * cgraph) {
ggml_metal_graph_compute(backend->context, cgraph);
}
static struct ggml_backend_interface metal_backend_interface = {
/* .get_name = */ ggml_backend_metal_name,
/* .free = */ ggml_backend_metal_free,
/* .alloc_buffer = */ ggml_backend_metal_alloc_buffer,
/* .set_tensor_async = */ ggml_backend_metal_set_tensor_async,
/* .get_tensor_async = */ ggml_backend_metal_get_tensor_async,
/* .synchronize = */ ggml_backend_metal_synchronize,
/* .cpy_tensor_from = */ nil, //ggml_backend_metal_get_tensor_async,
/* .cpy_tensor_to = */ nil, //ggml_backend_metal_synchronize,
/* .graph_plan_create = */ ggml_backend_metal_graph_plan_create,
/* .graph_plan_free = */ ggml_backend_metal_graph_plan_free,
/* .graph_plan_compute = */ ggml_backend_metal_graph_plan_compute,
/* .graph_compute = */ ggml_backend_metal_graph_compute,
};
struct ggml_backend * ggml_backend_metal_init(void) {
struct ggml_metal_context * ctx = ggml_metal_init(1);
struct ggml_backend * backend_metal = malloc(sizeof(struct ggml_backend));
*backend_metal = (struct ggml_backend){
/* .interface = */ metal_backend_interface,
/* .context = */ ctx,
/* .is_ram_shared = */ false,
};
return backend_metal;
}

545
ggml.c

File diff suppressed because it is too large Load diff

65
ggml.h
View file

@ -199,6 +199,7 @@
#define GGML_MAX_CONTEXTS 64
#define GGML_MAX_SRC 6
#define GGML_MAX_NAME 48
#define GGML_MAX_OP_PARAMS 32
#define GGML_DEFAULT_N_THREADS 4
@ -285,12 +286,6 @@ extern "C" {
GGML_TYPE_COUNT,
};
enum ggml_backend {
GGML_BACKEND_CPU = 0,
GGML_BACKEND_GPU = 10,
GGML_BACKEND_GPU_SPLIT = 20,
};
// model file types
enum ggml_ftype {
GGML_FTYPE_UNKNOWN = -1,
@ -405,8 +400,9 @@ extern "C" {
// n-dimensional tensor
struct ggml_tensor {
enum ggml_type type;
enum ggml_backend backend;
struct ggml_backend * backend;
enum ggml_type type;
int n_dims;
int64_t ne[GGML_MAX_DIMS]; // number of elements
@ -418,23 +414,30 @@ extern "C" {
// compute data
enum ggml_op op;
// op params - allocated as int32_t for alignment
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(uint32_t)];
bool is_param;
struct ggml_tensor * grad;
struct ggml_tensor * src[GGML_MAX_SRC];
int node_id; // used to build graphs
// performance
int perf_runs;
int64_t perf_cycles;
int64_t perf_time_us;
void * data;
char name[GGML_MAX_NAME];
void * extra; // extra things e.g. for ggml-cuda.cu
char padding[8];
char padding[4];
};
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
@ -459,6 +462,7 @@ extern "C" {
struct ggml_cgraph {
int n_nodes;
int n_leafs;
bool closed;
struct ggml_tensor * nodes[GGML_MAX_NODES];
struct ggml_tensor * grads[GGML_MAX_NODES];
@ -470,23 +474,27 @@ extern "C" {
int64_t perf_time_us;
};
// scratch buffer
struct ggml_scratch {
size_t offs;
size_t size;
void * data;
/*
TODO
enum ggml_alloc_mode {
GGML_ALLOC_IMMEDIATE,
GGML_ALLOC_NONE,
GGML_ALLOC_COMPUTE_SEQ,
GGML_ALLOC_COMPUTE_PAR,
};
*/
// context parameters
struct ggml_init_params {
// memory pool
size_t mem_size; // bytes
void * mem_buffer; // if NULL, memory will be allocated internally
struct ggml_buffer * buffer;
bool no_alloc; // don't allocate memory for the tensor data
//enum ggml_alloc_mode alloc_mode; // TODO: replace the above with this
enum ggml_type compute_type; // type of intermediate results
};
// compute types
// task types
// NOTE: the INIT or FINALIZE pass is not scheduled unless explicitly enabled.
// This behavior was changed since https://github.com/ggerganov/llama.cpp/pull/1995.
enum ggml_task_type {
@ -547,19 +555,20 @@ extern "C" {
GGML_API size_t ggml_tensor_overhead(void);
// main
GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
GGML_API void ggml_free(struct ggml_context * ctx);
GGML_API struct ggml_init_params ggml_init_params_default(void);
GGML_API struct ggml_context * ggml_init(struct ggml_init_params params);
GGML_API void ggml_free(struct ggml_context * ctx);
GGML_API size_t ggml_used_mem(const struct ggml_context * ctx);
GGML_API size_t ggml_set_scratch (struct ggml_context * ctx, struct ggml_scratch scratch);
GGML_API void ggml_set_no_alloc(struct ggml_context * ctx, bool no_alloc);
GGML_API void * ggml_get_mem_buffer (const struct ggml_context * ctx);
GGML_API size_t ggml_get_mem_size (const struct ggml_context * ctx);
GGML_API size_t ggml_get_max_tensor_size(const struct ggml_context * ctx);
GGML_API struct ggml_backend * ggml_get_ctx_backend(struct ggml_context * ctx);
GGML_API struct ggml_tensor * ggml_new_tensor(
struct ggml_context * ctx,
enum ggml_type type,
@ -1347,6 +1356,8 @@ extern "C" {
GGML_API struct ggml_cgraph ggml_build_forward (struct ggml_tensor * tensor);
GGML_API struct ggml_cgraph ggml_build_backward(struct ggml_context * ctx, struct ggml_cgraph * gf, bool keep);
GGML_API void ggml_graph_close (struct ggml_cgraph * cgraph);
// ggml_graph_plan() has to be called before ggml_graph_compute()
// when plan.work_size > 0, caller must allocate memory for plan.work_data
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
@ -1561,9 +1572,8 @@ extern "C" {
GGML_API int ggml_cpu_has_fp16_va (void);
GGML_API int ggml_cpu_has_wasm_simd (void);
GGML_API int ggml_cpu_has_blas (void);
GGML_API int ggml_cpu_has_cublas (void);
GGML_API int ggml_cpu_has_cuda (void);
GGML_API int ggml_cpu_has_clblast (void);
GGML_API int ggml_cpu_has_gpublas (void);
GGML_API int ggml_cpu_has_sse3 (void);
GGML_API int ggml_cpu_has_vsx (void);
@ -1594,3 +1604,6 @@ extern "C" {
#ifdef __cplusplus
}
#endif
#include "ggml-backend.h"

View file

@ -203,6 +203,17 @@ struct llama_mmap {
}
}
void discard(void * addr, size_t len) {
// align to the page size
int page_size = sysconf(_SC_PAGESIZE);
addr = (void *) (((uintptr_t) addr) & ~(page_size - 1));
len = (len + page_size - 1) & ~(page_size - 1);
if (madvise(addr, len, MADV_DONTNEED)) {
fprintf(stderr, "warning: madvise(.., MADV_DONTNEED) failed: %s\n",
strerror(errno));
}
}
~llama_mmap() {
munmap(addr, size);
}
@ -247,6 +258,10 @@ struct llama_mmap {
#endif // _WIN32_WINNT >= _WIN32_WINNT_WIN8
}
void discard(void * addr, size_t len) {
VirtualAlloc(addr, len, MEM_RESET, PAGE_NOACCESS);
}
~llama_mmap() {
if (!UnmapViewOfFile(addr)) {
fprintf(stderr, "warning: UnmapViewOfFile failed: %s\n",
@ -262,6 +277,13 @@ struct llama_mmap {
throw std::runtime_error(std::string("mmap not supported"));
}
void discard(void * addr, size_t len) {
(void) addr;
(void) len;
throw std::runtime_error(std::string("mmap not supported"));
}
#endif
};
@ -451,14 +473,14 @@ struct llama_buffer {
llama_buffer& operator=(llama_buffer&&) = delete;
};
#ifdef GGML_USE_CUBLAS
#if defined(GGML_USE_CUDA)
#include "ggml-cuda.h"
struct llama_ctx_buffer {
struct llama_host_buffer {
uint8_t * addr = NULL;
bool is_cuda;
size_t size = 0;
llama_ctx_buffer() = default;
llama_host_buffer() = default;
void resize(size_t size) {
free();
@ -487,18 +509,19 @@ struct llama_ctx_buffer {
addr = NULL;
}
~llama_ctx_buffer() {
~llama_host_buffer() {
free();
}
// disable copy and move
llama_ctx_buffer(const llama_ctx_buffer&) = delete;
llama_ctx_buffer(llama_ctx_buffer&&) = delete;
llama_ctx_buffer& operator=(const llama_ctx_buffer&) = delete;
llama_ctx_buffer& operator=(llama_ctx_buffer&&) = delete;
llama_host_buffer(const llama_host_buffer&) = delete;
llama_host_buffer(llama_host_buffer&&) = delete;
llama_host_buffer& operator=(const llama_host_buffer&) = delete;
llama_host_buffer& operator=(llama_host_buffer&&) = delete;
};
#else
typedef llama_buffer llama_ctx_buffer;
typedef llama_buffer llama_host_buffer;
#endif
typedef llama_buffer llama_ctx_buffer;
#endif

1700
llama.cpp

File diff suppressed because it is too large Load diff

View file

@ -2,12 +2,7 @@
#define LLAMA_H
#include "ggml.h"
#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h"
#define LLAMA_MAX_DEVICES GGML_CUDA_MAX_DEVICES
#else
#define LLAMA_MAX_DEVICES 1
#endif // GGML_USE_CUBLAS
#include <stddef.h>
#include <stdint.h>
#include <stdbool.h>
@ -48,7 +43,7 @@
#define LLAMA_DEFAULT_SEED 0xFFFFFFFF
#if defined(GGML_USE_CUBLAS) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
#if defined(GGML_USE_CUDA) || defined(GGML_USE_CLBLAST) || defined(GGML_USE_METAL)
// Defined when llama.cpp is compiled with support for offloading model layers to GPU.
#define LLAMA_SUPPORTS_GPU_OFFLOAD
#endif