add dequantize

This commit is contained in:
Eddie-Wang1120 2024-06-19 21:48:04 +08:00
parent 89c7e4c1dd
commit fcf2da4621
6 changed files with 42 additions and 34 deletions

View file

@ -1420,40 +1420,23 @@ class BitnetModel(Model):
return weight.type(dtype), scale.type(torch.float32)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# transform weight into 1/0/-1 (in fp32)
if name.endswith(("q_proj.weight", "k_proj.weight", "v_proj.weight",
"down_proj.weight", "up_proj.weight", "gate_proj.weight",
"o_proj.weight")):
new_name = self.map_tensor_name(name)
if any(self.match_model_tensor_name(new_name, key, bid) for key in [
gguf.MODEL_TENSOR.ATTN_Q,
gguf.MODEL_TENSOR.ATTN_K,
gguf.MODEL_TENSOR.ATTN_V,
gguf.MODEL_TENSOR.ATTN_OUT,
gguf.MODEL_TENSOR.FFN_UP,
gguf.MODEL_TENSOR.FFN_DOWN,
gguf.MODEL_TENSOR.FFN_GATE,
]):
# transform weight into 1/0/-1 (in fp32)
weight_torch, scale_torch = self.weight_quant(data_torch)
tensors: list[tuple[str, Tensor]] = []
if name.endswith("q_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_Q, bid, suffix=".scale"), scale_torch))
elif name.endswith("k_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_K, bid, suffix=".scale"), scale_torch))
elif name.endswith("v_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid, suffix=".scale"), scale_torch))
elif name.endswith("o_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_OUT, bid, suffix=".scale"), scale_torch))
elif name.endswith("up_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP, bid, suffix=".scale"), scale_torch))
elif name.endswith("down_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_DOWN, bid, suffix=".scale"), scale_torch))
elif name.endswith("gate_proj.weight"):
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid), weight_torch))
tensors.append((self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE, bid, suffix=".scale"), scale_torch))
if len(tensors) == 0:
tensors.append((self.map_tensor_name(name), data_torch))
return tensors
yield (new_name, weight_torch)
yield (new_name.removesuffix(".weight") + ".scale", scale_torch)
else:
yield (new_name, data_torch)
@Model.register("GrokForCausalLM")

View file

@ -1545,6 +1545,26 @@ void quantize_row_q8_1(const float * restrict x, void * restrict vy, int64_t k)
#endif
}
void dequantize_row_q2_2(const block_q2_2 * restrict x, float * restrict y, int64_t k) {
static const int qk = QK2_2;
assert(k % qk == 0);
const int nb = k / qk;
for (int i = 0; i < nb; i++) {
for (int j = 0; j < qk/4; ++j) {
const int8_t * q = (const int8_t *) (q22_grid + x[i].qs[j]);
*y++ = (float) q[0];
*y++ = (float) q[1];
*y++ = (float) q[2];
*y++ = (float) q[3];
}
}
}
void dequantize_row_q4_0(const block_q4_0 * restrict x, float * restrict y, int64_t k) {
static const int qk = QK4_0;

View file

@ -55,6 +55,7 @@ void quantize_row_iq3_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y,
void quantize_row_iq2_s (const float * GGML_RESTRICT x, void * GGML_RESTRICT y, int64_t k);
// Dequantization
void dequantize_row_q2_2(const block_q2_2 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q4_0(const block_q4_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q4_1(const block_q4_1 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);
void dequantize_row_q5_0(const block_q5_0 * GGML_RESTRICT x, float * GGML_RESTRICT y, int64_t k);

1
ggml.c
View file

@ -620,6 +620,7 @@ static const ggml_type_traits_t type_traits[GGML_TYPE_COUNT] = {
.type_name = "q2_2",
.blck_size = QK2_2,
.type_size = sizeof(block_q2_2),
.to_float = (ggml_to_float_t) dequantize_row_q2_2,
.is_quantized = true,
.from_float = quantize_row_q2_2,
.from_float_reference = (ggml_from_float_t) quantize_row_q2_2_reference,

View file

@ -923,6 +923,7 @@ class GGMLQuantizationType(IntEnum):
F64 = 28
IQ1_M = 29
BF16 = 30
Q2_2 = 31
# TODO: add GGMLFileType from ggml_ftype in ggml.h
@ -964,6 +965,7 @@ class LlamaFileType(IntEnum):
MOSTLY_IQ4_XS = 30 # except 1d tensors
MOSTLY_IQ1_M = 31 # except 1d tensors
MOSTLY_BF16 = 32 # except 1d tensors
MOSTLY_Q2_2 = 33 # except 1d tensors
GUESSED = 1024 # not specified in the model file
@ -1010,6 +1012,7 @@ QK_K = 256
GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.F32: (1, 4),
GGMLQuantizationType.F16: (1, 2),
GGMLQuantizationType.Q2_2: (32, 8),
GGMLQuantizationType.Q4_0: (32, 2 + 16),
GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),

View file

@ -3885,6 +3885,7 @@ static std::string llama_model_ftype_name(llama_ftype ftype) {
case LLAMA_FTYPE_ALL_F32: return "all F32";
case LLAMA_FTYPE_MOSTLY_F16: return "F16";
case LLAMA_FTYPE_MOSTLY_BF16: return "BF16";
case LLAMA_FTYPE_MOSTLY_Q2_2: return "Q2_2";
case LLAMA_FTYPE_MOSTLY_Q4_0: return "Q4_0";
case LLAMA_FTYPE_MOSTLY_Q4_1: return "Q4_1";
case LLAMA_FTYPE_MOSTLY_Q4_1_SOME_F16:
@ -11705,7 +11706,6 @@ struct llm_build_context {
cb(cur, "ffn_gate", il);
cur = ggml_silu(ctx0, cur);
cb(cur, "ffn_silu", il);