Merge branch 'master' into convert_rope_scale

This commit is contained in:
Nigel Bosch 2023-08-25 09:16:50 -05:00 committed by GitHub
commit 4950b2dbbc
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
26 changed files with 2579 additions and 1275 deletions

View file

@ -0,0 +1,44 @@
ARG UBUNTU_VERSION=22.04
# This needs to generally match the container host's environment.
ARG ROCM_VERSION=5.6
# Target the CUDA build image
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
FROM ${BASE_ROCM_DEV_CONTAINER} as build
# Unless otherwise specified, we make a fat build.
# List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878
# This is mostly tied to rocBLAS supported archs.
ARG ROCM_DOCKER_ARCH=\
gfx803 \
gfx900 \
gfx906 \
gfx908 \
gfx90a \
gfx1010 \
gfx1030 \
gfx1100 \
gfx1101 \
gfx1102
COPY requirements.txt requirements.txt
RUN pip install --upgrade pip setuptools wheel \
&& pip install -r requirements.txt
WORKDIR /app
COPY . .
# Set nvcc architecture
ENV GPU_TARGETS=${ROCM_DOCKER_ARCH}
# Enable ROCm
ENV LLAMA_HIPBLAS=1
ENV CC=/opt/rocm/llvm/bin/clang
ENV CXX=/opt/rocm/llvm/bin/clang++
RUN make
ENTRYPOINT ["/app/.devops/tools.sh"]

View file

@ -0,0 +1,44 @@
ARG UBUNTU_VERSION=22.04
# This needs to generally match the container host's environment.
ARG ROCM_VERSION=5.6
# Target the CUDA build image
ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete
FROM ${BASE_ROCM_DEV_CONTAINER} as build
# Unless otherwise specified, we make a fat build.
# List from https://github.com/ggerganov/llama.cpp/pull/1087#issuecomment-1682807878
# This is mostly tied to rocBLAS supported archs.
ARG ROCM_DOCKER_ARCH=\
gfx803 \
gfx900 \
gfx906 \
gfx908 \
gfx90a \
gfx1010 \
gfx1030 \
gfx1100 \
gfx1101 \
gfx1102
COPY requirements.txt requirements.txt
RUN pip install --upgrade pip setuptools wheel \
&& pip install -r requirements.txt
WORKDIR /app
COPY . .
# Set nvcc architecture
ENV GPU_TARGETS=${ROCM_DOCKER_ARCH}
# Enable ROCm
ENV LLAMA_HIPBLAS=1
ENV CC=/opt/rocm/llvm/bin/clang
ENV CXX=/opt/rocm/llvm/bin/clang++
RUN make
ENTRYPOINT [ "/app/main" ]

View file

@ -5,14 +5,7 @@
.vscode/ .vscode/
.DS_Store .DS_Store
build/ build*/
build-em/
build-debug/
build-release/
build-static/
build-no-accel/
build-sanitize-addr/
build-sanitize-thread/
models/* models/*

17
.gitignore vendored
View file

@ -16,20 +16,7 @@
.vs/ .vs/
.vscode/ .vscode/
build/ build*/
build-em/
build-debug/
build-release/
build-ci-debug/
build-ci-release/
build-static/
build-cublas/
build-opencl/
build-metal/
build-mpi/
build-no-accel/
build-sanitize-addr/
build-sanitize-thread/
out/ out/
tmp/ tmp/
@ -60,6 +47,7 @@ compile_commands.json
CMakeSettings.json CMakeSettings.json
__pycache__ __pycache__
dist
zig-out/ zig-out/
zig-cache/ zig-cache/
@ -70,7 +58,6 @@ perf-*.txt
examples/jeopardy/results.txt examples/jeopardy/results.txt
pyproject.toml
poetry.lock poetry.lock
poetry.toml poetry.toml

View file

@ -74,6 +74,7 @@ set(LLAMA_CUDA_DMMV_X "32" CACHE STRING "llama: x stride for dmmv CUDA kern
set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels") set(LLAMA_CUDA_MMV_Y "1" CACHE STRING "llama: y block size for mmv CUDA kernels")
option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF) option(LLAMA_CUDA_F16 "llama: use 16 bit floats for some calculations" OFF)
set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K") set(LLAMA_CUDA_KQUANTS_ITER "2" CACHE STRING "llama: iters./thread per block for Q2_K/Q6_K")
option(LLAMA_HIPBLAS "llama: use hipBLAS" OFF)
option(LLAMA_CLBLAST "llama: use CLBlast" OFF) option(LLAMA_CLBLAST "llama: use CLBlast" OFF)
option(LLAMA_METAL "llama: use Metal" OFF) option(LLAMA_METAL "llama: use Metal" OFF)
option(LLAMA_MPI "llama: use MPI" OFF) option(LLAMA_MPI "llama: use MPI" OFF)
@ -352,6 +353,43 @@ if (LLAMA_CLBLAST)
endif() endif()
endif() endif()
if (LLAMA_HIPBLAS)
list(APPEND CMAKE_PREFIX_PATH /opt/rocm)
if (NOT ${CMAKE_C_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CC=/opt/rocm/llvm/bin/clang")
endif()
if (NOT ${CMAKE_CXX_COMPILER_ID} MATCHES "Clang")
message(WARNING "Only LLVM is supported for HIP, hint: CXX=/opt/rocm/llvm/bin/clang++")
endif()
find_package(hip)
find_package(hipblas)
find_package(rocblas)
if (${hipblas_FOUND} AND ${hip_FOUND})
message(STATUS "HIP and hipBLAS found")
add_compile_definitions(GGML_USE_HIPBLAS GGML_USE_CUBLAS)
add_library(ggml-rocm OBJECT ggml-cuda.cu ggml-cuda.h)
if (LLAMA_CUDA_FORCE_DMMV)
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_FORCE_DMMV)
endif()
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_DMMV_X=${LLAMA_CUDA_DMMV_X})
target_compile_definitions(ggml-rocm PRIVATE GGML_CUDA_MMV_Y=${LLAMA_CUDA_MMV_Y})
target_compile_definitions(ggml-rocm PRIVATE K_QUANTS_PER_ITERATION=${LLAMA_CUDA_KQUANTS_ITER})
target_compile_definitions(ggml-rocm PRIVATE CC_TURING=1000000000)
set_source_files_properties(ggml-cuda.cu PROPERTIES LANGUAGE CXX)
target_link_libraries(ggml-rocm PRIVATE hip::device PUBLIC hip::host roc::rocblas roc::hipblas)
if (LLAMA_STATIC)
message(FATAL_ERROR "Static linking not supported for HIP/ROCm")
endif()
set(LLAMA_EXTRA_LIBS ${LLAMA_EXTRA_LIBS} ggml-rocm)
else()
message(WARNING "hipBLAS or HIP not found. Try setting CMAKE_PREFIX_PATH=/opt/rocm")
endif()
endif()
if (LLAMA_ALL_WARNINGS) if (LLAMA_ALL_WARNINGS)
if (NOT MSVC) if (NOT MSVC)
set(c_flags set(c_flags

View file

@ -280,6 +280,30 @@ ggml-opencl.o: ggml-opencl.cpp ggml-opencl.h
$(CXX) $(CXXFLAGS) -c $< -o $@ $(CXX) $(CXXFLAGS) -c $< -o $@
endif # LLAMA_CLBLAST endif # LLAMA_CLBLAST
ifdef LLAMA_HIPBLAS
ROCM_PATH ?= /opt/rocm
HIPCC ?= $(ROCM_PATH)/bin/hipcc
GPU_TARGETS ?= $(shell $(ROCM_PATH)/llvm/bin/amdgpu-arch)
LLAMA_CUDA_DMMV_X ?= 32
LLAMA_CUDA_MMV_Y ?= 1
LLAMA_CUDA_KQUANTS_ITER ?= 2
CFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
CXXFLAGS += -DGGML_USE_HIPBLAS -DGGML_USE_CUBLAS
LDFLAGS += -L$(ROCM_PATH)/lib -Wl,-rpath=$(ROCM_PATH)/lib
LDFLAGS += -lhipblas -lamdhip64 -lrocblas
HIPFLAGS += $(addprefix --offload-arch=,$(GPU_TARGETS))
HIPFLAGS += -DGGML_CUDA_DMMV_X=$(LLAMA_CUDA_DMMV_X)
HIPFLAGS += -DGGML_CUDA_MMV_Y=$(LLAMA_CUDA_MMV_Y)
HIPFLAGS += -DK_QUANTS_PER_ITERATION=$(LLAMA_CUDA_KQUANTS_ITER)
HIPFLAGS += -DCC_TURING=1000000000
ifdef LLAMA_CUDA_FORCE_DMMV
HIPFLAGS += -DGGML_CUDA_FORCE_DMMV
endif # LLAMA_CUDA_FORCE_DMMV
OBJS += ggml-cuda.o
ggml-cuda.o: ggml-cuda.cu ggml-cuda.h
$(HIPCC) $(CXXFLAGS) $(HIPFLAGS) -x hip -c -o $@ $<
endif # LLAMA_HIPBLAS
ifdef LLAMA_METAL ifdef LLAMA_METAL
CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG CFLAGS += -DGGML_USE_METAL -DGGML_METAL_NDEBUG
CXXFLAGS += -DGGML_USE_METAL CXXFLAGS += -DGGML_USE_METAL

View file

@ -422,6 +422,35 @@ Building the program with BLAS support may lead to some performance improvements
| LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. | | LLAMA_CUDA_F16 | Boolean | false | If enabled, use half-precision floating point arithmetic for the CUDA dequantization + mul mat vec kernels and for the q4_1 and q5_1 matrix matrix multiplication kernels. Can improve performance on relatively recent GPUs. |
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. | | LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per CUDA thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
- #### hipBLAS
This provide BLAS acceleation on HIP supported GPU like AMD GPU.
Make sure to have ROCm installed.
You can download it from your Linux distro's package manager or from here: [ROCm Quick Start (Linux)](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html).
Windows support is coming soon...
- Using `make`:
```bash
make LLAMA_HIPBLAS=1
```
- Using `CMake`:
```bash
mkdir build
cd build
CC=/opt/rocm/llvm/bin/clang CXX=/opt/rocm/llvm/bin/clang++ cmake .. -DLLAMA_HIPBLAS=ON
cmake --build .
```
The environment variable [`HIP_VISIBLE_DEVICES`](https://rocm.docs.amd.com/en/latest/understand/gpu_isolation.html#hip-visible-devices) can be used to specify which GPU(s) will be used.
If your GPU is not officialy supported you can use the environment variable [`HSA_OVERRIDE_GFX_VERSION`] set to a similar GPU, for example 10.3.0 on RDNA2 or 11.0.0 on RDNA3.
The following compilation options are also available to tweak performance (yes, they refer to CUDA, not HIP, because it uses the same code as the cuBLAS version above):
| Option | Legal values | Default | Description |
|-------------------------|------------------------|---------|-------------|
| LLAMA_CUDA_DMMV_X | Positive integer >= 32 | 32 | Number of values in x direction processed by the HIP dequantization + matrix vector multiplication kernel per iteration. Increasing this value can improve performance on fast GPUs. Power of 2 heavily recommended. Does not affect k-quants. |
| LLAMA_CUDA_MMV_Y | Positive integer | 1 | Block size in y direction for the HIP mul mat vec kernels. Increasing this value can improve performance on fast GPUs. Power of 2 recommended. Does not affect k-quants. |
| LLAMA_CUDA_KQUANTS_ITER | 1 or 2 | 2 | Number of values processed per iteration and per HIP thread for Q2_K and Q6_K quantization formats. Setting this value to 1 can improve performance for slow GPUs. |
- #### CLBlast - #### CLBlast
OpenCL acceleration is provided by the matrix multiplication kernels from the [CLBlast](https://github.com/CNugteren/CLBlast) project and custom kernels for ggml that can generate tokens on the GPU. OpenCL acceleration is provided by the matrix multiplication kernels from the [CLBlast](https://github.com/CNugteren/CLBlast) project and custom kernels for ggml that can generate tokens on the GPU.

View file

@ -391,6 +391,7 @@ if [ -z ${GG_BUILD_LOW_PERF} ]; then
ln -sfn ${mnt_models} ${SRC}/models-mnt ln -sfn ${mnt_models} ${SRC}/models-mnt
python3 -m pip install -r ${SRC}/requirements.txt python3 -m pip install -r ${SRC}/requirements.txt
python3 -m pip install --editable gguf-py
fi fi
ret=0 ret=0

View file

@ -613,9 +613,11 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n"); fprintf(stdout, " how to split tensors across multiple GPUs, comma-separated list of proportions, e.g. 3,1\n");
fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n"); fprintf(stdout, " -mg i, --main-gpu i the GPU to use for scratch and small tensors\n");
fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n"); fprintf(stdout, " -lv, --low-vram don't allocate VRAM scratch buffer\n");
#ifdef GGML_USE_CUBLAS
fprintf(stdout, " -nommq, --no-mul-mat-q\n"); fprintf(stdout, " -nommq, --no-mul-mat-q\n");
fprintf(stdout, " use cuBLAS instead of custom mul_mat_q CUDA kernels.\n"); fprintf(stdout, " use " GGML_CUBLAS_NAME " instead of custom mul_mat_q " GGML_CUDA_NAME " kernels.\n");
fprintf(stdout, " Not recommended since this is both slower and uses more VRAM.\n"); fprintf(stdout, " Not recommended since this is both slower and uses more VRAM.\n");
#endif // GGML_USE_CUBLAS
#endif #endif
fprintf(stdout, " --mtest compute maximum memory usage\n"); fprintf(stdout, " --mtest compute maximum memory usage\n");
fprintf(stdout, " --export export the computation graph to 'llama.ggml'\n"); fprintf(stdout, " --export export the computation graph to 'llama.ggml'\n");

View file

@ -168,6 +168,7 @@ class Params:
n_head = config["num_attention_heads"] n_head = config["num_attention_heads"]
n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head n_head_kv = config["num_key_value_heads"] if "num_key_value_heads" in config else n_head
f_norm_eps = config["rms_norm_eps"] f_norm_eps = config["rms_norm_eps"]
f_rope_freq_base = config["rope_theta"] if "rope_theta" in config else None
if "rope_scaling" in config and config["rope_scaling"].get("type") == "linear": if "rope_scaling" in config and config["rope_scaling"].get("type") == "linear":
f_rope_scale = config["rope_scaling"].get("factor") f_rope_scale = config["rope_scaling"].get("factor")
@ -194,6 +195,7 @@ class Params:
n_head = n_head, n_head = n_head,
n_head_kv = n_head_kv, n_head_kv = n_head_kv,
f_norm_eps = f_norm_eps, f_norm_eps = f_norm_eps,
f_rope_freq_base = f_rope_freq_base,
f_rope_scale = f_rope_scale, f_rope_scale = f_rope_scale,
) )

View file

@ -18,9 +18,7 @@
#include "llama.h" #include "llama.h"
#include "common.h" #include "common.h"
#include "build-info.h" #include "build-info.h"
#ifdef GGML_USE_CUBLAS
#include "ggml-cuda.h" #include "ggml-cuda.h"
#endif
// utils // utils
static uint64_t get_time_ns() { static uint64_t get_time_ns() {
@ -443,6 +441,8 @@ struct test {
static const std::string gpu_info; static const std::string gpu_info;
std::string model_filename; std::string model_filename;
std::string model_type; std::string model_type;
uint64_t model_size;
uint64_t model_n_params;
int n_batch; int n_batch;
int n_threads; int n_threads;
bool f32_kv; bool f32_kv;
@ -459,8 +459,10 @@ struct test {
test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) { test(const cmd_params_instance & inst, const llama_model * lmodel, const llama_context * ctx) {
model_filename = inst.model; model_filename = inst.model;
char buf[128]; char buf[128];
llama_model_type(lmodel, buf, sizeof(buf)); llama_model_desc(lmodel, buf, sizeof(buf));
model_type = buf; model_type = buf;
model_size = llama_model_size(lmodel);
model_n_params = llama_model_n_params(lmodel);
n_batch = inst.n_batch; n_batch = inst.n_batch;
n_threads = inst.n_threads; n_threads = inst.n_threads;
f32_kv = inst.f32_kv; f32_kv = inst.f32_kv;
@ -504,7 +506,7 @@ struct test {
static std::string get_backend() { static std::string get_backend() {
if (cuda) { if (cuda) {
return "CUDA"; return GGML_CUDA_NAME;
} }
if (opencl) { if (opencl) {
return "OpenCL"; return "OpenCL";
@ -526,7 +528,7 @@ struct test {
"build_commit", "build_number", "build_commit", "build_number",
"cuda", "opencl", "metal", "gpu_blas", "blas", "cuda", "opencl", "metal", "gpu_blas", "blas",
"cpu_info", "gpu_info", "cpu_info", "gpu_info",
"model_filename", "model_type", "model_filename", "model_type", "model_size", "model_n_params",
"n_batch", "n_threads", "f16_kv", "n_batch", "n_threads", "f16_kv",
"n_gpu_layers", "main_gpu", "mul_mat_q", "low_vram", "tensor_split", "n_gpu_layers", "main_gpu", "mul_mat_q", "low_vram", "tensor_split",
"n_prompt", "n_gen", "test_time", "n_prompt", "n_gen", "test_time",
@ -540,6 +542,7 @@ struct test {
static field_type get_field_type(const std::string & field) { static field_type get_field_type(const std::string & field) {
if (field == "build_number" || field == "n_batch" || field == "n_threads" || if (field == "build_number" || field == "n_batch" || field == "n_threads" ||
field == "model_size" || field == "model_n_params" ||
field == "n_gpu_layers" || field == "main_gpu" || field == "n_gpu_layers" || field == "main_gpu" ||
field == "n_prompt" || field == "n_gen" || field == "n_prompt" || field == "n_gen" ||
field == "avg_ns" || field == "stddev_ns") { field == "avg_ns" || field == "stddev_ns") {
@ -575,7 +578,7 @@ struct test {
build_commit, std::to_string(build_number), build_commit, std::to_string(build_number),
std::to_string(cuda), std::to_string(opencl), std::to_string(metal), std::to_string(gpu_blas), std::to_string(blas), std::to_string(cuda), std::to_string(opencl), std::to_string(metal), std::to_string(gpu_blas), std::to_string(blas),
cpu_info, gpu_info, cpu_info, gpu_info,
model_filename, model_type, model_filename, model_type, std::to_string(model_size), std::to_string(model_n_params),
std::to_string(n_batch), std::to_string(n_threads), std::to_string(!f32_kv), std::to_string(n_batch), std::to_string(n_threads), std::to_string(!f32_kv),
std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), std::to_string(low_vram), tensor_split_str, std::to_string(n_gpu_layers), std::to_string(main_gpu), std::to_string(mul_mat_q), std::to_string(low_vram), tensor_split_str,
std::to_string(n_prompt), std::to_string(n_gen), test_time, std::to_string(n_prompt), std::to_string(n_gen), test_time,
@ -711,8 +714,15 @@ struct markdown_printer : public printer {
return -30; return -30;
} }
if (field == "t/s") { if (field == "t/s") {
return 15; return 16;
} }
if (field == "size" || field == "params") {
return 10;
}
if (field == "n_gpu_layers") {
return 3;
}
int width = std::max((int)field.length(), 10); int width = std::max((int)field.length(), 10);
if (test::get_field_type(field) == test::STRING) { if (test::get_field_type(field) == test::STRING) {
@ -721,9 +731,28 @@ struct markdown_printer : public printer {
return width; return width;
} }
static std::string get_field_display_name(const std::string & field) {
if (field == "n_gpu_layers") {
return "ngl";
}
if (field == "n_threads") {
return "threads";
}
if (field == "mul_mat_q") {
return "mmq";
}
if (field == "tensor_split") {
return "ts";
}
return field;
}
void print_header(const cmd_params & params) override { void print_header(const cmd_params & params) override {
// select fields to print // select fields to print
fields = { "model", "backend" }; fields.push_back("model");
fields.push_back("size");
fields.push_back("params");
fields.push_back("backend");
bool is_cpu_backend = test::get_backend() == "CPU" || test::get_backend() == "BLAS"; bool is_cpu_backend = test::get_backend() == "CPU" || test::get_backend() == "BLAS";
if (!is_cpu_backend) { if (!is_cpu_backend) {
fields.push_back("n_gpu_layers"); fields.push_back("n_gpu_layers");
@ -754,7 +783,7 @@ struct markdown_printer : public printer {
fprintf(fout, "|"); fprintf(fout, "|");
for (const auto & field : fields) { for (const auto & field : fields) {
fprintf(fout, " %*s |", get_field_width(field), field.c_str()); fprintf(fout, " %*s |", get_field_width(field), get_field_display_name(field).c_str());
} }
fprintf(fout, "\n"); fprintf(fout, "\n");
fprintf(fout, "|"); fprintf(fout, "|");
@ -771,12 +800,26 @@ struct markdown_printer : public printer {
fprintf(fout, "|"); fprintf(fout, "|");
for (const auto & field : fields) { for (const auto & field : fields) {
std::string value; std::string value;
char buf[128];
if (field == "model") { if (field == "model") {
value = t.model_type; value = t.model_type;
} else if (field == "size") {
if (t.model_size < 1024*1024*1024) {
snprintf(buf, sizeof(buf), "%.2f MiB", t.model_size / 1024.0 / 1024.0);
} else {
snprintf(buf, sizeof(buf), "%.2f GiB", t.model_size / 1024.0 / 1024.0 / 1024.0);
}
value = buf;
} else if (field == "params") {
if (t.model_n_params < 1000*1000*1000) {
snprintf(buf, sizeof(buf), "%.2f M", t.model_n_params / 1e6);
} else {
snprintf(buf, sizeof(buf), "%.2f B", t.model_n_params / 1e9);
}
value = buf;
} else if (field == "backend") { } else if (field == "backend") {
value = test::get_backend(); value = test::get_backend();
} else if (field == "test") { } else if (field == "test") {
char buf[128];
if (t.n_prompt > 0 && t.n_gen == 0) { if (t.n_prompt > 0 && t.n_gen == 0) {
snprintf(buf, sizeof(buf), "pp %d", t.n_prompt); snprintf(buf, sizeof(buf), "pp %d", t.n_prompt);
} else if (t.n_gen > 0 && t.n_prompt == 0) { } else if (t.n_gen > 0 && t.n_prompt == 0) {
@ -787,7 +830,6 @@ struct markdown_printer : public printer {
} }
value = buf; value = buf;
} else if (field == "t/s") { } else if (field == "t/s") {
char buf[128];
snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts()); snprintf(buf, sizeof(buf), "%.2f ± %.2f", t.avg_ts(), t.stdev_ts());
value = buf; value = buf;
} else if (vmap.find(field) != vmap.end()) { } else if (vmap.find(field) != vmap.end()) {

File diff suppressed because it is too large Load diff

View file

@ -102,6 +102,17 @@
padding: 0.5em; padding: 0.5em;
} }
.prob-set {
padding: 0.3em;
border-bottom: 1px solid #ccc;
}
.popover-content {
position: absolute;
background-color: white;
padding: 0.2em;
box-shadow: 0 0 10px rgba(0, 0, 0, 0.1);
}
textarea { textarea {
padding: 5px; padding: 5px;
@ -133,11 +144,17 @@
font-size: 80%; font-size: 80%;
color: #888; color: #888;
} }
@media (prefers-color-scheme: dark) {
.popover-content {
background-color: black;
}
}
</style> </style>
<script type="module"> <script type="module">
import { import {
html, h, signal, effect, computed, render, useSignal, useEffect, useRef html, h, signal, effect, computed, render, useSignal, useEffect, useRef, Component
} from '/index.js'; } from '/index.js';
import { llama } from '/completion.js'; import { llama } from '/completion.js';
@ -168,6 +185,7 @@
mirostat_tau: 5, // target entropy mirostat_tau: 5, // target entropy
mirostat_eta: 0.1, // learning rate mirostat_eta: 0.1, // learning rate
grammar: '', grammar: '',
n_probs: 0, // no completion_probabilities
}) })
/* START: Support for storing prompt templates and parameters in borwser LocalStorage */ /* START: Support for storing prompt templates and parameters in borwser LocalStorage */
@ -334,10 +352,21 @@
const prompt = template(session.value.template, { const prompt = template(session.value.template, {
message: msg, message: msg,
history: session.value.transcript.flatMap(([name, message]) => template(session.value.historyTemplate, {name, message})).join("\n"), history: session.value.transcript.flatMap(
([name, data]) =>
template(
session.value.historyTemplate,
{
name,
message: Array.isArray(data) ?
data.map(msg => msg.content).join('').replace(/^\s/, '') :
data,
}
)
).join("\n"),
}); });
let currentMessage = ''; const currentMessages = [];
const history = session.value.transcript const history = session.value.transcript
const llamaParams = { const llamaParams = {
@ -347,15 +376,19 @@
for await (const chunk of llama(prompt, llamaParams, { controller: controller.value })) { for await (const chunk of llama(prompt, llamaParams, { controller: controller.value })) {
const data = chunk.data; const data = chunk.data;
currentMessage += data.content;
// remove leading whitespace
currentMessage = currentMessage.replace(/^\s+/, "")
transcriptUpdate([...history, ["{{char}}", currentMessage]])
if (data.stop) { if (data.stop) {
console.log("Completion finished: '", currentMessage, "', summary: ", data); while (
currentMessages.length > 0 &&
currentMessages[currentMessages.length - 1].content.match(/\n$/) != null
) {
currentMessages.pop();
}
transcriptUpdate([...history, ["{{char}}", currentMessages]])
console.log("Completion finished: '", currentMessages.map(msg => msg.content).join(''), "', summary: ", data);
} else {
currentMessages.push(data);
transcriptUpdate([...history, ["{{char}}", currentMessages]])
} }
if (data.timings) { if (data.timings) {
@ -420,8 +453,18 @@
} }
}, [messages]) }, [messages])
const chatLine = ([user, msg]) => { const chatLine = ([user, data], index) => {
return html`<p key=${msg}><strong>${template(user)}:</strong> <${Markdownish} text=${template(msg)} /></p>` let message
const isArrayMessage = Array.isArray(data)
if (params.value.n_probs > 0 && isArrayMessage) {
message = html`<${Probabilities} data=${data} />`
} else {
const text = isArrayMessage ?
data.map(msg => msg.content).join('').replace(/^\s+/, '') :
data;
message = html`<${Markdownish} text=${template(text)} />`
}
return html`<p key=${index}><strong>${template(user)}:</strong> ${message}</p>`
}; };
return html` return html`
@ -568,10 +611,71 @@
${FloatField({label: "Mirostat tau", max: 10.0, min: 0.0, name: "mirostat_tau", step: 0.01, value: params.value.mirostat_tau})} ${FloatField({label: "Mirostat tau", max: 10.0, min: 0.0, name: "mirostat_tau", step: 0.01, value: params.value.mirostat_tau})}
${FloatField({label: "Mirostat eta", max: 1.0, min: 0.0, name: "mirostat_eta", step: 0.01, value: params.value.mirostat_eta})} ${FloatField({label: "Mirostat eta", max: 1.0, min: 0.0, name: "mirostat_eta", step: 0.01, value: params.value.mirostat_eta})}
</fieldset> </fieldset>
<fieldset>
${IntField({label: "Show Probabilities", max: 10, min: 0, name: "n_probs", value: params.value.n_probs})}
</fieldset>
</details> </details>
</form> </form>
` `
} }
const probColor = (p) => {
const r = Math.floor(192 * (1 - p));
const g = Math.floor(192 * p);
return `rgba(${r},${g},0,0.3)`;
}
const Probabilities = (params) => {
return params.data.map(msg => {
const { completion_probabilities } = msg;
if (
!completion_probabilities ||
completion_probabilities.length === 0
) return msg.content
if (completion_probabilities.length > 1) {
// Not for byte pair
if (completion_probabilities[0].content.startsWith('byte: \\')) return msg.content
const splitData = completion_probabilities.map(prob => ({
content: prob.content,
completion_probabilities: [prob]
}))
return html`<${Probabilities} data=${splitData} />`
}
const { probs, content } = completion_probabilities[0]
const found = probs.find(p => p.tok_str === msg.content)
const pColor = found ? probColor(found.prob) : 'transparent'
const popoverChildren = html`
<div class="prob-set">
${probs.map((p, index) => {
return html`
<div
key=${index}
title=${`prob: ${p.prob}`}
style=${{
padding: '0.3em',
backgroundColor: p.tok_str === content ? probColor(p.prob) : 'transparent'
}}
>
<span>${p.tok_str}: </span>
<span>${Math.floor(p.prob * 100)}%</span>
</div>
`
})}
</div>
`
return html`
<${Popover} style=${{ backgroundColor: pColor }} popoverChildren=${popoverChildren}>
${msg.content.match(/\n/gim) ? html`<br />` : msg.content}
</>
`
});
}
// poor mans markdown replacement // poor mans markdown replacement
const Markdownish = (params) => { const Markdownish = (params) => {
const md = params.text const md = params.text
@ -600,10 +704,121 @@
` `
} }
// simple popover impl
const Popover = (props) => {
const isOpen = useSignal(false);
const position = useSignal({ top: '0px', left: '0px' });
const buttonRef = useRef(null);
const popoverRef = useRef(null);
const togglePopover = () => {
if (buttonRef.current) {
const rect = buttonRef.current.getBoundingClientRect();
position.value = {
top: `${rect.bottom + window.scrollY}px`,
left: `${rect.left + window.scrollX}px`,
};
}
isOpen.value = !isOpen.value;
};
const handleClickOutside = (event) => {
if (popoverRef.current && !popoverRef.current.contains(event.target) && !buttonRef.current.contains(event.target)) {
isOpen.value = false;
}
};
useEffect(() => {
document.addEventListener('mousedown', handleClickOutside);
return () => {
document.removeEventListener('mousedown', handleClickOutside);
};
}, []);
return html`
<span style=${props.style} ref=${buttonRef} onClick=${togglePopover}>${props.children}</span>
${isOpen.value && html`
<${Portal} into="#portal">
<div
ref=${popoverRef}
class="popover-content"
style=${{
top: position.value.top,
left: position.value.left,
}}
>
${props.popoverChildren}
</div>
</${Portal}>
`}
`;
};
// Source: preact-portal (https://github.com/developit/preact-portal/blob/master/src/preact-portal.js)
/** Redirect rendering of descendants into the given CSS selector */
class Portal extends Component {
componentDidUpdate(props) {
for (let i in props) {
if (props[i] !== this.props[i]) {
return setTimeout(this.renderLayer);
}
}
}
componentDidMount() {
this.isMounted = true;
this.renderLayer = this.renderLayer.bind(this);
this.renderLayer();
}
componentWillUnmount() {
this.renderLayer(false);
this.isMounted = false;
if (this.remote && this.remote.parentNode) this.remote.parentNode.removeChild(this.remote);
}
findNode(node) {
return typeof node === 'string' ? document.querySelector(node) : node;
}
renderLayer(show = true) {
if (!this.isMounted) return;
// clean up old node if moving bases:
if (this.props.into !== this.intoPointer) {
this.intoPointer = this.props.into;
if (this.into && this.remote) {
this.remote = render(html`<${PortalProxy} />`, this.into, this.remote);
}
this.into = this.findNode(this.props.into);
}
this.remote = render(html`
<${PortalProxy} context=${this.context}>
${show && this.props.children || null}
</${PortalProxy}>
`, this.into, this.remote);
}
render() {
return null;
}
}
// high-order component that renders its first child if it exists.
// used as a conditional rendering proxy.
class PortalProxy extends Component {
getChildContext() {
return this.props.context;
}
render({ children }) {
return children || null;
}
}
function App(props) { function App(props) {
return html` return html`
<div id="container"> <div>
<header> <header>
<h1>llama.cpp</h1> <h1>llama.cpp</h1>
</header> </header>
@ -624,11 +839,13 @@
`; `;
} }
render(h(App), document.body); render(h(App), document.querySelector('#container'));
</script> </script>
</head> </head>
<body> <body>
<div id="container"></div>
<div id="portal"></div>
</body> </body>
</html> </html>

View file

@ -124,8 +124,9 @@ static void server_log(const char *level, const char *function, int line,
static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token) static std::string tokens_to_output_formatted_string(const llama_context *ctx, const llama_token token)
{ {
std::string out = token == -1 ? "" : llama_token_to_str(ctx, token); std::string out = token == -1 ? "" : llama_token_to_str(ctx, token);
// if first bit is 1, meaning it's a partial character // if the size is 1 and first bit is 1, meaning it's a partial character
if (out.size() > 0 && (out[0] & 0x80) == 0x80) // (size > 1 meaning it's already a known token)
if (out.size() == 1 && (out[0] & 0x80) == 0x80)
{ {
std::stringstream ss; std::stringstream ss;
ss << std::hex << (out[0] & 0xff); ss << std::hex << (out[0] & 0xff);
@ -1321,27 +1322,36 @@ int main(int argc, char **argv)
while (llama.has_next_token) { while (llama.has_next_token) {
const completion_token_output token_with_probs = llama.doCompletion(); const completion_token_output token_with_probs = llama.doCompletion();
const std::string token_text = token_with_probs.tok == -1 ? "" : llama_token_to_str(llama.ctx, token_with_probs.tok); if (token_with_probs.tok == -1 || llama.multibyte_pending > 0) {
if (llama.multibyte_pending > 0) {
continue; continue;
} }
const std::string token_text = llama_token_to_str(llama.ctx, token_with_probs.tok);
size_t pos = std::min(sent_count, llama.generated_text.size()); size_t pos = std::min(sent_count, llama.generated_text.size());
const std::string str_test = llama.generated_text.substr(pos); const std::string str_test = llama.generated_text.substr(pos);
bool is_stop_full = false;
size_t stop_pos = size_t stop_pos =
llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL); llama.findStoppingStrings(str_test, token_text.size(), STOP_FULL);
if (stop_pos != std::string::npos) { if (stop_pos != std::string::npos) {
is_stop_full = true;
llama.generated_text.erase( llama.generated_text.erase(
llama.generated_text.begin() + pos + stop_pos, llama.generated_text.begin() + pos + stop_pos,
llama.generated_text.end()); llama.generated_text.end());
pos = std::min(sent_count, llama.generated_text.size()); pos = std::min(sent_count, llama.generated_text.size());
} else { } else {
is_stop_full = false;
stop_pos = llama.findStoppingStrings(str_test, token_text.size(), stop_pos = llama.findStoppingStrings(str_test, token_text.size(),
STOP_PARTIAL); STOP_PARTIAL);
} }
const std::string to_send = llama.generated_text.substr(pos, stop_pos); if (
stop_pos == std::string::npos ||
// Send rest of the text if we are at the end of the generation
(!llama.has_next_token && !is_stop_full && stop_pos > 0)
) {
const std::string to_send = llama.generated_text.substr(pos, std::string::npos);
sent_count += to_send.size(); sent_count += to_send.size();
std::vector<completion_token_output> probs_output = {}; std::vector<completion_token_output> probs_output = {};
@ -1356,10 +1366,7 @@ int main(int argc, char **argv)
sent_token_probs_index = probs_stop_pos; sent_token_probs_index = probs_stop_pos;
} }
const json data = llama.has_next_token const json data = format_partial_response(llama, to_send, probs_output);
? format_partial_response(llama, to_send, probs_output)
// Generation is done, send extra information.
: format_final_response(llama, to_send, llama.generated_token_probs);
const std::string str = const std::string str =
"data: " + "data: " +
@ -1377,6 +1384,27 @@ int main(int argc, char **argv)
} }
} }
if (!llama.has_next_token) {
// Generation is done, send extra information.
const json data = format_final_response(llama, "", llama.generated_token_probs);
const std::string str =
"data: " +
data.dump(-1, ' ', false, json::error_handler_t::replace) +
"\n\n";
LOG_VERBOSE("data stream", {
{ "to_send", str }
});
if (!sink.write(str.data(), str.size())) {
LOG_VERBOSE("stream closed", {});
llama_print_timings(llama.ctx);
return false;
}
}
}
llama_print_timings(llama.ctx); llama_print_timings(llama.ctx);
sink.done(); sink.done();
return true; return true;

View file

@ -8,6 +8,7 @@
#define UNUSED(x) (void)(x) #define UNUSED(x) (void)(x)
#define MAX(a, b) ((a) > (b) ? (a) : (b)) #define MAX(a, b) ((a) > (b) ? (a) : (b))
#define GGML_MAX_CONCUR (2*GGML_MAX_NODES)
//#define GGML_ALLOCATOR_DEBUG //#define GGML_ALLOCATOR_DEBUG
@ -67,7 +68,7 @@ struct ggml_allocr {
struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE]; struct hash_node hash_table[GGML_GRAPH_HASHTABLE_SIZE];
size_t max_size; size_t max_size;
bool measure; bool measure;
int parse_seq[GGML_MAX_NODES]; int parse_seq[GGML_MAX_CONCUR];
int parse_seq_len; int parse_seq_len;
#ifdef GGML_ALLOCATOR_DEBUG #ifdef GGML_ALLOCATOR_DEBUG

View file

@ -6,15 +6,116 @@
#include <atomic> #include <atomic>
#include <assert.h> #include <assert.h>
#if defined(GGML_USE_HIPBLAS)
#include <hip/hip_runtime.h>
#include <hipblas/hipblas.h>
#include <hip/hip_fp16.h>
#ifdef __HIP_PLATFORM_AMD__
// for rocblas_initialize()
#include "rocblas/rocblas.h"
#endif
#define CUBLAS_COMPUTE_32F HIPBLAS_R_32F
#define CUBLAS_COMPUTE_32F_FAST_16F HIPBLAS_R_32F
#define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT
#define CUBLAS_OP_N HIPBLAS_OP_N
#define CUBLAS_OP_T HIPBLAS_OP_T
#define CUBLAS_STATUS_SUCCESS HIPBLAS_STATUS_SUCCESS
#define CUBLAS_TF32_TENSOR_OP_MATH 0
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define __shfl_xor_sync(mask, var, laneMask, width) __shfl_xor(var, laneMask, width)
#define cublasCreate hipblasCreate
#define cublasGemmEx hipblasGemmEx
#define cublasHandle_t hipblasHandle_t
#define cublasSetMathMode(handle, mode) CUBLAS_STATUS_SUCCESS
#define cublasSetStream hipblasSetStream
#define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t
#define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t
#define cudaEventCreateWithFlags hipEventCreateWithFlags
#define cudaEventDisableTiming hipEventDisableTiming
#define cudaEventRecord hipEventRecord
#define cudaEvent_t hipEvent_t
#define cudaEventDestroy hipEventDestroy
#define cudaFree hipFree
#define cudaFreeHost hipHostFree
#define cudaGetDevice hipGetDevice
#define cudaGetDeviceCount hipGetDeviceCount
#define cudaGetDeviceProperties hipGetDeviceProperties
#define cudaGetErrorString hipGetErrorString
#define cudaGetLastError hipGetLastError
#define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)
#define cudaMemcpy hipMemcpy
#define cudaMemcpy2DAsync hipMemcpy2DAsync
#define cudaMemcpyAsync hipMemcpyAsync
#define cudaMemcpyDeviceToDevice hipMemcpyDeviceToDevice
#define cudaMemcpyDeviceToHost hipMemcpyDeviceToHost
#define cudaMemcpyHostToDevice hipMemcpyHostToDevice
#define cudaMemcpyKind hipMemcpyKind
#define cudaMemset hipMemset
#define cudaOccupancyMaxPotentialBlockSize hipOccupancyMaxPotentialBlockSize
#define cudaSetDevice hipSetDevice
#define cudaStreamCreateWithFlags hipStreamCreateWithFlags
#define cudaStreamNonBlocking hipStreamNonBlocking
#define cudaStreamSynchronize hipStreamSynchronize
#define cudaStreamWaitEvent(stream, event) hipStreamWaitEvent(stream, event, 0)
#define cudaStream_t hipStream_t
#define cudaSuccess hipSuccess
#else
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include <cublas_v2.h> #include <cublas_v2.h>
#include <cuda_fp16.h> #include <cuda_fp16.h>
#endif
#include "ggml-cuda.h" #include "ggml-cuda.h"
#include "ggml.h" #include "ggml.h"
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products #define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#ifndef CC_TURING
#define CC_TURING 700 #define CC_TURING 700
#endif
#if defined(GGML_USE_HIPBLAS)
#define __CUDA_ARCH__ 1300
typedef int8_t int8x4_t __attribute__((ext_vector_type(4)));
static __device__ __forceinline__ int __vsubss4(const int a, const int b) {
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
const int8x4_t c = __builtin_elementwise_sub_sat(va, vb);
return reinterpret_cast<const int&>(c);
}
static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) {
#if defined(__gfx906__) || defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx1030__)
c = __builtin_amdgcn_sdot4(a, b, c, false);
#elif defined(__gfx1100__)
c = __builtin_amdgcn_sudot4( true, a, true, b, c, false);
#elif defined(__gfx1010__) || defined(__gfx900__)
int tmp1;
int tmp2;
asm("\n \
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_0 src1_sel:BYTE_0 \n \
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_1 src1_sel:BYTE_1 \n \
v_add3_u32 %0, %1, %2, %0 \n \
v_mul_i32_i24 %1, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_2 src1_sel:BYTE_2 \n \
v_mul_i32_i24 %2, sext(%3), sext(%4) dst_sel:DWORD dst_unused:UNUSED_PAD src0_sel:BYTE_3 src1_sel:BYTE_3 \n \
v_add3_u32 %0, %1, %2, %0 \n \
"
: "+v"(c), "=&v"(tmp1), "=&v"(tmp2)
: "v"(a), "v"(b)
);
#else
const int8x4_t va = reinterpret_cast<const int8x4_t&>(a);
const int8x4_t vb = reinterpret_cast<const int8x4_t&>(b);
c += va[0] * vb[0] + va[1] * vb[1] + va[2] * vb[2] + va[3] * vb[3];
#endif
return c;
}
#endif
#if defined(_MSC_VER) #if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
@ -424,8 +525,8 @@ static __device__ __forceinline__ void dequantize_q4_0(const void * vx, const in
static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ static __device__ __forceinline__ void dequantize_q4_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q4_1 * x = (const block_q4_1 *) vx; const block_q4_1 * x = (const block_q4_1 *) vx;
const dfloat d = x[ib].dm.x; const dfloat d = __low2half(x[ib].dm);
const dfloat m = x[ib].dm.y; const dfloat m = __high2half(x[ib].dm);
const int vui = x[ib].qs[iqs]; const int vui = x[ib].qs[iqs];
@ -467,8 +568,8 @@ static __device__ __forceinline__ void dequantize_q5_0(const void * vx, const in
static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){ static __device__ __forceinline__ void dequantize_q5_1(const void * vx, const int ib, const int iqs, dfloat2 & v){
const block_q5_1 * x = (const block_q5_1 *) vx; const block_q5_1 * x = (const block_q5_1 *) vx;
const dfloat d = x[ib].dm.x; const dfloat d = __low2half(x[ib].dm);
const dfloat m = x[ib].dm.y; const dfloat m = __high2half(x[ib].dm);
uint32_t qh; uint32_t qh;
memcpy(&qh, x[ib].qh, sizeof(qh)); memcpy(&qh, x[ib].qh, sizeof(qh));
@ -520,8 +621,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
const uint8_t q = x[i].qs[32*n + l]; const uint8_t q = x[i].qs[32*n + l];
float * y = yy + i*QK_K + 128*n; float * y = yy + i*QK_K + 128*n;
float dall = x[i].dm.x; float dall = __low2half(x[i].dm);
float dmin = x[i].dm.y; float dmin = __high2half(x[i].dm);
y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); y[l+ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4); y[l+32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 2) & 3) - dmin * (x[i].scales[is+2] >> 4);
y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4); y[l+64] = dall * (x[i].scales[is+4] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+4] >> 4);
@ -531,8 +632,8 @@ static __global__ void dequantize_block_q2_K(const void * __restrict__ vx, float
const int il = tid%16; // 0...15 const int il = tid%16; // 0...15
const uint8_t q = x[i].qs[il] >> (2*is); const uint8_t q = x[i].qs[il] >> (2*is);
float * y = yy + i*QK_K + 16*is + il; float * y = yy + i*QK_K + 16*is + il;
float dall = x[i].dm.x; float dall = __low2half(x[i].dm);
float dmin = x[i].dm.y; float dmin = __high2half(x[i].dm);
y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4); y[ 0] = dall * (x[i].scales[is+0] & 0xF) * ((q >> 0) & 3) - dmin * (x[i].scales[is+0] >> 4);
y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4); y[32] = dall * (x[i].scales[is+2] & 0xF) * ((q >> 4) & 3) - dmin * (x[i].scales[is+2] >> 4);
#endif #endif
@ -618,8 +719,8 @@ static __global__ void dequantize_block_q4_K(const void * __restrict__ vx, float
float * y = yy + i*QK_K + 64*il + n*ir; float * y = yy + i*QK_K + 64*il + n*ir;
const float dall = x[i].dm.x; const float dall = __low2half(x[i].dm);
const float dmin = x[i].dm.y; const float dmin = __high2half(x[i].dm);
const uint8_t * q = x[i].qs + 32*il + n*ir; const uint8_t * q = x[i].qs + 32*il + n*ir;
@ -657,8 +758,8 @@ static __global__ void dequantize_block_q5_K(const void * __restrict__ vx, float
float * y = yy + i*QK_K + 64*il + 2*ir; float * y = yy + i*QK_K + 64*il + 2*ir;
const float dall = x[i].dm.x; const float dall = __low2half(x[i].dm);
const float dmin = x[i].dm.y; const float dmin = __high2half(x[i].dm);
const uint8_t * ql = x[i].qs + 32*il + 2*ir; const uint8_t * ql = x[i].qs + 32*il + 2*ir;
const uint8_t * qh = x[i].qh + 2*ir; const uint8_t * qh = x[i].qh + 2*ir;
@ -770,8 +871,8 @@ static __global__ void dequantize_mul_mat_vec_q2_k(const void * __restrict__ vx,
const float * y = yy + i * QK_K + y_offset; const float * y = yy + i * QK_K + y_offset;
const uint8_t * q = x[i].qs + q_offset; const uint8_t * q = x[i].qs + q_offset;
const float dall = x[i].dm.x; const float dall = __low2half(x[i].dm);
const float dmin = x[i].dm.y; const float dmin = __high2half(x[i].dm);
const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset); const uint32_t * a = (const uint32_t *)(x[i].scales + s_offset);
aux[0] = a[0] & 0x0f0f0f0f; aux[0] = a[0] & 0x0f0f0f0f;
@ -991,8 +1092,8 @@ static __global__ void dequantize_mul_mat_vec_q4_k(const void * __restrict__ vx,
const float * y1 = yy + i*QK_K + y_offset; const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128; const float * y2 = y1 + 128;
const float dall = x[i].dm.x; const float dall = __low2half(x[i].dm);
const float dmin = x[i].dm.y; const float dmin = __high2half(x[i].dm);
const uint16_t * a = (const uint16_t *)x[i].scales; const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1; aux[0] = a[im+0] & kmask1;
@ -1124,8 +1225,8 @@ static __global__ void dequantize_mul_mat_vec_q5_k(const void * __restrict__ vx,
const float * y1 = yy + i*QK_K + y_offset; const float * y1 = yy + i*QK_K + y_offset;
const float * y2 = y1 + 128; const float * y2 = y1 + 128;
const float dall = x[i].dm.x; const float dall = __low2half(x[i].dm);
const float dmin = x[i].dm.y; const float dmin = __high2half(x[i].dm);
const uint16_t * a = (const uint16_t *)x[i].scales; const uint16_t * a = (const uint16_t *)x[i].scales;
aux[0] = a[im+0] & kmask1; aux[0] = a[im+0] & kmask1;
@ -1348,8 +1449,8 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
return; return;
} }
y[ib].ds.x = d; reinterpret_cast<half&>(y[ib].ds.x) = d;
y[ib].ds.y = sum; reinterpret_cast<half&>(y[ib].ds.y) = sum;
} }
template <int qk, int qr, dequantize_kernel_t dequantize_kernel> template <int qk, int qr, dequantize_kernel_t dequantize_kernel>
@ -2346,7 +2447,7 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i); u[i] = get_int_from_int8_aligned(bq8_1->qs, iqs + i);
} }
return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, bq8_1->ds.x); return vec_dot_q8_0_q8_1_impl<VDR_Q8_0_Q8_1_MMVQ>(v, u, bq8_0->d, __low2half(bq8_1->ds));
} }
template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { template <int mmq_y> static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) {
@ -2432,7 +2533,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
#pragma unroll #pragma unroll
for (int i = 0; i < QR2_K; ++ i) { for (int i = 0; i < QR2_K; ++ i) {
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
d8[i] = bq8_1[bq8_offset + i].ds.x; d8[i] = __low2half(bq8_1[bq8_offset + i].ds);
} }
return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8); return vec_dot_q2_K_q8_1_impl_mmvq(v, u, scales, bq2_K->dm, d8);
@ -2551,7 +2652,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
#pragma unroll #pragma unroll
for (int i = 0; i < QR3_K; ++i) { for (int i = 0; i < QR3_K; ++i) {
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1); u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + i].qs, iqs % QI8_1);
d8[i] = bq8_1[bq8_offset + i].ds.x; d8[i] = __low2half(bq8_1[bq8_offset + i].ds);
} }
return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8); return vec_dot_q3_K_q8_1_impl_mmvq(vl, vh, u, bq3_K->scales, scale_offset, d, d8);
@ -2720,7 +2821,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
for (int i = 0; i < QR4_K; ++i) { for (int i = 0; i < QR4_K; ++i) {
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
d8[i] = bq8i->ds.x; d8[i] = __low2half(bq8i->ds);
const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
u[2*i+0] = q8[0]; u[2*i+0] = q8[0];
@ -2747,8 +2848,8 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
const float dall = bq4_K->d[0]; const float dall = bq4_K->d[0];
const float dmin = bq4_K->d[1]; const float dmin = bq4_K->d[1];
const float d8_1 = bq8_1[0].ds.x; const float d8_1 = __low2float(bq8_1[0].ds);
const float d8_2 = bq8_1[1].ds.x; const float d8_2 = __low2float(bq8_1[1].ds);
const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
@ -2901,7 +3002,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
#pragma unroll #pragma unroll
for (int i = 0; i < QR5_K; ++i) { for (int i = 0; i < QR5_K; ++i) {
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i; const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
d8[i] = bq8i->ds.x; d8[i] = __low2float(bq8i->ds);
const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4); const int * q8 = (const int *)bq8i->qs + ((iqs/2)%4);
u[2*i+0] = q8[0]; u[2*i+0] = q8[0];
@ -2919,8 +3020,8 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
const float d = bq5_K->d; const float d = bq5_K->d;
const float d8_1 = bq8_1[0].ds.x; const float d8_1 = __low2half(bq8_1[0].ds);
const float d8_2 = bq8_1[1].ds.x; const float d8_2 = __low2half(bq8_1[1].ds);
const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2)); const int ui1 = *((const int *)bq8_1[0].qs + (iqs/2));
const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4); const int ui2 = *((const int *)bq8_1[0].qs + (iqs/2) + 4);
@ -3075,7 +3176,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
#pragma unroll #pragma unroll
for (int i = 0; i < QR6_K; ++i) { for (int i = 0; i < QR6_K; ++i) {
u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1); u[i] = get_int_from_int8_aligned(bq8_1[bq8_offset + 2*i].qs, iqs % QI8_1);
d8[i] = bq8_1[bq8_offset + 2*i].ds.x; d8[i] = __low2half(bq8_1[bq8_offset + 2*i].ds);
} }
return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8); return vec_dot_q6_K_q8_1_impl_mmvq(vl, vh, u, scales, bq6_K->d, d8);
@ -3243,7 +3344,7 @@ static __device__ __forceinline__ void mul_mat_q(
*dsi_dst = *dsi_src; *dsi_dst = *dsi_src;
} else { } else {
float * dfi_dst = (float *) dsi_dst; float * dfi_dst = (float *) dsi_dst;
*dfi_dst = (*dsi_src).x; *dfi_dst = __low2half(*dsi_src);
} }
} }
@ -3907,28 +4008,27 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
dst[i + 1] = x0*sin_theta + x1*cos_theta; dst[i + 1] = x0*sin_theta + x1*cos_theta;
} }
// TODO: this implementation is wrong! static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0,
//static __global__ void rope_neox_f32(const float * x, float * dst, const int ncols, const float p0, const float p_delta, const int p_delta_rows, const float theta_scale) {
// const float p_delta, const int p_delta_rows, const float theta_scale) { const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
// const int col = 2*(blockDim.y*blockIdx.y + threadIdx.y);
// if (col >= ncols) {
// if (col >= ncols) { return;
// return; }
// }
// const int row = blockDim.x*blockIdx.x + threadIdx.x;
// const int row = blockDim.x*blockIdx.x + threadIdx.x; const int i = row*ncols + col/2;
// const int i = row*ncols + col/2;
// const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2);
// const float theta = (p0 + p_delta * (row/p_delta_rows))*powf(theta_scale, col/2); const float sin_theta = sinf(theta);
// const float sin_theta = sinf(theta); const float cos_theta = cosf(theta);
// const float cos_theta = cosf(theta);
// const float x0 = x[i + 0];
// const float x0 = x[i + 0]; const float x1 = x[i + ncols/2];
// const float x1 = x[i + ncols/2];
// dst[i + 0] = x0*cos_theta - x1*sin_theta;
// dst[i + 0] = x0*cos_theta - x1*sin_theta; dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
// dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; }
//}
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) { static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
const int col = blockDim.x*blockIdx.x + threadIdx.x; const int col = blockDim.x*blockIdx.x + threadIdx.x;
@ -4799,13 +4899,21 @@ static void scale_f32_cuda(const float * x, float * dst, const float scale, cons
static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0, static void rope_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) { const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(nrows % 2 == 0); GGML_ASSERT(nrows % 2 == 0); // GG: is this assert really needed? I don't see why
const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1); const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1); const dim3 block_nums(nrows, num_blocks_x, 1);
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
} }
static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale, cudaStream_t stream) {
const dim3 block_dims(1, 2*CUDA_ROPE_BLOCK_SIZE, 1);
const int num_blocks_x = (ncols + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nrows, num_blocks_x, 1);
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
}
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) { static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(nrows % 4 == 0); GGML_ASSERT(nrows % 4 == 0);
const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1); const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1);
@ -4937,10 +5045,18 @@ void ggml_init_cublas() {
static bool initialized = false; static bool initialized = false;
if (!initialized) { if (!initialized) {
#ifdef __HIP_PLATFORM_AMD__
// Workaround for a rocBLAS bug when using multiple graphics cards:
// https://github.com/ROCmSoftwarePlatform/rocBLAS/issues/1346
rocblas_initialize();
CUDA_CHECK(cudaDeviceSynchronize());
#endif
CUDA_CHECK(cudaGetDeviceCount(&g_device_count)); CUDA_CHECK(cudaGetDeviceCount(&g_device_count));
GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES); GGML_ASSERT(g_device_count <= GGML_CUDA_MAX_DEVICES);
int64_t total_vram = 0; int64_t total_vram = 0;
fprintf(stderr, "%s: found %d CUDA devices:\n", __func__, g_device_count); fprintf(stderr, "%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, g_device_count);
for (int id = 0; id < g_device_count; ++id) { for (int id = 0; id < g_device_count; ++id) {
cudaDeviceProp prop; cudaDeviceProp prop;
CUDA_CHECK(cudaGetDeviceProperties(&prop, id)); CUDA_CHECK(cudaGetDeviceProperties(&prop, id));
@ -5548,8 +5664,9 @@ inline void ggml_cuda_op_rope(
const float block_p = max(p - (n_ctx - 2.f), 0.f); const float block_p = max(p - (n_ctx - 2.f), 0.f);
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main); rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
} else if (is_neox) { } else if (is_neox) {
GGML_ASSERT(false && "RoPE NeoX not implemented yet"); GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
#pragma message("TODO: implement RoPE NeoX for CUDA") const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
} else { } else {
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale; const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main); rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);

View file

@ -2,6 +2,14 @@
#include "ggml.h" #include "ggml.h"
#ifdef GGML_USE_HIPBLAS
#define GGML_CUDA_NAME "ROCm"
#define GGML_CUBLAS_NAME "hipBLAS"
#else
#define GGML_CUDA_NAME "CUDA"
#define GGML_CUBLAS_NAME "cuBLAS"
#endif
#ifdef __cplusplus #ifdef __cplusplus
extern "C" { extern "C" {
#endif #endif

21
gguf-py/LICENSE Normal file
View file

@ -0,0 +1,21 @@
MIT License
Copyright (c) 2023 Georgi Gerganov
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

55
gguf-py/README.md Normal file
View file

@ -0,0 +1,55 @@
## gguf
This is a Python package for writing binary files in the [GGUF](https://github.com/ggerganov/ggml/pull/302)
(GGML Universal File) format.
See [convert-llama-hf-to-gguf.py](https://github.com/ggerganov/llama.cpp/blob/master/convert-llama-hf-to-gguf.py)
as an example for its usage.
## Installation
```sh
pip install gguf
```
## Development
Maintainers who participate in development of this package are advised to install it in editable mode:
```sh
cd /path/to/llama.cpp/gguf-py
pip install --editable .
```
**Note**: This may require to upgrade your Pip installation, with a message saying that editable installation currently requires `setup.py`.
In this case, upgrade Pip to the latest:
```sh
pip install --upgrade pip
```
## Publishing
To publish the package, you need to have `twine` and `build` installed:
```sh
pip install build twine
```
Then, folow these steps to release a new version:
1. Update the version in `pyproject.toml`.
2. Build the package:
```sh
python -m build
```
3. Upload the generated distribution archives:
```sh
python -m twine upload dist/*
```
## TODO
- [ ] Add tests
- [ ] Include conversion scripts as command line entry points in this package.
- Add CI workflow for releasing the package.

1
gguf-py/gguf/__init__.py Normal file
View file

@ -0,0 +1 @@
from .gguf import *

0
gguf.py → gguf-py/gguf/gguf.py Executable file → Normal file
View file

28
gguf-py/pyproject.toml Normal file
View file

@ -0,0 +1,28 @@
[tool.poetry]
name = "gguf"
version = "0.2.1"
description = "Write ML models in GGUF for GGML"
authors = ["GGML <ggml@ggml.ai>"]
packages = [
{include = "gguf"},
]
readme = "README.md"
homepage = "https://ggml.ai"
repository = "https://github.com/ggerganov/llama.cpp"
keywords = ["ggml", "gguf", "llama.cpp"]
classifiers = [
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
]
[tool.poetry.dependencies]
python = ">=3.8"
numpy = ">=1.17"
[tool.poetry.dev-dependencies]
pytest = "^5.2"
[build-system]
requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

View file

@ -0,0 +1,7 @@
import gguf
# TODO: add tests
def test_write_gguf():
pass

View file

@ -1836,7 +1836,7 @@ static void llm_load_tensors(
(void) main_gpu; (void) main_gpu;
(void) mul_mat_q; (void) mul_mat_q;
#if defined(GGML_USE_CUBLAS) #if defined(GGML_USE_CUBLAS)
LLAMA_LOG_INFO("%s: using CUDA for GPU acceleration\n", __func__); LLAMA_LOG_INFO("%s: using " GGML_CUDA_NAME " for GPU acceleration\n", __func__);
ggml_cuda_set_main_device(main_gpu); ggml_cuda_set_main_device(main_gpu);
ggml_cuda_set_mul_mat_q(mul_mat_q); ggml_cuda_set_mul_mat_q(mul_mat_q);
#define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU #define LLAMA_BACKEND_OFFLOAD GGML_BACKEND_GPU
@ -1958,6 +1958,14 @@ static void llm_load_tensors(
model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm); model.output_norm = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, backend_norm);
model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm); model.output_norm_b = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, backend_norm);
model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output); model.output = ml.create_tensor(ctx, tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, backend_output);
if (backend_norm == GGML_BACKEND_GPU) {
vram_weights += ggml_nbytes(model.output_norm);
vram_weights += ggml_nbytes(model.output_norm_b);
}
if (backend_output == GGML_BACKEND_GPU_SPLIT) {
vram_weights += ggml_nbytes(model.output);
}
} }
const uint32_t n_ff = hparams.n_ff; const uint32_t n_ff = hparams.n_ff;
@ -1978,6 +1986,11 @@ static void llm_load_tensors(
if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) { if (gguf_find_tensor(ml.ctx_gguf, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i).c_str()) >= 0) {
layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend); layer.attn_norm_2 = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "weight", i), {n_embd}, backend);
layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, backend); layer.attn_norm_2_b = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, backend);
if (backend == GGML_BACKEND_GPU) {
vram_weights += ggml_nbytes(layer.attn_norm_2);
vram_weights += ggml_nbytes(layer.attn_norm_2_b);
}
} }
layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split); layer.wqkv = ml.create_tensor(ctx, tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, backend_split);
@ -1985,6 +1998,13 @@ static void llm_load_tensors(
layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split); layer.w2 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, backend_split);
layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split); layer.w3 = ml.create_tensor(ctx, tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, backend_split);
if (backend == GGML_BACKEND_GPU) {
vram_weights +=
ggml_nbytes(layer.attn_norm) + ggml_nbytes(layer.attn_norm_b) +
ggml_nbytes(layer.wqkv) + ggml_nbytes(layer.wo) +
ggml_nbytes(layer.w2) + ggml_nbytes(layer.w3);
}
} }
} break; } break;
default: default:
@ -5277,13 +5297,29 @@ int llama_model_n_embd(const struct llama_model * model) {
return model->hparams.n_embd; return model->hparams.n_embd;
} }
int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size) { int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size) {
return snprintf(buf, buf_size, "%s %s %s", return snprintf(buf, buf_size, "%s %s %s",
model->name.c_str(), model->name.c_str(),
llama_model_type_name(model->type), llama_model_type_name(model->type),
llama_model_ftype_name(model->ftype).c_str()); llama_model_ftype_name(model->ftype).c_str());
} }
uint64_t llama_model_size(const struct llama_model * model) {
uint64_t size = 0;
for (const auto & it : model->tensors_by_name) {
size += ggml_nbytes(it.second);
}
return size;
}
uint64_t llama_model_n_params(const struct llama_model * model) {
uint64_t nparams = 0;
for (const auto & it : model->tensors_by_name) {
nparams += ggml_nelements(it.second);
}
return nparams;
}
int llama_model_quantize( int llama_model_quantize(
const char * fname_inp, const char * fname_inp,
const char * fname_out, const char * fname_out,

View file

@ -254,7 +254,11 @@ extern "C" {
LLAMA_API int llama_model_n_embd (const struct llama_model * model); LLAMA_API int llama_model_n_embd (const struct llama_model * model);
// Get a string describing the model type // Get a string describing the model type
LLAMA_API int llama_model_type(const struct llama_model * model, char * buf, size_t buf_size); LLAMA_API int llama_model_desc(const struct llama_model * model, char * buf, size_t buf_size);
// Returns the total size of all the tensors in the model in bytes
LLAMA_API uint64_t llama_model_size(const struct llama_model * model);
// Returns the total number of parameters in the model
LLAMA_API uint64_t llama_model_n_params(const struct llama_model * model);
// Returns 0 on success // Returns 0 on success
LLAMA_API int llama_model_quantize( LLAMA_API int llama_model_quantize(
@ -348,7 +352,7 @@ extern "C" {
LLAMA_API float llama_token_get_score(const struct llama_context * ctx, llama_token token); LLAMA_API float llama_token_get_score(const struct llama_context * ctx, llama_token token);
LLAMA_API llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token); LLAMA_API enum llama_token_type llama_token_get_type(const struct llama_context * ctx, llama_token token);
// Special tokens // Special tokens
LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence LLAMA_API llama_token llama_token_bos(const struct llama_context * ctx); // beginning-of-sentence

View file

@ -1,2 +1,3 @@
numpy==1.24 numpy==1.24
sentencepiece==0.1.98 sentencepiece==0.1.98
gguf>=0.1.0