Merge branch 'ggerganov:master' into betterlogs2

This commit is contained in:
staviq 2023-09-11 17:20:18 +02:00 committed by GitHub
commit ec86950b09
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
23 changed files with 594 additions and 274 deletions

View file

@ -12,7 +12,7 @@ FROM ${BASE_CUDA_DEV_CONTAINER} as build
ARG CUDA_DOCKER_ARCH=all ARG CUDA_DOCKER_ARCH=all
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y build-essential python3 python3-pip apt-get install -y build-essential python3 python3-pip git
COPY requirements.txt requirements.txt COPY requirements.txt requirements.txt

View file

@ -12,7 +12,7 @@ FROM ${BASE_CUDA_DEV_CONTAINER} as build
ARG CUDA_DOCKER_ARCH=all ARG CUDA_DOCKER_ARCH=all
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y build-essential apt-get install -y build-essential git
WORKDIR /app WORKDIR /app

View file

@ -197,6 +197,62 @@ jobs:
cd build cd build
ctest --verbose --timeout 900 ctest --verbose --timeout 900
macOS-latest-cmake-ios:
runs-on: macos-latest
steps:
- name: Clone
id: checkout
uses: actions/checkout@v1
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
- name: Build
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake -G Xcode .. \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-DCMAKE_SYSTEM_NAME=iOS \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0
cmake --build . --config Release
macOS-latest-cmake-tvos:
runs-on: macos-latest
steps:
- name: Clone
id: checkout
uses: actions/checkout@v1
- name: Dependencies
id: depends
continue-on-error: true
run: |
brew update
- name: Build
id: cmake_build
run: |
sysctl -a
mkdir build
cd build
cmake -G Xcode .. \
-DLLAMA_BUILD_EXAMPLES=OFF \
-DLLAMA_BUILD_TESTS=OFF \
-DLLAMA_BUILD_SERVER=OFF \
-DCMAKE_SYSTEM_NAME=tvOS \
-DCMAKE_OSX_DEPLOYMENT_TARGET=14.0
cmake --build . --config Release
windows-latest-cmake: windows-latest-cmake:
runs-on: windows-latest runs-on: windows-latest

View file

@ -476,7 +476,7 @@ if (NOT MSVC)
endif() endif()
endif() endif()
if (${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm" OR ${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64") if ((${CMAKE_SYSTEM_PROCESSOR} MATCHES "arm") OR (${CMAKE_SYSTEM_PROCESSOR} MATCHES "aarch64"))
message(STATUS "ARM detected") message(STATUS "ARM detected")
if (MSVC) if (MSVC)
# TODO: arm msvc? # TODO: arm msvc?
@ -551,6 +551,55 @@ else()
message(STATUS "Unknown architecture") message(STATUS "Unknown architecture")
endif() endif()
#
# POSIX conformance
#
# clock_gettime came in POSIX.1b (1993)
# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
# posix_memalign came in POSIX.1-2001 / SUSv3
# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985)
add_compile_definitions(_XOPEN_SOURCE=600)
# Somehow in OpenBSD whenever POSIX conformance is specified
# some string functions rely on locale_t availability,
# which was introduced in POSIX.1-2008, forcing us to go higher
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
remove_definitions(-D_XOPEN_SOURCE=600)
add_compile_definitions(_XOPEN_SOURCE=700)
endif()
# Data types, macros and functions related to controlling CPU affinity and
# some memory allocation are available on Linux through GNU extensions in libc
if (CMAKE_SYSTEM_NAME MATCHES "Linux")
add_compile_definitions(_GNU_SOURCE)
endif()
# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
# and on macOS its availability depends on enabling Darwin extensions
# similarly on DragonFly, enabling BSD extensions is necessary
if (
CMAKE_SYSTEM_NAME MATCHES "Darwin" OR
CMAKE_SYSTEM_NAME MATCHES "iOS" OR
CMAKE_SYSTEM_NAME MATCHES "tvOS" OR
CMAKE_SYSTEM_NAME MATCHES "DragonFly"
)
add_compile_definitions(_DARWIN_C_SOURCE)
endif()
# alloca is a non-standard interface that is not visible on BSDs when
# POSIX conformance is specified, but not all of them provide a clean way
# to enable it in such cases
if (CMAKE_SYSTEM_NAME MATCHES "FreeBSD")
add_compile_definitions(__BSD_VISIBLE)
endif()
if (CMAKE_SYSTEM_NAME MATCHES "NetBSD")
add_compile_definitions(_NETBSD_SOURCE)
endif()
if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD")
add_compile_definitions(_BSD_SOURCE)
endif()
# #
# libraries # libraries
# #

View file

@ -106,6 +106,56 @@ MK_CFLAGS = $(OPT) -std=c11 -fPIC
MK_CXXFLAGS = $(OPT) -std=c++11 -fPIC MK_CXXFLAGS = $(OPT) -std=c++11 -fPIC
MK_LDFLAGS = MK_LDFLAGS =
# clock_gettime came in POSIX.1b (1993)
# CLOCK_MONOTONIC came in POSIX.1-2001 / SUSv3 as optional
# posix_memalign came in POSIX.1-2001 / SUSv3
# M_PI is an XSI extension since POSIX.1-2001 / SUSv3, came in XPG1 (1985)
MK_CFLAGS += -D_XOPEN_SOURCE=600
MK_CXXFLAGS += -D_XOPEN_SOURCE=600
# Somehow in OpenBSD whenever POSIX conformance is specified
# some string functions rely on locale_t availability,
# which was introduced in POSIX.1-2008, forcing us to go higher
ifeq ($(UNAME_S),OpenBSD)
MK_CFLAGS += -U_XOPEN_SOURCE -D_XOPEN_SOURCE=700
MK_CXXFLAGS += -U_XOPEN_SOURCE -D_XOPEN_SOURCE=700
endif
# Data types, macros and functions related to controlling CPU affinity and
# some memory allocation are available on Linux through GNU extensions in libc
ifeq ($(UNAME_S),Linux)
MK_CFLAGS += -D_GNU_SOURCE
MK_CXXFLAGS += -D_GNU_SOURCE
endif
# RLIMIT_MEMLOCK came in BSD, is not specified in POSIX.1,
# and on macOS its availability depends on enabling Darwin extensions
# similarly on DragonFly, enabling BSD extensions is necessary
ifeq ($(UNAME_S),Darwin)
MK_CFLAGS += -D_DARWIN_C_SOURCE
MK_CXXFLAGS += -D_DARWIN_C_SOURCE
endif
ifeq ($(UNAME_S),DragonFly)
MK_CFLAGS += -D__BSD_VISIBLE
MK_CXXFLAGS += -D__BSD_VISIBLE
endif
# alloca is a non-standard interface that is not visible on BSDs when
# POSIX conformance is specified, but not all of them provide a clean way
# to enable it in such cases
ifeq ($(UNAME_S),FreeBSD)
MK_CFLAGS += -D__BSD_VISIBLE
MK_CXXFLAGS += -D__BSD_VISIBLE
endif
ifeq ($(UNAME_S),NetBSD)
MK_CFLAGS += -D_NETBSD_SOURCE
MK_CXXFLAGS += -D_NETBSD_SOURCE
endif
ifeq ($(UNAME_S),OpenBSD)
MK_CFLAGS += -D_BSD_SOURCE
MK_CXXFLAGS += -D_BSD_SOURCE
endif
ifdef LLAMA_DEBUG ifdef LLAMA_DEBUG
MK_CFLAGS += -O0 -g MK_CFLAGS += -O0 -g
MK_CXXFLAGS += -O0 -g MK_CXXFLAGS += -O0 -g

View file

@ -2,8 +2,30 @@
import PackageDescription import PackageDescription
#if arch(arm) || arch(arm64)
let platforms: [SupportedPlatform]? = [
.macOS(.v11),
.iOS(.v14),
.watchOS(.v4),
.tvOS(.v14)
]
let exclude: [String] = []
let additionalSources: [String] = ["ggml-metal.m"]
let additionalSettings: [CSetting] = [
.unsafeFlags(["-fno-objc-arc"]),
.define("GGML_SWIFT"),
.define("GGML_USE_METAL")
]
#else
let platforms: [SupportedPlatform]? = nil
let exclude: [String] = ["ggml-metal.metal"]
let additionalSources: [String] = []
let additionalSettings: [CSetting] = []
#endif
let package = Package( let package = Package(
name: "llama", name: "llama",
platforms: platforms,
products: [ products: [
.library(name: "llama", targets: ["llama"]), .library(name: "llama", targets: ["llama"]),
], ],
@ -11,23 +33,23 @@ let package = Package(
.target( .target(
name: "llama", name: "llama",
path: ".", path: ".",
exclude: ["ggml-metal.metal"], exclude: exclude,
sources: [ sources: [
"ggml.c", "ggml.c",
"llama.cpp", "llama.cpp",
"ggml-alloc.c", "ggml-alloc.c",
"k_quants.c" "k_quants.c",
], ] + additionalSources,
publicHeadersPath: "spm-headers", publicHeadersPath: "spm-headers",
cSettings: [ cSettings: [
.unsafeFlags(["-Wno-shorten-64-to-32"]), .unsafeFlags(["-Wno-shorten-64-to-32"]),
.define("GGML_USE_K_QUANTS"), .define("GGML_USE_K_QUANTS"),
.define("GGML_USE_ACCELERATE") .define("GGML_USE_ACCELERATE")
], ] + additionalSettings,
linkerSettings: [ linkerSettings: [
.linkedFramework("Accelerate") .linkedFramework("Accelerate")
] ]
), )
], ],
cxxLanguageStandard: .cxx11 cxxLanguageStandard: .cxx11
) )

View file

@ -11,21 +11,9 @@ Inference of [LLaMA](https://arxiv.org/abs/2302.13971) model in pure C/C++
### Hot topics ### Hot topics
- #### IMPORTANT: Tokenizer fixes and API change (developers and projects using `llama.cpp` built-in tokenization must read): https://github.com/ggerganov/llama.cpp/pull/2810 - Local Falcon 180B inference on Mac Studio
- GGUFv2 adds support for 64-bit sizes + backwards compatible: https://github.com/ggerganov/llama.cpp/pull/2821 https://github.com/ggerganov/llama.cpp/assets/1991296/98abd4e8-7077-464c-ae89-aebabca7757e
- Added support for Falcon models: https://github.com/ggerganov/llama.cpp/pull/2717
- A new file format has been introduced: [GGUF](https://github.com/ggerganov/llama.cpp/pull/2398)
Last revision compatible with the old format: [dadbed9](https://github.com/ggerganov/llama.cpp/commit/dadbed99e65252d79f81101a392d0d6497b86caa)
### Current `master` should be considered in Beta - expect some issues for a few days!
### Be prepared to re-convert and / or re-quantize your GGUF models while this notice is up!
### Issues with non-GGUF models will be considered with low priority!
---- ----
@ -413,7 +401,7 @@ Building the program with BLAS support may lead to some performance improvements
- #### hipBLAS - #### hipBLAS
This provide BLAS acceleation on HIP supported GPU like AMD GPU. This provides BLAS acceleration on HIP-supported AMD GPUs.
Make sure to have ROCm installed. Make sure to have ROCm installed.
You can download it from your Linux distro's package manager or from here: [ROCm Quick Start (Linux)](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html). You can download it from your Linux distro's package manager or from here: [ROCm Quick Start (Linux)](https://rocm.docs.amd.com/en/latest/deploy/linux/quick_start.html).
Windows support is coming soon... Windows support is coming soon...
@ -737,12 +725,12 @@ python3 convert.py pygmalion-7b/ --outtype q4_1
- Refer to [Facebook's LLaMA download page](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) if you want to access the model data. - Refer to [Facebook's LLaMA download page](https://ai.meta.com/resources/models-and-libraries/llama-downloads/) if you want to access the model data.
- Alternatively, if you want to save time and space, you can download already converted and quantized models from [TheBloke](https://huggingface.co/TheBloke), including: - Alternatively, if you want to save time and space, you can download already converted and quantized models from [TheBloke](https://huggingface.co/TheBloke), including:
- [LLaMA 2 7B base](https://huggingface.co/TheBloke/Llama-2-7B-GGML) - [LLaMA 2 7B base](https://huggingface.co/TheBloke/Llama-2-7B-GGUF)
- [LLaMA 2 13B base](https://huggingface.co/TheBloke/Llama-2-13B-GGML) - [LLaMA 2 13B base](https://huggingface.co/TheBloke/Llama-2-13B-GGUF)
- [LLaMA 2 70B base](https://huggingface.co/TheBloke/Llama-2-70B-GGML) - [LLaMA 2 70B base](https://huggingface.co/TheBloke/Llama-2-70B-GGUF)
- [LLaMA 2 7B chat](https://huggingface.co/TheBloke/Llama-2-7B-chat-GGML) - [LLaMA 2 7B chat](https://huggingface.co/TheBloke/Llama-2-7B-chat-GGUF)
- [LLaMA 2 13B chat](https://huggingface.co/TheBloke/Llama-2-13B-chat-GGML) - [LLaMA 2 13B chat](https://huggingface.co/TheBloke/Llama-2-13B-chat-GGUF)
- [LLaMA 2 70B chat](https://huggingface.co/TheBloke/Llama-2-70B-chat-GGML) - [LLaMA 2 70B chat](https://huggingface.co/TheBloke/Llama-2-70B-chat-GGUF)
### Verifying the model files ### Verifying the model files

View file

@ -145,7 +145,6 @@ GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = {
class Params: class Params:
n_vocab: int n_vocab: int
n_embd: int n_embd: int
n_mult: int
n_layer: int n_layer: int
n_ctx: int n_ctx: int
n_ff: int n_ff: int
@ -161,15 +160,6 @@ class Params:
# path to the directory containing the model files # path to the directory containing the model files
path_model: Path | None = None path_model: Path | None = None
@staticmethod
def find_n_mult(n_ff: int, n_embd: int) -> int:
# hardcoded magic range
for n_mult in range(8192, 1, -1):
calc_ff = (((8*n_embd) // 3 + n_mult - 1) // n_mult)*n_mult
if calc_ff == n_ff:
return n_mult
raise Exception(f"failed to find n_mult for (n_ff={n_ff}, n_embd={n_embd}).")
@staticmethod @staticmethod
def guessed(model: LazyModel) -> Params: def guessed(model: LazyModel) -> Params:
# try transformer naming first # try transformer naming first
@ -197,7 +187,6 @@ class Params:
return Params( return Params(
n_vocab = n_vocab, n_vocab = n_vocab,
n_embd = n_embd, n_embd = n_embd,
n_mult = n_mult,
n_layer = n_layer, n_layer = n_layer,
n_ctx = -1, n_ctx = -1,
n_ff = n_ff, n_ff = n_ff,
@ -225,8 +214,6 @@ class Params:
else: else:
f_rope_scale = None f_rope_scale = None
n_mult = Params.find_n_mult(n_ff, n_embd)
if "max_sequence_length" in config: if "max_sequence_length" in config:
n_ctx = config["max_sequence_length"] n_ctx = config["max_sequence_length"]
elif "max_position_embeddings" in config: elif "max_position_embeddings" in config:
@ -238,7 +225,6 @@ class Params:
return Params( return Params(
n_vocab = n_vocab, n_vocab = n_vocab,
n_embd = n_embd, n_embd = n_embd,
n_mult = n_mult,
n_layer = n_layer, n_layer = n_layer,
n_ctx = n_ctx, n_ctx = n_ctx,
n_ff = n_ff, n_ff = n_ff,
@ -250,7 +236,7 @@ class Params:
) )
# LLaMA v2 70B params.json # LLaMA v2 70B params.json
# {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1 # {"dim": 8192, "multiple_of": 4096, "ffn_dim_multiplier": 1.3, "n_heads": 64, "n_kv_heads": 8, "n_layers": 80, "norm_eps": 1e-05, "vocab_size": -1}
@staticmethod @staticmethod
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params: def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
config = json.load(open(config_path)) config = json.load(open(config_path))
@ -258,7 +244,6 @@ class Params:
n_vocab = config["vocab_size"] if "vocab_size" in config else -1 n_vocab = config["vocab_size"] if "vocab_size" in config else -1
n_embd = config["dim"] n_embd = config["dim"]
n_layer = config["n_layers"] n_layer = config["n_layers"]
n_mult = config["multiple_of"]
n_ff = -1 n_ff = -1
n_head = config["n_heads"] n_head = config["n_heads"]
n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head n_head_kv = config["n_kv_heads"] if "n_kv_heads" in config else n_head
@ -285,7 +270,6 @@ class Params:
return Params( return Params(
n_vocab = n_vocab, n_vocab = n_vocab,
n_embd = n_embd, n_embd = n_embd,
n_mult = n_mult,
n_layer = n_layer, n_layer = n_layer,
n_ctx = n_ctx, n_ctx = n_ctx,
n_ff = n_ff, n_ff = n_ff,

View file

@ -1,7 +1,3 @@
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "common.h" #include "common.h"
#include "llama.h" #include "llama.h"
#include "build-info.h" #include "build-info.h"

View file

@ -1,8 +1,3 @@
// Defines sigaction on msys:
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "embd-input.h" #include "embd-input.h"
#include <cassert> #include <cassert>

View file

@ -17,11 +17,6 @@ int main(int argc, char ** argv) {
params.embedding = true; params.embedding = true;
if (params.n_ctx > 2048) {
fprintf(stderr, "%s: warning: model might not support context sizes greater than 2048 tokens (%d specified);"
"expect poor results\n", __func__, params.n_ctx);
}
fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT); fprintf(stderr, "%s: build = %d (%s)\n", __func__, BUILD_NUMBER, BUILD_COMMIT);
if (params.seed == LLAMA_DEFAULT_SEED) { if (params.seed == LLAMA_DEFAULT_SEED) {
@ -47,6 +42,12 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
const int n_ctx_train = llama_n_ctx_train(ctx);
if (params.n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
}
// print system information // print system information
{ {
fprintf(stderr, "\n"); fprintf(stderr, "\n");

View file

@ -1,8 +1,3 @@
// Defines sigaction on msys:
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "common.h" #include "common.h"
#include "console.h" #include "console.h"
@ -187,8 +182,10 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
if (params.n_ctx > llama_n_ctx(ctx)) { const int n_ctx_train = llama_n_ctx_train(ctx);
LOG_TEE("%s: warning: base model only supports context sizes no greater than %d tokens (%d specified)\n", __func__, llama_n_ctx(ctx), params.n_ctx); if (params.n_ctx > n_ctx_train) {
LOG_TEE("%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
} else if (params.n_ctx < 8) { } else if (params.n_ctx < 8) {
LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__); LOG_TEE("%s: warning: minimum context size is 8, using minimum size.\n", __func__);
params.n_ctx = 8; params.n_ctx = 8;

View file

@ -693,9 +693,10 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
if (params.n_ctx > llama_n_ctx(ctx)) { const int n_ctx_train = llama_n_ctx_train(ctx);
fprintf(stderr, "%s: warning: model might not support context sizes greater than %d tokens (%d specified);" if (params.n_ctx > n_ctx_train) {
"expect poor results\n", __func__, llama_n_ctx(ctx), params.n_ctx); fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, params.n_ctx);
} }
// print system information // print system information

View file

@ -1,7 +1,3 @@
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "build-info.h" #include "build-info.h"
#include "common.h" #include "common.h"

View file

@ -1,7 +1,3 @@
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "build-info.h" #include "build-info.h"
#include "common.h" #include "common.h"

View file

@ -93,6 +93,10 @@
type = "app"; type = "app";
program = "${self.packages.${system}.default}/bin/quantize"; program = "${self.packages.${system}.default}/bin/quantize";
}; };
apps.train-text-from-scratch = {
type = "app";
program = "${self.packages.${system}.default}/bin/train-text-from-scratch";
};
apps.default = self.apps.${system}.llama; apps.default = self.apps.${system}.llama;
devShells.default = pkgs.mkShell { devShells.default = pkgs.mkShell {
buildInputs = [ llama-python ]; buildInputs = [ llama-python ];

View file

@ -1,8 +1,3 @@
// defines MAP_ANONYMOUS
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "ggml-alloc.h" #include "ggml-alloc.h"
#include "ggml.h" #include "ggml.h"
#include <assert.h> #include <assert.h>

View file

@ -144,8 +144,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
do { \ do { \
cudaError_t err_ = (err); \ cudaError_t err_ = (err); \
if (err_ != cudaSuccess) { \ if (err_ != cudaSuccess) { \
fprintf(stderr, "CUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \ int id; \
cudaGetDevice(&id); \
fprintf(stderr, "\nCUDA error %d at %s:%d: %s\n", err_, __FILE__, __LINE__, \
cudaGetErrorString(err_)); \ cudaGetErrorString(err_)); \
fprintf(stderr, "current device: %d\n", id); \
exit(1); \ exit(1); \
} \ } \
} while (0) } while (0)
@ -155,8 +158,11 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
do { \ do { \
cublasStatus_t err_ = (err); \ cublasStatus_t err_ = (err); \
if (err_ != CUBLAS_STATUS_SUCCESS) { \ if (err_ != CUBLAS_STATUS_SUCCESS) { \
int id; \
cudaGetDevice(&id); \
fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \ fprintf(stderr, "\ncuBLAS error %d at %s:%d: %s\n", \
err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \ err_, __FILE__, __LINE__, cublasGetStatusString(err_)); \
fprintf(stderr, "current device: %d\n", id); \
exit(1); \ exit(1); \
} \ } \
} while (0) } while (0)
@ -165,7 +171,10 @@ static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
do { \ do { \
cublasStatus_t err_ = (err); \ cublasStatus_t err_ = (err); \
if (err_ != CUBLAS_STATUS_SUCCESS) { \ if (err_ != CUBLAS_STATUS_SUCCESS) { \
int id; \
cudaGetDevice(&id); \
fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \ fprintf(stderr, "\ncuBLAS error %d at %s:%d\n", err_, __FILE__, __LINE__); \
fprintf(stderr, "current device: %d\n", id); \
exit(1); \ exit(1); \
} \ } \
} while (0) } while (0)
@ -4086,7 +4095,8 @@ static __global__ void rope_neox_f32(const float * x, float * dst, const int nco
dst[i + ncols/2] = x0*sin_theta + x1*cos_theta; dst[i + ncols/2] = x0*sin_theta + x1*cos_theta;
} }
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) { static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p0,
const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx) {
const int col = blockDim.x*blockIdx.x + threadIdx.x; const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int half_n_dims = ncols/4; const int half_n_dims = ncols/4;
@ -4098,8 +4108,9 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
const int i = row*ncols + col; const int i = row*ncols + col;
const float col_theta_scale = powf(theta_scale, col); const float col_theta_scale = powf(theta_scale, col);
const float p = p0 + p_delta*(row/p_delta_rows);
const float theta = p*col_theta_scale; const float theta = min(p, p_delta*(n_ctx - 2))*col_theta_scale;
const float sin_theta = sinf(theta); const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta); const float cos_theta = cosf(theta);
@ -4109,7 +4120,7 @@ static __global__ void rope_glm_f32(const float * x, float * dst, const int ncol
dst[i + 0] = x0*cos_theta - x1*sin_theta; dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta; dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
const float block_theta = block_p*col_theta_scale; const float block_theta = max(p - p_delta*(n_ctx - 2), 0.f)*col_theta_scale;
const float sin_block_theta = sinf(block_theta); const float sin_block_theta = sinf(block_theta);
const float cos_block_theta = cosf(block_theta); const float cos_block_theta = cosf(block_theta);
@ -4984,12 +4995,13 @@ static void rope_neox_f32_cuda(const float * x, float * dst, const int ncols, co
rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale); rope_neox_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale);
} }
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) { static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p0,
GGML_ASSERT(nrows % 4 == 0); const float p_delta, const int p_delta_rows, const float theta_scale, const int n_ctx, cudaStream_t stream) {
const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1); GGML_ASSERT(ncols % 4 == 0);
const int num_blocks_x = (ncols + 4*CUDA_ROPE_BLOCK_SIZE - 1) / (4*CUDA_ROPE_BLOCK_SIZE); const dim3 block_dims(CUDA_ROPE_BLOCK_SIZE/4, 1, 1);
const int num_blocks_x = (ncols + CUDA_ROPE_BLOCK_SIZE - 1) / CUDA_ROPE_BLOCK_SIZE;
const dim3 block_nums(num_blocks_x, nrows, 1); const dim3 block_nums(num_blocks_x, nrows, 1);
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale); rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p0, p_delta, p_delta_rows, theta_scale, n_ctx);
} }
static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, static void alibi_f32_cuda(const float * x, float * dst, const int ncols, const int nrows,
@ -5723,22 +5735,18 @@ inline void ggml_cuda_op_rope(
memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (int32_t *) dst->op_params + 5, sizeof(float));
const float theta_scale = powf(freq_base, -2.0f/n_dims); const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
const bool is_neox = mode & 2; const bool is_neox = mode & 2;
const bool is_glm = mode & 4; const bool is_glm = mode & 4;
// compute // compute
if (is_glm) { if (is_glm) {
const float p = (((mode & 1) == 0 ? n_past + i02 : i02)) * freq_scale; rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, n_ctx, cudaStream_main);
const float id_p = min(p, n_ctx - 2.f);
const float block_p = max(p - (n_ctx - 2.f), 0.f);
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
} else if (is_neox) { } else if (is_neox) {
GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet"); GGML_ASSERT(ne00 == n_dims && "ne00 != n_dims is not implemented for CUDA yet");
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main); rope_neox_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
} else { } else {
const float p0 = (((mode & 1) == 0 ? n_past : 0)) * freq_scale;
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main); rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p0, freq_scale, ne01, theta_scale, cudaStream_main);
} }
@ -6400,10 +6408,7 @@ void ggml_cuda_rope(const ggml_tensor * src0, const ggml_tensor * src1, ggml_ten
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented GGML_ASSERT(ggml_is_contiguous(src0)); // TODO: this restriction is temporary until non-cont support is implemented
const int mode = ((int32_t *) dst->op_params)[2]; ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, true);
const bool is_glm = mode & 4;
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rope, true, !is_glm); // flatten support not implemented for glm
} }
void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { void ggml_cuda_alibi(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {

View file

@ -63,7 +63,9 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(relu); GGML_METAL_DECL_KERNEL(relu);
GGML_METAL_DECL_KERNEL(gelu); GGML_METAL_DECL_KERNEL(gelu);
GGML_METAL_DECL_KERNEL(soft_max); GGML_METAL_DECL_KERNEL(soft_max);
GGML_METAL_DECL_KERNEL(soft_max_4);
GGML_METAL_DECL_KERNEL(diag_mask_inf); GGML_METAL_DECL_KERNEL(diag_mask_inf);
GGML_METAL_DECL_KERNEL(diag_mask_inf_8);
GGML_METAL_DECL_KERNEL(get_rows_f16); GGML_METAL_DECL_KERNEL(get_rows_f16);
GGML_METAL_DECL_KERNEL(get_rows_q4_0); GGML_METAL_DECL_KERNEL(get_rows_q4_0);
GGML_METAL_DECL_KERNEL(get_rows_q4_1); GGML_METAL_DECL_KERNEL(get_rows_q4_1);
@ -77,6 +79,7 @@ struct ggml_metal_context {
GGML_METAL_DECL_KERNEL(norm); GGML_METAL_DECL_KERNEL(norm);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row); GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DECL_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DECL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32); GGML_METAL_DECL_KERNEL(mul_mat_q8_0_f32);
@ -117,14 +120,17 @@ static NSString * const msl_library_source = @"see metal.metal";
struct ggml_metal_context * ggml_metal_init(int n_cb) { struct ggml_metal_context * ggml_metal_init(int n_cb) {
metal_printf("%s: allocating\n", __func__); metal_printf("%s: allocating\n", __func__);
// Show all the Metal device instances in the system
NSArray * devices = MTLCopyAllDevices();
id <MTLDevice> device; id <MTLDevice> device;
NSString * s; NSString * s;
#if TARGET_OS_OSX
// Show all the Metal device instances in the system
NSArray * devices = MTLCopyAllDevices();
for (device in devices) { for (device in devices) {
s = [device name]; s = [device name];
metal_printf("%s: found device: %s\n", __func__, [s UTF8String]); metal_printf("%s: found device: %s\n", __func__, [s UTF8String]);
} }
#endif
// Pick and show default Metal device // Pick and show default Metal device
device = MTLCreateSystemDefaultDevice(); device = MTLCreateSystemDefaultDevice();
@ -141,12 +147,20 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT); ctx->d_queue = dispatch_queue_create("llama.cpp", DISPATCH_QUEUE_CONCURRENT);
#if 0 #ifdef GGML_SWIFT
// compile from source string and show compile log // load the default.metallib file
{ {
NSError * error = nil; NSError * error = nil;
ctx->library = [ctx->device newLibraryWithSource:msl_library_source options:nil error:&error]; NSBundle * bundle = [NSBundle bundleForClass:[GGMLMetalClass class]];
NSString * llamaBundlePath = [bundle pathForResource:@"llama_llama" ofType:@"bundle"];
NSBundle * llamaBundle = [NSBundle bundleWithPath:llamaBundlePath];
NSString * libPath = [llamaBundle pathForResource:@"default" ofType:@"metallib"];
NSURL * libURL = [NSURL fileURLWithPath:libPath];
// Load the metallib file into a Metal library
ctx->library = [ctx->device newLibraryWithURL:libURL error:&error];
if (error) { if (error) {
metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]); metal_printf("%s: error: %s\n", __func__, [[error description] UTF8String]);
return NULL; return NULL;
@ -207,7 +221,9 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(relu); GGML_METAL_ADD_KERNEL(relu);
GGML_METAL_ADD_KERNEL(gelu); GGML_METAL_ADD_KERNEL(gelu);
GGML_METAL_ADD_KERNEL(soft_max); GGML_METAL_ADD_KERNEL(soft_max);
GGML_METAL_ADD_KERNEL(soft_max_4);
GGML_METAL_ADD_KERNEL(diag_mask_inf); GGML_METAL_ADD_KERNEL(diag_mask_inf);
GGML_METAL_ADD_KERNEL(diag_mask_inf_8);
GGML_METAL_ADD_KERNEL(get_rows_f16); GGML_METAL_ADD_KERNEL(get_rows_f16);
GGML_METAL_ADD_KERNEL(get_rows_q4_0); GGML_METAL_ADD_KERNEL(get_rows_q4_0);
GGML_METAL_ADD_KERNEL(get_rows_q4_1); GGML_METAL_ADD_KERNEL(get_rows_q4_1);
@ -221,6 +237,7 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
GGML_METAL_ADD_KERNEL(norm); GGML_METAL_ADD_KERNEL(norm);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row); GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_ADD_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32); GGML_METAL_ADD_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32); GGML_METAL_ADD_KERNEL(mul_mat_q8_0_f32);
@ -247,13 +264,15 @@ struct ggml_metal_context * ggml_metal_init(int n_cb) {
#undef GGML_METAL_ADD_KERNEL #undef GGML_METAL_ADD_KERNEL
} }
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false"); metal_printf("%s: hasUnifiedMemory = %s\n", __func__, ctx->device.hasUnifiedMemory ? "true" : "false");
#if TARGET_OS_OSX
metal_printf("%s: recommendedMaxWorkingSetSize = %8.2f MB\n", __func__, ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
if (ctx->device.maxTransferRate != 0) { if (ctx->device.maxTransferRate != 0) {
metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0); metal_printf("%s: maxTransferRate = %8.2f MB/s\n", __func__, ctx->device.maxTransferRate / 1024.0 / 1024.0);
} else { } else {
metal_printf("%s: maxTransferRate = built-in GPU\n", __func__); metal_printf("%s: maxTransferRate = built-in GPU\n", __func__);
} }
#endif
return ctx; return ctx;
} }
@ -273,7 +292,8 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(relu); GGML_METAL_DEL_KERNEL(relu);
GGML_METAL_DEL_KERNEL(gelu); GGML_METAL_DEL_KERNEL(gelu);
GGML_METAL_DEL_KERNEL(soft_max); GGML_METAL_DEL_KERNEL(soft_max);
GGML_METAL_DEL_KERNEL(diag_mask_inf); GGML_METAL_DEL_KERNEL(soft_max_4);
GGML_METAL_DEL_KERNEL(diag_mask_inf_8);
GGML_METAL_DEL_KERNEL(get_rows_f16); GGML_METAL_DEL_KERNEL(get_rows_f16);
GGML_METAL_DEL_KERNEL(get_rows_q4_0); GGML_METAL_DEL_KERNEL(get_rows_q4_0);
GGML_METAL_DEL_KERNEL(get_rows_q4_1); GGML_METAL_DEL_KERNEL(get_rows_q4_1);
@ -287,6 +307,7 @@ void ggml_metal_free(struct ggml_metal_context * ctx) {
GGML_METAL_DEL_KERNEL(norm); GGML_METAL_DEL_KERNEL(norm);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32); GGML_METAL_DEL_KERNEL(mul_mat_f16_f32);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row); GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_1row);
GGML_METAL_DEL_KERNEL(mul_mat_f16_f32_l4);
GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32); GGML_METAL_DEL_KERNEL(mul_mat_q4_0_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32); GGML_METAL_DEL_KERNEL(mul_mat_q4_1_f32);
GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32); GGML_METAL_DEL_KERNEL(mul_mat_q8_0_f32);
@ -454,6 +475,7 @@ bool ggml_metal_add_buffer(
} }
} }
#if TARGET_OS_OSX
metal_printf(", (%8.2f / %8.2f)", metal_printf(", (%8.2f / %8.2f)",
ctx->device.currentAllocatedSize / 1024.0 / 1024.0, ctx->device.currentAllocatedSize / 1024.0 / 1024.0,
ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0); ctx->device.recommendedMaxWorkingSetSize / 1024.0 / 1024.0);
@ -463,6 +485,9 @@ bool ggml_metal_add_buffer(
} else { } else {
metal_printf("\n"); metal_printf("\n");
} }
#else
metal_printf(", (%8.2f)\n", ctx->device.currentAllocatedSize / 1024.0 / 1024.0);
#endif
} }
return true; return true;
@ -750,7 +775,7 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&scale length:sizeof(scale) atIndex:2]; [encoder setBytes:&scale length:sizeof(scale) atIndex:2];
const int64_t n = ggml_nelements(dst); const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -762,7 +787,7 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst); const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -782,7 +807,7 @@ void ggml_metal_graph_compute(
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
const int64_t n = ggml_nelements(dst); const int64_t n = ggml_nelements(dst)/4;
[encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(n, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
} break; } break;
@ -796,13 +821,16 @@ void ggml_metal_graph_compute(
{ {
const int nth = 32; const int nth = 32;
if (ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_soft_max_4];
} else {
[encoder setComputePipelineState:ctx->pipeline_soft_max]; [encoder setComputePipelineState:ctx->pipeline_soft_max];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4]; [encoder setBytes:&ne02 length:sizeof(ne02) atIndex:4];
[encoder setThreadgroupMemoryLength:nth*sizeof(float) atIndex:0];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
} break; } break;
@ -810,14 +838,23 @@ void ggml_metal_graph_compute(
{ {
const int n_past = ((int32_t *)(dst->op_params))[0]; const int n_past = ((int32_t *)(dst->op_params))[0];
if (ne00%8 == 0) {
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf_8];
} else {
[encoder setComputePipelineState:ctx->pipeline_diag_mask_inf]; [encoder setComputePipelineState:ctx->pipeline_diag_mask_inf];
}
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:id_src0 offset:offs_src0 atIndex:0];
[encoder setBuffer:id_dst offset:offs_dst atIndex:1]; [encoder setBuffer:id_dst offset:offs_dst atIndex:1];
[encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2]; [encoder setBytes:&ne00 length:sizeof(ne00) atIndex:2];
[encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3]; [encoder setBytes:&ne01 length:sizeof(ne01) atIndex:3];
[encoder setBytes:&n_past length:sizeof(int) atIndex:4]; [encoder setBytes:&n_past length:sizeof(int) atIndex:4];
if (ne00%8 == 0) {
[encoder dispatchThreadgroups:MTLSizeMake(ne00*ne01*ne02/8, 1, 1) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
}
else {
[encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne00, ne01, ne02) threadsPerThreadgroup:MTLSizeMake(1, 1, 1)];
}
} break; } break;
case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT:
{ {
@ -864,6 +901,7 @@ void ggml_metal_graph_compute(
} else { } else {
int nth0 = 32; int nth0 = 32;
int nth1 = 1; int nth1 = 1;
int nrows = 1;
// use custom matrix x vector kernel // use custom matrix x vector kernel
switch (src0t) { switch (src0t) {
@ -873,8 +911,12 @@ void ggml_metal_graph_compute(
nth1 = 1; nth1 = 1;
if (ne11 * ne12 < 4) { if (ne11 * ne12 < 4) {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row]; [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_1row];
} else if (ne00 >= 128 && ne01 >= 8 && ne00%4 == 0) {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32_l4];
nrows = ne11;
} else { } else {
[encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32]; [encoder setComputePipelineState:ctx->pipeline_mul_mat_f16_f32];
nrows = 4;
} }
} break; } break;
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
@ -995,7 +1037,7 @@ void ggml_metal_graph_compute(
else if (src0t == GGML_TYPE_Q6_K) { else if (src0t == GGML_TYPE_Q6_K) {
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 1)/2, ne11, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} else { } else {
int64_t ny = (ne11 + 3)/4; int64_t ny = (ne11 + nrows - 1)/nrows;
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)]; [encoder dispatchThreadgroups:MTLSizeMake(ne01, ny, ne12) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
} }

View file

@ -63,18 +63,18 @@ kernel void kernel_mul_row(
} }
kernel void kernel_scale( kernel void kernel_scale(
device const float * src0, device const float4 * src0,
device float * dst, device float4 * dst,
constant float & scale, constant float & scale,
uint tpig[[thread_position_in_grid]]) { uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale; dst[tpig] = src0[tpig] * scale;
} }
kernel void kernel_silu( kernel void kernel_silu(
device const float * src0, device const float4 * src0,
device float * dst, device float4 * dst,
uint tpig[[thread_position_in_grid]]) { uint tpig[[thread_position_in_grid]]) {
float x = src0[tpig]; device const float4 & x = src0[tpig];
dst[tpig] = x / (1.0f + exp(-x)); dst[tpig] = x / (1.0f + exp(-x));
} }
@ -89,10 +89,10 @@ constant float GELU_COEF_A = 0.044715f;
constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
kernel void kernel_gelu( kernel void kernel_gelu(
device const float * src0, device const float4 * src0,
device float * dst, device float4 * dst,
uint tpig[[thread_position_in_grid]]) { uint tpig[[thread_position_in_grid]]) {
float x = src0[tpig]; device const float4 & x = src0[tpig];
// BEWARE !!! // BEWARE !!!
// Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs!
@ -107,7 +107,6 @@ kernel void kernel_soft_max(
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne01, constant int64_t & ne01,
constant int64_t & ne02, constant int64_t & ne02,
threadgroup float * buf [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]], uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]], uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) { uint3 ntg[[threads_per_threadgroup]]) {
@ -119,64 +118,70 @@ kernel void kernel_soft_max(
device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00; device float * pdst = dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00;
// parallel max // parallel max
buf[tpitg[0]] = -INFINITY; float lmax = psrc0[tpitg[0]];
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { for (int i00 = tpitg[0] + ntg[0]; i00 < ne00; i00 += ntg[0]) {
buf[tpitg[0]] = MAX(buf[tpitg[0]], psrc0[i00]); lmax = MAX(lmax, psrc0[i00]);
} }
const float max = simd_max(lmax);
// reduce
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg[0]/2; i > 0; i /= 2) {
if (tpitg[0] < i) {
buf[tpitg[0]] = MAX(buf[tpitg[0]], buf[tpitg[0] + i]);
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
//// broadcast - not needed. There is a threadgroup barrier above in the last iteration of
// the loop, and when that is done, buf[0] has the correct (synchronized) value
//if (tpitg[0] == 0) {
// buf[0] = buf[0];
//}
//threadgroup_barrier(mem_flags::mem_threadgroup);
const float max = buf[0];
// parallel sum // parallel sum
buf[tpitg[0]] = 0.0f; float lsum = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
const float exp_psrc0 = exp(psrc0[i00] - max); const float exp_psrc0 = exp(psrc0[i00] - max);
buf[tpitg[0]] += exp_psrc0; lsum += exp_psrc0;
// Remember the result of exp here. exp is expensive, so we really do not // Remember the result of exp here. exp is expensive, so we really do not
// whish to compute it twice. // whish to compute it twice.
pdst[i00] = exp_psrc0; pdst[i00] = exp_psrc0;
} }
// reduce const float sum = simd_sum(lsum);
threadgroup_barrier(mem_flags::mem_threadgroup);
for (uint i = ntg[0]/2; i > 0; i /= 2) {
if (tpitg[0] < i) {
buf[tpitg[0]] += buf[tpitg[0] + i];
}
threadgroup_barrier(mem_flags::mem_threadgroup);
}
// broadcast - not needed, see above
//// broadcast
//if (tpitg[0] == 0) {
// buf[0] = buf[0];
//}
//threadgroup_barrier(mem_flags::mem_threadgroup);
const float sum = buf[0];
for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) { for (int i00 = tpitg[0]; i00 < ne00; i00 += ntg[0]) {
pdst[i00] /= sum; pdst[i00] /= sum;
} }
} }
kernel void kernel_soft_max_4(
device const float * src0,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
uint3 tgpig[[threadgroup_position_in_grid]],
uint3 tpitg[[thread_position_in_threadgroup]],
uint3 ntg[[threads_per_threadgroup]]) {
const int64_t i03 = tgpig[2];
const int64_t i02 = tgpig[1];
const int64_t i01 = tgpig[0];
device const float4 * psrc4 = (device const float4 *)(src0 + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device float4 * pdst4 = (device float4 *)(dst + i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
// parallel max
float4 lmax4 = psrc4[tpitg[0]];
for (int i00 = tpitg[0] + ntg[0]; i00 < ne00/4; i00 += ntg[0]) {
lmax4 = fmax(lmax4, psrc4[i00]);
}
float lmax = MAX(MAX(lmax4[0], lmax4[1]), MAX(lmax4[2], lmax4[3]));
const float max = simd_max(lmax);
// parallel sum
float4 lsum4 = 0.0f;
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
const float4 exp_psrc4 = exp(psrc4[i00] - max);
lsum4 += exp_psrc4;
pdst4[i00] = exp_psrc4;
}
float lsum = lsum4[0] + lsum4[1] + lsum4[2] + lsum4[3];
const float sum = simd_sum(lsum);
for (int i00 = tpitg[0]; i00 < ne00/4; i00 += ntg[0]) {
pdst4[i00] /= sum;
}
}
kernel void kernel_diag_mask_inf( kernel void kernel_diag_mask_inf(
device const float * src0, device const float * src0,
device float * dst, device float * dst,
@ -195,6 +200,33 @@ kernel void kernel_diag_mask_inf(
} }
} }
kernel void kernel_diag_mask_inf_8(
device const float4 * src0,
device float4 * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int & n_past,
uint3 tpig[[thread_position_in_grid]]) {
const int64_t i = 2*tpig[0];
dst[i+0] = src0[i+0];
dst[i+1] = src0[i+1];
int64_t i4 = 4*i;
const int64_t i02 = i4/(ne00*ne01); i4 -= i02*ne00*ne01;
const int64_t i01 = i4/(ne00); i4 -= i01*ne00;
const int64_t i00 = i4;
for (int k = 3; k >= 0; --k) {
if (i00 + 4 + k <= n_past + i01) {
break;
}
dst[i+1][k] = -INFINITY;
if (i00 + k > n_past + i01) {
dst[i][k] = -INFINITY;
}
}
}
kernel void kernel_norm( kernel void kernel_norm(
device const void * src0, device const void * src0,
device float * dst, device float * dst,
@ -616,6 +648,49 @@ kernel void kernel_mul_mat_f16_f32(
} }
} }
// Assumes row size (ne00) is a multiple of 4
kernel void kernel_mul_mat_f16_f32_l4(
device const char * src0,
device const char * src1,
device float * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant uint64_t & nb00,
constant uint64_t & nb01,
constant uint64_t & nb02,
constant int64_t & ne10,
constant int64_t & ne11,
constant int64_t & ne12,
constant uint64_t & nb10,
constant uint64_t & nb11,
constant uint64_t & nb12,
constant int64_t & ne0,
constant int64_t & ne1,
uint3 tgpig[[threadgroup_position_in_grid]],
uint tiisg[[thread_index_in_simdgroup]]) {
const int nrows = ne11;
const int64_t r0 = tgpig.x;
const int64_t im = tgpig.z;
device const half4 * x4 = (device const half4 *) (src0 + r0*nb01 + im/(ne12/ne02)*nb02);
for (int r1 = 0; r1 < nrows; ++r1) {
device const float4 * y4 = (device const float4 *) (src1 + r1*nb11 + im*nb12);
float sumf = 0;
for (int i = tiisg; i < ne00/4; i += 32) {
for (int k = 0; k < 4; ++k) sumf += (float) x4[i][k] * y4[i][k];
}
float all_sum = simd_sum(sumf);
if (tiisg == 0) {
dst[im*ne1*ne0 + r1*ne0 + r0] = all_sum;
}
}
}
kernel void kernel_alibi_f32( kernel void kernel_alibi_f32(
device const float * src0, device const float * src0,
device float * dst, device float * dst,
@ -1123,31 +1198,40 @@ kernel void kernel_mul_mat_q3_K_f32(
device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0; device const block_q3_K * x = (device const block_q3_K *) src0 + first_row*nb + offset0;
device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1; device const float * yy = (device const float *) src1 + r1*ne10 + r2*ne00*ne1;
float yl[16]; float yl[32];
const uint16_t kmask1 = 0x0303; const uint16_t kmask1 = 0x3030;
const uint16_t kmask2 = 0x0f0f; const uint16_t kmask2 = 0x0f0f;
const int tid = tiisg/2; const int tid = tiisg/4;
const int ix = tiisg%2; const int ix = tiisg%4;
const int ip = tid/8; // 0 or 1 const int ip = tid/4; // 0 or 1
const int il = tid/2 - 4*ip; // 0...3 const int il = 2*((tid%4)/2); // 0 or 2
const int ir = tid%2; const int ir = tid%2;
const int n = 8; const int n = 8;
const int l0 = n*ir; const int l0 = n*ir;
const uint16_t m1 = 1 << (4*ip + il); // One would think that the Metal compiler would figure out that ip and il can only have
const uint16_t m2 = m1 << 8; // 4 possible states, and optimize accordingly. Well, no. It needs help, and we do it
// with these two tales.
//
// Possible masks for the high bit
const ushort4 mm[4] = {{0x0001, 0x0100, 0x0002, 0x0200}, // ip = 0, il = 0
{0x0004, 0x0400, 0x0008, 0x0800}, // ip = 0, il = 2
{0x0010, 0x1000, 0x0020, 0x2000}, // ip = 1, il = 0
{0x0040, 0x4000, 0x0080, 0x8000}}; // ip = 1, il = 2
// Possible masks for the low 2 bits
const int4 qm[2] = {{0x0003, 0x0300, 0x000c, 0x0c00}, {0x0030, 0x3000, 0x00c0, 0xc000}};
const ushort4 hm = mm[2*ip + il/2];
const int shift = 2*il; const int shift = 2*il;
const uint16_t qm1 = 0x0003 << shift; const float v1 = il == 0 ? 4.f : 64.f;
const uint16_t qm2 = 0x0300 << shift; const float v2 = 4.f * v1;
const int32_t v1 = 4 << shift;
const int32_t v2 = 1024 << shift;
const uint16_t s_shift1 = 4*ip; const uint16_t s_shift1 = 4*ip;
const uint16_t s_shift2 = s_shift1 + 2*(il/2); const uint16_t s_shift2 = s_shift1 + il;
const int ik = 4 + (il%2);
const int q_offset = 32*ip + l0; const int q_offset = 32*ip + l0;
const int y_offset = 128*ip + 32*il + l0; const int y_offset = 128*ip + 32*il + l0;
@ -1156,12 +1240,19 @@ kernel void kernel_mul_mat_q3_K_f32(
device const float * y1 = yy + ix*QK_K + y_offset; device const float * y1 = yy + ix*QK_K + y_offset;
float sumf1[2] = {0.f}, sumf2[2] = {0.f}; uint32_t scales32, aux32;
for (int i = ix; i < nb; i += 2) { thread uint16_t * scales16 = (thread uint16_t *)&scales32;
thread const int8_t * scales = (thread const int8_t *)&scales32;
float sumf1[2] = {0.f};
float sumf2[2] = {0.f};
for (int i = ix; i < nb; i += 4) {
for (int l = 0; l < 8; ++l) { for (int l = 0; l < 8; ++l) {
yl[l+ 0] = y1[l+ 0]; yl[l+ 0] = y1[l+ 0];
yl[l+ 8] = y1[l+16]; yl[l+ 8] = y1[l+16];
yl[l+16] = y1[l+32];
yl[l+24] = y1[l+48];
} }
device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset); device const uint16_t * q = (device const uint16_t *)(x[i].qs + q_offset);
@ -1172,27 +1263,43 @@ kernel void kernel_mul_mat_q3_K_f32(
for (int row = 0; row < 2; ++row) { for (int row = 0; row < 2; ++row) {
const float d_all = (float)dh[0]; const float d_all = (float)dh[0];
const char2 scales = as_type<char2>((uint16_t)(((a[il] >> s_shift1) & kmask2) | (((a[ik] >> s_shift2) & kmask1) << 4)));
float s1 = 0, s2 = 0; scales16[0] = a[4];
for (int l = 0; l < n; l += 2) { scales16[1] = a[5];
const uint16_t qs = q[l/2]; aux32 = ((scales32 >> s_shift2) << 4) & 0x30303030;
s1 += yl[l+0] * ((int32_t)(qs & qm1) - ((h[l/2] & m1) ? 0 : v1)); scales16[0] = a[il+0];
s2 += yl[l+1] * ((int32_t)(qs & qm2) - ((h[l/2] & m2) ? 0 : v2)); scales16[1] = a[il+1];
} scales32 = ((scales32 >> s_shift1) & 0x0f0f0f0f) | aux32;
float d = d_all * (s1 + 1.f/256.f * s2);
sumf1[row] += d * scales[0];
sumf2[row] += d;
s1 = s2 = 0; float s1 = 0, s2 = 0, s3 = 0, s4 = 0, s5 = 0, s6 = 0;
for (int l = 0; l < n; l += 2) { for (int l = 0; l < n; l += 2) {
const uint16_t qs = q[l/2+8]; const int32_t qs = q[l/2];
s1 += yl[l+8] * ((int32_t)(qs & qm1) - ((h[l/2+8] & m1) ? 0 : v1)); s1 += yl[l+0] * (qs & qm[il/2][0]);
s2 += yl[l+9] * ((int32_t)(qs & qm2) - ((h[l/2+8] & m2) ? 0 : v2)); s2 += yl[l+1] * (qs & qm[il/2][1]);
s3 += ((h[l/2] & hm[0]) ? 0.f : yl[l+0]) + ((h[l/2] & hm[1]) ? 0.f : yl[l+1]);
s4 += yl[l+16] * (qs & qm[il/2][2]);
s5 += yl[l+17] * (qs & qm[il/2][3]);
s6 += ((h[l/2] & hm[2]) ? 0.f : yl[l+16]) + ((h[l/2] & hm[3]) ? 0.f : yl[l+17]);
} }
d = d_all * (s1 + 1.f/256.f * s2); float d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
sumf1[row] += d * scales[1]; float d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
sumf2[row] += d; sumf1[row] += d1 * (scales[0] - 32);
sumf2[row] += d2 * (scales[2] - 32);
s1 = s2 = s3 = s4 = s5 = s6 = 0;
for (int l = 0; l < n; l += 2) {
const int32_t qs = q[l/2+8];
s1 += yl[l+8] * (qs & qm[il/2][0]);
s2 += yl[l+9] * (qs & qm[il/2][1]);
s3 += ((h[l/2+8] & hm[0]) ? 0.f : yl[l+8]) + ((h[l/2+8] & hm[1]) ? 0.f : yl[l+9]);
s4 += yl[l+24] * (qs & qm[il/2][2]);
s5 += yl[l+25] * (qs & qm[il/2][3]);
s6 += ((h[l/2+8] & hm[2]) ? 0.f : yl[l+24]) + ((h[l/2+8] & hm[3]) ? 0.f : yl[l+25]);
}
d1 = d_all * (s1 + 1.f/256.f * s2 - s3*v1);
d2 = d_all * (s4 + 1.f/256.f * s5 - s6*v2);
sumf1[row] += d1 * (scales[1] - 32);
sumf2[row] += d2 * (scales[3] - 32);
q += step; q += step;
h += step; h += step;
@ -1201,17 +1308,20 @@ kernel void kernel_mul_mat_q3_K_f32(
} }
y1 += 2 * QK_K; y1 += 4 * QK_K;
} }
for (int row = 0; row < 2; ++row) { for (int row = 0; row < 2; ++row) {
const float sumf = (sumf1[row] - 32.f*sumf2[row]) / (1 << shift); const float sumf = (sumf1[row] + 0.25f * sumf2[row]) / (1 << shift);
const float tot = simd_sum(sumf); sumf1[row] = simd_sum(sumf);
}
if (tiisg == 0) { if (tiisg == 0) {
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = tot; for (int row = 0; row < 2; ++row) {
dst[r1*ne0 + r2*ne0*ne1 + first_row + row] = sumf1[row];
} }
} }
} }
#else #else
kernel void kernel_mul_mat_q3_K_f32( kernel void kernel_mul_mat_q3_K_f32(
@ -1564,17 +1674,25 @@ kernel void kernel_mul_mat_q5_K_f32(
sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2); sc16[2] = ((a[4] >> 0) & kmask2) | ((a[0] & kmask3) >> 2);
sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2); sc16[3] = ((a[4] >> 4) & kmask2) | ((a[2] & kmask3) >> 2);
float4 acc = {0.f, 0.f, 0.f, 0.f}; float4 acc1 = {0.f};
float4 acc2 = {0.f};
for (int l = 0; l < n; ++l) { for (int l = 0; l < n; ++l) {
uint8_t h = qh[l]; uint8_t h = qh[l];
acc[0] += yl[l+0] * ((uint16_t)(q1[l] & 0x0F) + (h & hm1 ? 16 : 0)); acc1[0] += yl[l+0] * (q1[l] & 0x0F);
acc[1] += yl[l+8] * ((uint16_t)(q1[l] & 0xF0) + (h & hm2 ? 256 : 0)); acc1[1] += yl[l+8] * (q1[l] & 0xF0);
acc[2] += yh[l+0] * ((uint16_t)(q2[l] & 0x0F) + (h & hm3 ? 16 : 0)); acc1[2] += yh[l+0] * (q2[l] & 0x0F);
acc[3] += yh[l+8] * ((uint16_t)(q2[l] & 0xF0) + (h & hm4 ? 256 : 0)); acc1[3] += yh[l+8] * (q2[l] & 0xF0);
acc2[0] += h & hm1 ? yl[l+0] : 0.f;
acc2[1] += h & hm2 ? yl[l+8] : 0.f;
acc2[2] += h & hm3 ? yh[l+0] : 0.f;
acc2[3] += h & hm4 ? yh[l+8] : 0.f;
} }
const float dall = dh[0]; const float dall = dh[0];
const float dmin = dh[1]; const float dmin = dh[1];
sumf[row] += dall * (acc[0] * sc8[0] + acc[1] * sc8[1] * 1.f/16.f + acc[2] * sc8[4] + acc[3] * sc8[5] * 1.f/16.f) - sumf[row] += dall * (sc8[0] * (acc1[0] + 16.f*acc2[0]) +
sc8[1] * (acc1[1]/16.f + 16.f*acc2[1]) +
sc8[4] * (acc1[2] + 16.f*acc2[2]) +
sc8[5] * (acc1[3]/16.f + 16.f*acc2[3])) -
dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]); dmin * (sumy[0] * sc8[2] + sumy[1] * sc8[3] + sumy[2] * sc8[6] + sumy[3] * sc8[7]);
q1 += step; q1 += step;
@ -1757,29 +1875,34 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
template <typename type4x4> template <typename type4x4>
void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) { void dequantize_q4_0(device const block_q4_0 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 1); device const uint16_t * qs = ((device const uint16_t *)xb + 1);
const half d = il ? (xb->d / 16.h) : xb->d; const float d1 = il ? (xb->d / 16.h) : xb->d;
const half m = il ? ( -8.h * 16.h) : -8.h; const float d2 = d1 / 256.f;
const float md = -8.h * xb->d;
const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = il ? 0xF000 : 0x0F00; const ushort mask1 = mask0 << 8;
for (int i=0;i<8;i++) { for (int i=0;i<8;i++) {
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) + m) * d; reg[i/2][2*(i%2)+0] = d1 * (qs[i] & mask0) + md;
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) + m) * d; reg[i/2][2*(i%2)+1] = d2 * (qs[i] & mask1) + md;
} }
} }
template <typename type4x4> template <typename type4x4>
void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) { void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg) {
device const uint16_t * qs = ((device const uint16_t *)xb + 2); device const uint16_t * qs = ((device const uint16_t *)xb + 2);
const half d = il ? (xb->d / 16.h) : xb->d; const float d1 = il ? (xb->d / 16.h) : xb->d;
const half m = xb->m; const float d2 = d1 / 256.f;
const float m = xb->m;
const ushort mask0 = il ? 0x00F0 : 0x000F; const ushort mask0 = il ? 0x00F0 : 0x000F;
const ushort mask1 = il ? 0xF000 : 0x0F00; const ushort mask1 = mask0 << 8;
for (int i=0;i<8;i++) { for (int i=0;i<8;i++) {
reg[i/2][2*(i%2)] = (((qs[i] & mask0) ) * d) + m; reg[i/2][2*(i%2)+0] = ((qs[i] & mask0) * d1) + m;
reg[i/2][2*(i%2)+1] = (((qs[i] & mask1) >> 8) * d) + m; reg[i/2][2*(i%2)+1] = ((qs[i] & mask1) * d2) + m;
} }
} }
@ -1815,7 +1938,7 @@ void dequantize_q2_K(device const block_q2_K *xb, short il, thread type4x4 & reg
template <typename type4x4> template <typename type4x4>
void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) { void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg) {
const float d_all = (float)(xb->d); const half d_all = xb->d;
device const uint8_t * q = (device const uint8_t *)xb->qs; device const uint8_t * q = (device const uint8_t *)xb->qs;
device const uint8_t * h = (device const uint8_t *)xb->hmask; device const uint8_t * h = (device const uint8_t *)xb->hmask;
device const int8_t * scales = (device const int8_t *)xb->scales; device const int8_t * scales = (device const int8_t *)xb->scales;
@ -1828,17 +1951,20 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
((il/4)>0 ? 12 : 3); ((il/4)>0 ? 12 : 3);
uint16_t kmask2 = il/8 ? 0xF0 : 0x0F; uint16_t kmask2 = il/8 ? 0xF0 : 0x0F;
uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4]; uint16_t scale_2 = scales[il%8], scale_1 = scales[8 + il%4];
int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2) : \ int16_t dl_int = (il/4)&1 ? (scale_2&kmask2) | ((scale_1&kmask1) << 2)
(scale_2&kmask2) | ((scale_1&kmask1) << 4); : (scale_2&kmask2) | ((scale_1&kmask1) << 4);
float dl = il<8 ? d_all * (dl_int - 32.f) : d_all * (dl_int / 16.f - 32.f); half dl = il<8 ? d_all * (dl_int - 32.h) : d_all * (dl_int / 16.h - 32.h);
const half ml = 4.h * dl;
il = (il/2)%4; il = (il/2) & 3;
float coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h); const half coef = il>1 ? (il>2 ? 1/64.h : 1/16.h) : (il>0 ? 1/4.h : 1.h);
uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); const uint8_t mask = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
dl *= coef;
for (int i = 0; i < 16; ++i) { for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = coef * dl * ((q[i] & mask) - ((h[i] & m) ? 0 : 4.f/coef)); reg[i/4][i%4] = dl * (q[i] & mask) - (h[i] & m ? 0 : ml);
} }
#else #else
float kcoef = il&1 ? 1.f/16.f : 1.f; float kcoef = il&1 ? 1.f/16.f : 1.f;
uint16_t kmask = il&1 ? 0xF0 : 0x0F; uint16_t kmask = il&1 ? 0xF0 : 0x0F;
@ -1852,19 +1978,24 @@ void dequantize_q3_K(device const block_q3_K *xb, short il, thread type4x4 & reg
#endif #endif
} }
static inline uchar2 get_scale_min_k4_just2(int j, int k, device const uchar * q) {
return j < 4 ? uchar2{uchar(q[j+0+k] & 63), uchar(q[j+4+k] & 63)}
: uchar2{uchar((q[j+4+k] & 0xF) | ((q[j-4+k] & 0xc0) >> 2)), uchar((q[j+4+k] >> 4) | ((q[j-0+k] & 0xc0) >> 2))};
}
template <typename type4x4> template <typename type4x4>
void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) { void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg) {
device const uint8_t * q = xb->qs; device const uchar * q = xb->qs;
#if QK_K == 256 #if QK_K == 256
const float d = (float)(xb->d);
const float min = (float)(xb->dmin);
short is = (il/4) * 2; short is = (il/4) * 2;
q = q + (il/4) * 32 + 16 * (il&1); q = q + (il/4) * 32 + 16 * (il&1);
il = il%4; il = il & 3;
const uchar4 sc = get_scale_min_k4(is, xb->scales); const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h; const half d = il < 2 ? xb->d : xb->d / 16.h;
const float ml = il<2 ? min * sc[1] : min * sc[3]; const half min = xb->dmin;
const half dl = d * sc[0];
const half ml = min * sc[1];
#else #else
q = q + 16 * (il&1); q = q + 16 * (il&1);
device const uint8_t * s = xb->scales; device const uint8_t * s = xb->scales;
@ -1877,6 +2008,7 @@ void dequantize_q4_K(device const block_q4_K *xb, short il, thread type4x4 & reg
for (int i = 0; i < 16; ++i) { for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * (q[i] & mask) - ml; reg[i/4][i%4] = dl * (q[i] & mask) - ml;
} }
} }
template <typename type4x4> template <typename type4x4>
@ -1885,19 +2017,19 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
device const uint8_t * qh = xb->qh; device const uint8_t * qh = xb->qh;
#if QK_K == 256 #if QK_K == 256
const float d = (float)(xb->d);
const float min = (float)(xb->dmin);
short is = (il/4) * 2; short is = (il/4) * 2;
q = q + 32 * (il/4) + 16 * (il&1); q = q + 32 * (il/4) + 16 * (il&1);
qh = qh + 16 * (il&1); qh = qh + 16 * (il&1);
uint8_t ul = 1 << (il/2); uint8_t ul = 1 << (il/2);
il = il%4; il = il & 3;
const uchar4 sc = get_scale_min_k4(is, xb->scales); const uchar2 sc = get_scale_min_k4_just2(is, il/2, xb->scales);
const float dl = il<2 ? d * sc[0] : d * sc[2]/16.h; const half d = il < 2 ? xb->d : xb->d / 16.h;
const float ml = il<2 ? min * sc[1] : min * sc[3]; const half min = xb->dmin;
const half dl = d * sc[0];
const half ml = min * sc[1];
const ushort mask = il<2 ? 0x0F : 0xF0; const ushort mask = il<2 ? 0x0F : 0xF0;
const float qh_val = il<2 ? 16.f : 256.f; const half qh_val = il<2 ? 16.h : 256.h;
for (int i = 0; i < 16; ++i) { for (int i = 0; i < 16; ++i) {
reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml; reg[i/4][i%4] = dl * ((q[i] & mask) + (qh[i] & ul ? qh_val : 0)) - ml;
} }
@ -1916,7 +2048,7 @@ void dequantize_q5_K(device const block_q5_K *xb, short il, thread type4x4 & reg
template <typename type4x4> template <typename type4x4>
void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) { void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg) {
const float d_all = (float)(xb->d); const half d_all = xb->d;
device const uint8_t * ql = (device const uint8_t *)xb->ql; device const uint8_t * ql = (device const uint8_t *)xb->ql;
device const uint8_t * qh = (device const uint8_t *)xb->qh; device const uint8_t * qh = (device const uint8_t *)xb->qh;
device const int8_t * scales = (device const int8_t *)xb->scales; device const int8_t * scales = (device const int8_t *)xb->scales;
@ -1924,19 +2056,21 @@ void dequantize_q6_K(device const block_q6_K *xb, short il, thread type4x4 & reg
#if QK_K == 256 #if QK_K == 256
ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1); ql = ql + 64*(il/8) + 32*((il/2)&1) + 16*(il&1);
qh = qh + 32*(il/8) + 16*(il&1); qh = qh + 32*(il/8) + 16*(il&1);
float sc = scales[(il%2) + 2 * ((il/2))]; half sc = scales[(il%2) + 2 * ((il/2))];
il = (il/2)%4; il = (il/2) & 3;
#else #else
ql = ql + 16 * (il&1); ql = ql + 16 * (il&1);
float sc = scales[il]; half sc = scales[il];
#endif #endif
const uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3);
const uint16_t kmask2 = il>1 ? 0xF0 : 0x0F;
const half coef = il>1 ? 1.f/16.h : 1.h;
const half ml = d_all * sc * 32.h;
const half dl = d_all * sc * coef;
for (int i = 0; i < 16; ++i) { for (int i = 0; i < 16; ++i) {
uint16_t kmask1 = il>1 ? (il>2 ? 192 : 48) : (il>0 ? 12 : 3); const half q = il&1 ? ((ql[i] & kmask2) | ((qh[i] & kmask1) << 2))
uint16_t kmask2 = il>1 ? 0xF0 : 0x0F; : ((ql[i] & kmask2) | ((qh[i] & kmask1) << 4));
const float coef = il>1 ? 1.f/16.f : 1.f; reg[i/4][i%4] = dl * q - ml;
float q = il&1 ? ((ql[i]&kmask2)|((qh[i]&kmask1)<<2)) - 32.f/coef : \
((ql[i]&kmask2)|((qh[i]&kmask1)<<4)) - 32.f/coef;
reg[i/4][i%4] = d_all * sc * q * coef;
} }
} }

8
ggml.c
View file

@ -1,4 +1,3 @@
#define _GNU_SOURCE // Defines CLOCK_MONOTONIC on Linux
#define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows #define _CRT_SECURE_NO_DEPRECATE // Disables ridiculous "unsafe" warnigns on Windows
#include "ggml.h" #include "ggml.h"
@ -47,6 +46,10 @@
// disable "possible loss of data" to avoid hundreds of casts // disable "possible loss of data" to avoid hundreds of casts
// we should just be careful :) // we should just be careful :)
#pragma warning(disable: 4244 4267) #pragma warning(disable: 4244 4267)
// disable POSIX deprecation warnigns
// these functions are never going away, anyway
#pragma warning(disable: 4996)
#endif #endif
#if defined(_WIN32) #if defined(_WIN32)
@ -307,12 +310,14 @@ typedef double ggml_float;
#if defined(_MSC_VER) || defined(__MINGW32__) #if defined(_MSC_VER) || defined(__MINGW32__)
#include <intrin.h> #include <intrin.h>
#else #else
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) || defined(__SSSE3__) || defined(__SSE3__)
#if !defined(__riscv) #if !defined(__riscv)
#include <immintrin.h> #include <immintrin.h>
#endif #endif
#endif #endif
#endif #endif
#endif #endif
#endif
#ifdef __riscv_v_intrinsic #ifdef __riscv_v_intrinsic
#include <riscv_vector.h> #include <riscv_vector.h>
@ -18872,7 +18877,6 @@ static enum ggml_opt_result linesearch_backtracking(
// strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE) // strong Wolfe condition (GGML_LINESEARCH_BACKTRACKING_STRONG_WOLFE)
return count; return count;
} }
return count;
} }
} }

View file

@ -1,8 +1,3 @@
// Defines fileno on msys:
#ifndef _GNU_SOURCE
#define _GNU_SOURCE
#endif
#include "llama.h" #include "llama.h"
#include "ggml.h" #include "ggml.h"
@ -5638,15 +5633,19 @@ void llama_free(struct llama_context * ctx) {
} }
int llama_n_vocab(const struct llama_context * ctx) { int llama_n_vocab(const struct llama_context * ctx) {
return ctx->model.vocab.id_to_token.size(); return llama_model_n_vocab(&ctx->model);
} }
int llama_n_ctx(const struct llama_context * ctx) { int llama_n_ctx(const struct llama_context * ctx) {
return ctx->model.hparams.n_ctx; return llama_model_n_ctx(&ctx->model);
}
int llama_n_ctx_train(const struct llama_context * ctx) {
return llama_model_n_ctx_train(&ctx->model);
} }
int llama_n_embd(const struct llama_context * ctx) { int llama_n_embd(const struct llama_context * ctx) {
return ctx->model.hparams.n_embd; return llama_model_n_embd(&ctx->model);
} }
enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) { enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx) {
@ -5661,6 +5660,10 @@ int llama_model_n_ctx(const struct llama_model * model) {
return model->hparams.n_ctx; return model->hparams.n_ctx;
} }
int llama_model_n_ctx_train(const struct llama_model * model) {
return model->hparams.n_ctx_train;
}
int llama_model_n_embd(const struct llama_model * model) { int llama_model_n_embd(const struct llama_model * model) {
return model->hparams.n_embd; return model->hparams.n_embd;
} }

View file

@ -247,12 +247,14 @@ extern "C" {
LLAMA_API int llama_n_vocab (const struct llama_context * ctx); LLAMA_API int llama_n_vocab (const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx); LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
LLAMA_API int llama_n_ctx_train(const struct llama_context * ctx);
LLAMA_API int llama_n_embd (const struct llama_context * ctx); LLAMA_API int llama_n_embd (const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx); LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_context * ctx);
LLAMA_API int llama_model_n_vocab (const struct llama_model * model); LLAMA_API int llama_model_n_vocab (const struct llama_model * model);
LLAMA_API int llama_model_n_ctx (const struct llama_model * model); LLAMA_API int llama_model_n_ctx (const struct llama_model * model);
LLAMA_API int llama_model_n_ctx_train(const struct llama_model * model);
LLAMA_API int llama_model_n_embd (const struct llama_model * model); LLAMA_API int llama_model_n_embd (const struct llama_model * model);
// Get a string describing the model type // Get a string describing the model type