Clean up QK and file and tensor types
This commit is contained in:
parent
3525899277
commit
39f91e3f6e
9 changed files with 277 additions and 305 deletions
|
@ -7,7 +7,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from numba import njit
|
from numba import njit
|
||||||
from tqdm.auto import tqdm
|
from tqdm.auto import tqdm
|
||||||
|
from ggml import *
|
||||||
|
|
||||||
def read_header(fin):
|
def read_header(fin):
|
||||||
values = struct.unpack("i" * 9, fin.read(4 * 9))
|
values = struct.unpack("i" * 9, fin.read(4 * 9))
|
||||||
|
@ -37,9 +37,8 @@ def read_tokens(fin, vocab_size):
|
||||||
|
|
||||||
@njit
|
@njit
|
||||||
def dequantize_weights_numba(fin_data, n_rows, n_cols):
|
def dequantize_weights_numba(fin_data, n_rows, n_cols):
|
||||||
qk = 32
|
qk = GGML_BLCK_SIZE[GGML_TYPE.Q4_0]
|
||||||
nb = n_cols // qk
|
nb = n_cols // qk
|
||||||
bs = 4 + (qk // 2)
|
|
||||||
|
|
||||||
weights = np.zeros((n_rows, n_cols), dtype=np.float32)
|
weights = np.zeros((n_rows, n_cols), dtype=np.float32)
|
||||||
data_pos = 0
|
data_pos = 0
|
||||||
|
@ -63,9 +62,7 @@ def dequantize_weights_numba(fin_data, n_rows, n_cols):
|
||||||
|
|
||||||
|
|
||||||
def dequantize_weights(fin, n_rows, n_cols):
|
def dequantize_weights(fin, n_rows, n_cols):
|
||||||
qk = 32
|
data_size = n_rows * n_cols // GGML_BLCK_SIZE[GGML_TYPE.Q4_0] * GGML_TYPE_SIZE[GGML_TYPE.Q4_0]
|
||||||
nb = n_cols // qk
|
|
||||||
data_size = n_rows * n_cols // 2 + n_rows * nb * 4
|
|
||||||
fin_data = fin.read(data_size)
|
fin_data = fin.read(data_size)
|
||||||
return dequantize_weights_numba(fin_data, n_rows, n_cols)
|
return dequantize_weights_numba(fin_data, n_rows, n_cols)
|
||||||
|
|
||||||
|
@ -89,16 +86,16 @@ def read_variables(fin):
|
||||||
tensor_data_offset = (tensor_data_offset + 31) & -32
|
tensor_data_offset = (tensor_data_offset + 31) & -32
|
||||||
fin.seek(tensor_data_offset)
|
fin.seek(tensor_data_offset)
|
||||||
|
|
||||||
if ftype_cur == 2:
|
if ftype_cur == GGML_FILE.Q4_0:
|
||||||
# 4-bit quantized weights
|
# 4-bit quantized weights
|
||||||
dtype = np.uint8
|
dtype = np.uint8
|
||||||
data = dequantize_weights(fin, shape[0], shape[1])
|
data = dequantize_weights(fin, shape[0], shape[1])
|
||||||
data = data.reshape(shape)
|
data = data.reshape(shape)
|
||||||
elif ftype_cur == 0:
|
elif ftype_cur == GGML_FILE.F32:
|
||||||
dtype = np.float32
|
dtype = np.float32
|
||||||
data_size = np.prod(shape)
|
data_size = np.prod(shape)
|
||||||
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
|
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
|
||||||
elif ftype_cur == 1:
|
elif ftype_cur == GGML_FILE.F16:
|
||||||
dtype = np.float16
|
dtype = np.float16
|
||||||
data_size = np.prod(shape)
|
data_size = np.prod(shape)
|
||||||
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
|
data = np.fromfile(fin, dtype=dtype, count=data_size).reshape(shape)
|
||||||
|
@ -269,6 +266,7 @@ def main():
|
||||||
|
|
||||||
fin = open(ggml_files[0], "rb")
|
fin = open(ggml_files[0], "rb")
|
||||||
hparams, ftype = read_header(fin)
|
hparams, ftype = read_header(fin)
|
||||||
|
GGML_FILE(ftype) # raise ValueError on invalid file type
|
||||||
tokens = read_tokens(fin, hparams["vocab_size"])
|
tokens = read_tokens(fin, hparams["vocab_size"])
|
||||||
model = read_variables(fin)
|
model = read_variables(fin)
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ import os
|
||||||
import struct
|
import struct
|
||||||
import sys
|
import sys
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
from ggml import *
|
||||||
|
|
||||||
HPARAMS = keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
|
HPARAMS = keys = ["vocab_size", "dim", "multiple_of", "n_heads", "n_layers"]
|
||||||
|
|
||||||
|
@ -32,6 +33,7 @@ def write_header(f_out, header):
|
||||||
|
|
||||||
if magic != 0x67676d6c:
|
if magic != 0x67676d6c:
|
||||||
raise Exception('Invalid file magic. Must be an old style ggml file.')
|
raise Exception('Invalid file magic. Must be an old style ggml file.')
|
||||||
|
GGML_FILE(ftype) # raise ValueError on invalid file type
|
||||||
|
|
||||||
values = [
|
values = [
|
||||||
0x67676d66, # magic: ggml in hex
|
0x67676d66, # magic: ggml in hex
|
||||||
|
|
|
@ -9,6 +9,7 @@ import struct
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
from ggml import *
|
||||||
|
|
||||||
if len(sys.argv) != 4:
|
if len(sys.argv) != 4:
|
||||||
print("Usage: convert-gptq-to-ggml.py llamaXXb-4bit.pt tokenizer.model out.bin\n")
|
print("Usage: convert-gptq-to-ggml.py llamaXXb-4bit.pt tokenizer.model out.bin\n")
|
||||||
|
@ -143,7 +144,7 @@ def convert_q4(src_name, dst_name, permute=False):
|
||||||
.reshape(blob.shape))
|
.reshape(blob.shape))
|
||||||
|
|
||||||
# header
|
# header
|
||||||
write_header(shape, dst_name, 3) # ftype = Q4_1
|
write_header(shape, dst_name, GGML_FILE.Q4_1)
|
||||||
|
|
||||||
# data
|
# data
|
||||||
blob.tofile(fout)
|
blob.tofile(fout)
|
||||||
|
|
|
@ -23,43 +23,7 @@ import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from sentencepiece import SentencePieceProcessor
|
from sentencepiece import SentencePieceProcessor
|
||||||
|
from ggml import *
|
||||||
QK = 32
|
|
||||||
|
|
||||||
GGML_TYPE_Q4_0 = 0
|
|
||||||
GGML_TYPE_Q4_1 = 1
|
|
||||||
GGML_TYPE_I8 = 2
|
|
||||||
GGML_TYPE_I16 = 3
|
|
||||||
GGML_TYPE_I32 = 4
|
|
||||||
GGML_TYPE_F16 = 5
|
|
||||||
GGML_TYPE_F32 = 6
|
|
||||||
|
|
||||||
WTYPES = {
|
|
||||||
0: GGML_TYPE_F32,
|
|
||||||
1: GGML_TYPE_F16,
|
|
||||||
2: GGML_TYPE_Q4_0,
|
|
||||||
3: GGML_TYPE_Q4_1,
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_BLCK_SIZE = {
|
|
||||||
GGML_TYPE_Q4_0: QK,
|
|
||||||
GGML_TYPE_Q4_1: QK,
|
|
||||||
GGML_TYPE_I8: 1,
|
|
||||||
GGML_TYPE_I16: 1,
|
|
||||||
GGML_TYPE_I32: 1,
|
|
||||||
GGML_TYPE_F16: 1,
|
|
||||||
GGML_TYPE_F32: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_TYPE_SIZE = {
|
|
||||||
GGML_TYPE_Q4_0: 4 + QK//2,
|
|
||||||
GGML_TYPE_Q4_1: 4*2 + QK//2,
|
|
||||||
GGML_TYPE_I8: 1,
|
|
||||||
GGML_TYPE_I16: 2,
|
|
||||||
GGML_TYPE_I32: 4,
|
|
||||||
GGML_TYPE_F16: 2,
|
|
||||||
GGML_TYPE_F32: 4,
|
|
||||||
}
|
|
||||||
|
|
||||||
def ggml_nelements(shape):
|
def ggml_nelements(shape):
|
||||||
r = 1
|
r = 1
|
||||||
|
@ -69,7 +33,7 @@ def ggml_nelements(shape):
|
||||||
|
|
||||||
def ggml_nbytes(shape, ftype):
|
def ggml_nbytes(shape, ftype):
|
||||||
x = ggml_nelements(shape)
|
x = ggml_nelements(shape)
|
||||||
t = WTYPES[ftype]
|
t = ggml_type_from_ftype[ftype]
|
||||||
x *= GGML_TYPE_SIZE[t]
|
x *= GGML_TYPE_SIZE[t]
|
||||||
x //= GGML_BLCK_SIZE[t]
|
x //= GGML_BLCK_SIZE[t]
|
||||||
return x
|
return x
|
||||||
|
@ -155,8 +119,8 @@ def process_and_write_variables(fout, model, ftype, part_id, n_parts):
|
||||||
print(" Converting to float32")
|
print(" Converting to float32")
|
||||||
data = data.astype(np.float32)
|
data = data.astype(np.float32)
|
||||||
ftype_cur = 0
|
ftype_cur = 0
|
||||||
blck_size = GGML_BLCK_SIZE[WTYPES[ftype_cur]]
|
blck_size = GGML_BLCK_SIZE[ggml_type_from_ftype[ftype_cur]]
|
||||||
type_size = GGML_TYPE_SIZE[WTYPES[ftype_cur]]
|
type_size = GGML_TYPE_SIZE[ggml_type_from_ftype[ftype_cur]]
|
||||||
|
|
||||||
# determine dimension along which multipart tensor is sharded
|
# determine dimension along which multipart tensor is sharded
|
||||||
#
|
#
|
||||||
|
@ -199,7 +163,7 @@ def process_and_write_variables(fout, model, ftype, part_id, n_parts):
|
||||||
|
|
||||||
# ensure tensor data is aligned
|
# ensure tensor data is aligned
|
||||||
tensor_data_offset = fout.tell()
|
tensor_data_offset = fout.tell()
|
||||||
while tensor_data_offset % QK != 0:
|
while tensor_data_offset % 32 != 0:
|
||||||
fout.write(struct.pack("B", 0))
|
fout.write(struct.pack("B", 0))
|
||||||
tensor_data_offset += 1
|
tensor_data_offset += 1
|
||||||
|
|
||||||
|
@ -234,8 +198,7 @@ def process_and_write_variables(fout, model, ftype, part_id, n_parts):
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
dir_model = args.dir_model
|
dir_model = args.dir_model
|
||||||
ftype = args.ftype
|
ftype = GGML_FILE(args.ftype)
|
||||||
ftype_str = ["f32", "f16"]
|
|
||||||
hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
|
hparams, tokenizer = load_hparams_and_tokenizer(dir_model)
|
||||||
|
|
||||||
print(args)
|
print(args)
|
||||||
|
@ -252,7 +215,7 @@ def main():
|
||||||
return
|
return
|
||||||
|
|
||||||
n_parts = get_n_parts(hparams["dim"])
|
n_parts = get_n_parts(hparams["dim"])
|
||||||
fname_out = f"{dir_model}/ggml-model-{ftype_str[ftype]}.bin"
|
fname_out = f"{dir_model}/ggml-model-{ftype.name.lower()}.bin"
|
||||||
|
|
||||||
# we output a single file for ggml
|
# we output a single file for ggml
|
||||||
with open(fname_out, "wb") as fout:
|
with open(fname_out, "wb") as fout:
|
||||||
|
|
335
ggml.c
335
ggml.c
|
@ -423,8 +423,6 @@ static const size_t CACHE_LINE_SIZE_F32 = CACHE_LINE_SIZE/sizeof(float);
|
||||||
// quantization
|
// quantization
|
||||||
//
|
//
|
||||||
|
|
||||||
#define QK 32
|
|
||||||
|
|
||||||
// AVX routines provided by GH user Const-me
|
// AVX routines provided by GH user Const-me
|
||||||
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
|
// ref: https://github.com/ggerganov/ggml/pull/27#issuecomment-1464934600
|
||||||
#if __AVX2__ || __AVX512F__
|
#if __AVX2__ || __AVX512F__
|
||||||
|
@ -499,34 +497,36 @@ static inline __m128i packNibbles( __m128i bytes1, __m128i bytes2 )
|
||||||
// method 5
|
// method 5
|
||||||
// blocks of QK elements
|
// blocks of QK elements
|
||||||
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
|
// represented with a single float (delta) and QK/2 8-bit ints (i.e QK 4-bit signed integer factors)
|
||||||
|
#define QK_4_0 32
|
||||||
typedef struct {
|
typedef struct {
|
||||||
float d; // delta
|
float d; // delta
|
||||||
uint8_t qs[QK / 2]; // nibbles / quants
|
uint8_t qs[QK_4_0 / 2]; // nibbles / quants
|
||||||
} block_q4_0;
|
} block_q4_0;
|
||||||
static_assert(sizeof(block_q4_0) == sizeof(float) + QK / 2, "wrong q4_0 block size/padding");
|
static_assert(sizeof(block_q4_0) == sizeof(float) + QK_4_0 / 2, "wrong q4_0 block size/padding");
|
||||||
|
|
||||||
// method 4
|
// method 4
|
||||||
// blocks of QK elements
|
// blocks of QK elements
|
||||||
// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
|
// represented with 2 floats (delta + min) and QK/2 8-bit ints (i.e QK 4-bit unsigned integer factors)
|
||||||
|
#define QK_4_1 32
|
||||||
typedef struct {
|
typedef struct {
|
||||||
float d;
|
float d;
|
||||||
float m;
|
float m;
|
||||||
uint8_t qs[QK / 2]; // nibbles / quants
|
uint8_t qs[QK_4_1 / 2]; // nibbles / quants
|
||||||
} block_q4_1;
|
} block_q4_1;
|
||||||
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK / 2, "wrong q4_1 block size/padding");
|
static_assert(sizeof(block_q4_1) == sizeof(float) * 2 + QK_4_1 / 2, "wrong q4_1 block size/padding");
|
||||||
|
|
||||||
// reference implementation for deterministic creation of model files
|
// reference implementation for deterministic creation of model files
|
||||||
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * restrict y, int k) {
|
||||||
assert(k % QK == 0);
|
assert(k % QK_4_0 == 0);
|
||||||
const int nb = k / QK;
|
const int nb = k / QK_4_0;
|
||||||
|
|
||||||
uint8_t pp[QK/2];
|
uint8_t pp[QK_4_0/2];
|
||||||
|
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
float amax = 0.0f; // absolute max
|
float amax = 0.0f; // absolute max
|
||||||
|
|
||||||
for (int l = 0; l < QK; l++) {
|
for (int l = 0; l < QK_4_0; l++) {
|
||||||
const float v = x[i*QK + l];
|
const float v = x[i*QK_4_0 + l];
|
||||||
amax = MAX(amax, fabsf(v));
|
amax = MAX(amax, fabsf(v));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -535,9 +535,9 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
||||||
|
|
||||||
y[i].d = d;
|
y[i].d = d;
|
||||||
|
|
||||||
for (int l = 0; l < QK; l += 2) {
|
for (int l = 0; l < QK_4_0; l += 2) {
|
||||||
const float v0 = x[i*QK + l + 0]*id;
|
const float v0 = x[i*QK_4_0 + l + 0]*id;
|
||||||
const float v1 = x[i*QK + l + 1]*id;
|
const float v1 = x[i*QK_4_0 + l + 1]*id;
|
||||||
|
|
||||||
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
|
const uint8_t vi0 = (int8_t)roundf(v0) + 8;
|
||||||
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
|
const uint8_t vi1 = (int8_t)roundf(v1) + 8;
|
||||||
|
@ -553,8 +553,8 @@ static void quantize_row_q4_0_reference(const float * restrict x, block_q4_0 * r
|
||||||
}
|
}
|
||||||
|
|
||||||
static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
|
static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int k) {
|
||||||
assert(k % QK == 0);
|
assert(k % QK_4_0 == 0);
|
||||||
const int nb = k / QK;
|
const int nb = k / QK_4_0;
|
||||||
|
|
||||||
block_q4_0 * restrict y = vy;
|
block_q4_0 * restrict y = vy;
|
||||||
|
|
||||||
|
@ -807,19 +807,19 @@ static void quantize_row_q4_0(const float * restrict x, void * restrict vy, int
|
||||||
}
|
}
|
||||||
|
|
||||||
static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
|
static void quantize_row_q4_1_reference(const float * restrict x, void * restrict vy, int k) {
|
||||||
assert(k % QK == 0);
|
assert(k % QK_4_1 == 0);
|
||||||
const int nb = k / QK;
|
const int nb = k / QK_4_1;
|
||||||
|
|
||||||
block_q4_1 * restrict y = vy;
|
block_q4_1 * restrict y = vy;
|
||||||
|
|
||||||
uint8_t pp[QK/2];
|
uint8_t pp[QK_4_1/2];
|
||||||
|
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
float min = FLT_MAX;
|
float min = FLT_MAX;
|
||||||
float max = -FLT_MAX;
|
float max = -FLT_MAX;
|
||||||
|
|
||||||
for (int l = 0; l < QK; l++) {
|
for (int l = 0; l < QK_4_1; l++) {
|
||||||
const float v = x[i*QK + l];
|
const float v = x[i*QK_4_1 + l];
|
||||||
if (v < min) min = v;
|
if (v < min) min = v;
|
||||||
if (v > max) max = v;
|
if (v > max) max = v;
|
||||||
}
|
}
|
||||||
|
@ -830,9 +830,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
|
||||||
y[i].d = d;
|
y[i].d = d;
|
||||||
y[i].m = min;
|
y[i].m = min;
|
||||||
|
|
||||||
for (int l = 0; l < QK; l += 2) {
|
for (int l = 0; l < QK_4_1; l += 2) {
|
||||||
const float v0 = (x[i*QK + l + 0] - min)*id;
|
const float v0 = (x[i*QK_4_1 + l + 0] - min)*id;
|
||||||
const float v1 = (x[i*QK + l + 1] - min)*id;
|
const float v1 = (x[i*QK_4_1 + l + 1] - min)*id;
|
||||||
|
|
||||||
const uint8_t vi0 = roundf(v0);
|
const uint8_t vi0 = roundf(v0);
|
||||||
const uint8_t vi1 = roundf(v1);
|
const uint8_t vi1 = roundf(v1);
|
||||||
|
@ -848,9 +848,9 @@ static void quantize_row_q4_1_reference(const float * restrict x, void * restric
|
||||||
}
|
}
|
||||||
|
|
||||||
static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
|
static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int k) {
|
||||||
assert(k % QK == 0);
|
assert(k % QK_4_1 == 0);
|
||||||
|
|
||||||
const int nb = k / QK;
|
const int nb = k / QK_4_1;
|
||||||
|
|
||||||
block_q4_1 * restrict y = vy;
|
block_q4_1 * restrict y = vy;
|
||||||
|
|
||||||
|
@ -970,8 +970,8 @@ static void quantize_row_q4_1(const float * restrict x, void * restrict vy, int
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, int k) {
|
||||||
assert(k % QK == 0);
|
assert(k % QK_4_0 == 0);
|
||||||
const int nb = k / QK;
|
const int nb = k / QK_4_0;
|
||||||
|
|
||||||
const block_q4_0 * restrict x = vx;
|
const block_q4_0 * restrict x = vx;
|
||||||
|
|
||||||
|
@ -982,7 +982,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
||||||
|
|
||||||
const uint8_t * restrict pp = x[i].qs;
|
const uint8_t * restrict pp = x[i].qs;
|
||||||
|
|
||||||
for (int l = 0; l < QK; l += 32) {
|
for (int l = 0; l < QK_4_0; l += 32) { // loop is done once, keep for easy experimenting with QK
|
||||||
// Load 32x4-bit integers into 32x8-bit integers
|
// Load 32x4-bit integers into 32x8-bit integers
|
||||||
__m256i vx8 = bytesFromNibbles(pp+l/2);
|
__m256i vx8 = bytesFromNibbles(pp+l/2);
|
||||||
|
|
||||||
|
@ -1004,7 +1004,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
||||||
// Scale and store
|
// Scale and store
|
||||||
for (int j = 0; j < 4; j++) {
|
for (int j = 0; j < 4; j++) {
|
||||||
const __m256 result = _mm256_mul_ps(vf[j], d_v);
|
const __m256 result = _mm256_mul_ps(vf[j], d_v);
|
||||||
_mm256_storeu_ps(y + i * QK + l + j*8, result);
|
_mm256_storeu_ps(y + i * QK_4_0 + l + j*8, result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1014,7 +1014,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
||||||
|
|
||||||
const uint8_t * restrict pp = x[i].qs;
|
const uint8_t * restrict pp = x[i].qs;
|
||||||
|
|
||||||
for (int l = 0; l < QK; l += 16) {
|
for (int l = 0; l < QK_4_0; l += 16) {
|
||||||
// Load 16x4-bit integers into 8x8-bit integers
|
// Load 16x4-bit integers into 8x8-bit integers
|
||||||
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
||||||
|
|
||||||
|
@ -1053,10 +1053,10 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
||||||
const float32x4_t r3 = vmulq_f32(vf_3, vd);
|
const float32x4_t r3 = vmulq_f32(vf_3, vd);
|
||||||
|
|
||||||
// Store
|
// Store
|
||||||
vst1q_f32(y + i*QK + l + 0, r0);
|
vst1q_f32(y + i*QK_4_0 + l + 0, r0);
|
||||||
vst1q_f32(y + i*QK + l + 4, r1);
|
vst1q_f32(y + i*QK_4_0 + l + 4, r1);
|
||||||
vst1q_f32(y + i*QK + l + 8, r2);
|
vst1q_f32(y + i*QK_4_0 + l + 8, r2);
|
||||||
vst1q_f32(y + i*QK + l + 12, r3);
|
vst1q_f32(y + i*QK_4_0 + l + 12, r3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
@ -1066,7 +1066,7 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
||||||
|
|
||||||
const uint8_t * restrict pp = x[i].qs;
|
const uint8_t * restrict pp = x[i].qs;
|
||||||
|
|
||||||
for (int l = 0; l < QK; l += 2) {
|
for (int l = 0; l < QK_4_0; l += 2) {
|
||||||
const uint8_t vi = pp[l/2];
|
const uint8_t vi = pp[l/2];
|
||||||
|
|
||||||
const int8_t vi0 = vi & 0xf;
|
const int8_t vi0 = vi & 0xf;
|
||||||
|
@ -1077,19 +1077,19 @@ static void dequantize_row_q4_0(const void * restrict vx, float * restrict y, in
|
||||||
|
|
||||||
//printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
|
//printf("d = %f, vi = %d, vi0 = %d, vi1 = %d, v0 = %f, v1 = %f\n", d, vi, vi0, vi1, v0, v1);
|
||||||
|
|
||||||
y[i*QK + l + 0] = v0;
|
y[i*QK_4_0 + l + 0] = v0;
|
||||||
y[i*QK + l + 1] = v1;
|
y[i*QK_4_0 + l + 1] = v1;
|
||||||
|
|
||||||
assert(!isnan(y[i*QK + l + 0]));
|
assert(!isnan(y[i*QK_4_0 + l + 0]));
|
||||||
assert(!isnan(y[i*QK + l + 1]));
|
assert(!isnan(y[i*QK_4_0 + l + 1]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
}
|
}
|
||||||
|
|
||||||
static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
|
static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, int k) {
|
||||||
assert(k % QK == 0);
|
assert(k % QK_4_1 == 0);
|
||||||
const int nb = k / QK;
|
const int nb = k / QK_4_1;
|
||||||
|
|
||||||
const block_q4_1 * restrict x = vx;
|
const block_q4_1 * restrict x = vx;
|
||||||
|
|
||||||
|
@ -1100,7 +1100,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
||||||
|
|
||||||
const uint8_t * restrict pp = x[i].qs;
|
const uint8_t * restrict pp = x[i].qs;
|
||||||
|
|
||||||
for (int l = 0; l < QK; l += 32) {
|
for (int l = 0; l < QK_4_1; l += 32) { // loop is done once, keep for easy experimenting with QK
|
||||||
// Load 32x4-bit integers into 32x8-bit integers
|
// Load 32x4-bit integers into 32x8-bit integers
|
||||||
__m256i vx8 = bytesFromNibbles(pp+l/2);
|
__m256i vx8 = bytesFromNibbles(pp+l/2);
|
||||||
|
|
||||||
|
@ -1119,7 +1119,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
||||||
// Scale, add m and store
|
// Scale, add m and store
|
||||||
for (int j = 0; j < 4; j++) {
|
for (int j = 0; j < 4; j++) {
|
||||||
const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
|
const __m256 result = _mm256_add_ps(_mm256_mul_ps(vf[j], d_v), d_m);
|
||||||
_mm256_storeu_ps(y + i * QK + l + j*8, result);
|
_mm256_storeu_ps(y + i * QK_4_1 + l + j*8, result);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1130,7 +1130,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
||||||
|
|
||||||
const uint8_t * restrict pp = x[i].qs;
|
const uint8_t * restrict pp = x[i].qs;
|
||||||
|
|
||||||
for (int l = 0; l < QK; l += 16) {
|
for (int l = 0; l < QK_4_1; l += 16) {
|
||||||
// Load 16x4-bit integers into 8x8-bit integers
|
// Load 16x4-bit integers into 8x8-bit integers
|
||||||
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
const uint8x8_t v8 = vld1_u8(pp + l/2);
|
||||||
|
|
||||||
|
@ -1161,10 +1161,10 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
||||||
const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
|
const float32x4_t r3 = vmlaq_f32(vm, vf_3, vd);
|
||||||
|
|
||||||
// Store
|
// Store
|
||||||
vst1q_f32(y + i*QK + l + 0, r0);
|
vst1q_f32(y + i*QK_4_1 + l + 0, r0);
|
||||||
vst1q_f32(y + i*QK + l + 4, r1);
|
vst1q_f32(y + i*QK_4_1 + l + 4, r1);
|
||||||
vst1q_f32(y + i*QK + l + 8, r2);
|
vst1q_f32(y + i*QK_4_1 + l + 8, r2);
|
||||||
vst1q_f32(y + i*QK + l + 12, r3);
|
vst1q_f32(y + i*QK_4_1 + l + 12, r3);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#else
|
#else
|
||||||
|
@ -1174,7 +1174,7 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
||||||
|
|
||||||
const uint8_t * restrict pp = x[i].qs;
|
const uint8_t * restrict pp = x[i].qs;
|
||||||
|
|
||||||
for (int l = 0; l < QK; l += 2) {
|
for (int l = 0; l < QK_4_1; l += 2) {
|
||||||
const uint8_t vi = pp[l/2];
|
const uint8_t vi = pp[l/2];
|
||||||
|
|
||||||
const int8_t vi0 = vi & 0xf;
|
const int8_t vi0 = vi & 0xf;
|
||||||
|
@ -1183,11 +1183,11 @@ static void dequantize_row_q4_1(const void * restrict vx, float * restrict y, in
|
||||||
const float v0 = vi0*d + m;
|
const float v0 = vi0*d + m;
|
||||||
const float v1 = vi1*d + m;
|
const float v1 = vi1*d + m;
|
||||||
|
|
||||||
y[i*QK + l + 0] = v0;
|
y[i*QK_4_1 + l + 0] = v0;
|
||||||
y[i*QK + l + 1] = v1;
|
y[i*QK_4_1 + l + 1] = v1;
|
||||||
|
|
||||||
assert(!isnan(y[i*QK + l + 0]));
|
assert(!isnan(y[i*QK_4_1 + l + 0]));
|
||||||
assert(!isnan(y[i*QK + l + 1]));
|
assert(!isnan(y[i*QK_4_1 + l + 1]));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#endif
|
#endif
|
||||||
|
@ -1757,7 +1757,7 @@ inline static void ggml_vec_dot_f32(const int n, float * restrict s, const float
|
||||||
*s = sumf;
|
*s = sumf;
|
||||||
}
|
}
|
||||||
|
|
||||||
#if __AVX512F__ && QK == 32
|
#if __AVX512F__ && QK_4_0 == 32
|
||||||
static inline __m512 dot_q4_0_oneblock_avx512(
|
static inline __m512 dot_q4_0_oneblock_avx512(
|
||||||
__m512 acc,
|
__m512 acc,
|
||||||
const block_q4_0 * restrict x,
|
const block_q4_0 * restrict x,
|
||||||
|
@ -1825,9 +1825,9 @@ inline static void ggml_vec_dot_f16(const int n, float * restrict s, ggml_fp16_t
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||||
const int nb = n / QK;
|
const int nb = n / QK_4_0;
|
||||||
|
|
||||||
assert(n % QK == 0);
|
assert(n % QK_4_0 == 0);
|
||||||
assert(nb % 2 == 0);
|
assert(nb % 2 == 0);
|
||||||
|
|
||||||
const block_q4_0 * restrict x = vx;
|
const block_q4_0 * restrict x = vx;
|
||||||
|
@ -2140,7 +2140,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
||||||
const uint8_t * restrict p0 = x[i].qs;
|
const uint8_t * restrict p0 = x[i].qs;
|
||||||
const uint8_t * restrict p1 = y[i].qs;
|
const uint8_t * restrict p1 = y[i].qs;
|
||||||
|
|
||||||
for (int j = 0; j < QK/2; j++) {
|
for (int j = 0; j < QK_4_0/2; j++) {
|
||||||
const uint8_t v0 = p0[j];
|
const uint8_t v0 = p0[j];
|
||||||
const uint8_t v1 = p1[j];
|
const uint8_t v1 = p1[j];
|
||||||
|
|
||||||
|
@ -2159,7 +2159,7 @@ static void ggml_vec_dot_q4_0(const int n, float * restrict s, const void * rest
|
||||||
}
|
}
|
||||||
|
|
||||||
static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||||
const int nb = n / QK;
|
const int nb = n / QK_4_1;
|
||||||
|
|
||||||
const block_q4_1 * restrict x = vx;
|
const block_q4_1 * restrict x = vx;
|
||||||
const block_q4_1 * restrict y = vy;
|
const block_q4_1 * restrict y = vy;
|
||||||
|
@ -2236,7 +2236,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
||||||
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
res = _mm_add_ps( res, _mm_movehl_ps( res, res ) );
|
||||||
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
res = _mm_add_ss( res, _mm_movehdup_ps( res ) );
|
||||||
|
|
||||||
sumf = _mm_cvtss_f32( res ) + acc_offset * QK;
|
sumf = _mm_cvtss_f32( res ) + acc_offset * QK_4_1;
|
||||||
#elif defined(__ARM_NEON)
|
#elif defined(__ARM_NEON)
|
||||||
float sum00 = 0.0f;
|
float sum00 = 0.0f;
|
||||||
float sum01 = 0.0f;
|
float sum01 = 0.0f;
|
||||||
|
@ -2275,7 +2275,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
||||||
sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
|
sum11 += x0->d*y0->d*vaddvq_u16(vaddq_u16(pl0, ph0));
|
||||||
}
|
}
|
||||||
|
|
||||||
sumf = QK*sum00 + sum01 + sum10 + sum11;
|
sumf = QK_4_1*sum00 + sum01 + sum10 + sum11;
|
||||||
#else
|
#else
|
||||||
// scalar
|
// scalar
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
|
@ -2288,7 +2288,7 @@ static void ggml_vec_dot_q4_1(const int n, float * restrict s, const void * rest
|
||||||
const uint8_t * restrict p0 = x[i].qs;
|
const uint8_t * restrict p0 = x[i].qs;
|
||||||
const uint8_t * restrict p1 = y[i].qs;
|
const uint8_t * restrict p1 = y[i].qs;
|
||||||
|
|
||||||
for (int j = 0; j < QK/2; j++) {
|
for (int j = 0; j < QK_4_1/2; j++) {
|
||||||
const uint8_t v0 = p0[j];
|
const uint8_t v0 = p0[j];
|
||||||
const uint8_t v1 = p1[j];
|
const uint8_t v1 = p1[j];
|
||||||
|
|
||||||
|
@ -2547,118 +2547,113 @@ inline static void ggml_vec_norm_inv_f32(const int n, float * s, const float * x
|
||||||
// data types
|
// data types
|
||||||
//
|
//
|
||||||
|
|
||||||
static const int GGML_BLCK_SIZE[GGML_TYPE_COUNT] = {
|
static const int GGML_BLCK_SIZE[] = {
|
||||||
QK,
|
[GGML_TYPE_Q4_0] = QK_4_0,
|
||||||
QK,
|
[GGML_TYPE_Q4_1] = QK_4_1,
|
||||||
1,
|
[GGML_TYPE_I8] = 1,
|
||||||
1,
|
[GGML_TYPE_I16] = 1,
|
||||||
1,
|
[GGML_TYPE_I32] = 1,
|
||||||
1,
|
[GGML_TYPE_F16] = 1,
|
||||||
1,
|
[GGML_TYPE_F32] = 1,
|
||||||
};
|
};
|
||||||
|
static_assert(sizeof(GGML_BLCK_SIZE)/sizeof(*GGML_BLCK_SIZE) == GGML_TYPE_COUNT, "GGML_BLCK_SIZE incomplete");
|
||||||
|
|
||||||
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
|
static const size_t GGML_TYPE_SIZE[] = {
|
||||||
|
[GGML_TYPE_Q4_0] = sizeof(block_q4_0),
|
||||||
static const size_t GGML_TYPE_SIZE[GGML_TYPE_COUNT] = {
|
[GGML_TYPE_Q4_1] = sizeof(block_q4_1),
|
||||||
sizeof(block_q4_0),
|
[GGML_TYPE_I8] = sizeof(int8_t),
|
||||||
sizeof(block_q4_1),
|
[GGML_TYPE_I16] = sizeof(int16_t),
|
||||||
sizeof(int8_t ),
|
[GGML_TYPE_I32] = sizeof(int32_t),
|
||||||
sizeof(int16_t),
|
[GGML_TYPE_F16] = sizeof(ggml_fp16_t),
|
||||||
sizeof(int32_t),
|
[GGML_TYPE_F32] = sizeof(float),
|
||||||
sizeof(ggml_fp16_t),
|
|
||||||
sizeof(float ),
|
|
||||||
};
|
};
|
||||||
|
static_assert(sizeof(GGML_TYPE_SIZE)/sizeof(*GGML_TYPE_SIZE) == GGML_TYPE_COUNT, "GGML_TYPE_SIZE incomplete");
|
||||||
|
|
||||||
// don't forget to update the array above when adding new types
|
static const char * GGML_OP_LABEL[] = {
|
||||||
static_assert(GGML_TYPE_COUNT == 7, "GGML_TYPE_COUNT != 5");
|
[GGML_OP_NONE] = "NONE",
|
||||||
|
|
||||||
static const char * GGML_OP_LABEL[GGML_OP_COUNT] = {
|
[GGML_OP_DUP] = "DUP",
|
||||||
"NONE",
|
[GGML_OP_ADD] = "ADD",
|
||||||
|
[GGML_OP_SUB] = "SUB",
|
||||||
|
[GGML_OP_MUL] = "MUL",
|
||||||
|
[GGML_OP_DIV] = "DIV",
|
||||||
|
[GGML_OP_SQR] = "SQR",
|
||||||
|
[GGML_OP_SQRT] = "SQRT",
|
||||||
|
[GGML_OP_SUM] = "SUM",
|
||||||
|
[GGML_OP_MEAN] = "MEAN",
|
||||||
|
[GGML_OP_REPEAT] = "REPEAT",
|
||||||
|
[GGML_OP_ABS] = "ABS",
|
||||||
|
[GGML_OP_SGN] = "SGN",
|
||||||
|
[GGML_OP_NEG] = "NEG",
|
||||||
|
[GGML_OP_STEP] = "STEP",
|
||||||
|
[GGML_OP_RELU] = "RELU",
|
||||||
|
[GGML_OP_GELU] = "GELU",
|
||||||
|
[GGML_OP_SILU] = "SILU",
|
||||||
|
[GGML_OP_NORM] = "NORM",
|
||||||
|
[GGML_OP_RMS_NORM] = "RMS_NORM",
|
||||||
|
|
||||||
"DUP",
|
[GGML_OP_MUL_MAT] = "MUL_MAT",
|
||||||
"ADD",
|
|
||||||
"SUB",
|
|
||||||
"MUL",
|
|
||||||
"DIV",
|
|
||||||
"SQR",
|
|
||||||
"SQRT",
|
|
||||||
"SUM",
|
|
||||||
"MEAN",
|
|
||||||
"REPEAT",
|
|
||||||
"ABS",
|
|
||||||
"SGN",
|
|
||||||
"NEG",
|
|
||||||
"STEP",
|
|
||||||
"RELU",
|
|
||||||
"GELU",
|
|
||||||
"SILU",
|
|
||||||
"NORM",
|
|
||||||
"RMS_NORM",
|
|
||||||
|
|
||||||
"MUL_MAT",
|
[GGML_OP_SCALE] = "SCALE",
|
||||||
|
[GGML_OP_CPY] = "CPY",
|
||||||
|
[GGML_OP_RESHAPE] = "RESHAPE",
|
||||||
|
[GGML_OP_VIEW] = "VIEW",
|
||||||
|
[GGML_OP_PERMUTE] = "PERMUTE",
|
||||||
|
[GGML_OP_TRANSPOSE] = "TRANSPOSE",
|
||||||
|
[GGML_OP_GET_ROWS] = "GET_ROWS",
|
||||||
|
[GGML_OP_DIAG_MASK_INF] = "DIAG_MASK_INF",
|
||||||
|
[GGML_OP_SOFT_MAX] = "SOFT_MAX",
|
||||||
|
[GGML_OP_ROPE] = "ROPE",
|
||||||
|
[GGML_OP_CONV_1D_1S] = "CONV_1D_1S",
|
||||||
|
[GGML_OP_CONV_1D_2S] = "CONV_1D_2S",
|
||||||
|
|
||||||
"SCALE",
|
[GGML_OP_FLASH_ATTN] = "FLASH_ATTN",
|
||||||
"CPY",
|
[GGML_OP_FLASH_FF] = "FLASH_FF",
|
||||||
"RESHAPE",
|
|
||||||
"VIEW",
|
|
||||||
"PERMUTE",
|
|
||||||
"TRANSPOSE",
|
|
||||||
"GET_ROWS",
|
|
||||||
"DIAG_MASK_INF",
|
|
||||||
"SOFT_MAX",
|
|
||||||
"ROPE",
|
|
||||||
"CONV_1D_1S",
|
|
||||||
"CONV_1D_2S",
|
|
||||||
|
|
||||||
"FLASH_ATTN",
|
|
||||||
"FLASH_FF",
|
|
||||||
};
|
};
|
||||||
|
static_assert(sizeof(GGML_OP_LABEL)/sizeof(*GGML_OP_LABEL) == GGML_OP_COUNT, "GGML_OP_LABEL incomplete");
|
||||||
|
|
||||||
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
|
static const char * GGML_OP_SYMBOL[] = {
|
||||||
|
[GGML_OP_NONE] = "none",
|
||||||
|
|
||||||
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
|
[GGML_OP_DUP] = "x",
|
||||||
"none",
|
[GGML_OP_ADD] = "x+y",
|
||||||
|
[GGML_OP_SUB] = "x-y",
|
||||||
|
[GGML_OP_MUL] = "x*y",
|
||||||
|
[GGML_OP_DIV] = "x/y",
|
||||||
|
[GGML_OP_SQR] = "x^2",
|
||||||
|
[GGML_OP_SQRT] = "√x",
|
||||||
|
[GGML_OP_SUM] = "Σx",
|
||||||
|
[GGML_OP_MEAN] = "Σx/n",
|
||||||
|
[GGML_OP_REPEAT] = "repeat(x)",
|
||||||
|
[GGML_OP_ABS] = "abs(x)",
|
||||||
|
[GGML_OP_SGN] = "sgn(x)",
|
||||||
|
[GGML_OP_NEG] = "-x",
|
||||||
|
[GGML_OP_STEP] = "step(x)",
|
||||||
|
[GGML_OP_RELU] = "relu(x)",
|
||||||
|
[GGML_OP_GELU] = "gelu(x)",
|
||||||
|
[GGML_OP_SILU] = "silu(x)",
|
||||||
|
[GGML_OP_NORM] = "norm(x)",
|
||||||
|
[GGML_OP_RMS_NORM] = "rms_norm(x)",
|
||||||
|
|
||||||
"x",
|
[GGML_OP_MUL_MAT] = "X*Y",
|
||||||
"x+y",
|
|
||||||
"x-y",
|
|
||||||
"x*y",
|
|
||||||
"x/y",
|
|
||||||
"x^2",
|
|
||||||
"√x",
|
|
||||||
"Σx",
|
|
||||||
"Σx/n",
|
|
||||||
"repeat(x)",
|
|
||||||
"abs(x)",
|
|
||||||
"sgn(x)",
|
|
||||||
"-x",
|
|
||||||
"step(x)",
|
|
||||||
"relu(x)",
|
|
||||||
"gelu(x)",
|
|
||||||
"silu(x)",
|
|
||||||
"norm(x)",
|
|
||||||
"rms_norm(x)",
|
|
||||||
|
|
||||||
"X*Y",
|
[GGML_OP_SCALE] = "x*v",
|
||||||
|
[GGML_OP_CPY] = "x-\\>y",
|
||||||
|
[GGML_OP_RESHAPE] = "reshape(x)",
|
||||||
|
[GGML_OP_VIEW] = "view(x)",
|
||||||
|
[GGML_OP_PERMUTE] = "permute(x)",
|
||||||
|
[GGML_OP_TRANSPOSE] = "transpose(x)",
|
||||||
|
[GGML_OP_GET_ROWS] = "get_rows(x)",
|
||||||
|
[GGML_OP_DIAG_MASK_INF] = "diag_mask_inf(x)",
|
||||||
|
[GGML_OP_SOFT_MAX] = "soft_max(x)",
|
||||||
|
[GGML_OP_ROPE] = "rope(x)",
|
||||||
|
[GGML_OP_CONV_1D_1S] = "conv_1d_1s(x)",
|
||||||
|
[GGML_OP_CONV_1D_2S] = "conv_1d_2s(x)",
|
||||||
|
|
||||||
"x*v",
|
[GGML_OP_FLASH_ATTN] = "flash_attn(x)",
|
||||||
"x-\\>y",
|
[GGML_OP_FLASH_FF] = "flash_ff(x)",
|
||||||
"reshape(x)",
|
|
||||||
"view(x)",
|
|
||||||
"permute(x)",
|
|
||||||
"transpose(x)",
|
|
||||||
"get_rows(x)",
|
|
||||||
"diag_mask_inf(x)",
|
|
||||||
"soft_max(x)",
|
|
||||||
"rope(x)",
|
|
||||||
"conv_1d_1s(x)",
|
|
||||||
"conv_1d_2s(x)",
|
|
||||||
|
|
||||||
"flash_attn(x)",
|
|
||||||
"flash_ff(x)",
|
|
||||||
};
|
};
|
||||||
|
static_assert(sizeof(GGML_OP_SYMBOL)/sizeof(*GGML_OP_SYMBOL) == GGML_OP_COUNT, "GGML_OP_SYMBOL incomplete");
|
||||||
static_assert(GGML_OP_COUNT == 35, "GGML_OP_COUNT != 35");
|
|
||||||
|
|
||||||
//
|
//
|
||||||
// ggml object
|
// ggml object
|
||||||
|
@ -6686,7 +6681,7 @@ static void ggml_compute_forward_mul_mat_q_f32(
|
||||||
|
|
||||||
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
|
float * dst_col = (float *) ((char *) dst->data + (i0*nb0 + 0*nb1 + i2*nb2 + i3*nb3));
|
||||||
|
|
||||||
assert(ne00 % 32 == 0);
|
assert(ne00 % GGML_BLCK_SIZE[type] == 0);
|
||||||
|
|
||||||
for (int ic = 0; ic < ne11; ++ic) {
|
for (int ic = 0; ic < ne11; ++ic) {
|
||||||
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
|
vec_dot_q(ne00, &dst_col[ic*ne0], src0_row, (void *) (src1_col + ic*row_size));
|
||||||
|
@ -10496,16 +10491,16 @@ enum ggml_opt_result ggml_opt(
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t * hist) {
|
||||||
assert(k % QK == 0);
|
assert(k % QK_4_0 == 0);
|
||||||
const int nb = k / QK;
|
const int nb = k / QK_4_0;
|
||||||
|
|
||||||
for (int j = 0; j < n; j += k) {
|
for (int j = 0; j < n; j += k) {
|
||||||
block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK;
|
block_q4_0 * restrict y = (block_q4_0 *)dst + j/QK_4_0;
|
||||||
|
|
||||||
quantize_row_q4_0_reference(src + j, y, k);
|
quantize_row_q4_0_reference(src + j, y, k);
|
||||||
|
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
for (int l = 0; l < QK; l += 2) {
|
for (int l = 0; l < QK_4_0; l += 2) {
|
||||||
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
||||||
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
||||||
|
|
||||||
|
@ -10515,20 +10510,20 @@ size_t ggml_quantize_q4_0(const float * src, void * dst, int n, int k, int64_t *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return (n/QK*sizeof(block_q4_0));
|
return (n/QK_4_0*sizeof(block_q4_0));
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
|
size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t * hist) {
|
||||||
assert(k % QK == 0);
|
assert(k % QK_4_1 == 0);
|
||||||
const int nb = k / QK;
|
const int nb = k / QK_4_1;
|
||||||
|
|
||||||
for (int j = 0; j < n; j += k) {
|
for (int j = 0; j < n; j += k) {
|
||||||
block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK;
|
block_q4_1 * restrict y = (block_q4_1 *)dst + j/QK_4_1;
|
||||||
|
|
||||||
quantize_row_q4_1_reference(src + j, y, k);
|
quantize_row_q4_1_reference(src + j, y, k);
|
||||||
|
|
||||||
for (int i = 0; i < nb; i++) {
|
for (int i = 0; i < nb; i++) {
|
||||||
for (int l = 0; l < QK; l += 2) {
|
for (int l = 0; l < QK_4_1; l += 2) {
|
||||||
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
const uint8_t vi0 = y[i].qs[l/2] & 0xF;
|
||||||
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
const uint8_t vi1 = y[i].qs[l/2] >> 4;
|
||||||
|
|
||||||
|
@ -10538,7 +10533,7 @@ size_t ggml_quantize_q4_1(const float * src, void * dst, int n, int k, int64_t *
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return (n/QK*sizeof(block_q4_1));
|
return (n/QK_4_1*sizeof(block_q4_1));
|
||||||
}
|
}
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
50
ggml.py
Normal file
50
ggml.py
Normal file
|
@ -0,0 +1,50 @@
|
||||||
|
from enum import IntEnum
|
||||||
|
|
||||||
|
|
||||||
|
class GGML_TYPE(IntEnum):
|
||||||
|
"""Tensor types, corresponding to enum ggml_type in ggml.h"""
|
||||||
|
|
||||||
|
Q4_0 = 0
|
||||||
|
Q4_1 = 1
|
||||||
|
I8 = 2
|
||||||
|
I16 = 3
|
||||||
|
I32 = 4
|
||||||
|
F16 = 5
|
||||||
|
F32 = 6
|
||||||
|
|
||||||
|
|
||||||
|
class GGML_FILE(IntEnum):
|
||||||
|
"""File types, corresponding to enum e_ftype in llama.cpp"""
|
||||||
|
|
||||||
|
F32 = 0
|
||||||
|
F16 = 1
|
||||||
|
Q4_0 = 2
|
||||||
|
Q4_1 = 3
|
||||||
|
|
||||||
|
|
||||||
|
ggml_type_from_ftype = {
|
||||||
|
GGML_FILE.F32: GGML_TYPE.F32,
|
||||||
|
GGML_FILE.F16: GGML_TYPE.F16,
|
||||||
|
GGML_FILE.Q4_0: GGML_TYPE.Q4_0,
|
||||||
|
GGML_FILE.Q4_1: GGML_TYPE.Q4_1,
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_BLCK_SIZE = {
|
||||||
|
GGML_TYPE.Q4_0: 32,
|
||||||
|
GGML_TYPE.Q4_1: 32,
|
||||||
|
GGML_TYPE.I8: 1,
|
||||||
|
GGML_TYPE.I16: 1,
|
||||||
|
GGML_TYPE.I32: 1,
|
||||||
|
GGML_TYPE.F16: 1,
|
||||||
|
GGML_TYPE.F32: 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
GGML_TYPE_SIZE = {
|
||||||
|
GGML_TYPE.Q4_0: 4 + GGML_BLCK_SIZE[GGML_TYPE.Q4_0] // 2,
|
||||||
|
GGML_TYPE.Q4_1: 4 * 2 + GGML_BLCK_SIZE[GGML_TYPE.Q4_1] // 2,
|
||||||
|
GGML_TYPE.I8: 1,
|
||||||
|
GGML_TYPE.I16: 2,
|
||||||
|
GGML_TYPE.I32: 4,
|
||||||
|
GGML_TYPE.F16: 2,
|
||||||
|
GGML_TYPE.F32: 4,
|
||||||
|
}
|
52
llama.cpp
52
llama.cpp
|
@ -54,6 +54,15 @@ enum e_model {
|
||||||
MODEL_65B,
|
MODEL_65B,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// model file types
|
||||||
|
enum e_ftype {
|
||||||
|
FTYPE_F32 = 0,
|
||||||
|
FTYPE_F16 = 1,
|
||||||
|
FTYPE_Q4_0 = 2,
|
||||||
|
FTYPE_Q4_1 = 3,
|
||||||
|
};
|
||||||
|
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1" };
|
||||||
|
|
||||||
static const size_t MB = 1024*1024;
|
static const size_t MB = 1024*1024;
|
||||||
|
|
||||||
// computed for n_ctx == 2048
|
// computed for n_ctx == 2048
|
||||||
|
@ -100,7 +109,7 @@ struct llama_hparams {
|
||||||
int32_t n_head = 32;
|
int32_t n_head = 32;
|
||||||
int32_t n_layer = 32;
|
int32_t n_layer = 32;
|
||||||
int32_t n_rot = 64;
|
int32_t n_rot = 64;
|
||||||
int32_t f16 = 1;
|
int32_t f16 = FTYPE_F16;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct llama_layer {
|
struct llama_layer {
|
||||||
|
@ -508,10 +517,10 @@ static bool llama_model_load(
|
||||||
// wtype is for per-layer weights, while vtype is for other weights
|
// wtype is for per-layer weights, while vtype is for other weights
|
||||||
ggml_type wtype, vtype;
|
ggml_type wtype, vtype;
|
||||||
switch (model.hparams.f16) {
|
switch (model.hparams.f16) {
|
||||||
case 0: wtype = vtype = GGML_TYPE_F32; break;
|
case FTYPE_F32: wtype = vtype = GGML_TYPE_F32; break;
|
||||||
case 1: wtype = vtype = GGML_TYPE_F16; break;
|
case FTYPE_F16: wtype = vtype = GGML_TYPE_F16; break;
|
||||||
case 2: wtype = vtype = GGML_TYPE_Q4_0; break;
|
case FTYPE_Q4_0: wtype = vtype = GGML_TYPE_Q4_0; break;
|
||||||
case 3: wtype = vtype = GGML_TYPE_Q4_1; break;
|
case FTYPE_Q4_1: wtype = vtype = GGML_TYPE_Q4_1; break;
|
||||||
case 4: wtype = GGML_TYPE_Q4_1; vtype = GGML_TYPE_F16; break;
|
case 4: wtype = GGML_TYPE_Q4_1; vtype = GGML_TYPE_F16; break;
|
||||||
default:
|
default:
|
||||||
{
|
{
|
||||||
|
@ -684,16 +693,15 @@ static bool llama_model_load(
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
if (0) {
|
if (0) {
|
||||||
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
|
|
||||||
fprintf(stderr, "%24s - [%5d, %5d], type = %6s\n", name.data(), ne[0], ne[1], ftype_str[ftype]);
|
fprintf(stderr, "%24s - [%5d, %5d], type = %6s\n", name.data(), ne[0], ne[1], ftype_str[ftype]);
|
||||||
}
|
}
|
||||||
|
|
||||||
switch (ftype) {
|
switch (ftype) {
|
||||||
case 0: // f32
|
case FTYPE_F32:
|
||||||
case 1: // f16
|
case FTYPE_F16:
|
||||||
break;
|
break;
|
||||||
case 2: // q4_0
|
case FTYPE_Q4_0:
|
||||||
case 3: // q4_1
|
case FTYPE_Q4_1:
|
||||||
assert(ne[0] % 64 == 0);
|
assert(ne[0] % 64 == 0);
|
||||||
break;
|
break;
|
||||||
default:
|
default:
|
||||||
|
@ -1273,20 +1281,15 @@ static llama_vocab::id llama_sample_top_p_top_k(
|
||||||
//
|
//
|
||||||
|
|
||||||
// TODO: reuse code from the llama_model_load() somehow
|
// TODO: reuse code from the llama_model_load() somehow
|
||||||
static bool llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, int itype) {
|
static bool llama_model_quantize_internal(const std::string & fname_inp, const std::string & fname_out, enum e_ftype itype) {
|
||||||
ggml_type type = GGML_TYPE_Q4_1;
|
ggml_type type;
|
||||||
|
|
||||||
switch (itype) {
|
switch (itype) {
|
||||||
case 2: type = GGML_TYPE_Q4_0; break;
|
case FTYPE_Q4_0: type = GGML_TYPE_Q4_0; break;
|
||||||
case 3: type = GGML_TYPE_Q4_1; break;
|
case FTYPE_Q4_1: type = GGML_TYPE_Q4_1; break;
|
||||||
default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return 1;
|
default: fprintf(stderr, "%s: invalid quantization type %d\n", __func__, itype); return false;
|
||||||
};
|
};
|
||||||
|
|
||||||
if (type != GGML_TYPE_Q4_0 && type != GGML_TYPE_Q4_1) {
|
|
||||||
fprintf(stderr, "%s: invalid quantization type %d\n", __func__, type);
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
llama_vocab vocab;
|
llama_vocab vocab;
|
||||||
|
|
||||||
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
|
printf("%s: loading model from '%s'\n", __func__, fname_inp.c_str());
|
||||||
|
@ -1438,7 +1441,6 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
static const char * ftype_str[] = { "f32", "f16", "q4_0", "q4_1", };
|
|
||||||
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]);
|
printf("%48s - [%5d, %5d], type = %6s ", name.data(), ne[0], ne[1], ftype_str[ftype]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1459,12 +1461,12 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
quantize &= (n_dims == 2);
|
quantize &= (n_dims == 2);
|
||||||
|
|
||||||
if (quantize) {
|
if (quantize) {
|
||||||
if (ftype != 0 && ftype != 1) {
|
if (ftype != FTYPE_F32 && ftype != FTYPE_F16) {
|
||||||
fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype);
|
fprintf(stderr, "%s: unsupported ftype %d for integer quantization\n", __func__, ftype);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (ftype == 1) {
|
if (ftype == FTYPE_F16) {
|
||||||
data_f16.resize(nelements);
|
data_f16.resize(nelements);
|
||||||
finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
|
finp.read(reinterpret_cast<char *>(data_f16.data()), nelements * sizeof(ggml_fp16_t));
|
||||||
data_f32.resize(nelements);
|
data_f32.resize(nelements);
|
||||||
|
@ -1478,7 +1480,7 @@ static bool llama_model_quantize_internal(const std::string & fname_inp, const s
|
||||||
|
|
||||||
ftype = itype;
|
ftype = itype;
|
||||||
} else {
|
} else {
|
||||||
const int bpe = (ftype == 0) ? sizeof(float) : sizeof(uint16_t);
|
const int bpe = (ftype == FTYPE_F32) ? sizeof(float) : sizeof(uint16_t);
|
||||||
|
|
||||||
data_u8.resize(nelements*bpe);
|
data_u8.resize(nelements*bpe);
|
||||||
finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
|
finp.read(reinterpret_cast<char *>(data_u8.data()), nelements * bpe);
|
||||||
|
@ -1660,7 +1662,7 @@ int llama_model_quantize(
|
||||||
const char * fname_inp,
|
const char * fname_inp,
|
||||||
const char * fname_out,
|
const char * fname_out,
|
||||||
int itype) {
|
int itype) {
|
||||||
if (!llama_model_quantize_internal(fname_inp, fname_out, itype)) {
|
if (!llama_model_quantize_internal(fname_inp, fname_out, (enum e_ftype)itype)) {
|
||||||
fprintf(stderr, "%s: failed to quantize\n", __func__);
|
fprintf(stderr, "%s: failed to quantize\n", __func__);
|
||||||
return 1;
|
return 1;
|
||||||
}
|
}
|
||||||
|
|
|
@ -54,50 +54,7 @@ import sys
|
||||||
import json
|
import json
|
||||||
import struct
|
import struct
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from ggml import *
|
||||||
QK = 32
|
|
||||||
|
|
||||||
GGML_TYPE_Q4_0 = 0
|
|
||||||
GGML_TYPE_Q4_1 = 1
|
|
||||||
GGML_TYPE_I8 = 2
|
|
||||||
GGML_TYPE_I16 = 3
|
|
||||||
GGML_TYPE_I32 = 4
|
|
||||||
GGML_TYPE_F16 = 5
|
|
||||||
GGML_TYPE_F32 = 6
|
|
||||||
|
|
||||||
WTYPE_NAMES = {
|
|
||||||
0: "F32",
|
|
||||||
1: "F16",
|
|
||||||
2: "Q4_0",
|
|
||||||
3: "Q4_1",
|
|
||||||
}
|
|
||||||
|
|
||||||
WTYPES = {
|
|
||||||
0: GGML_TYPE_F32,
|
|
||||||
1: GGML_TYPE_F16,
|
|
||||||
2: GGML_TYPE_Q4_0,
|
|
||||||
3: GGML_TYPE_Q4_1,
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_BLCK_SIZE = {
|
|
||||||
GGML_TYPE_Q4_0: QK,
|
|
||||||
GGML_TYPE_Q4_1: QK,
|
|
||||||
GGML_TYPE_I8: 1,
|
|
||||||
GGML_TYPE_I16: 1,
|
|
||||||
GGML_TYPE_I32: 1,
|
|
||||||
GGML_TYPE_F16: 1,
|
|
||||||
GGML_TYPE_F32: 1,
|
|
||||||
}
|
|
||||||
|
|
||||||
GGML_TYPE_SIZE = {
|
|
||||||
GGML_TYPE_Q4_0: 4 + QK//2,
|
|
||||||
GGML_TYPE_Q4_1: 4*2 + QK//2,
|
|
||||||
GGML_TYPE_I8: 1,
|
|
||||||
GGML_TYPE_I16: 2,
|
|
||||||
GGML_TYPE_I32: 4,
|
|
||||||
GGML_TYPE_F16: 2,
|
|
||||||
GGML_TYPE_F32: 4,
|
|
||||||
}
|
|
||||||
|
|
||||||
HPARAMS = [
|
HPARAMS = [
|
||||||
'magic', # int32
|
'magic', # int32
|
||||||
|
@ -150,7 +107,7 @@ def ggml_nelements(shape):
|
||||||
|
|
||||||
def ggml_nbytes(shape, ftype):
|
def ggml_nbytes(shape, ftype):
|
||||||
x = ggml_nelements(shape)
|
x = ggml_nelements(shape)
|
||||||
t = WTYPES[ftype]
|
t = ggml_type_from_ftype[ftype]
|
||||||
x *= GGML_TYPE_SIZE[t]
|
x *= GGML_TYPE_SIZE[t]
|
||||||
x //= GGML_BLCK_SIZE[t]
|
x //= GGML_BLCK_SIZE[t]
|
||||||
return x
|
return x
|
||||||
|
@ -177,10 +134,10 @@ def copy_tensors(fin, fout, part_id, n_parts):
|
||||||
name = fin.read(length)
|
name = fin.read(length)
|
||||||
data = fin.read(ggml_nbytes(partshape, ftype))
|
data = fin.read(ggml_nbytes(partshape, ftype))
|
||||||
|
|
||||||
blck_size = GGML_BLCK_SIZE[WTYPES[ftype]]
|
blck_size = GGML_BLCK_SIZE[ggml_type_from_ftype[ftype]]
|
||||||
type_size = GGML_TYPE_SIZE[WTYPES[ftype]]
|
type_size = GGML_TYPE_SIZE[ggml_type_from_ftype[ftype]]
|
||||||
|
|
||||||
print(f"Processing tensor {name} with shape: {partshape} and type: {WTYPE_NAMES[ftype]}")
|
print(f"Processing tensor {name} with shape: {partshape} and type: {GGML_FILE(ftype).name}")
|
||||||
|
|
||||||
# determine dimension along which multipart tensor is sharded
|
# determine dimension along which multipart tensor is sharded
|
||||||
#
|
#
|
||||||
|
@ -222,7 +179,7 @@ def copy_tensors(fin, fout, part_id, n_parts):
|
||||||
|
|
||||||
# ensure tensor data is aligned
|
# ensure tensor data is aligned
|
||||||
tensor_data_offset = fout.tell()
|
tensor_data_offset = fout.tell()
|
||||||
while tensor_data_offset % QK != 0:
|
while tensor_data_offset % 32 != 0:
|
||||||
fout.write(struct.pack("B", 0))
|
fout.write(struct.pack("B", 0))
|
||||||
tensor_data_offset += 1
|
tensor_data_offset += 1
|
||||||
|
|
||||||
|
|
|
@ -3,28 +3,32 @@
|
||||||
#include <assert.h>
|
#include <assert.h>
|
||||||
#include <math.h>
|
#include <math.h>
|
||||||
|
|
||||||
|
#define MAX(a, b) ((a) > (b) ? (a) : (b))
|
||||||
|
|
||||||
int main(void) {
|
int main(void) {
|
||||||
#define QK 32
|
const int qk0 = ggml_blck_size(GGML_TYPE_Q4_0);
|
||||||
float src[QK];
|
const int qk1 = ggml_blck_size(GGML_TYPE_Q4_1);
|
||||||
|
const int qk_max = MAX(qk0, qk1);
|
||||||
|
float src[qk_max];
|
||||||
uint8_t dst[24];
|
uint8_t dst[24];
|
||||||
int64_t hist[16];
|
int64_t hist[16];
|
||||||
|
|
||||||
for (int i = 0; i < QK; i++) {
|
for (int i = 0; i < qk_max; i++) {
|
||||||
src[i] = (float)(i + 1);
|
src[i] = (float)(i + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
size_t size = ggml_quantize_q4_0(src, dst, QK, QK, hist);
|
size_t size = ggml_quantize_q4_0(src, dst, qk0, qk0, hist);
|
||||||
assert(size == 20);
|
assert(size == 20);
|
||||||
float max_result = ((float *)dst)[0];
|
float max_result = ((float *)dst)[0];
|
||||||
float max_expected = src[31] / ((1 << 3) - 1);
|
float max_expected = src[31] / ((1 << 3) - 1);
|
||||||
assert(max_result == max_expected);
|
assert(max_result == max_expected);
|
||||||
for (int i = 0; i < QK; i++) {
|
for (int i = 0; i < qk0; i++) {
|
||||||
uint8_t q4_result = (i % 2) ? (dst[sizeof(float) + i/2] >> 4) : (dst[sizeof(float) + i/2] & 0xF);
|
uint8_t q4_result = (i % 2) ? (dst[sizeof(float) + i/2] >> 4) : (dst[sizeof(float) + i/2] & 0xF);
|
||||||
uint8_t q4_expected = roundf(src[i] / max_expected) + 8;
|
uint8_t q4_expected = roundf(src[i] / max_expected) + 8;
|
||||||
assert(q4_result == q4_expected);
|
assert(q4_result == q4_expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
size = ggml_quantize_q4_1(src, dst, QK, QK, hist);
|
size = ggml_quantize_q4_1(src, dst, qk1, qk1, hist);
|
||||||
assert(size == 24);
|
assert(size == 24);
|
||||||
float delta_result = ((float *)dst)[0];
|
float delta_result = ((float *)dst)[0];
|
||||||
float delta_expected = (src[31] - src[0]) / ((1 << 4) - 1);
|
float delta_expected = (src[31] - src[0]) / ((1 << 4) - 1);
|
||||||
|
@ -32,7 +36,7 @@ int main(void) {
|
||||||
float min_result = ((float *)dst)[1];
|
float min_result = ((float *)dst)[1];
|
||||||
float min_expected = src[0];
|
float min_expected = src[0];
|
||||||
assert(min_result == min_expected);
|
assert(min_result == min_expected);
|
||||||
for (int i = 0; i < QK; i++) {
|
for (int i = 0; i < qk1; i++) {
|
||||||
uint8_t q4_result = (i % 2) ? (dst[sizeof(float)*2 + i/2] >> 4) : (dst[sizeof(float)*2 + i/2] & 0xF);
|
uint8_t q4_result = (i % 2) ? (dst[sizeof(float)*2 + i/2] >> 4) : (dst[sizeof(float)*2 + i/2] & 0xF);
|
||||||
uint8_t q4_expected = roundf((src[i] - min_expected) / delta_expected);
|
uint8_t q4_expected = roundf((src[i] - min_expected) / delta_expected);
|
||||||
assert(q4_result == q4_expected);
|
assert(q4_result == q4_expected);
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue