Merge commit 'a6803cab94' into concedo_experimental

# Conflicts:
#	.devops/tools.sh
#	Makefile
#	build.zig
#	flake.nix
#	ggml-cuda.cu
#	ggml.h
#	tests/test-grad0.c
#	tests/test-opt.c
This commit is contained in:
Concedo 2023-07-18 19:12:06 +08:00
commit 6d32e7fc8b
17 changed files with 1261 additions and 325 deletions

View file

@ -97,7 +97,7 @@ if (LLAMA_CUBLAS)
if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES) if (NOT DEFINED CMAKE_CUDA_ARCHITECTURES)
if (LLAMA_CUDA_DMMV_F16) if (LLAMA_CUDA_DMMV_F16)
set(CMAKE_CUDA_ARCHITECTURES "61") # needed for f16 CUDA intrinsics set(CMAKE_CUDA_ARCHITECTURES "60;61") # needed for f16 CUDA intrinsics
else() else()
set(CMAKE_CUDA_ARCHITECTURES "52;61") # lowest CUDA 12 standard + lowest for integer intrinsics set(CMAKE_CUDA_ARCHITECTURES "52;61") # lowest CUDA 12 standard + lowest for integer intrinsics
endif() endif()

View file

@ -285,6 +285,7 @@ bool gpt_params_parse(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.lora_adapter = argv[i]; params.lora_adapter = argv[i];
params.use_mmap = false;
} else if (arg == "--lora-base") { } else if (arg == "--lora-base") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -520,7 +521,7 @@ void gpt_print_usage(int /*argc*/, char ** argv, const gpt_params & params) {
fprintf(stderr, " --mtest compute maximum memory usage\n"); fprintf(stderr, " --mtest compute maximum memory usage\n");
fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n"); fprintf(stderr, " --export export the computation graph to 'llama.ggml'\n");
fprintf(stderr, " --verbose-prompt print prompt before generation\n"); fprintf(stderr, " --verbose-prompt print prompt before generation\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter\n"); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
fprintf(stderr, " -m FNAME, --model FNAME\n"); fprintf(stderr, " -m FNAME, --model FNAME\n");
fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, " model path (default: %s)\n", params.model.c_str());

View file

@ -17,7 +17,7 @@ make
import torch import torch
bin_path = "../LLaVA-13b-delta-v1-1/pytorch_model-00003-of-00003.bin" bin_path = "../LLaVA-13b-delta-v1-1/pytorch_model-00003-of-00003.bin"
pth_path = "./examples/embd_input/llava_projection.pth" pth_path = "./examples/embd-input/llava_projection.pth"
dic = torch.load(bin_path) dic = torch.load(bin_path)
used_key = ["model.mm_projector.weight","model.mm_projector.bias"] used_key = ["model.mm_projector.weight","model.mm_projector.bias"]

View file

@ -59,7 +59,7 @@ if __name__=="__main__":
# Also here can use pytorch_model-00003-of-00003.bin directly. # Also here can use pytorch_model-00003-of-00003.bin directly.
a.load_projection(os.path.join( a.load_projection(os.path.join(
os.path.dirname(__file__) , os.path.dirname(__file__) ,
"llava_projetion.pth")) "llava_projection.pth"))
respose = a.chat_with_image( respose = a.chat_with_image(
Image.open("./media/llama1-logo.png").convert('RGB'), Image.open("./media/llama1-logo.png").convert('RGB'),
"what is the text in the picture?") "what is the text in the picture?")

View file

@ -293,5 +293,5 @@ These options provide extra functionality and customization when running the LLa
- `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS. - `-mg i, --main-gpu i`: When using multiple GPUs this option controls which GPU is used for small tensors for which the overhead of splitting the computation across all GPUs is not worthwhile. The GPU in question will use slightly more VRAM to store a scratch buffer for temporary results. By default GPU 0 is used. Requires cuBLAS.
- `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS. - `-ts SPLIT, --tensor-split SPLIT`: When using multiple GPUs this option controls how large tensors should be split across all GPUs. `SPLIT` is a comma-separated list of non-negative values that assigns the proportion of data that each GPU should get in order. For example, "3,2" will assign 60% of the data to GPU 0 and 40% to GPU 1. By default the data is split in proportion to VRAM but this may not be optimal for performance. Requires cuBLAS.
- `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS. - `-lv, --low-vram`: Do not allocate a VRAM scratch buffer for holding temporary results. Reduces VRAM usage at the cost of performance, particularly prompt processing speed. Requires cuBLAS.
- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model. This allows you to adapt the pretrained model to specific tasks or domains. - `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation. - `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.

View file

@ -16,7 +16,7 @@ Command line options:
- `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. Not recommended. - `--memory-f32`: Use 32-bit floats instead of 16-bit floats for memory key+value. Not recommended.
- `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped. - `--mlock`: Lock the model in memory, preventing it from being swapped out when memory-mapped.
- `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed. - `--no-mmap`: Do not memory-map the model. By default, models are mapped into memory, which allows the system to load only the necessary parts of the model as needed.
- `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model. This allows you to adapt the pretrained model to specific tasks or domains. - `--lora FNAME`: Apply a LoRA (Low-Rank Adaptation) adapter to the model (implies --no-mmap). This allows you to adapt the pretrained model to specific tasks or domains.
- `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation. - `--lora-base FNAME`: Optional model to use as a base for the layers modified by the LoRA adapter. This flag is used in conjunction with the `--lora` flag, and specifies the base model for the adaptation.
- `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`. - `-to N`, `--timeout N`: Server read/write timeout in seconds. Default `600`.
- `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`. - `--host`: Set the hostname or ip address to listen. Default `127.0.0.1`.

View file

@ -632,7 +632,7 @@ static void server_print_usage(const char *argv0, const gpt_params &params,
fprintf(stderr, " model path (default: %s)\n", params.model.c_str()); fprintf(stderr, " model path (default: %s)\n", params.model.c_str());
fprintf(stderr, " -a ALIAS, --alias ALIAS\n"); fprintf(stderr, " -a ALIAS, --alias ALIAS\n");
fprintf(stderr, " set an alias for the model, will be added as `model` field in completion response\n"); fprintf(stderr, " set an alias for the model, will be added as `model` field in completion response\n");
fprintf(stderr, " --lora FNAME apply LoRA adapter\n"); fprintf(stderr, " --lora FNAME apply LoRA adapter (implies --no-mmap)\n");
fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n"); fprintf(stderr, " --lora-base FNAME optional model to use as a base for the layers modified by the LoRA adapter\n");
fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str()); fprintf(stderr, " --host ip address to listen (default (default: %s)\n", sparams.hostname.c_str());
fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port); fprintf(stderr, " --port PORT port to listen (default (default: %d)\n", sparams.port);
@ -820,6 +820,7 @@ static void server_params_parse(int argc, char **argv, server_params &sparams,
break; break;
} }
params.lora_adapter = argv[i]; params.lora_adapter = argv[i];
params.use_mmap = false;
} }
else if (arg == "--lora-base") else if (arg == "--lora-base")
{ {

View file

@ -13,6 +13,8 @@
#include "ggml-cuda.h" #include "ggml-cuda.h"
#include "ggml.h" #include "ggml.h"
#define MIN_CC_DP4A 610 // minimum compute capability for __dp4a, an intrinsic for byte-wise dot products
#if defined(_MSC_VER) #if defined(_MSC_VER)
#pragma warning(disable: 4244 4267) // possible loss of data #pragma warning(disable: 4244 4267) // possible loss of data
#endif #endif
@ -74,7 +76,7 @@ typedef void (*ggml_cuda_op_t)(
#define QK4_0 32 #define QK4_0 32
#define QR4_0 2 #define QR4_0 2
#define QI4_0 4 #define QI4_0 (QK4_0 / (4 * QR4_0))
typedef struct { typedef struct {
half d; // delta half d; // delta
uint8_t qs[QK4_0 / 2]; // nibbles / quants uint8_t qs[QK4_0 / 2]; // nibbles / quants
@ -83,7 +85,7 @@ static_assert(sizeof(block_q4_0) == sizeof(ggml_fp16_t) + QK4_0 / 2, "wrong q4_0
#define QK4_1 32 #define QK4_1 32
#define QR4_1 2 #define QR4_1 2
#define QI4_1 4 #define QI4_1 (QK4_1 / (4 * QR4_1))
typedef struct { typedef struct {
half d; // delta half d; // delta
half m; // min half m; // min
@ -93,7 +95,7 @@ static_assert(sizeof(block_q4_1) == sizeof(ggml_fp16_t) * 2 + QK4_1 / 2, "wrong
#define QK5_0 32 #define QK5_0 32
#define QR5_0 2 #define QR5_0 2
#define QI5_0 4 #define QI5_0 (QK5_0 / (4 * QR5_0))
typedef struct { typedef struct {
half d; // delta half d; // delta
uint8_t qh[4]; // 5-th bit of quants uint8_t qh[4]; // 5-th bit of quants
@ -103,7 +105,7 @@ static_assert(sizeof(block_q5_0) == sizeof(ggml_fp16_t) + sizeof(uint32_t) + QK5
#define QK5_1 32 #define QK5_1 32
#define QR5_1 2 #define QR5_1 2
#define QI5_1 4 #define QI5_1 (QK5_1 / (4 * QR5_1))
typedef struct { typedef struct {
half d; // delta half d; // delta
half m; // min half m; // min
@ -114,7 +116,7 @@ static_assert(sizeof(block_q5_1) == 2 * sizeof(ggml_fp16_t) + sizeof(uint32_t) +
#define QK8_0 32 #define QK8_0 32
#define QR8_0 1 #define QR8_0 1
#define QI8_0 8 #define QI8_0 (QK8_0 / (4 * QR8_0))
typedef struct { typedef struct {
half d; // delta half d; // delta
int8_t qs[QK8_0]; // quants int8_t qs[QK8_0]; // quants
@ -123,7 +125,7 @@ static_assert(sizeof(block_q8_0) == sizeof(ggml_fp16_t) + QK8_0, "wrong q8_0 blo
#define QK8_1 32 #define QK8_1 32
#define QR8_1 1 #define QR8_1 1
#define QI8_1 8 #define QI8_1 (QK8_1 / (4 * QR8_1))
typedef struct { typedef struct {
half d; // delta half d; // delta
half s; // unquantized sum half s; // unquantized sum
@ -143,6 +145,8 @@ typedef float (*vec_dot_q_cuda_t)(const void * __restrict__ vbq, const block_q8_
#define K_SCALE_SIZE 12 #define K_SCALE_SIZE 12
#endif #endif
#define QR2_K 4
#define QI2_K (QK_K / (4*QR2_K))
typedef struct { typedef struct {
uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits uint8_t scales[QK_K/16]; // scales and mins, quantized with 4 bits
uint8_t qs[QK_K/4]; // quants uint8_t qs[QK_K/4]; // quants
@ -151,6 +155,8 @@ typedef struct {
} block_q2_K; } block_q2_K;
static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding"); static_assert(sizeof(block_q2_K) == 2*sizeof(ggml_fp16_t) + QK_K/16 + QK_K/4, "wrong q2_K block size/padding");
#define QR3_K 4
#define QI3_K (QK_K / (4*QR3_K))
typedef struct { typedef struct {
uint8_t hmask[QK_K/8]; // quants - high bit uint8_t hmask[QK_K/8]; // quants - high bit
uint8_t qs[QK_K/4]; // quants - low 2 bits uint8_t qs[QK_K/4]; // quants - low 2 bits
@ -163,6 +169,8 @@ typedef struct {
} block_q3_K; } block_q3_K;
//static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding"); //static_assert(sizeof(block_q3_K) == sizeof(ggml_fp16_t) + QK_K / 4 + QK_K / 8 + K_SCALE_SIZE, "wrong q3_K block size/padding");
#define QR4_K 2
#define QI4_K (QK_K / (4*QR4_K))
#ifdef GGML_QKK_64 #ifdef GGML_QKK_64
typedef struct { typedef struct {
half d[2]; // super-block scales/mins half d[2]; // super-block scales/mins
@ -180,6 +188,8 @@ typedef struct {
static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding"); static_assert(sizeof(block_q4_K) == 2*sizeof(ggml_fp16_t) + 3*QK_K/64 + QK_K/2, "wrong q4_K block size/padding");
#endif #endif
#define QR5_K 2
#define QI5_K (QK_K / (4*QR5_K))
#ifdef GGML_QKK_64 #ifdef GGML_QKK_64
typedef struct { typedef struct {
half d; // super-block scale half d; // super-block scale
@ -199,6 +209,8 @@ typedef struct {
static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding"); static_assert(sizeof(block_q5_K) == 2*sizeof(ggml_fp16_t) + K_SCALE_SIZE + QK_K/2 + QK_K/8, "wrong q5_K block size/padding");
#endif #endif
#define QR6_K 2
#define QI6_K (QK_K / (4*QR6_K))
typedef struct { typedef struct {
uint8_t ql[QK_K/2]; // quants, lower 4 bits uint8_t ql[QK_K/2]; // quants, lower 4 bits
uint8_t qh[QK_K/4]; // quants, upper 2 bits uint8_t qh[QK_K/4]; // quants, upper 2 bits
@ -212,6 +224,7 @@ static_assert(sizeof(block_q6_K) == sizeof(ggml_fp16_t) + 13*QK_K/16, "wrong q6_
#define CUDA_ADD_BLOCK_SIZE 256 #define CUDA_ADD_BLOCK_SIZE 256
#define CUDA_MUL_BLOCK_SIZE 256 #define CUDA_MUL_BLOCK_SIZE 256
#define CUDA_GELU_BLOCK_SIZE 256
#define CUDA_SILU_BLOCK_SIZE 256 #define CUDA_SILU_BLOCK_SIZE 256
#define CUDA_CPY_BLOCK_SIZE 32 #define CUDA_CPY_BLOCK_SIZE 32
#define CUDA_SCALE_BLOCK_SIZE 256 #define CUDA_SCALE_BLOCK_SIZE 256
@ -239,13 +252,13 @@ struct ggml_tensor_extra_gpu {
cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs cudaEvent_t events[GGML_CUDA_MAX_DEVICES]; // events for synchronizing multiple GPUs
}; };
static __global__ void add_f32(const float * x, const float * y, float * dst, const int k) { static __global__ void add_f32(const float * x, const float * y, float * dst, const int kx, const int ky) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) { if (i >= kx) {
return; return;
} }
dst[i] = x[i] + y[i]; dst[i] = x[i] + y[i%ky];
} }
static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) { static __global__ void add_f16_f32_f16(const half * x, const float * y, half * dst, const int k) {
@ -266,6 +279,19 @@ static __global__ void mul_f32(const float * x, const float * y, float * dst, co
dst[i] = x[i] * y[i%ky]; dst[i] = x[i] * y[i%ky];
} }
static __global__ void gelu_f32(const float * x, float * dst, const int k) {
const float GELU_COEF_A = 0.044715f;
const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) {
return;
}
float xi = x[i];
dst[i] = 0.5f*xi*(1.0f + tanhf(SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi)));
}
static __global__ void silu_f32(const float * x, float * dst, const int k) { static __global__ void silu_f32(const float * x, float * dst, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
@ -275,16 +301,46 @@ static __global__ void silu_f32(const float * x, float * dst, const int k) {
dst[i] = x[i] / (1.0f + expf(-x[i])); dst[i] = x[i] / (1.0f + expf(-x[i]));
} }
static __global__ void norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x;
const float eps = 1e-5f;
float mean = 0.0f;
float var = 0.0f;
for (int col = tid; col < ncols; col += WARP_SIZE) {
const float xi = x[row*ncols + col];
mean += xi;
var += xi * xi;
}
// sum up partial sums
#pragma unroll
for (int mask = 16; mask > 0; mask >>= 1) {
mean += __shfl_xor_sync(0xffffffff, mean, mask, 32);
var += __shfl_xor_sync(0xffffffff, var, mask, 32);
}
mean /= ncols;
var = var / ncols - mean * mean;
const float inv_var = rsqrtf(var + eps);
for (int col = tid; col < ncols; col += WARP_SIZE) {
dst[row*ncols + col] = (x[row*ncols + col] - mean) * inv_var;
}
}
static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) { static __global__ void rms_norm_f32(const float * x, float * dst, const int ncols) {
const int row = blockIdx.x*blockDim.y + threadIdx.y; const int row = blockIdx.x*blockDim.y + threadIdx.y;
const int tid = threadIdx.x; const int tid = threadIdx.x;
const float eps = 1e-6; const float eps = 1e-6f;
float tmp = 0.0f; // partial sum for thread in warp float tmp = 0.0f; // partial sum for thread in warp
for (int i = 0; i < ncols; i += WARP_SIZE) { for (int col = tid; col < ncols; col += WARP_SIZE) {
const int col = i + tid;
const float xi = x[row*ncols + col]; const float xi = x[row*ncols + col];
tmp += xi * xi; tmp += xi * xi;
} }
@ -296,10 +352,9 @@ static __global__ void rms_norm_f32(const float * x, float * dst, const int ncol
} }
const float mean = tmp / ncols; const float mean = tmp / ncols;
const float scale = 1.0f / sqrtf(mean + eps); const float scale = rsqrtf(mean + eps);
for (int i = 0; i < ncols; i += WARP_SIZE) { for (int col = tid; col < ncols; col += WARP_SIZE) {
const int col = i + tid;
dst[row*ncols + col] = scale * x[row*ncols + col]; dst[row*ncols + col] = scale * x[row*ncols + col];
} }
} }
@ -1228,8 +1283,9 @@ static __global__ void dequantize_block(const void * __restrict__ vx, float * __
y[iybs + iqs + y_offset] = v.y; y[iybs + iqs + y_offset] = v.y;
} }
static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { static __device__ __forceinline__ float vec_dot_q4_0_q8_1(
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq; const block_q4_0 * bq4_0 = (const block_q4_0 *) vbq;
int vi; int vi;
@ -1250,11 +1306,12 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1(const void * __restric
return sumi*d; return sumi*d;
#else #else
return 0.0f; // only to satisfy the compiler return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= 600 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
} }
static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { static __device__ __forceinline__ float vec_dot_q4_1_q8_1(
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq; const block_q4_1 * bq4_1 = (const block_q4_1 *) vbq;
const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]); const int vi = *((int *) &bq4_1->qs[sizeof(int) * (iqs + 0)]);
@ -1275,11 +1332,12 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1(const void * __restric
return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block return sumi*d + m*s / QI4_1; // scale sum by QI4_1 because there are QI4_1 threads working on this block
#else #else
return 0.0f; // only to satisfy the compiler return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= 600 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
} }
static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { static __device__ __forceinline__ float vec_dot_q5_0_q8_1(
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq; const block_q5_0 * bq5_0 = (const block_q5_0 *) vbq;
int qs; int qs;
@ -1310,11 +1368,12 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1(const void * __restric
return sumi*d; return sumi*d;
#else #else
return 0.0f; // only to satisfy the compiler return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= 600 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
} }
static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { static __device__ __forceinline__ float vec_dot_q5_1_q8_1(
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq; const block_q5_1 * bq5_1 = (const block_q5_1 *) vbq;
const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]); const int qs = *((int *) &bq5_1->qs[sizeof(int) * (iqs + 0)]);
@ -1344,11 +1403,12 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1(const void * __restric
return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block return sumi*d + m*s / QI5_1; // scale sum by QI5_1 because there are QI5_1 threads working on this block
#else #else
return 0.0f; // only to satisfy the compiler return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= 600 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
} }
static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) { static __device__ __forceinline__ float vec_dot_q8_0_q8_1(
#if __CUDA_ARCH__ >= 600 // lowest compute capability for integer intrinsics const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq; const block_q8_0 * bq8_0 = (const block_q8_0 *) vbq;
int vi; int vi;
@ -1363,7 +1423,220 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1(const void * __restric
return sumi*d; return sumi*d;
#else #else
return 0.0f; // only to satisfy the compiler return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= 600 #endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
static __device__ __forceinline__ float vec_dot_q2_K_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q2_K * bq2_K = (const block_q2_K *) vbq;
const int bq8_offset = QR2_K * (iqs / QI8_1);
const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
float sumf_d = 0.0f;
float sumf_m = 0.0f;
const float d = bq2_K->d;
const float dmin = bq2_K->dmin;
const int v = *((int *) &bq2_K->qs[sizeof(int) * iqs]);
for (int i = 0; i < QR2_K; ++i) {
const int sc = bq2_K->scales[scale_offset + 2*i];
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
const float d8i = bq8i->d;
const int vi = (v >> (2*i)) & 0x03030303;
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
sumf_d += d8i * (__dp4a(vi, ui, 0) * (sc & 0xF)); // SIMD dot product
sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * (sc >> 4)); // multiply constant q2_K part with sum of q8_1 values
}
return d*sumf_d - dmin*sumf_m;
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
static __device__ __forceinline__ float vec_dot_q3_K_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q3_K * bq3_K = (const block_q3_K *) vbq;
const int bq8_offset = QR3_K * (iqs / (QI3_K/2));
const int scale_offset = iqs - iqs % QI8_1 + (iqs % QI8_1) / (QI8_1/2);
float sumf = 0.0f;
const float d = bq3_K->d;
int vl;
memcpy(&vl, &bq3_K->qs[sizeof(int) * iqs], sizeof(int));
int vh;
memcpy(&vh, &bq3_K->hmask[sizeof(int) * (iqs % (QI3_K/2))], sizeof(int));
vh = ~vh; // invert the mask so that a 0/1 results in 4/0 being subtracted
vh >>= bq8_offset;
for (int i = 0; i < QR3_K; ++i) {
const int isc = scale_offset + 2*i;
const int isc_low = isc % (QK_K/32);
const int sc_shift_low = 4 * (isc / (QK_K/32));
const int sc_low = (bq3_K->scales[isc_low] >> sc_shift_low) & 0xF;
const int isc_high = isc % (QK_K/64);
const int sc_shift_high = 2 * (isc / (QK_K/64));
const int sc_high = ((bq3_K->scales[(QK_K/32) + isc_high] >> sc_shift_high) & 3) << 4;
const int sc = (sc_low | sc_high) - 32;
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
const float d8i = bq8i->d;
const int vil = (vl >> (2*i)) & 0x03030303;
const int vih = ((vh >> i) << 2) & 0x04040404;
const int vi = __vsubss4(vil, vih);
sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
}
return d*sumf;
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
static __device__ __forceinline__ float vec_dot_q4_K_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q4_K * bq4_K = (const block_q4_K *) vbq;
const int bq8_offset = QR4_K * (iqs / QI8_1);
float sumf_d = 0.0f;
float sumf_m = 0.0f;
const float d = bq4_K->d;
const float dmin = bq4_K->dmin;
const int v = *((int *) &bq4_K->qs[sizeof(int) * iqs]);
for (int i = 0; i < QR4_K; ++i) {
const int isc = bq8_offset + i;
uint8_t sc, m;
get_scale_min_k4(isc, bq4_K->scales, sc, m);
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
const float d8i = bq8i->d;
const int vi = (v >> (4*i)) & 0x0F0F0F0F;
sumf_d += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m); // multiply constant part of q4_K with sum of q8_1 values
}
return d*sumf_d - dmin*sumf_m;
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
static __device__ __forceinline__ float vec_dot_q5_K_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q5_K * bq5_K = (const block_q5_K *) vbq;
const int bq8_offset = QR5_K * (iqs / QI8_1);
float sumf_d = 0.0f;
float sumf_m = 0.0f;
const float d = bq5_K->d;
const float dmin = bq5_K->dmin;
const int vl = *((int *) &bq5_K->qs[sizeof(int) * iqs]);
const int vh = (*((int *) &bq5_K->qh[sizeof(int) * (iqs % (QI5_K/4))])) >> bq8_offset;
for (int i = 0; i < QR5_K; ++i) {
const int isc = bq8_offset + i;
uint8_t sc, m;
get_scale_min_k4(isc, bq5_K->scales, sc, m);
const block_q8_1 * bq8i = bq8_1 + bq8_offset + i;
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % QI8_1)]);
const float d8i = bq8i->d;
const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
const int vih = ((vh >> i) << 4) & 0x10101010;
const int vi = vil | vih;
sumf_d += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
sumf_m += d8i * (__dp4a(0x01010101, ui, 0) * m); // multiply constant part of q5_K with sum of q8_1 values
}
return d*sumf_d - dmin*sumf_m;
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
}
static __device__ __forceinline__ float vec_dot_q6_K_q8_1(
const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int iqs) {
#if __CUDA_ARCH__ >= MIN_CC_DP4A // lowest compute capability for integer intrinsics
const block_q6_K * bq6_K = (const block_q6_K *) vbq;
const int bq8_offset = 2 * QR6_K * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/4);
const int scale_offset = (QI6_K/4) * (iqs / (QI6_K/2)) + (iqs % (QI6_K/2)) / (QI6_K/8);
const int vh_shift = 2 * ((iqs % (QI6_K/2)) / (QI6_K/4));
float sumf = 0.0f;
const float d = bq6_K->d;
int vl;
memcpy(&vl, &bq6_K->ql[sizeof(int) * iqs], sizeof(int));
int vh;
memcpy(&vh, &bq6_K->qh[sizeof(int) * ((QI6_K/4) * (iqs / (QI6_K/2)) + iqs % (QI6_K/4))], sizeof(int));
for (int i = 0; i < QR6_K; ++i) {
const int sc = bq6_K->scales[scale_offset + 4*i];
const block_q8_1 * bq8i = bq8_1 + bq8_offset + 2*i;
const int ui = *((int*) &bq8i->qs[sizeof(int) * (iqs % (QI8_1))]);
const float d8i = bq8i->d;
const int vil = (vl >> (4*i)) & 0x0F0F0F0F;
const int vih = ((vh >> (vh_shift + 4*i)) << 4) & 0x30303030;
const int vi = __vsubss4((vil | vih), 0x20202020); // vi = (vil | vih) - 32
sumf += d8i * (__dp4a(vi, ui, 0) * sc); // SIMD dot product
}
return d*sumf;
#else
return 0.0f; // only to satisfy the compiler
#endif // __CUDA_ARCH__ >= MIN_CC_DP4A
} }
template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda> template <int qk, int qi, typename block_q_t, vec_dot_q_cuda_t vec_dot_q_cuda>
@ -1386,7 +1659,7 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void *
for (int i = 0; i < blocks_per_row; i += blocks_per_warp) { for (int i = 0; i < blocks_per_row; i += blocks_per_warp) {
const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index const int ibx = row*blocks_per_row + i + threadIdx.x / qi; // x block index
const int iby = i + threadIdx.x / qi; // y block index const int iby = (i + threadIdx.x / qi) * qk/QK8_1; // y block index that aligns with ibx
const int iqs = threadIdx.x % qi; // x block quant index when casting the quants to int const int iqs = threadIdx.x % qi; // x block quant index when casting the quants to int
@ -1624,6 +1897,40 @@ static __global__ void rope_f32(const float * x, float * dst, const int ncols, c
dst[i + 1] = x0*sin_theta + x1*cos_theta; dst[i + 1] = x0*sin_theta + x1*cos_theta;
} }
static __global__ void rope_glm_f32(const float * x, float * dst, const int ncols, const float p, const float block_p, const float theta_scale) {
const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int half_n_dims = ncols/4;
if (col >= half_n_dims) {
return;
}
const int row = blockDim.y*blockIdx.y + threadIdx.y;
const int i = row*ncols + col;
const float col_theta_scale = powf(theta_scale, col);
const float theta = p*col_theta_scale;
const float sin_theta = sinf(theta);
const float cos_theta = cosf(theta);
const float x0 = x[i + 0];
const float x1 = x[i + half_n_dims];
dst[i + 0] = x0*cos_theta - x1*sin_theta;
dst[i + half_n_dims] = x0*sin_theta + x1*cos_theta;
const float block_theta = block_p*col_theta_scale;
const float sin_block_theta = sinf(block_theta);
const float cos_block_theta = cosf(block_theta);
const float x2 = x[i + half_n_dims * 2];
const float x3 = x[i + half_n_dims * 3];
dst[i + half_n_dims * 2] = x2*cos_block_theta - x3*sin_block_theta;
dst[i + half_n_dims * 3] = x2*sin_block_theta + x3*cos_block_theta;
}
static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) { static __global__ void diag_mask_inf_f32(const float * x, float * dst, const int ncols, const int rows_per_channel, const int n_past) {
const int col = blockDim.x*blockIdx.x + threadIdx.x; const int col = blockDim.x*blockIdx.x + threadIdx.x;
const int row = blockDim.y*blockIdx.y + threadIdx.y; const int row = blockDim.y*blockIdx.y + threadIdx.y;
@ -1689,9 +1996,9 @@ static __global__ void scale_f32(const float * x, float * dst, const float scale
dst[i] = scale * x[i]; dst[i] = scale * x[i];
} }
static void add_f32_cuda(const float * x, const float * y, float * dst, const int k, cudaStream_t stream) { static void add_f32_cuda(const float * x, const float * y, float * dst, const int kx, const int ky, cudaStream_t stream) {
const int num_blocks = (k + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE; const int num_blocks = (kx + CUDA_ADD_BLOCK_SIZE - 1) / CUDA_ADD_BLOCK_SIZE;
add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, k); add_f32<<<num_blocks, CUDA_ADD_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
} }
static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) { static void add_f16_f32_f16_cuda(const half * x, const float * y, half * dst, const int k, cudaStream_t stream) {
@ -1704,11 +2011,22 @@ static void mul_f32_cuda(const float * x, const float * y, float * dst, const in
mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky); mul_f32<<<num_blocks, CUDA_MUL_BLOCK_SIZE, 0, stream>>>(x, y, dst, kx, ky);
} }
static void gelu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_GELU_BLOCK_SIZE - 1) / CUDA_GELU_BLOCK_SIZE;
gelu_f32<<<num_blocks, CUDA_GELU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
}
static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) { static void silu_f32_cuda(const float * x, float * dst, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE; const int num_blocks = (k + CUDA_SILU_BLOCK_SIZE - 1) / CUDA_SILU_BLOCK_SIZE;
silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k); silu_f32<<<num_blocks, CUDA_SILU_BLOCK_SIZE, 0, stream>>>(x, dst, k);
} }
static void norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1);
norm_f32<<<nrows, block_dims, 0, stream>>>(x, dst, ncols);
}
static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % WARP_SIZE == 0); GGML_ASSERT(ncols % WARP_SIZE == 0);
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(WARP_SIZE, 1, 1);
@ -1874,7 +2192,7 @@ static void dequantize_mul_mat_vec_q6_K_cuda(const void * vx, const float * y, f
} }
static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(ncols % QK4_0 == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@ -1883,7 +2201,7 @@ static void mul_mat_vec_q4_0_q8_1_cuda(const void * vx, const void * vy, float *
} }
static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(ncols % QK4_1 == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@ -1892,7 +2210,7 @@ static void mul_mat_vec_q4_1_q8_1_cuda(const void * vx, const void * vy, float *
} }
static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(ncols % QK5_0 == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@ -1901,7 +2219,7 @@ static void mul_mat_vec_q5_0_q8_1_cuda(const void * vx, const void * vy, float *
} }
static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(ncols % QK5_1 == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@ -1910,7 +2228,7 @@ static void mul_mat_vec_q5_1_q8_1_cuda(const void * vx, const void * vy, float *
} }
static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) { static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % GGML_CUDA_DMMV_X == 0); GGML_ASSERT(ncols % QK8_0 == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y; const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1); const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1); const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
@ -1918,6 +2236,51 @@ static void mul_mat_vec_q8_0_q8_1_cuda(const void * vx, const void * vy, float *
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows); <<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
} }
static void mul_mat_vec_q2_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI2_K, block_q2_K, vec_dot_q2_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_q3_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI3_K, block_q3_K, vec_dot_q3_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_q4_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI4_K, block_q4_K, vec_dot_q4_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_q5_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI5_K, block_q5_K, vec_dot_q5_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void mul_mat_vec_q6_K_q8_1_cuda(const void * vx, const void * vy, float * dst, const int ncols, const int nrows, cudaStream_t stream) {
GGML_ASSERT(ncols % QK_K == 0);
const int block_num_y = (nrows + GGML_CUDA_MMV_Y - 1) / GGML_CUDA_MMV_Y;
const dim3 block_nums(1, block_num_y, 1);
const dim3 block_dims(WARP_SIZE, GGML_CUDA_MMV_Y, 1);
mul_mat_vec_q<QK_K, QI6_K, block_q6_K, vec_dot_q6_K_q8_1>
<<<block_nums, block_dims, 0, stream>>>(vx, vy, dst, ncols, nrows);
}
static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) { static void convert_fp16_to_fp32_cuda(const void * vx, float * y, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE; const int num_blocks = (k + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE;
dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k); dequantize_block<1, 1, convert_f16><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>(vx, y, k);
@ -2010,6 +2373,14 @@ static void rope_f32_cuda(const float * x, float * dst, const int ncols, const i
rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale); rope_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, theta_scale);
} }
static void rope_glm_f32_cuda(const float * x, float * dst, const int ncols, const int nrows, const float p, const float block_p, const float theta_scale, cudaStream_t stream) {
GGML_ASSERT(nrows % 4 == 0);
const dim3 block_dims(4*CUDA_ROPE_BLOCK_SIZE, 1, 1);
const int num_blocks_x = (ncols + 4*CUDA_ROPE_BLOCK_SIZE - 1) / (4*CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(num_blocks_x, nrows, 1);
rope_glm_f32<<<block_nums, block_dims, 0, stream>>>(x, dst, ncols, p, block_p, theta_scale);
}
static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) { static void diag_mask_inf_f32_cuda(const float * x, float * dst, const int ncols_x, const int nrows_x, const int rows_per_channel, const int n_past, cudaStream_t stream) {
const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1); const dim3 block_dims(CUDA_DIAG_MASK_INF_BLOCK_SIZE, 1, 1);
const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE; const int block_num_x = (ncols_x + CUDA_DIAG_MASK_INF_BLOCK_SIZE - 1) / CUDA_DIAG_MASK_INF_BLOCK_SIZE;
@ -2264,14 +2635,17 @@ inline void ggml_cuda_op_add(
GGML_ASSERT(src1_ddf_i != nullptr); GGML_ASSERT(src1_ddf_i != nullptr);
GGML_ASSERT(dst_ddf_i != nullptr); GGML_ASSERT(dst_ddf_i != nullptr);
const int64_t ne0 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low; const int64_t i01_diff = i01_high - i01_low;
const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1];
// compute // compute
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne0*i01_diff, cudaStream_main); add_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
} else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) {
add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne0*i01_diff, cudaStream_main); add_f16_f32_f16_cuda((half *) src0_ddq_i, src1_ddf_i, (half *) dst_ddf_i, ne00*i01_diff, cudaStream_main);
} else { } else {
GGML_ASSERT(false); GGML_ASSERT(false);
} }
@ -2293,26 +2667,40 @@ inline void ggml_cuda_op_mul(
GGML_ASSERT(dst_ddf_i != nullptr); GGML_ASSERT(dst_ddf_i != nullptr);
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;
const int64_t ne10 = src1->ne[0]; const int64_t ne10 = src1->ne[0];
const int64_t ne11 = src1->ne[1]; const int64_t ne11 = src1->ne[1];
for (int64_t i01 = i01_low; i01 < i01_high; i01++) { mul_f32_cuda(src0_ddf_i, src1_ddf_i, dst_ddf_i, ne00*i01_diff, ne10*ne11, cudaStream_main);
const int64_t i11 = i1*ne11 + i01%ne11; // broadcast src1 across src0
float * src0_ddf_i01 = src0_ddf_i + i01*ne00;
float * src1_ddf_i01 = src1_ddf_i + i11*ne10;
float * dst_ddf_i01 = dst_ddf_i + i01*ne00;
// compute
mul_f32_cuda(src0_ddf_i01, src1_ddf_i01, dst_ddf_i01, ne00, ne10, cudaStream_main);
}
(void) dst; (void) dst;
(void) src0_ddq_i; (void) src0_ddq_i;
(void) i02; (void) i02;
} }
inline void ggml_cuda_op_gelu(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
cudaStream_t & cudaStream_main){
GGML_ASSERT(src0_ddf_i != nullptr);
GGML_ASSERT(dst_ddf_i != nullptr);
const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;
// compute
gelu_f32_cuda(src0_ddf_i, dst_ddf_i, ne00*i01_diff, cudaStream_main);
(void) src1;
(void) dst;
(void) src0_ddq_i;
(void) src1_ddf_i;
(void) i02;
(void) i1;
}
inline void ggml_cuda_op_silu( inline void ggml_cuda_op_silu(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@ -2335,6 +2723,28 @@ inline void ggml_cuda_op_silu(
(void) i1; (void) i1;
} }
inline void ggml_cuda_op_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
cudaStream_t & cudaStream_main){
GGML_ASSERT(src0_ddf_i != nullptr);
GGML_ASSERT(dst_ddf_i != nullptr);
const int64_t ne00 = src0->ne[0];
const int64_t i01_diff = i01_high - i01_low;
// compute
norm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, cudaStream_main);
(void) src1;
(void) dst;
(void) src0_ddq_i;
(void) src1_ddf_i;
(void) i02;
(void) i1;
}
inline void ggml_cuda_op_rms_norm( inline void ggml_cuda_op_rms_norm(
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, char * src0_ddq_i,
float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1, float * src0_ddf_i, float * src1_ddf_i, float * dst_ddf_i, int64_t i02, int64_t i01_low, int64_t i01_high, int i1,
@ -2375,13 +2785,22 @@ inline void ggml_cuda_op_mul_mat_vec(
int id; int id;
CUDA_CHECK(cudaGetDevice(&id)); CUDA_CHECK(cudaGetDevice(&id));
const bool mul_mat_vec_q_implemented = src0->type == GGML_TYPE_Q4_0 || bool mul_mat_vec_q_implemented =
src0->type == GGML_TYPE_Q4_0 ||
src0->type == GGML_TYPE_Q4_1 || src0->type == GGML_TYPE_Q4_1 ||
src0->type == GGML_TYPE_Q5_0 || src0->type == GGML_TYPE_Q5_0 ||
src0->type == GGML_TYPE_Q5_1 || src0->type == GGML_TYPE_Q5_1 ||
src0->type == GGML_TYPE_Q8_0; src0->type == GGML_TYPE_Q8_0;
#if QK_K == 256
mul_mat_vec_q_implemented = mul_mat_vec_q_implemented ||
src0->type == GGML_TYPE_Q2_K ||
src0->type == GGML_TYPE_Q3_K ||
src0->type == GGML_TYPE_Q4_K ||
src0->type == GGML_TYPE_Q5_K ||
src0->type == GGML_TYPE_Q6_K;
#endif // QK_K == 256
const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= 600 && mul_mat_vec_q_implemented; const bool use_mul_mat_vec_q = g_compute_capabilities[id] >= MIN_CC_DP4A && mul_mat_vec_q_implemented;
#endif #endif
if (use_mul_mat_vec_q) { if (use_mul_mat_vec_q) {
@ -2407,6 +2826,21 @@ inline void ggml_cuda_op_mul_mat_vec(
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main); mul_mat_vec_q8_0_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
break; break;
case GGML_TYPE_Q2_K:
mul_mat_vec_q2_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
break;
case GGML_TYPE_Q3_K:
mul_mat_vec_q3_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
break;
case GGML_TYPE_Q4_K:
mul_mat_vec_q4_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
break;
case GGML_TYPE_Q5_K:
mul_mat_vec_q5_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
break;
case GGML_TYPE_Q6_K:
mul_mat_vec_q6_K_q8_1_cuda(src0_ddq_i, src1_q8_1, dst_ddf_i, ne00, nrows, cudaStream_main);
break;
default: default:
GGML_ASSERT(false); GGML_ASSERT(false);
break; break;
@ -2542,15 +2976,22 @@ inline void ggml_cuda_op_rope(
const int n_dims = ((int32_t *) src1->data)[1]; const int n_dims = ((int32_t *) src1->data)[1];
const int mode = ((int32_t *) src1->data)[2]; const int mode = ((int32_t *) src1->data)[2];
const int n_ctx = ((int32_t *) src1->data)[3]; const int n_ctx = ((int32_t *) src1->data)[3];
GGML_ASSERT(mode == 0);
const float theta_scale = get_theta_scale(n_dims,n_past,n_ctx); const float theta_scale = get_theta_scale(n_dims,n_past,n_ctx);
const float p0 = ((mode & 1) == 0 ? n_past + i02 : i02); const float p0 = ((mode & 1) == 0 ? n_past + i02 : i02);
const float p = get_ntk_rope_scale_mode()?p0:(n_ctx <= GGML_TRAINING_CTX ? p0 : p0 * GGML_TRAINING_CTX / n_ctx); const float p = get_ntk_rope_scale_mode()?p0:(n_ctx <= GGML_TRAINING_CTX ? p0 : p0 * GGML_TRAINING_CTX / n_ctx);
bool is_glm = mode & 4;
// compute // compute
if (is_glm) {
const float id_p = min(p, n_ctx - 2.f);
const float block_p = max(p - (n_ctx - 2.f), 0.f);
rope_glm_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, id_p, block_p, theta_scale, cudaStream_main);
} else {
rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main); rope_f32_cuda(src0_ddf_i, dst_ddf_i, ne00, i01_diff, p, theta_scale, cudaStream_main);
}
(void) dst; (void) dst;
(void) src0_ddq_i; (void) src0_ddq_i;
@ -2953,11 +3394,21 @@ void ggml_cuda_mul(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tens
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten ggml_cuda_op(src0, src1, dst, ggml_cuda_op_mul, true, false); // TODO ggml_cuda_op needs modification for flatten
} }
void ggml_cuda_gelu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_gelu, true, true);
}
void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { void ggml_cuda_silu(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true); ggml_cuda_op(src0, src1, dst, ggml_cuda_op_silu, true, true);
} }
void ggml_cuda_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_norm, true, true);
}
void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { void ggml_cuda_rms_norm(const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32);
ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true); ggml_cuda_op(src0, src1, dst, ggml_cuda_op_rms_norm, true, true);
@ -3188,7 +3639,7 @@ void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor) {
} }
cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice); CUDA_CHECK(cudaMemcpy(buf, buf_host, size, cudaMemcpyHostToDevice));
extra->data_device[id] = buf; extra->data_device[id] = buf;
@ -3222,6 +3673,22 @@ void ggml_cuda_free_data(struct ggml_tensor * tensor) {
delete extra; delete extra;
} }
static struct ggml_tensor_extra_gpu * g_temp_tensor_extras = nullptr;
static size_t g_temp_tensor_extra_index = 0;
static struct ggml_tensor_extra_gpu * ggml_cuda_alloc_temp_tensor_extra() {
if (g_temp_tensor_extras == nullptr) {
g_temp_tensor_extras = new ggml_tensor_extra_gpu[GGML_MAX_NODES];
}
size_t alloc_index = g_temp_tensor_extra_index;
g_temp_tensor_extra_index = (g_temp_tensor_extra_index + 1) % GGML_MAX_NODES;
struct ggml_tensor_extra_gpu * extra = &g_temp_tensor_extras[alloc_index];
memset(extra, 0, sizeof(*extra));
return extra;
}
void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) { void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bool force_inplace) {
if (scratch && g_scratch_size == 0) { if (scratch && g_scratch_size == 0) {
return; return;
@ -3239,8 +3706,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
} }
tensor->backend = GGML_BACKEND_GPU; tensor->backend = GGML_BACKEND_GPU;
struct ggml_tensor_extra_gpu * extra = new ggml_tensor_extra_gpu; struct ggml_tensor_extra_gpu * extra;
memset(extra, 0, sizeof(*extra));
const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) || const bool inplace = (tensor->src[0] != nullptr && tensor->src[0]->data == tensor->data) ||
tensor->op == GGML_OP_VIEW || tensor->op == GGML_OP_VIEW ||
@ -3255,10 +3721,12 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
if (tensor->op == GGML_OP_VIEW) { if (tensor->op == GGML_OP_VIEW) {
memcpy(&offset, tensor->src[2]->data, sizeof(size_t)); memcpy(&offset, tensor->src[2]->data, sizeof(size_t));
} }
extra = ggml_cuda_alloc_temp_tensor_extra();
extra->data_device[g_main_device] = src0_ddc + offset; extra->data_device[g_main_device] = src0_ddc + offset;
} else if (tensor->op == GGML_OP_CPY) { } else if (tensor->op == GGML_OP_CPY) {
struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra; struct ggml_tensor_extra_gpu * src1_extra = (ggml_tensor_extra_gpu * ) tensor->src[1]->extra;
void * src1_ddv = src1_extra->data_device[g_main_device]; void * src1_ddv = src1_extra->data_device[g_main_device];
extra = ggml_cuda_alloc_temp_tensor_extra();
extra->data_device[g_main_device] = src1_ddv; extra->data_device[g_main_device] = src1_ddv;
} else if (scratch) { } else if (scratch) {
GGML_ASSERT(size <= g_scratch_size); GGML_ASSERT(size <= g_scratch_size);
@ -3271,6 +3739,7 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
CUDA_CHECK(cudaMalloc(&data, g_scratch_size)); CUDA_CHECK(cudaMalloc(&data, g_scratch_size));
g_scratch_buffer = data; g_scratch_buffer = data;
} }
extra = ggml_cuda_alloc_temp_tensor_extra();
extra->data_device[g_main_device] = data + g_scratch_offset; extra->data_device[g_main_device] = data + g_scratch_offset;
g_scratch_offset += size; g_scratch_offset += size;
@ -3280,6 +3749,8 @@ void ggml_cuda_assign_buffers_impl(struct ggml_tensor * tensor, bool scratch, bo
void * data; void * data;
CUDA_CHECK(cudaMalloc(&data, size)); CUDA_CHECK(cudaMalloc(&data, size));
CUDA_CHECK(cudaMemset(data, 0, size)); CUDA_CHECK(cudaMemset(data, 0, size));
extra = new ggml_tensor_extra_gpu;
memset(extra, 0, sizeof(*extra));
extra->data_device[g_main_device] = data; extra->data_device[g_main_device] = data;
} }
@ -3344,12 +3815,24 @@ bool ggml_cuda_compute_forward(struct ggml_compute_params * params, struct ggml_
} }
func = ggml_cuda_mul; func = ggml_cuda_mul;
break; break;
case GGML_OP_GELU:
if (!any_on_device) {
return false;
}
func = ggml_cuda_gelu;
break;
case GGML_OP_SILU: case GGML_OP_SILU:
if (!any_on_device) { if (!any_on_device) {
return false; return false;
} }
func = ggml_cuda_silu; func = ggml_cuda_silu;
break; break;
case GGML_OP_NORM:
if (!any_on_device) {
return false;
}
func = ggml_cuda_norm;
break;
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
if (!any_on_device) { if (!any_on_device) {
return false; return false;

View file

@ -740,8 +740,7 @@ void ggml_metal_graph_compute(
[encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14]; [encoder setBytes:&ne1 length:sizeof(ne1) atIndex:14];
if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) { if (src0t == GGML_TYPE_Q4_0 || src0t == GGML_TYPE_Q4_1) {
[encoder setThreadgroupMemoryLength:nth0*nth1*sizeof(float) atIndex:0]; [encoder dispatchThreadgroups:MTLSizeMake((ne01 + 7) / 8, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne11, 1) threadsPerThreadgroup:MTLSizeMake(nth0, nth1, 1)];
} }
else if (src0t == GGML_TYPE_Q2_K || else if (src0t == GGML_TYPE_Q2_K ||
src0t == GGML_TYPE_Q3_K || src0t == GGML_TYPE_Q3_K ||

View file

@ -365,6 +365,10 @@ kernel void kernel_rms_norm(
} }
} }
// putting them in the kernel cause a significant performance penalty
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 32 // assuming SIMD group size is 32
kernel void kernel_mul_mat_q4_0_f32( kernel void kernel_mul_mat_q4_0_f32(
device const void * src0, device const void * src0,
device const float * src1, device const float * src1,
@ -372,64 +376,83 @@ kernel void kernel_mul_mat_q4_0_f32(
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne10, constant int64_t & ne10,
constant int64_t & ne0, constant int64_t & ne0,
threadgroup float * sum [[threadgroup(0)]], constant int64_t & ne01[[buffer(4)]],
uint2 tgpig[[threadgroup_position_in_grid]], uint2 tgpig[[threadgroup_position_in_grid]],
uint2 tpitg[[thread_position_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint2 tptg[[threads_per_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK4_0; const int nb = ne00/QK4_0;
const int r0 = tgpig.x;
const int64_t r0 = tgpig.x; const int r1 = tgpig.y;
const int64_t r1 = tgpig.y; device const block_q4_0 * x = (device const block_q4_0 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
device const block_q4_0 * x = (device const block_q4_0 *) src0 + r0*nb;
device const float * y = (device const float *) src1 + r1*ne10; device const float * y = (device const float *) src1 + r1*ne10;
block_q4_0 qb_curr, qb_next;
float4 y_curr[8]; // src1 vector cache
float sumf[N_DST]={0.f}, all_sum;
thread float * yl=(thread float *)y_curr;
const int nth = tptg.x*tptg.y; // bootstrap
const int ith = tptg.y*tpitg.x + tpitg.y; qb_curr = x[tiisg];
// each thread in a SIMD group deals with 1 block.
for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
const int ix = tpitg.y/4; // 0 or 1 float sumy = 0;
const int iy = tpitg.y - 4*ix; // 0...3 for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
sumy *= (-8.f);
const int first = 4 * iy; for (int row = 0; row < N_DST; row++) {
// prefetch next x block
float sumf = 0; qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
const float d = (float)x[i].d;
device const uint8_t * xl = x[i].qs + first;
device const float * yl = y + i * QK4_0 + first;
float2 acc = {0.0f, 0.0f};
for (int j = 0; j < 4; ++j) {
acc[0] += yl[j] * (xl[j] & 0xF) + yl[j+16] * (xl[j] >> 4);
acc[1] += yl[j] + yl[j+16];
// calculate
float d = qb_curr.d;
float acc = sumy;
for (int i = 0; i < 16; i++) {
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
}
sumf[row] += d * acc;
qb_curr = qb_next;
}
} }
sumf += d * (acc[0] - 8.f*acc[1]); if (nb % N_SIMDWIDTH == 0) {
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
} }
}
} else {
sum[ith] = sumf; float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
}
sumy *= (-8.f);
// for (int row = 0; row < N_DST; row++) {
// Accumulate the sum from all threads in the threadgroup // prefetch next x block
// qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith%4 == 0) { // calculate
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3]; float d = qb_curr.d;
float acc = sumy;
for (int i = 0; i < 16; i++) {
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
}
if (tiisg < nb % N_SIMDWIDTH) {
sumf[row] += d * acc;
}
qb_curr = qb_next;
all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith%16 == 0) {
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith == 0) {
for (int i = 16; i < nth; i += 16) sum[0] += sum[i];
dst[r1*ne0 + r0] = sum[0];
} }
} }
@ -440,65 +463,83 @@ kernel void kernel_mul_mat_q4_1_f32(
constant int64_t & ne00, constant int64_t & ne00,
constant int64_t & ne10, constant int64_t & ne10,
constant int64_t & ne0, constant int64_t & ne0,
threadgroup float * sum [[threadgroup(0)]], constant int64_t & ne01[[buffer(4)]],
uint2 tgpig[[threadgroup_position_in_grid]], uint2 tgpig[[threadgroup_position_in_grid]],
uint2 tpitg[[thread_position_in_threadgroup]], uint tiisg[[thread_index_in_simdgroup]],
uint2 tptg[[threads_per_threadgroup]]) { uint sgitg[[simdgroup_index_in_threadgroup]]) {
const int nb = ne00/QK4_1; const int nb = ne00/QK4_0;
const int r0 = tgpig.x;
const int64_t r0 = tgpig.x; const int r1 = tgpig.y;
const int64_t r1 = tgpig.y; device const block_q4_1 * x = (device const block_q4_1 *) src0 + (r0 * N_SIMDGROUP + sgitg) * N_DST * nb;
device const block_q4_1 * x = (device const block_q4_1 *) src0 + r0*nb;
device const float * y = (device const float *) src1 + r1*ne10; device const float * y = (device const float *) src1 + r1*ne10;
block_q4_1 qb_curr, qb_next;
float4 y_curr[8]; // src1 vector cache
float sumf[N_DST]={0.f}, all_sum;
thread float * yl=(thread float *)y_curr;
const uint nth = tptg.x*tptg.y; // bootstrap
const uint ith = tptg.y*tpitg.x + tpitg.y; qb_curr = x[tiisg];
// each thread in a SIMD group deals with 1 block.
const int ix = tpitg.y/4; // 0 or 1 for (int column = 0; column < nb / N_SIMDWIDTH; column++) {
const int iy = tpitg.y - 4*ix; // 0...3
const int first = 4 * iy;
float sumf = 0;
for (int i = 2*tpitg.x + ix; i < nb; i += 2*tptg.x) {
const float d = (float)x[i].d;
const float m = (float)x[i].m;
device const uint8_t * xl = x[i].qs + first;
device const float * yl = y + i * QK4_1 + first;
float2 acc = {0.0f, 0.0f};
for (int j = 0; j < 4; ++j) {
acc[0] += yl[j+ 0] * (d * (xl[j] & 0xF) + m);
acc[1] += yl[j+16] * (d * (xl[j] >> 4) + m);
float sumy = 0;
for (int i = 0; i < QK4_0 / 4; i++) {
y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + column * QK4_0) + 4 * i));
sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
} }
sumf += acc[0] + acc[1]; for (int row = 0; row < N_DST; row++) {
// prefetch next x block
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (column + ((row + 1) / N_DST)) * N_SIMDWIDTH];
// calculate
const float d = qb_curr.d;
const float m = qb_curr.m;
float acc = 0.f;
for (int i = 0; i < 16; i++) {
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
}
sumf[row] += d * acc + m * sumy;
qb_curr = qb_next;
}
} }
sum[ith] = sumf; if (nb % N_SIMDWIDTH == 0) {
for (int row = 0; row < N_DST; ++row) {
all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
}
}
} else {
// float sumy = 0;
// Accumulate the sum from all threads in the threadgroup for (int i = 0; i < QK4_0 / 4; i++) {
// y_curr[i] = *((device float4 *)(y + N_SIMDWIDTH * (tiisg + (nb / N_SIMDWIDTH) * QK4_0) + 4 * i));
threadgroup_barrier(mem_flags::mem_threadgroup); sumy += y_curr[i][0] + y_curr[i][1] + y_curr[i][2] + y_curr[i][3];
if (ith%4 == 0) { }
sum[ith] += sum[ith+1] + sum[ith+2] + sum[ith+3];
for (int row = 0; row < N_DST; row++) {
// prefetch next x block
qb_next = x[tiisg + ((row + 1) % N_DST) * nb + (nb / N_SIMDWIDTH + ((row + 1) / N_DST)) * N_SIMDWIDTH];
// calculate
const float d = qb_curr.d;
const float m = qb_curr.m;
float acc = 0.f;
for (int i = 0; i < 16; i++) {
acc += yl[i] * (qb_curr.qs[i] & 0xF) + yl[i+16] * (qb_curr.qs[i] >> 4);
}
if (tiisg < nb % N_SIMDWIDTH) {
sumf[row] += d * acc + m * sumy;
}
qb_curr = qb_next;
all_sum = simd_sum(sumf[row]);
if (tiisg == 0 && ((r0 * N_SIMDGROUP + sgitg) * N_DST + row) < ne01) {
dst[r1*ne0 + (r0 * N_SIMDGROUP + sgitg) * N_DST + row] = all_sum;
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith%16 == 0) {
sum[ith] += sum[ith+4] + sum[ith+8] + sum[ith+12];
} }
threadgroup_barrier(mem_flags::mem_threadgroup);
if (ith == 0) {
for (uint i = 16; i < nth; i += 16) sum[0] += sum[i];
dst[r1*ne0 + r0] = sum[0];
} }
} }

565
ggml.c
View file

@ -25,16 +25,23 @@
#include <float.h> #include <float.h>
#include <limits.h> #include <limits.h>
#include <stdarg.h> #include <stdarg.h>
#include <signal.h>
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
#include <unistd.h> #include <unistd.h>
#endif #endif
// static_assert should be a #define, but if it's not,
// fall back to the _Static_assert C11 keyword.
// if C99 - static_assert is noop // if C99 - static_assert is noop
// ref: https://stackoverflow.com/a/53923785/4039976 // ref: https://stackoverflow.com/a/53923785/4039976
#ifndef static_assert #ifndef static_assert
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
#define static_assert(cond, msg) _Static_assert(cond, msg)
#else
#define static_assert(cond, msg) struct global_scope_noop_trick #define static_assert(cond, msg) struct global_scope_noop_trick
#endif #endif
#endif
#if defined(_MSC_VER) #if defined(_MSC_VER)
// disable "possible loss of data" to avoid hundreds of casts // disable "possible loss of data" to avoid hundreds of casts
@ -49,23 +56,23 @@
typedef volatile LONG atomic_int; typedef volatile LONG atomic_int;
typedef atomic_int atomic_bool; typedef atomic_int atomic_bool;
static void atomic_store(atomic_int* ptr, LONG val) { static void atomic_store(atomic_int * ptr, LONG val) {
InterlockedExchange(ptr, val); InterlockedExchange(ptr, val);
} }
static LONG atomic_load(atomic_int* ptr) { static LONG atomic_load(atomic_int * ptr) {
return InterlockedCompareExchange(ptr, 0, 0); return InterlockedCompareExchange(ptr, 0, 0);
} }
static LONG atomic_fetch_add(atomic_int* ptr, LONG inc) { static LONG atomic_fetch_add(atomic_int * ptr, LONG inc) {
return InterlockedExchangeAdd(ptr, inc); return InterlockedExchangeAdd(ptr, inc);
} }
static LONG atomic_fetch_sub(atomic_int* ptr, LONG dec) { static LONG atomic_fetch_sub(atomic_int * ptr, LONG dec) {
return atomic_fetch_add(ptr, -(dec)); return atomic_fetch_add(ptr, -(dec));
} }
typedef HANDLE pthread_t; typedef HANDLE pthread_t;
typedef DWORD thread_ret_t; typedef DWORD thread_ret_t;
static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void*), void* arg) { static int pthread_create(pthread_t * out, void * unused, thread_ret_t(*func)(void *), void * arg) {
(void) unused; (void) unused;
HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL); HANDLE handle = CreateThread(NULL, 0, (LPTHREAD_START_ROUTINE) func, arg, 0, NULL);
if (handle == NULL) if (handle == NULL)
@ -77,7 +84,7 @@ static int pthread_create(pthread_t* out, void* unused, thread_ret_t(*func)(void
return 0; return 0;
} }
static int pthread_join(pthread_t thread, void* unused) { static int pthread_join(pthread_t thread, void * unused) {
(void) unused; (void) unused;
return (int) WaitForSingleObject(thread, INFINITE); return (int) WaitForSingleObject(thread, INFINITE);
} }
@ -90,7 +97,7 @@ static int sched_yield (void) {
#include <pthread.h> #include <pthread.h>
#include <stdatomic.h> #include <stdatomic.h>
typedef void* thread_ret_t; typedef void * thread_ret_t;
#include <sys/types.h> #include <sys/types.h>
#include <sys/stat.h> #include <sys/stat.h>
@ -111,10 +118,6 @@ typedef void* thread_ret_t;
#endif #endif
#endif #endif
#ifdef __HAIKU__
#define static_assert(cond, msg) _Static_assert(cond, msg)
#endif
/*#define GGML_PERF*/ /*#define GGML_PERF*/
#define GGML_DEBUG 0 #define GGML_DEBUG 0
#define GGML_GELU_FP16 #define GGML_GELU_FP16
@ -3787,6 +3790,8 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CLAMP", "CLAMP",
"CONV_1D", "CONV_1D",
"CONV_2D", "CONV_2D",
"POOL_1D",
"POOL_2D",
"FLASH_ATTN", "FLASH_ATTN",
"FLASH_FF", "FLASH_FF",
@ -3805,7 +3810,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CROSS_ENTROPY_LOSS_BACK", "CROSS_ENTROPY_LOSS_BACK",
}; };
static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66"); static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -3865,6 +3870,8 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"clamp(x)", "clamp(x)",
"conv_1d(x)", "conv_1d(x)",
"conv_2d(x)", "conv_2d(x)",
"pool_1d(x)",
"pool_2d(x)",
"flash_attn(x)", "flash_attn(x)",
"flash_ff(x)", "flash_ff(x)",
@ -3883,7 +3890,9 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"cross_entropy_loss_back(x,y)", "cross_entropy_loss_back(x,y)",
}; };
static_assert(GGML_OP_COUNT == 66, "GGML_OP_COUNT != 66"); static_assert(GGML_OP_COUNT == 68, "GGML_OP_COUNT != 68");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_object)%GGML_MEM_ALIGN == 0, "ggml_object size must be a multiple of GGML_MEM_ALIGN");
static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN"); static_assert(sizeof(struct ggml_tensor)%GGML_MEM_ALIGN == 0, "ggml_tensor size must be a multiple of GGML_MEM_ALIGN");
@ -4162,10 +4171,9 @@ static inline bool ggml_is_matrix(const struct ggml_tensor * tensor) {
static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static inline bool ggml_can_mul_mat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
return return (t0->ne[0] == t1->ne[0]) &&
(t0->ne[0] == t1->ne[0]) && (t1->ne[2]%t0->ne[2] == 0) && // verify t0 is broadcastable
(t0->ne[2] == t1->ne[2]) && (t1->ne[3]%t0->ne[3] == 0);
(t0->ne[3] == t1->ne[3]);
} }
static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static inline bool ggml_can_out_prod(const struct ggml_tensor * t0, const struct ggml_tensor * t1) {
@ -4753,7 +4761,7 @@ struct ggml_tensor * ggml_set_i32 (struct ggml_tensor * tensor, int32_t value) {
{ {
assert(tensor->nb[0] == sizeof(ggml_fp16_t)); assert(tensor->nb[0] == sizeof(ggml_fp16_t));
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value); ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
} }
} break; } break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
@ -4805,7 +4813,7 @@ struct ggml_tensor * ggml_set_f32(struct ggml_tensor * tensor, float value) {
{ {
assert(tensor->nb[0] == sizeof(ggml_fp16_t)); assert(tensor->nb[0] == sizeof(ggml_fp16_t));
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), value); ggml_vec_set_f16(nc, (ggml_fp16_t *)(data + i*n1), GGML_FP32_TO_FP16(value));
} }
} break; } break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
@ -5065,11 +5073,15 @@ struct ggml_tensor * ggml_add_impl(
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b, struct ggml_tensor * b,
bool inplace) { bool inplace) {
GGML_ASSERT(ggml_are_same_shape(a, b)); // TODO: support less-strict constraint
// GGML_ASSERT(ggml_can_repeat(b, a));
GGML_ASSERT(ggml_can_repeat_rows(b, a));
bool is_node = false; bool is_node = false;
if (a->grad || b->grad) { if (!inplace && (a->grad || b->grad)) {
// TODO: support backward pass for broadcasting
GGML_ASSERT(ggml_are_same_shape(a, b));
is_node = true; is_node = true;
} }
@ -6055,8 +6067,8 @@ struct ggml_tensor * ggml_mul_mat(
is_node = true; is_node = true;
} }
const int64_t ne[4] = { a->ne[1], b->ne[1], a->ne[2], b->ne[3] }; const int64_t ne[4] = { a->ne[1], b->ne[1], b->ne[2], b->ne[3] };
struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MIN(a->n_dims, b->n_dims), ne); struct ggml_tensor * result = ggml_new_tensor(ctx, GGML_TYPE_F32, MAX(a->n_dims, b->n_dims), ne);
result->op = GGML_OP_MUL_MAT; result->op = GGML_OP_MUL_MAT;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL; result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
@ -7192,7 +7204,6 @@ struct ggml_tensor* ggml_conv_2d(
int d0, int d0,
int d1) { int d1) {
GGML_ASSERT(b->ne[3] == 1);
GGML_ASSERT(a->ne[2] == b->ne[2]); GGML_ASSERT(a->ne[2] == b->ne[2]);
bool is_node = false; bool is_node = false;
@ -7204,7 +7215,7 @@ struct ggml_tensor* ggml_conv_2d(
const int64_t ne[4] = { const int64_t ne[4] = {
ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0), ggml_calc_conv_output_size(b->ne[0], a->ne[0], s0, p0, d0),
ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1), ggml_calc_conv_output_size(b->ne[1], a->ne[1], s1, p1, d1),
a->ne[3], 1, a->ne[3], b->ne[3],
}; };
struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne); struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 4, ne);
@ -7239,6 +7250,98 @@ struct ggml_tensor* ggml_conv_1d_ph(
return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d); return ggml_conv_1d(ctx, a, b, s, a->ne[0] / 2, d);
} }
// ggml_pool_*
static int64_t ggml_calc_pool_output_size(int64_t ins, int ks, int s, int p) {
return (ins + 2 * p - ks) / s + 1;
}
// ggml_pool_2d
struct ggml_tensor* ggml_pool_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_op_pool op,
int k0,
int s0,
int p0) {
bool is_node = false;
if (a->grad) {
GGML_ASSERT(false); // TODO: implement backward
is_node = true;
}
const int64_t ne[3] = {
ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
a->ne[1],
};
struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 2, ne);
ggml_scratch_save(ctx);
struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 4);
((int32_t*)c->data)[0] = op;
((int32_t*)c->data)[1] = k0;
((int32_t*)c->data)[2] = s0;
((int32_t*)c->data)[3] = p0;
ggml_scratch_load(ctx);
result->op = GGML_OP_POOL_1D;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = c;
return result;
}
// ggml_pool_2d
struct ggml_tensor* ggml_pool_2d(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_op_pool op,
int k0,
int k1,
int s0,
int s1,
int p0,
int p1) {
bool is_node = false;
if (a->grad) {
GGML_ASSERT(false); // TODO: implement backward
is_node = true;
}
const int64_t ne[3] = {
ggml_calc_pool_output_size(a->ne[0], k0, s0, p0),
ggml_calc_pool_output_size(a->ne[1], k1, s1, p1),
a->ne[2],
};
struct ggml_tensor* result = ggml_new_tensor(ctx, GGML_TYPE_F32, 3, ne);
ggml_scratch_save(ctx);
struct ggml_tensor* c = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, 7);
((int32_t*)c->data)[0] = op;
((int32_t*)c->data)[1] = k0;
((int32_t*)c->data)[2] = k1;
((int32_t*)c->data)[3] = s0;
((int32_t*)c->data)[4] = s1;
((int32_t*)c->data)[5] = p0;
((int32_t*)c->data)[6] = p1;
ggml_scratch_load(ctx);
result->op = GGML_OP_POOL_2D;
result->grad = is_node ? ggml_dup_tensor(ctx, result) : NULL;
result->src[0] = a;
result->src[1] = c;
return result;
}
// ggml_flash_attn // ggml_flash_attn
struct ggml_tensor * ggml_flash_attn( struct ggml_tensor * ggml_flash_attn(
@ -8327,7 +8430,7 @@ static void ggml_compute_forward_add_f32(
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
GGML_ASSERT(ggml_are_same_shape(src0, src1) && ggml_are_same_shape(src0, dst)); GGML_ASSERT(ggml_can_repeat_rows(src1, src0) && ggml_are_same_shape(src0, dst));
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) { if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return; return;
@ -8352,23 +8455,23 @@ static void ggml_compute_forward_add_f32(
if (nb10 == sizeof(float)) { if (nb10 == sizeof(float)) {
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
// src0, src1 and dst are same shape => same indices // src1 is broadcastable across src0 and dst in i1, i2, i3
const int i3 = ir/(ne2*ne1); const int64_t i03 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne2*ne1)/ne1; const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1); const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
#ifdef GGML_USE_ACCELERATE #ifdef GGML_USE_ACCELERATE
vDSP_vadd( vDSP_vadd(src0_ptr, 1, src1_ptr, 1, dst_ptr, 1, ne00);
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01), 1,
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11), 1,
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ), 1,
ne0);
#else #else
ggml_vec_add_f32(ne0, ggml_vec_add_f32(ne00, dst_ptr, src0_ptr, src1_ptr);
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 ),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01),
(float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11));
#endif #endif
// } // }
// } // }
@ -8376,15 +8479,20 @@ static void ggml_compute_forward_add_f32(
} else { } else {
// src1 is not contiguous // src1 is not contiguous
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
// src0, src1 and dst are same shape => same indices // src1 is broadcastable across src0 and dst in i1, i2, i3
const int i3 = ir/(ne2*ne1); const int64_t i03 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne2*ne1)/ne1; const int64_t i02 = (ir - i03*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne2*ne1 - i2*ne1); const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01);
const int64_t i13 = i03 % ne13;
const int64_t i12 = i02 % ne12;
const int64_t i11 = i01 % ne11;
float * dst_ptr = (float *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
float * dst_ptr = (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1 );
float * src0_ptr = (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01);
for (int i0 = 0; i0 < ne0; i0++) { for (int i0 = 0; i0 < ne0; i0++) {
float * src1_ptr = (float *) ((char *) src1->data + i3*nb13 + i2*nb12 + i1*nb11 + i0*nb10); float * src1_ptr = (float *) ((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11 + i0*nb10);
dst_ptr[i0] = src0_ptr[i0] + *src1_ptr; dst_ptr[i0] = src0_ptr[i0] + *src1_ptr;
} }
@ -10563,7 +10671,6 @@ static void ggml_compute_forward_rms_norm_back(
} }
} }
// ggml_compute_forward_mul_mat // ggml_compute_forward_mul_mat
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
@ -10607,17 +10714,19 @@ static void ggml_compute_forward_mul_mat(
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
const enum ggml_type type = src0->type; const enum ggml_type type = src0->type;
const bool src1_cont = ggml_is_contiguous(src1);
ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot; ggml_vec_dot_t const vec_dot = type_traits[type].vec_dot;
enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type; enum ggml_type const vec_dot_type = type_traits[type].vec_dot_type;
ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float; ggml_from_float_t const from_float_to_vec_dot = type_traits[vec_dot_type].from_float;
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne12);
GGML_ASSERT(ne3 == ne13);
// we don't support permuted src0 or src1 // we don't support permuted src0 or src1
GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]); GGML_ASSERT(nb00 == GGML_TYPE_SIZE[type]);
GGML_ASSERT(nb10 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float));
@ -10628,16 +10737,16 @@ static void ggml_compute_forward_mul_mat(
GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3); GGML_ASSERT(nb2 <= nb3);
GGML_ASSERT(ne0 == ne01);
GGML_ASSERT(ne1 == ne11);
GGML_ASSERT(ne2 == ne02);
GGML_ASSERT(ne3 == ne03);
// nb01 >= nb00 - src0 is not transposed // nb01 >= nb00 - src0 is not transposed
// compute by src0 rows // compute by src0 rows
#if defined(GGML_USE_CLBLAST) #if defined(GGML_USE_CLBLAST)
if (ggml_cl_can_mul_mat(src0, src1, dst)) { if (ggml_cl_can_mul_mat(src0, src1, dst)) {
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
// ref: https://github.com/ggerganov/ggml/pull/224
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) { if (params->ith == 0 && params->type == GGML_TASK_COMPUTE) {
ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize); ggml_cl_mul_mat(src0, src1, dst, params->wdata, params->wsize);
} }
@ -10647,6 +10756,11 @@ static void ggml_compute_forward_mul_mat(
#if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS) #if defined(GGML_USE_ACCELERATE) || defined(GGML_USE_OPENBLAS)
if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) { if (ggml_compute_forward_mul_mat_use_blas(src0, src1, dst)) {
// TODO: handle case when src0 is broadcast-able into src1 across 2nd,3rd dimension
// ref: https://github.com/ggerganov/ggml/pull/224
GGML_ASSERT(ne02 == ne12);
GGML_ASSERT(ne03 == ne13);
if (params->ith != 0) { if (params->ith != 0) {
return; return;
} }
@ -10716,41 +10830,52 @@ static void ggml_compute_forward_mul_mat(
return; return;
} }
// parallelize by src0 rows using ggml_vec_dot_q // parallelize by src0 rows
const int64_t dr = (ne01 + nth - 1)/nth;
// total rows in src0 const int64_t ir10 = dr*ith;
const int nr = ne01*ne02*ne03; const int64_t ir11 = MIN(ir10 + dr, ne01);
// rows per thread // src1 rows
const int dr = (nr + nth - 1)/nth; const int64_t nr1 = ne11*ne12*ne13;
// row range for this thread const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
const int ir0 = dr*ith; const size_t row_size = ne10*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type];
const int ir1 = MIN(ir0 + dr, nr);
void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata; for (int64_t ir1 = 0; ir1 < nr1; ++ir1) {
const size_t row_size = ne00*GGML_TYPE_SIZE[vec_dot_type]/GGML_BLCK_SIZE[vec_dot_type]; const int64_t i13 = (ir1/(ne12*ne11));
const int64_t i12 = (ir1 - i13*ne12*ne11)/ne11;
const int64_t i11 = (ir1 - i13*ne12*ne11 - i12*ne11);
for (int ir = ir0; ir < ir1; ++ir) { const int64_t ir0 = (ir1/ne11)%(ne02*ne03);
// src0 indices const int64_t i03 = (ir0/(ne02));
const int i03 = ir/(ne02*ne01); // Hack for "Falcon multi-query-attention key stutter" / alternative to ggml_repeat2.
const int i02 = (ir - i03*ne02*ne01)/ne01; // See https://github.com/ggerganov/llama.cpp/issues/1602#issuecomment-1606087470:
const int i01 = (ir - i03*ne02*ne01 - i02*ne01); // GG: this is likely the correct way to broadcast, though need some more thought
// therefore leaving the comments to remind us for now
const int64_t i02 = (i12 / (ne12 / ne02));
// Original from PR/224 (and also essential/correct for non-broadcast matmuls in Falcon)
// const int64_t i02 = (ir0 - i03*ne02);
const int i13 = i03; const int64_t i1 = i11;
const int i12 = i02; const int64_t i2 = i12;
const int64_t i3 = i13;
const int i0 = i01; const char * src0_row = (const char *) src0->data + ( 0 + i02*nb02 + i03*nb03 );
const int i2 = i02;
const int i3 = i03;
void * src0_row = (void *) ((char *) src0->data + (i01*nb01 + i02*nb02 + i03*nb03)); // desc: when src1 is not a contiguous memory block we have to calculate the offset using the strides
char * src1_col = ((char *) wdata + ( (0 + i12*ne11 + i13*ne12*ne11)*row_size)); // if it is, then we have either copied the data to params->wdata and made it contiguous or we are using
// the original src1 data pointer, so we should index using the indices directly
// TODO: this is a bit of a hack, we should probably have a better way to handle this
const char * src1_col = (const char *) wdata +
(src1_cont || src1->type != vec_dot_type
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
: (i11*nb11 + i12*nb12 + i13*nb13));
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3)); float * dst_col = (float *) ((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3));
for (int64_t ic = 0; ic < ne11; ++ic) { for (int64_t ir = ir10; ir < ir11; ++ir) {
vec_dot(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size)); vec_dot(ne00, &dst_col[ir], src0_row + ir*nb01, src1_col);
} }
} }
@ -11747,7 +11872,7 @@ static void ggml_compute_forward_alibi_f32(
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past const int ne1 = src0->ne[1]; // seq_len_without_past
//const int ne2 = src0->ne[2]; // n_head -> this is k const int ne2 = src0->ne[2]; // n_head -> this is k
//const int ne3 = src0->ne[3]; // 1 -> bsz //const int ne3 = src0->ne[3]; // 1 -> bsz
const int n = ggml_nrows(src0); const int n = ggml_nrows(src0);
@ -11758,8 +11883,9 @@ static void ggml_compute_forward_alibi_f32(
const int nb2 = src0->nb[2]; const int nb2 = src0->nb[2];
//const int nb3 = src0->nb[3]; //const int nb3 = src0->nb[3];
assert(nb0 == sizeof(float)); GGML_ASSERT(nb0 == sizeof(float));
assert(ne1 + n_past == ne0); (void) n_past; GGML_ASSERT(ne1 + n_past == ne0);
GGML_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled) // add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@ -11783,7 +11909,7 @@ static void ggml_compute_forward_alibi_f32(
m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1); m_k = powf(m1, 2 * (k - n_heads_log2_floor) + 1);
} }
pdst[0] = (i-ne0+1) * m_k + src[0]; pdst[0] = i * m_k + src[0];
} }
} }
@ -11812,7 +11938,7 @@ static void ggml_compute_forward_alibi_f16(
const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1 const int ne0 = src0->ne[0]; // all_seq_len = n_past + ne1
const int ne1 = src0->ne[1]; // seq_len_without_past const int ne1 = src0->ne[1]; // seq_len_without_past
//const int ne2 = src0->ne[2]; // n_head -> this is k const int ne2 = src0->ne[2]; // n_head -> this is k
//const int ne3 = src0->ne[3]; // 1 -> bsz //const int ne3 = src0->ne[3]; // 1 -> bsz
const int n = ggml_nrows(src0); const int n = ggml_nrows(src0);
@ -11823,8 +11949,9 @@ static void ggml_compute_forward_alibi_f16(
const int nb2 = src0->nb[2]; const int nb2 = src0->nb[2];
//const int nb3 = src0->nb[3]; //const int nb3 = src0->nb[3];
assert(nb0 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb0 == sizeof(ggml_fp16_t));
assert(ne1 + n_past == ne0); (void) n_past; GGML_ASSERT(ne1 + n_past == ne0); (void) n_past;
GGML_ASSERT(n_head == ne2);
// add alibi to src0 (KQ_scaled) // add alibi to src0 (KQ_scaled)
const int n_heads_log2_floor = 1 << (int) floor(log2(n_head)); const int n_heads_log2_floor = 1 << (int) floor(log2(n_head));
@ -11849,7 +11976,7 @@ static void ggml_compute_forward_alibi_f16(
} }
// we return F32 // we return F32
pdst[0] = (i-ne0+1) * m_k + GGML_FP16_TO_FP32(src[0]); pdst[0] = i * m_k + GGML_FP16_TO_FP32(src[0]);
} }
} }
} }
@ -12911,12 +13038,13 @@ static void ggml_compute_forward_conv_1d(
}; };
} }
// ggml_compute_forward_conv_2d_sk_p0 // ggml_compute_forward_conv_2d
static void ggml_compute_forward_conv_2d_sk_p0_f16_f32( static void ggml_compute_forward_conv_2d_f16_f32(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
const struct ggml_tensor * opt0,
struct ggml_tensor * dst) { struct ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F16); GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32); GGML_ASSERT(src1->type == GGML_TYPE_F32);
@ -12936,11 +13064,17 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
// size of the convolution row - the kernel size unrolled across all channels // size of the convolution row - the kernel size unrolled across all channels
const int ew0 = nk0*nk1*ne02; const int ew0 = nk0*nk1*ne02;
const int32_t s0 = ((const int32_t*)(opt0->data))[0];
const int32_t s1 = ((const int32_t*)(opt0->data))[1];
const int32_t p0 = ((const int32_t*)(opt0->data))[2];
const int32_t p1 = ((const int32_t*)(opt0->data))[3];
const int32_t d0 = ((const int32_t*)(opt0->data))[4];
const int32_t d1 = ((const int32_t*)(opt0->data))[5];
GGML_ASSERT(nb00 == sizeof(ggml_fp16_t)); GGML_ASSERT(nb00 == sizeof(ggml_fp16_t));
GGML_ASSERT(nb10 == sizeof(float)); GGML_ASSERT(nb10 == sizeof(float));
if (params->type == GGML_TASK_INIT) { if (params->type == GGML_TASK_INIT) {
// TODO: fix this memset (wsize is overestimated)
memset(params->wdata, 0, params->wsize); memset(params->wdata, 0, params->wsize);
// prepare source data (src1) // prepare source data (src1)
@ -12955,8 +13089,13 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
for (int i0 = 0; i0 < ne0; i0++) { for (int i0 = 0; i0 < ne0; i0++) {
for (int ik1 = 0; ik1 < nk1; ik1++) { for (int ik1 = 0; ik1 < nk1; ik1++) {
for (int ik0 = 0; ik0 < nk0; ik0++) { for (int ik0 = 0; ik0 < nk0; ik0++) {
const int idx0 = i0*s0 + ik0*d0 - p0;
const int idx1 = i1*s1 + ik1*d1 - p1;
if (!(idx1 < 0 || idx1 >= ne11 || idx0 < 0 || idx0 >= ne10)) {
dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] = dst_data[(i1*ne0 + i0)*ew0 + i12*(nk0*nk1) + ik1*nk0 + ik0] =
GGML_FP32_TO_FP16(src[(i1*nk1 + ik1)*ne10 + (i0*nk0 + ik0)]); GGML_FP32_TO_FP16(src[idx1*ne10 + idx0]);
}
} }
} }
} }
@ -12983,32 +13122,36 @@ static void ggml_compute_forward_conv_2d_sk_p0_f16_f32(
ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0; ggml_fp16_t * const wdata = (ggml_fp16_t *) params->wdata + 0;
for (int i3 = 0; i3 < ne3; i3++) {
for (int i2 = ip0; i2 < ip1; i2++) { for (int i2 = ip0; i2 < ip1; i2++) {
float * dst_data = (float *)((char *) dst->data + i2*nb2); float * dst_data = (float *)((char *) dst->data + i3*nb3 + i2*nb2);
for (int i1 = 0; i1 < ne1; ++i1) { for (int i1 = 0; i1 < ne1; ++i1) {
for (int i0 = 0; i0 < ne0; ++i0) { for (int i0 = 0; i0 < ne0; ++i0) {
ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0, ggml_vec_dot_f16(ew0, dst_data + i1*ne0 + i0,
(ggml_fp16_t *) ((char *) src0->data + i2*nb03), (ggml_fp16_t *) ((char *) src0->data + i2*nb03),
(ggml_fp16_t *) wdata + (i1*ne0 + i0)*ew0); (ggml_fp16_t *) wdata + i3*nb3 + (i1*ne0 + i0)*ew0);
}
} }
} }
} }
} }
static void ggml_compute_forward_conv_2d_sk_p0( static void ggml_compute_forward_conv_2d(
const struct ggml_compute_params * params, const struct ggml_compute_params * params,
const struct ggml_tensor * src0, const struct ggml_tensor * src0,
const struct ggml_tensor * src1, const struct ggml_tensor * src1,
struct ggml_tensor * dst) { const struct ggml_tensor * opt0,
struct ggml_tensor * dst
) {
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
ggml_compute_forward_conv_2d_sk_p0_f16_f32(params, src0, src1, dst); ggml_compute_forward_conv_2d_f16_f32(params, src0, src1, opt0, dst);
} break; } break;
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
//ggml_compute_forward_conv_2d_sk_p0_f32(params, src0, src1, dst); //ggml_compute_forward_conv_2d_f32(params, src0, src1, opt0, dst);
GGML_ASSERT(false); GGML_ASSERT(false);
} break; } break;
default: default:
@ -13018,31 +13161,164 @@ static void ggml_compute_forward_conv_2d_sk_p0(
} }
} }
// ggml_compute_forward_conv_2d // ggml_compute_forward_pool_1d_sk_p0
static void ggml_compute_forward_conv_2d( static void ggml_compute_forward_pool_1d_sk_p0(
const struct ggml_compute_params * params,
const enum ggml_op_pool op,
const struct ggml_tensor * src,
const int k,
struct ggml_tensor * dst) {
assert(src->type == GGML_TYPE_F32);
assert(params->ith == 0);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
}
const char * cdata = (const char *)src->data;
const char * const data_end = cdata + ggml_nbytes(src);
float * drow = (float *)dst->data;
const int64_t rs = dst->ne[0];
while (cdata < data_end) {
const float * const srow = (const float *)cdata;
int j = 0;
for (int64_t i = 0; i < rs; ++i) {
switch (op) {
case GGML_OP_POOL_AVG: drow[i] = 0; break;
case GGML_OP_POOL_MAX: drow[i] = -FLT_MAX; break;
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
}
for (int ki = 0; ki < k; ++ki) {
switch (op) {
case GGML_OP_POOL_AVG: drow[i] += srow[j]; break;
case GGML_OP_POOL_MAX: if (srow[j] > drow[i]) drow[i] = srow[j]; break;
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
}
++j;
}
switch (op) {
case GGML_OP_POOL_AVG: drow[i] /= k; break;
case GGML_OP_POOL_MAX: break;
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
}
}
cdata += src->nb[1];
drow += rs;
}
}
// ggml_compute_forward_pool_1d
static void ggml_compute_forward_pool_1d(
const struct ggml_compute_params* params, const struct ggml_compute_params* params,
const struct ggml_tensor* src0, const struct ggml_tensor* src0,
const struct ggml_tensor* src1,
const struct ggml_tensor* opt0, const struct ggml_tensor* opt0,
struct ggml_tensor* dst) { struct ggml_tensor* dst) {
const int32_t s0 = ((const int32_t*)(opt0->data))[0]; GGML_ASSERT(opt0->ne[0] == 4);
const int32_t s1 = ((const int32_t*)(opt0->data))[1]; const int* opts = (const int*)opt0->data;
const int32_t p0 = ((const int32_t*)(opt0->data))[2]; enum ggml_op_pool op = opts[0];
const int32_t p1 = ((const int32_t*)(opt0->data))[3]; const int k0 = opts[1];
const int32_t d0 = ((const int32_t*)(opt0->data))[4]; const int s0 = opts[2];
const int32_t d1 = ((const int32_t*)(opt0->data))[5]; const int p0 = opts[3];
GGML_ASSERT(d0 == 1); // dilation not supported
GGML_ASSERT(d1 == 1);
GGML_ASSERT(p0 == 0); // padding not supported GGML_ASSERT(p0 == 0); // padding not supported
GGML_ASSERT(p1 == 0); GGML_ASSERT(k0 == s0); // only s = k supported
if (s0 == src0->ne[0] && s1 == src0->ne[1]) { ggml_compute_forward_pool_1d_sk_p0(params, op, src0, k0, dst);
ggml_compute_forward_conv_2d_sk_p0(params, src0, src1, dst); }
// ggml_compute_forward_pool_2d_sk_p0
static void ggml_compute_forward_pool_2d_sk_p0(
const struct ggml_compute_params * params,
const enum ggml_op_pool op,
const struct ggml_tensor * src,
const int k0,
const int k1,
struct ggml_tensor * dst) {
assert(src->type == GGML_TYPE_F32);
assert(params->ith == 0);
if (params->type == GGML_TASK_INIT || params->type == GGML_TASK_FINALIZE) {
return;
} }
else {
GGML_ASSERT(false); // only stride equal to kernel size is supported const char * cdata = (const char*)src->data;
}; const char * const data_end = cdata + ggml_nbytes(src);
const int64_t px = dst->ne[0];
const int64_t py = dst->ne[1];
const int64_t pa = px * py;
float * dplane = (float *)dst->data;
const int ka = k0 * k1;
while (cdata < data_end) {
for (int oy = 0; oy < py; ++oy) {
float * const drow = dplane + oy * px;
for (int ox = 0; ox < px; ++ox) {
float * const out = drow + ox;
switch (op) {
case GGML_OP_POOL_AVG: *out = 0; break;
case GGML_OP_POOL_MAX: *out = -FLT_MAX; break;
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
}
const int ix = ox * k0;
const int iy = oy * k1;
for (int ky = 0; ky < k1; ++ky) {
const float * const srow = (const float *)(cdata + src->nb[1] * (iy + ky));
for (int kx = 0; kx < k0; ++kx) {
int j = ix + kx;
switch (op) {
case GGML_OP_POOL_AVG: *out += srow[j]; break;
case GGML_OP_POOL_MAX: if (srow[j] > *out) *out = srow[j]; break;
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
}
}
}
switch (op) {
case GGML_OP_POOL_AVG: *out /= ka; break;
case GGML_OP_POOL_MAX: break;
case GGML_OP_POOL_COUNT: GGML_ASSERT(false); break;
}
}
}
cdata += src->nb[2];
dplane += pa;
}
}
// ggml_compute_forward_pool_2d
static void ggml_compute_forward_pool_2d(
const struct ggml_compute_params * params,
const struct ggml_tensor * src0,
const struct ggml_tensor * opt0,
struct ggml_tensor * dst) {
GGML_ASSERT(opt0->ne[0] == 7);
const int* opts = (const int*)opt0->data;
enum ggml_op_pool op = opts[0];
const int k0 = opts[1];
const int k1 = opts[2];
const int s0 = opts[3];
const int s1 = opts[4];
const int p0 = opts[5];
const int p1 = opts[6];
GGML_ASSERT(p0 == 0);
GGML_ASSERT(p1 == 0); // padding not supported
GGML_ASSERT(k0 == s0);
GGML_ASSERT(k1 == s1); // only s = k supported
ggml_compute_forward_pool_2d_sk_p0(params, op, src0, k0, k1, dst);
} }
@ -14826,6 +15102,14 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor); ggml_compute_forward_conv_2d(params, tensor->src[0], tensor->src[1], tensor->src[2], tensor);
} break; } break;
case GGML_OP_POOL_1D:
{
ggml_compute_forward_pool_1d(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_POOL_2D:
{
ggml_compute_forward_pool_2d(params, tensor->src[0], tensor->src[1], tensor);
} break;
case GGML_OP_FLASH_ATTN: case GGML_OP_FLASH_ATTN:
{ {
const int32_t t = ggml_get_i32_1d(tensor->src[3], 0); const int32_t t = ggml_get_i32_1d(tensor->src[3], 0);
@ -15526,6 +15810,14 @@ static void ggml_compute_backward(struct ggml_context * ctx, struct ggml_tensor
{ {
GGML_ASSERT(false); // TODO: not implemented GGML_ASSERT(false); // TODO: not implemented
} break; } break;
case GGML_OP_POOL_1D:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_POOL_2D:
{
GGML_ASSERT(false); // TODO: not implemented
} break;
case GGML_OP_FLASH_ATTN: case GGML_OP_FLASH_ATTN:
{ {
struct ggml_tensor * flash_grad = NULL; struct ggml_tensor * flash_grad = NULL;
@ -15988,6 +16280,9 @@ struct ggml_compute_state_shared {
// synchronization primitives // synchronization primitives
atomic_int n_active; // num active threads atomic_int n_active; // num active threads
atomic_int node_n; // active graph node atomic_int node_n; // active graph node
bool (*abort_callback)(void * data); // abort ggml_graph_compute when true
void * abort_callback_data;
}; };
struct ggml_compute_state { struct ggml_compute_state {
@ -16019,6 +16314,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
int node_n = -1; int node_n = -1;
while (true) { while (true) {
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
state->shared->node_n += 1;
return (thread_ret_t) GGML_EXIT_ABORTED;
}
if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) { if (atomic_fetch_sub(&state->shared->n_active, 1) == 1) {
// all other threads are finished and spinning // all other threads are finished and spinning
// do finalize and init here so we don't have synchronize again // do finalize and init here so we don't have synchronize again
@ -16072,6 +16371,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
} else { } else {
break; break;
} }
if (cplan->abort_callback && cplan->abort_callback(cplan->abort_callback_data)) {
break;
}
} }
atomic_store(&state->shared->n_active, n_threads); atomic_store(&state->shared->n_active, n_threads);
@ -16105,7 +16408,7 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
} }
} }
return 0; return GGML_EXIT_SUCCESS;
} }
struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) { struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
@ -16305,8 +16608,6 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
{ {
n_tasks = n_threads; n_tasks = n_threads;
GGML_ASSERT(node->src[1]->ne[3] == 1);
const int64_t ne00 = node->src[0]->ne[0]; // W const int64_t ne00 = node->src[0]->ne[0]; // W
const int64_t ne01 = node->src[0]->ne[1]; // H const int64_t ne01 = node->src[0]->ne[1]; // H
const int64_t ne02 = node->src[0]->ne[2]; // C const int64_t ne02 = node->src[0]->ne[2]; // C
@ -16316,17 +16617,20 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
const int64_t ne11 = node->src[1]->ne[1]; // H const int64_t ne11 = node->src[1]->ne[1]; // H
const int64_t ne12 = node->src[1]->ne[2]; // C const int64_t ne12 = node->src[1]->ne[2]; // C
const int64_t ne0 = node->ne[0];
const int64_t ne1 = node->ne[1];
const int64_t ne2 = node->ne[2];
const int64_t nk = ne00*ne01; const int64_t nk = ne00*ne01;
const int64_t ew0 = nk * ne02;
UNUSED(ne02);
UNUSED(ne03); UNUSED(ne03);
UNUSED(nk); UNUSED(ne2);
size_t cur = 0; size_t cur = 0;
if (node->src[0]->type == GGML_TYPE_F16 && if (node->src[0]->type == GGML_TYPE_F16 &&
node->src[1]->type == GGML_TYPE_F32) { node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(ggml_fp16_t)*(ne10*ne11*ne12); cur = sizeof(ggml_fp16_t)*(ne0*ne1*ew0);
} else if (node->src[0]->type == GGML_TYPE_F32 && } else if (node->src[0]->type == GGML_TYPE_F32 &&
node->src[1]->type == GGML_TYPE_F32) { node->src[1]->type == GGML_TYPE_F32) {
cur = sizeof(float)* (ne10*ne11*ne12); cur = sizeof(float)* (ne10*ne11*ne12);
@ -16336,6 +16640,11 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
work_size = MAX(work_size, cur); work_size = MAX(work_size, cur);
} break; } break;
case GGML_OP_POOL_1D:
case GGML_OP_POOL_2D:
{
n_tasks = 1;
} break;
case GGML_OP_FLASH_ATTN: case GGML_OP_FLASH_ATTN:
{ {
n_tasks = n_threads; n_tasks = n_threads;
@ -16445,7 +16754,7 @@ struct ggml_cplan ggml_graph_plan(struct ggml_cgraph * cgraph, int n_threads) {
return cplan; return cplan;
} }
void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) { int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan) {
{ {
GGML_ASSERT(cplan); GGML_ASSERT(cplan);
GGML_ASSERT(cplan->n_threads > 0); GGML_ASSERT(cplan->n_threads > 0);
@ -16471,6 +16780,8 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
/*.n_threads =*/ n_threads, /*.n_threads =*/ n_threads,
/*.n_active =*/ n_threads, /*.n_active =*/ n_threads,
/*.node_n =*/ -1, /*.node_n =*/ -1,
/*.abort_callback =*/ NULL,
/*.abort_callback_data =*/ NULL,
}; };
struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads); struct ggml_compute_state * workers = alloca(sizeof(struct ggml_compute_state)*n_threads);
@ -16494,12 +16805,12 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
const int64_t perf_start_time_us = ggml_perf_time_us(); const int64_t perf_start_time_us = ggml_perf_time_us();
// this is a work thread too // this is a work thread too
ggml_graph_compute_thread(&workers[0]); int compute_status = (size_t) ggml_graph_compute_thread(&workers[0]);
// don't leave affinity set on the main thread // don't leave affinity set on the main thread
clear_numa_thread_affinity(); clear_numa_thread_affinity();
// join thread pool // join or kill thread pool
if (n_threads > 1) { if (n_threads > 1) {
for (int j = 1; j < n_threads; j++) { for (int j = 1; j < n_threads; j++) {
const int rc = ggml_thread_join(workers[j].thrd, NULL); const int rc = ggml_thread_join(workers[j].thrd, NULL);
@ -16523,6 +16834,8 @@ void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan)
(double) perf_time_us_cur / 1000.0, (double) perf_time_us_cur / 1000.0,
(double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs); (double) cgraph->perf_time_us / 1000.0 / cgraph->perf_runs);
} }
return compute_status;
} }
void ggml_graph_reset(struct ggml_cgraph * cgraph) { void ggml_graph_reset(struct ggml_cgraph * cgraph) {

37
ggml.h
View file

@ -201,6 +201,10 @@
#define GGML_MAX_NAME 48 #define GGML_MAX_NAME 48
#define GGML_DEFAULT_N_THREADS 4 #define GGML_DEFAULT_N_THREADS 4
#define GGML_EXIT_SUCCESS 0
#define GGML_EXIT_ABORTED 1
#define GGML_UNUSED(x) (void)(x) #define GGML_UNUSED(x) (void)(x)
// Maximum training context of the model in use // Maximum training context of the model in use
@ -369,6 +373,8 @@ extern "C" {
GGML_OP_CLAMP, GGML_OP_CLAMP,
GGML_OP_CONV_1D, GGML_OP_CONV_1D,
GGML_OP_CONV_2D, GGML_OP_CONV_2D,
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
GGML_OP_FLASH_ATTN, GGML_OP_FLASH_ATTN,
GGML_OP_FLASH_FF, GGML_OP_FLASH_FF,
@ -448,6 +454,10 @@ extern "C" {
// the `n_tasks` of nodes, 1:1 mapping to cgraph nodes // the `n_tasks` of nodes, 1:1 mapping to cgraph nodes
int n_tasks[GGML_MAX_NODES]; int n_tasks[GGML_MAX_NODES];
// abort ggml_graph_compute when true
bool (*abort_callback)(void * data);
void * abort_callback_data;
}; };
// computation graph // computation graph
@ -1174,6 +1184,31 @@ extern "C" {
int s, int s,
int d); int d);
enum ggml_op_pool {
GGML_OP_POOL_MAX,
GGML_OP_POOL_AVG,
GGML_OP_POOL_COUNT,
};
GGML_API struct ggml_tensor* ggml_pool_1d(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_op_pool op,
int k0, // kernel size
int s0, // stride
int p0); // padding
GGML_API struct ggml_tensor* ggml_pool_2d(
struct ggml_context * ctx,
struct ggml_tensor * a,
enum ggml_op_pool op,
int k0,
int k1,
int s0,
int s1,
int p0,
int p1);
GGML_API struct ggml_tensor * ggml_flash_attn( GGML_API struct ggml_tensor * ggml_flash_attn(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * q, struct ggml_tensor * q,
@ -1313,7 +1348,7 @@ extern "C" {
// ggml_graph_plan() has to be called before ggml_graph_compute() // ggml_graph_plan() has to be called before ggml_graph_compute()
// when plan.work_size > 0, caller must allocate memory for plan.work_data // when plan.work_size > 0, caller must allocate memory for plan.work_data
GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/); GGML_API struct ggml_cplan ggml_graph_plan (struct ggml_cgraph * cgraph, int n_threads /*= GGML_DEFAULT_N_THREADS*/);
GGML_API void ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan); GGML_API int ggml_graph_compute(struct ggml_cgraph * cgraph, struct ggml_cplan * cplan);
GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph); GGML_API void ggml_graph_reset (struct ggml_cgraph * cgraph);
// same as ggml_graph_compute() but the work data is allocated as a part of the context // same as ggml_graph_compute() but the work data is allocated as a part of the context

View file

@ -15,6 +15,14 @@
#define K_SCALE_SIZE 12 #define K_SCALE_SIZE 12
#endif #endif
#ifndef static_assert
#if defined(__STDC_VERSION__) && (__STDC_VERSION__ >= 201100L)
#define static_assert(cond, msg) _Static_assert(cond, msg)
#else
#define static_assert(cond, msg) struct global_scope_noop_trick
#endif
#endif
// //
// Super-block quantization structures // Super-block quantization structures
// //

View file

@ -275,7 +275,7 @@ maxhordectx = 1024
maxhordelen = 256 maxhordelen = 256
modelbusy = False modelbusy = False
defaultport = 5001 defaultport = 5001
KcppVersion = "1.35" KcppVersion = "1.36"
showdebug = True showdebug = True
class ServerRequestHandler(http.server.SimpleHTTPRequestHandler): class ServerRequestHandler(http.server.SimpleHTTPRequestHandler):

View file

@ -175,13 +175,13 @@ struct llama_mmap {
llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false) { llama_mmap(struct llama_file * file, size_t prefetch = (size_t) -1 /* -1 = max value */, bool numa = false) {
size = file->size; size = file->size;
int fd = fileno(file->fp); int fd = fileno(file->fp);
int flags = MAP_PRIVATE; int flags = MAP_SHARED;
// prefetch/readahead impairs performance on NUMA systems // prefetch/readahead impairs performance on NUMA systems
if (numa) { prefetch = 0; } if (numa) { prefetch = 0; }
#ifdef __linux__ #ifdef __linux__
if (prefetch) { flags |= MAP_POPULATE; } if (prefetch) { flags |= MAP_POPULATE; }
#endif #endif
addr = mmap(NULL, file->size, PROT_READ | PROT_WRITE, flags, fd, 0); addr = mmap(NULL, file->size, PROT_READ, flags, fd, 0);
if (addr == MAP_FAILED) { if (addr == MAP_FAILED) {
throw std::runtime_error(format("mmap failed: %s", strerror(errno))); throw std::runtime_error(format("mmap failed: %s", strerror(errno)));
} }
@ -223,7 +223,7 @@ struct llama_mmap {
throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str())); throw std::runtime_error(format("CreateFileMappingA failed: %s", llama_format_win_err(error).c_str()));
} }
addr = MapViewOfFile(hMapping, FILE_MAP_COPY, 0, 0, 0); addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0);
error = GetLastError(); error = GetLastError();
CloseHandle(hMapping); CloseHandle(hMapping);

View file

@ -304,7 +304,7 @@ struct llama_model {
}; };
struct llama_context { struct llama_context {
llama_context(const llama_model & model, const llama_vocab & vocab) : model(model), vocab(vocab), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {} llama_context(const llama_model & model) : model(model), t_load_us(model.t_load_us), t_start_us(model.t_start_us) {}
#ifdef GGML_USE_METAL #ifdef GGML_USE_METAL
~llama_context() { ~llama_context() {
if (ctx_metal) { if (ctx_metal) {
@ -325,7 +325,6 @@ struct llama_context {
int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1) int32_t n_p_eval = 0; // number of tokens in eval calls for the prompt (with batch size > 1)
const llama_model & model; const llama_model & model;
const llama_vocab & vocab;
bool model_owner = false; bool model_owner = false;
@ -2699,7 +2698,7 @@ struct llama_context * llama_new_context_with_model(
return nullptr; return nullptr;
} }
llama_context * ctx = new llama_context(*model, model->vocab); llama_context * ctx = new llama_context(*model);
if (params.seed == LLAMA_DEFAULT_SEED) { if (params.seed == LLAMA_DEFAULT_SEED) {
params.seed = time(NULL); params.seed = time(NULL);
@ -3538,13 +3537,13 @@ int llama_eval_export(struct llama_context * ctx, const char * fname) {
return 0; return 0;
} }
int llama_tokenize( int llama_tokenize_with_model(
struct llama_context * ctx, const struct llama_model * model,
const char * text, const char * text,
llama_token * tokens, llama_token * tokens,
int n_max_tokens, int n_max_tokens,
bool add_bos) { bool add_bos) {
auto res = llama_tokenize(ctx->vocab, text, add_bos); auto res = llama_tokenize(model->vocab, text, add_bos);
if (n_max_tokens < (int) res.size()) { if (n_max_tokens < (int) res.size()) {
fprintf(stderr, "%s: too many tokens\n", __func__); fprintf(stderr, "%s: too many tokens\n", __func__);
@ -3558,8 +3557,29 @@ int llama_tokenize(
return res.size(); return res.size();
} }
int llama_tokenize(
struct llama_context * ctx,
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos) {
return llama_tokenize_with_model(&ctx->model, text, tokens, n_max_tokens, add_bos);
}
int llama_n_vocab_from_model(const struct llama_model * model) {
return model->vocab.id_to_token.size();
}
int llama_n_ctx_from_model(const struct llama_model * model) {
return model->hparams.n_ctx;
}
int llama_n_embd_from_model(const struct llama_model * model) {
return model->hparams.n_embd;
}
int llama_n_vocab(const struct llama_context * ctx) { int llama_n_vocab(const struct llama_context * ctx) {
return ctx->vocab.id_to_token.size(); return ctx->model.vocab.id_to_token.size();
} }
int llama_n_ctx(const struct llama_context * ctx) { int llama_n_ctx(const struct llama_context * ctx) {
@ -3570,17 +3590,25 @@ int llama_n_embd(const struct llama_context * ctx) {
return ctx->model.hparams.n_embd; return ctx->model.hparams.n_embd;
} }
int llama_get_vocab_from_model(
const struct llama_model * model,
const char * * strings,
float * scores,
int capacity) {
int n = std::min(capacity, (int) model->vocab.id_to_token.size());
for (int i = 0; i<n; ++i) {
strings[i] = model->vocab.id_to_token[i].tok.c_str();
scores[i] = model->vocab.id_to_token[i].score;
}
return n;
}
int llama_get_vocab( int llama_get_vocab(
const struct llama_context * ctx, const struct llama_context * ctx,
const char * * strings, const char * * strings,
float * scores, float * scores,
int capacity) { int capacity) {
int n = std::min(capacity, (int) ctx->vocab.id_to_token.size()); return llama_get_vocab_from_model(&ctx->model, strings, scores, capacity);
for (int i = 0; i<n; ++i) {
strings[i] = ctx->vocab.id_to_token[i].tok.c_str();
scores[i] = ctx->vocab.id_to_token[i].score;
}
return n;
} }
float * llama_get_logits(struct llama_context * ctx) { float * llama_get_logits(struct llama_context * ctx) {
@ -3591,12 +3619,16 @@ float * llama_get_embeddings(struct llama_context * ctx) {
return ctx->embedding.data(); return ctx->embedding.data();
} }
const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) { const char * llama_token_to_str_with_model(const struct llama_model * model, llama_token token) {
if (token >= llama_n_vocab(ctx)) { if (token >= llama_n_vocab_from_model(model)) {
return nullptr; return nullptr;
} }
return ctx->vocab.id_to_token[token].tok.c_str(); return model->vocab.id_to_token[token].tok.c_str();
}
const char * llama_token_to_str(const struct llama_context * ctx, llama_token token) {
return llama_token_to_str_with_model(&ctx->model, token);
} }
llama_token llama_token_bos() { llama_token llama_token_bos() {

25
llama.h
View file

@ -270,10 +270,21 @@ extern "C" {
int n_max_tokens, int n_max_tokens,
bool add_bos); bool add_bos);
LLAMA_API int llama_tokenize_with_model(
const struct llama_model * model,
const char * text,
llama_token * tokens,
int n_max_tokens,
bool add_bos);
LLAMA_API int llama_n_vocab(const struct llama_context * ctx); LLAMA_API int llama_n_vocab(const struct llama_context * ctx);
LLAMA_API int llama_n_ctx (const struct llama_context * ctx); LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
LLAMA_API int llama_n_embd (const struct llama_context * ctx); LLAMA_API int llama_n_embd (const struct llama_context * ctx);
LLAMA_API int llama_n_vocab_from_model(const struct llama_model * model);
LLAMA_API int llama_n_ctx_from_model (const struct llama_model * model);
LLAMA_API int llama_n_embd_from_model (const struct llama_model * model);
// Get the vocabulary as output parameters. // Get the vocabulary as output parameters.
// Returns number of results. // Returns number of results.
LLAMA_API int llama_get_vocab( LLAMA_API int llama_get_vocab(
@ -282,6 +293,12 @@ extern "C" {
float * scores, float * scores,
int capacity); int capacity);
LLAMA_API int llama_get_vocab_from_model(
const struct llama_model * model,
const char * * strings,
float * scores,
int capacity);
// Token logits obtained from the last call to llama_eval() // Token logits obtained from the last call to llama_eval()
// The logits for the last token are stored in the last row // The logits for the last token are stored in the last row
// Can be mutated in order to change the probabilities of the next token // Can be mutated in order to change the probabilities of the next token
@ -294,7 +311,13 @@ extern "C" {
LLAMA_API float * llama_get_embeddings(struct llama_context * ctx); LLAMA_API float * llama_get_embeddings(struct llama_context * ctx);
// Token Id -> String. Uses the vocabulary in the provided context // Token Id -> String. Uses the vocabulary in the provided context
LLAMA_API const char * llama_token_to_str(const struct llama_context * ctx, llama_token token); LLAMA_API const char * llama_token_to_str(
const struct llama_context * ctx,
llama_token token);
LLAMA_API const char * llama_token_to_str_with_model(
const struct llama_model * model,
llama_token token);
// Special tokens // Special tokens
LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence LLAMA_API llama_token llama_token_bos(); // beginning-of-sentence