llama : add Mixtral support (#4406)
* convert : support Mixtral as LLAMA arch
* convert : fix n_ff typo
* llama : model loading
* ggml : sync latest ggml_mul_mat_id
* llama : update graph to support MoE
* llama : fix cur -> cur_expert
* llama : first working version
* llama : fix expert weighting in the FFN
* ggml : ggml_get_rows support 2D indexing [n_tokens, n_experts] (cpu only)
* ggml : add n_as argument to ggml_mul_mat_id
* ggml : fix ggml_get_rows to take into account ne02 / ne11
* metal : add more general support for ggml_get_rows + tests
* llama : add basic support for offloading moe with CUDA
* metal : add/mul/div use general kernel when src1 not cont
* metal : reduce the kernel launches for ggml_mul_mat_id
* ggml : get_rows : support non-contiguos tensors with gaps, generalize up to 3D
* ggml : update get_rows f16 and q
* cuda : support non-contiguous src1 in get_rows
* llama : offload missing ffn_moe_silu
* metal : fix ggml_get_rows to work with non-cont src1
* metal : add indirect mat-vec kernels for all quantization types
* llama : do not quantize expert gating tensors
* llama : add n_expert and n_expert_used to hparams + change quants
* test-backend-ops : add moe test
* cuda : fix get_rows when ncols is odd
* convert : determine n_ctx correctly
* metal : fix ggml_mul_mat_id for F32
* test-backend-ops : make experts more evenly probable (test_moe)
* test-backend-ops : cleanup, add moe test for batches
* test-backend-ops : add cpy from f32 -> all types test
* test-backend-ops : fix dequantize block offset
* llama : fix hard-coded number of experts
* test-backend-ops : simplify and disable slow tests to avoid CI timeout
* test-backend-ops : disable MOE test with thread sanitizer
* cuda : fix mul_mat_id with multi gpu
* convert : use 1e6 rope_freq_base for mixtral
* convert : fix style
* convert : support safetensors format
* gguf-py : bump version
* metal : add cpy f16 -> f32 kernel
* metal : fix binary ops for ne10 % 4 != 0
* test-backend-ops : add one more sum_rows test
* ggml : do not use BLAS with ggml_mul_mat_id
* convert-hf : support for mixtral-instruct (#4428)
* convert : typo fix, add additional hyperparameters, use LLaMA arch for Mixtral-instruct
* convert : use sentencepiece tokenizer for Mixtral-instruct
* convert : make flake8 happy
* metal : fix soft_max kernels
ref: 1914017863
* metal : limit kernels to not use more than the allowed threads
---------
Co-authored-by: Georgi Gerganov <ggerganov@gmail.com>
Co-authored-by: Radek Pilar <github@mrkva.eu>
This commit is contained in:
parent
fecac45658
commit
799a1cb13b
14 changed files with 2370 additions and 395 deletions
74
convert.py
74
convert.py
|
@ -42,6 +42,7 @@ NDArray: TypeAlias = 'np.ndarray[Any, Any]'
|
|||
ARCH = gguf.MODEL_ARCH.LLAMA
|
||||
|
||||
DEFAULT_CONCURRENCY = 8
|
||||
|
||||
#
|
||||
# data types
|
||||
#
|
||||
|
@ -62,10 +63,10 @@ class UnquantizedDataType(DataType):
|
|||
pass
|
||||
|
||||
|
||||
DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
|
||||
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
|
||||
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
|
||||
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
|
||||
DT_F16 = UnquantizedDataType('F16', dtype = np.dtype(np.float16), valid_conversions = ['F32', 'Q8_0'])
|
||||
DT_F32 = UnquantizedDataType('F32', dtype = np.dtype(np.float32), valid_conversions = ['F16', 'Q8_0'])
|
||||
DT_I32 = UnquantizedDataType('I32', dtype = np.dtype(np.int16), valid_conversions = [])
|
||||
DT_BF16 = UnquantizedDataType('BF16', dtype = np.dtype(np.uint16), valid_conversions = ['F32', 'F16', 'Q8_0'])
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
@ -151,14 +152,16 @@ GGML_FILE_TYPE_TO_DATA_TYPE: dict[GGMLFileType, DataType] = {
|
|||
|
||||
@dataclass
|
||||
class Params:
|
||||
n_vocab: int
|
||||
n_embd: int
|
||||
n_layer: int
|
||||
n_ctx: int
|
||||
n_ff: int
|
||||
n_head: int
|
||||
n_head_kv: int
|
||||
f_norm_eps: float
|
||||
n_vocab: int
|
||||
n_embd: int
|
||||
n_layer: int
|
||||
n_ctx: int
|
||||
n_ff: int
|
||||
n_head: int
|
||||
n_head_kv: int
|
||||
n_experts: int | None = None
|
||||
n_experts_used: int | None = None
|
||||
f_norm_eps: float | None = None
|
||||
|
||||
rope_scaling_type: gguf.RopeScalingType | None = None
|
||||
f_rope_freq_base: float | None = None
|
||||
|
@ -233,6 +236,13 @@ class Params:
|
|||
raise Exception("failed to guess 'n_ctx'. This model is unknown or unsupported.\n"
|
||||
"Suggestion: provide 'config.json' of the model in the same directory containing model files.")
|
||||
|
||||
n_experts = None
|
||||
n_experts_used = None
|
||||
|
||||
if "num_local_experts" in config:
|
||||
n_experts = config["num_local_experts"]
|
||||
n_experts_used = config["num_experts_per_tok"]
|
||||
|
||||
return Params(
|
||||
n_vocab = config["vocab_size"],
|
||||
n_embd = config["hidden_size"],
|
||||
|
@ -241,6 +251,8 @@ class Params:
|
|||
n_ff = config["intermediate_size"],
|
||||
n_head = (n_head := config["num_attention_heads"]),
|
||||
n_head_kv = config.get("num_key_value_heads", n_head),
|
||||
n_experts = n_experts,
|
||||
n_experts_used = n_experts_used,
|
||||
f_norm_eps = config["rms_norm_eps"],
|
||||
f_rope_freq_base = config.get("rope_theta"),
|
||||
rope_scaling_type = rope_scaling_type,
|
||||
|
@ -255,8 +267,15 @@ class Params:
|
|||
def loadOriginalParamsJson(model: LazyModel, config_path: Path) -> Params:
|
||||
config = json.load(open(config_path))
|
||||
|
||||
n_experts = None
|
||||
n_experts_used = None
|
||||
f_rope_freq_base = None
|
||||
|
||||
# hack to determine LLaMA v1 vs v2 vs CodeLlama
|
||||
if config.get("rope_theta") == 1000000:
|
||||
if config.get("moe"):
|
||||
# Mixtral
|
||||
n_ctx = 32768
|
||||
elif config.get("rope_theta") == 1000000:
|
||||
# CodeLlama
|
||||
n_ctx = 16384
|
||||
elif config["norm_eps"] == 1e-05:
|
||||
|
@ -266,16 +285,27 @@ class Params:
|
|||
# LLaMA v1
|
||||
n_ctx = 2048
|
||||
|
||||
if "layers.0.feed_forward.w1.weight" in model:
|
||||
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0]
|
||||
|
||||
if config.get("moe"):
|
||||
n_ff = model["layers.0.feed_forward.experts.0.w1.weight"].shape[0]
|
||||
n_experts = config["moe"]["num_experts"]
|
||||
n_experts_used = config["moe"]["num_experts_per_tok"]
|
||||
f_rope_freq_base = 1e6
|
||||
|
||||
return Params(
|
||||
n_vocab = model["tok_embeddings.weight"].shape[0],
|
||||
n_embd = config["dim"],
|
||||
n_layer = config["n_layers"],
|
||||
n_ctx = n_ctx,
|
||||
n_ff = model["layers.0.feed_forward.w1.weight"].shape[0],
|
||||
n_ff = n_ff,
|
||||
n_head = (n_head := config["n_heads"]),
|
||||
n_head_kv = config.get("n_kv_heads", n_head),
|
||||
n_experts = n_experts,
|
||||
n_experts_used = n_experts_used,
|
||||
f_norm_eps = config["norm_eps"],
|
||||
f_rope_freq_base = config.get("rope_theta"),
|
||||
f_rope_freq_base = config.get("rope_theta", f_rope_freq_base),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
|
@ -832,7 +862,17 @@ class OutputFile:
|
|||
self.gguf.add_rope_dimension_count(params.n_embd // params.n_head)
|
||||
self.gguf.add_head_count (params.n_head)
|
||||
self.gguf.add_head_count_kv (params.n_head_kv)
|
||||
self.gguf.add_layer_norm_rms_eps (params.f_norm_eps)
|
||||
|
||||
if params.n_experts:
|
||||
self.gguf.add_expert_count(params.n_experts)
|
||||
|
||||
if params.n_experts_used:
|
||||
self.gguf.add_expert_used_count(params.n_experts_used)
|
||||
|
||||
if params.f_norm_eps:
|
||||
self.gguf.add_layer_norm_rms_eps(params.f_norm_eps)
|
||||
else:
|
||||
raise ValueError('f_norm_eps is None')
|
||||
|
||||
if params.f_rope_freq_base is not None:
|
||||
self.gguf.add_rope_freq_base(params.f_rope_freq_base)
|
||||
|
@ -956,7 +996,7 @@ class OutputFile:
|
|||
|
||||
|
||||
def pick_output_type(model: LazyModel, output_type_str: str | None) -> GGMLFileType:
|
||||
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) +".weight"].data_type
|
||||
wq_type = model[gguf.TENSOR_NAMES[gguf.MODEL_TENSOR.ATTN_Q].format(bid=0) + ".weight"].data_type
|
||||
|
||||
if output_type_str == "f32" or (output_type_str is None and wq_type == DT_F32):
|
||||
return GGMLFileType.AllF32
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue