First draft of SqueezeLLM PR
This commit is contained in:
parent
ec2a24fedf
commit
c6b0ebbe1b
7 changed files with 99 additions and 13 deletions
11
Makefile
11
Makefile
|
@ -298,6 +298,12 @@ ifdef LLAMA_QKK_64
|
|||
endif
|
||||
endif
|
||||
|
||||
ifndef LLAMA_NO_SQLLM
|
||||
MK_CPPFLAGS += -DGGML_USE_SQLLM
|
||||
OBJS += sqllm.o
|
||||
endif
|
||||
|
||||
|
||||
ifndef LLAMA_NO_ACCELERATE
|
||||
# Mac OS - include Accelerate framework.
|
||||
# `-framework Accelerate` works both with Apple Silicon and Mac Intel
|
||||
|
@ -441,6 +447,11 @@ k_quants.o: k_quants.c k_quants.h
|
|||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
endif # LLAMA_NO_K_QUANTS
|
||||
|
||||
ifndef LLAMA_NO_SQLLM
|
||||
sqllm.o: sqllm.c sqllm.h
|
||||
$(CC) $(CFLAGS) -c $< -o $@
|
||||
endif # LLAMA_NO_SQLLM
|
||||
|
||||
# combine build flags with cmdline overrides
|
||||
override CFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CFLAGS) $(CFLAGS)
|
||||
override CXXFLAGS := $(MK_CPPFLAGS) $(CPPFLAGS) $(MK_CXXFLAGS) $(CXXFLAGS)
|
||||
|
|
|
@ -33,6 +33,7 @@ GGML_QUANT_SIZES = {
|
|||
gguf.GGMLQuantizationType.Q5_K : (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12),
|
||||
gguf.GGMLQuantizationType.Q6_K : (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16),
|
||||
gguf.GGMLQuantizationType.Q8_K : (256, 4 + QK_K + QK_K // 8),
|
||||
gguf.GGMLQuantizationType.Q4_SQ : (1, 4),
|
||||
}
|
||||
|
||||
class GGMLFormat(IntEnum):
|
||||
|
@ -58,6 +59,7 @@ class GGMLFType(IntEnum):
|
|||
MOSTLY_Q5_K_S = 16
|
||||
MOSTLY_Q5_K_M = 17
|
||||
MOSTLY_Q6_K = 18
|
||||
MOSTLY_Q4_SQ = 19
|
||||
|
||||
class Hyperparameters:
|
||||
def __init__(self):
|
||||
|
@ -120,7 +122,7 @@ class Tensor:
|
|||
self.len_bytes = np.int64(0)
|
||||
self.use_padding = use_padding
|
||||
|
||||
def load(self, data, offset):
|
||||
def load(self, data, offset, squeezellm=False):
|
||||
orig_offset = offset
|
||||
(n_dims, name_len, dtype) = struct.unpack('<3I', data[offset:offset + 12])
|
||||
assert n_dims >= 0 and n_dims <= 4, f'Invalid tensor dimensions {n_dims}'
|
||||
|
@ -137,6 +139,9 @@ class Tensor:
|
|||
pad = ((offset + 31) & ~31) - offset if self.use_padding else 0
|
||||
offset += pad
|
||||
n_elems = np.prod(self.dims)
|
||||
if squeezellm and n_dims > 1 and dtype == gguf.GGMLQuantizationType.Q4_SQ:
|
||||
n_elems = n_elems / 8
|
||||
n_elems += self.dims[1] * 8 # add 16 fp16 elements per row
|
||||
n_bytes = np.int64(np.int64(n_elems) * np.int64(tysize)) // np.int64(blksize)
|
||||
self.start_offset = offset
|
||||
self.len_bytes = n_bytes
|
||||
|
@ -186,19 +191,20 @@ class GGMLModel:
|
|||
if len(err) > 0:
|
||||
raise ValueError(f'{err} Sorry, your {self.file_format.name}v{self.format_version} file of type {ftype.name} is not eligible for conversion.')
|
||||
|
||||
def load(self, data, offset):
|
||||
def load(self, data, offset, squeezellm=False):
|
||||
offset += self.validate_header(data, offset)
|
||||
hp = Hyperparameters()
|
||||
offset += hp.load(data, offset)
|
||||
print(f'* File format: {self.file_format.name}v{self.format_version} with ftype {hp.ftype.name}')
|
||||
self.validate_conversion(hp.ftype)
|
||||
if not squeezellm:
|
||||
self.validate_conversion(hp.ftype)
|
||||
vocab = Vocab(load_scores = self.file_format > GGMLFormat.GGML)
|
||||
offset += vocab.load(data, offset, hp.n_vocab)
|
||||
tensors: list[Tensor] = []
|
||||
tensor_map = {}
|
||||
while offset < len(data):
|
||||
tensor = Tensor(use_padding = self.file_format > GGMLFormat.GGMF)
|
||||
offset += tensor.load(data, offset)
|
||||
offset += tensor.load(data, offset, squeezellm=squeezellm)
|
||||
tensor_map[tensor.name] = len(tensors)
|
||||
tensors.append(tensor)
|
||||
self.hyperparameters = hp
|
||||
|
@ -414,6 +420,7 @@ def handle_args():
|
|||
help="directory containing tokenizer.model, if separate from model file - only meaningful with --model-metadata-dir")
|
||||
parser.add_argument("--vocabtype", choices=["spm", "bpe"], default="spm",
|
||||
help="vocab format - only meaningful with --model-metadata-dir and/or --vocab-dir (default: spm)")
|
||||
parser.add_argument("--squeezellm", action="store_true", help="Convert to SQLLM")
|
||||
return parser.parse_args()
|
||||
|
||||
def main():
|
||||
|
@ -425,7 +432,7 @@ def main():
|
|||
data = np.memmap(cfg.input, mode = 'r')
|
||||
model = GGMLModel()
|
||||
print('* Scanning GGML input file')
|
||||
offset = model.load(data, 0)
|
||||
offset = model.load(data, 0, cfg.squeezellm)
|
||||
print(f'* GGML model hyperparameters: {model.hyperparameters}')
|
||||
vocab_override = None
|
||||
params_override = None
|
||||
|
|
71
ggml.c
71
ggml.c
|
@ -6,6 +6,10 @@
|
|||
#include "k_quants.h"
|
||||
#endif
|
||||
|
||||
#ifdef GGML_USE_SQLLM
|
||||
#include "sqllm.h"
|
||||
#endif
|
||||
|
||||
#if defined(_MSC_VER) || defined(__MINGW32__)
|
||||
#include <malloc.h> // using malloc.h with MSC/MINGW
|
||||
#elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__)
|
||||
|
@ -1777,6 +1781,19 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
|
|||
.type_size = sizeof(block_q8_K),
|
||||
.is_quantized = true,
|
||||
.from_float = quantize_row_q8_K,
|
||||
},
|
||||
#endif
|
||||
#ifdef GGML_USE_SQLLM
|
||||
[GGML_TYPE_Q4_SQ] = {
|
||||
.type_name = "q4_sq",
|
||||
.blck_size = 1,
|
||||
.type_size = sizeof(int32_t),
|
||||
.is_quantized = true,
|
||||
.to_float = NULL,
|
||||
.from_float = NULL,
|
||||
.from_float_reference = NULL,
|
||||
.vec_dot = ggml_vec_dot_q4_sq_fp16,
|
||||
.vec_dot_type = GGML_TYPE_F32,
|
||||
}
|
||||
#endif
|
||||
};
|
||||
|
@ -4403,6 +4420,7 @@ enum ggml_type ggml_ftype_to_ggml_type(enum ggml_ftype ftype) {
|
|||
case GGML_FTYPE_MOSTLY_Q4_K: wtype = GGML_TYPE_Q4_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q5_K: wtype = GGML_TYPE_Q5_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q6_K: wtype = GGML_TYPE_Q6_K; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_SQ: wtype = GGML_TYPE_Q4_SQ; break;
|
||||
case GGML_FTYPE_UNKNOWN: wtype = GGML_TYPE_COUNT; break;
|
||||
case GGML_FTYPE_MOSTLY_Q4_1_SOME_F16: wtype = GGML_TYPE_COUNT; break;
|
||||
}
|
||||
|
@ -4777,7 +4795,13 @@ static struct ggml_tensor * ggml_new_tensor_impl(
|
|||
view_src = view_src->view_src;
|
||||
}
|
||||
|
||||
size_t data_size = ggml_type_size(type)*(ne[0]/ggml_blck_size(type));
|
||||
size_t data_size = 0;
|
||||
if (type == GGML_TYPE_Q4_SQ) { //SQLLM
|
||||
data_size += 16*2 + (ne[0]/2);
|
||||
} else {
|
||||
data_size += ggml_type_size(type)*(ne[0]/ggml_blck_size(type));
|
||||
}
|
||||
|
||||
for (int i = 1; i < n_dims; i++) {
|
||||
data_size *= ne[i];
|
||||
}
|
||||
|
@ -4845,8 +4869,13 @@ static struct ggml_tensor * ggml_new_tensor_impl(
|
|||
result->ne[i] = ne[i];
|
||||
}
|
||||
|
||||
result->nb[0] = ggml_type_size(type);
|
||||
result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type));
|
||||
if (type == GGML_TYPE_Q4_SQ) { //SQLLM
|
||||
result->nb[0] = ggml_type_size(type);
|
||||
result->nb[1] = result->nb[0]*(16/2 + result->ne[0]/8);
|
||||
} else {
|
||||
result->nb[0] = ggml_type_size(type);
|
||||
result->nb[1] = result->nb[0]*(result->ne[0]/ggml_blck_size(type));
|
||||
}
|
||||
for (int i = 2; i < GGML_MAX_DIMS; i++) {
|
||||
result->nb[i] = result->nb[i - 1]*result->ne[i - 1];
|
||||
}
|
||||
|
@ -9028,6 +9057,7 @@ static void ggml_compute_forward_add(
|
|||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_Q4_SQ:
|
||||
{
|
||||
ggml_compute_forward_add_q_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
|
@ -9292,6 +9322,7 @@ static void ggml_compute_forward_add1(
|
|||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_Q4_SQ:
|
||||
{
|
||||
ggml_compute_forward_add1_q_f32(params, src0, src1, dst);
|
||||
} break;
|
||||
|
@ -9407,6 +9438,7 @@ static void ggml_compute_forward_acc(
|
|||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_Q4_SQ:
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
@ -11319,7 +11351,7 @@ static void ggml_compute_forward_mul_mat(
|
|||
}
|
||||
#endif
|
||||
|
||||
if (params->type == GGML_TASK_INIT) {
|
||||
if (params->type == GGML_TASK_INIT && src0->type != GGML_TYPE_Q4_SQ) {
|
||||
if (src1->type != vec_dot_type) {
|
||||
char * wdata = params->wdata;
|
||||
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
|
||||
|
@ -11334,6 +11366,21 @@ static void ggml_compute_forward_mul_mat(
|
|||
}
|
||||
}
|
||||
|
||||
return;
|
||||
} else if (params->type == GGML_TASK_INIT) { //SQLLM - copy fp32 vec over
|
||||
ggml_fp16_t * wdata = params->wdata;
|
||||
float * srcvec;
|
||||
for (int64_t i13 = 0; i13 < ne13; ++i13) {
|
||||
for (int64_t i12 = 0; i12 < ne12; ++i12) {
|
||||
for (int64_t i11 = 0; i11 < ne11; ++i11) {
|
||||
srcvec = (float *)((char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
|
||||
for (int64_t i10 = 0; i10 < ne10; ++i10) {
|
||||
*wdata = ggml_fp32_to_fp16(srcvec[i10]);
|
||||
wdata += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -11341,8 +11388,15 @@ static void ggml_compute_forward_mul_mat(
|
|||
return;
|
||||
}
|
||||
|
||||
const void * wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
const size_t row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
|
||||
void * wdata;
|
||||
size_t row_size;
|
||||
if (src0->type != GGML_TYPE_Q4_SQ) {
|
||||
row_size = ne10*ggml_type_size(vec_dot_type)/ggml_blck_size(vec_dot_type);
|
||||
wdata = (src1->type == vec_dot_type) ? src1->data : params->wdata;
|
||||
} else {
|
||||
row_size = ne10*sizeof(int16_t); // for fp16 row
|
||||
wdata = params->wdata;
|
||||
}
|
||||
|
||||
const int64_t nr0 = ne01; // src0 rows
|
||||
const int64_t nr1 = ne11*ne12*ne13; // src1 rows
|
||||
|
@ -11406,7 +11460,7 @@ static void ggml_compute_forward_mul_mat(
|
|||
// the original src1 data pointer, so we should index using the indices directly
|
||||
// TODO: this is a bit of a hack, we should probably have a better way to handle this
|
||||
const char * src1_col = (const char *) wdata +
|
||||
(src1_cont || src1->type != vec_dot_type
|
||||
(src1_cont || src1->type != vec_dot_type || src0->type == GGML_TYPE_Q4_SQ
|
||||
? (i11 + i12*ne11 + i13*ne12*ne11)*row_size
|
||||
: (i11*nb11 + i12*nb12 + i13*nb13));
|
||||
|
||||
|
@ -11724,6 +11778,7 @@ static void ggml_compute_forward_set(
|
|||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_Q4_SQ:
|
||||
default:
|
||||
{
|
||||
GGML_ASSERT(false);
|
||||
|
@ -11894,6 +11949,7 @@ static void ggml_compute_forward_get_rows(
|
|||
case GGML_TYPE_Q4_K:
|
||||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_Q4_SQ:
|
||||
{
|
||||
ggml_compute_forward_get_rows_q(params, src0, src1, dst);
|
||||
} break;
|
||||
|
@ -12523,6 +12579,7 @@ static void ggml_compute_forward_alibi(
|
|||
case GGML_TYPE_Q5_K:
|
||||
case GGML_TYPE_Q6_K:
|
||||
case GGML_TYPE_Q8_K:
|
||||
case GGML_TYPE_Q4_SQ:
|
||||
case GGML_TYPE_I8:
|
||||
case GGML_TYPE_I16:
|
||||
case GGML_TYPE_I32:
|
||||
|
|
2
ggml.h
2
ggml.h
|
@ -304,6 +304,7 @@ extern "C" {
|
|||
GGML_TYPE_Q5_K = 13,
|
||||
GGML_TYPE_Q6_K = 14,
|
||||
GGML_TYPE_Q8_K = 15,
|
||||
GGML_TYPE_Q4_SQ = 16,
|
||||
GGML_TYPE_I8,
|
||||
GGML_TYPE_I16,
|
||||
GGML_TYPE_I32,
|
||||
|
@ -332,6 +333,7 @@ extern "C" {
|
|||
GGML_FTYPE_MOSTLY_Q4_K = 12, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_Q5_K = 13, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_Q6_K = 14, // except 1d tensors
|
||||
GGML_FTYPE_MOSTLY_Q4_SQ = 16, // except 1d tensors
|
||||
};
|
||||
|
||||
// available tensor operations:
|
||||
|
|
|
@ -389,6 +389,7 @@ class GGMLQuantizationType(IntEnum):
|
|||
Q5_K = 13
|
||||
Q6_K = 14
|
||||
Q8_K = 15
|
||||
Q4_SQ = 16
|
||||
|
||||
|
||||
class GGUFValueType(IntEnum):
|
||||
|
|
|
@ -1301,6 +1301,7 @@ struct llama_model_loader {
|
|||
case GGML_TYPE_Q4_K: ftype = LLAMA_FTYPE_MOSTLY_Q4_K_M; break;
|
||||
case GGML_TYPE_Q5_K: ftype = LLAMA_FTYPE_MOSTLY_Q5_K_M; break;
|
||||
case GGML_TYPE_Q6_K: ftype = LLAMA_FTYPE_MOSTLY_Q6_K; break;
|
||||
case GGML_TYPE_Q4_SQ: ftype = LLAMA_FTYPE_MOSTLY_Q4_SQ; break;
|
||||
default:
|
||||
{
|
||||
LLAMA_LOG_WARN("%s: unknown type %s\n", __func__, ggml_type_name(type_max));
|
||||
|
@ -1566,6 +1567,9 @@ std::string llama_model_ftype_name(enum llama_ftype ftype) {
|
|||
case LLAMA_FTYPE_MOSTLY_Q5_K_M: return "mostly Q5_K - Medium";
|
||||
case LLAMA_FTYPE_MOSTLY_Q6_K: return "mostly Q6_K";
|
||||
|
||||
//SQLLM
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_SQ: return "mostly Q4_SQ";
|
||||
|
||||
default: return "unknown, may not work";
|
||||
}
|
||||
}
|
||||
|
@ -2950,7 +2954,7 @@ static bool llama_eval_internal(
|
|||
// TODO: this is mostly important for Apple Silicon where CBLAS is still performing very well
|
||||
// we still need some threads to process all non-mul_mat ops, but not too much to avoid interfering
|
||||
// with the BLAS calls. need a better solution
|
||||
if (N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas()) {
|
||||
if (N >= 32 && ggml_cpu_has_blas() && !ggml_cpu_has_gpublas() && !(model.ftype == LLAMA_FTYPE_MOSTLY_Q4_SQ)) {
|
||||
n_threads = std::min(4, n_threads);
|
||||
}
|
||||
|
||||
|
@ -4721,6 +4725,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
|
|||
case LLAMA_FTYPE_MOSTLY_Q5_K_S:
|
||||
case LLAMA_FTYPE_MOSTLY_Q5_K_M: quantized_type = GGML_TYPE_Q5_K; break;
|
||||
case LLAMA_FTYPE_MOSTLY_Q6_K: quantized_type = GGML_TYPE_Q6_K; break;
|
||||
#endif
|
||||
#ifdef GGML_USE_SQLLM
|
||||
case LLAMA_FTYPE_MOSTLY_Q4_SQ: quantized_type = GGML_TYPE_Q4_SQ; break;
|
||||
#endif
|
||||
default: throw std::runtime_error(format("invalid output file type %d\n", ftype));
|
||||
}
|
||||
|
|
1
llama.h
1
llama.h
|
@ -104,6 +104,7 @@ extern "C" {
|
|||
LLAMA_FTYPE_MOSTLY_Q5_K_S = 16,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q5_K_M = 17,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q6_K = 18,// except 1d tensors
|
||||
LLAMA_FTYPE_MOSTLY_Q4_SQ = 19, // except 1d tensors
|
||||
|
||||
LLAMA_FTYPE_GUESSED = 1024, // not specified in the model file
|
||||
};
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue