merge PowerInfer impl from the internal codebase
This commit is contained in:
parent
6bb4908a17
commit
a3c295a2ae
17 changed files with 4515 additions and 169 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -98,3 +98,5 @@ tests/test-tokenizer-0-llama
|
|||
tests/test-tokenizer-0-falcon
|
||||
tests/test-tokenizer-1-llama
|
||||
tests/test-tokenizer-1-bpe
|
||||
|
||||
build-info.h
|
||||
|
|
|
@ -471,6 +471,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
|
|||
break;
|
||||
}
|
||||
params.lora_base = argv[i];
|
||||
} else if (arg == "--mlp-adapter") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
break;
|
||||
}
|
||||
params.mlp_adapter = argv[i];
|
||||
} else if (arg == "--mmproj") {
|
||||
if (++i >= argc) {
|
||||
invalid_param = true;
|
||||
|
@ -950,8 +956,26 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||
return std::make_tuple(nullptr, nullptr);
|
||||
}
|
||||
|
||||
auto cparams = llama_context_params_from_gpt_params(params);
|
||||
if (llama_use_sparse_inference(model)) {
|
||||
fprintf(stderr, "%s: postprocessing PowerInfer model '%s'\n", __func__, params.model.c_str());
|
||||
if (!params.mlp_adapter.empty()) {
|
||||
fprintf(stderr, "%s: warning: --mlp-adapter is deprecated and has no effect\n", __func__);
|
||||
int err = llama_model_apply_mlp_from_file(model, params.mlp_adapter.c_str(), true);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "%s: error: failed to apply mlp adapter\n", __func__);
|
||||
llama_free_model(model);
|
||||
return std::make_tuple(nullptr, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
if (llama_model_apply_augmentation(model) != 0) {
|
||||
fprintf(stderr, "%s: error: failed to apply augmentation\n", __func__);
|
||||
llama_free_model(model);
|
||||
return std::make_tuple(nullptr, nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
auto cparams = llama_context_params_from_gpt_params(params);
|
||||
llama_context * lctx = llama_new_context_with_model(model, cparams);
|
||||
if (lctx == NULL) {
|
||||
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
|
||||
|
@ -981,6 +1005,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
|
|||
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
|
||||
}
|
||||
|
||||
|
||||
|
||||
{
|
||||
LOG("warming up the model with an empty run\n");
|
||||
|
||||
|
@ -1320,6 +1346,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
|
|||
fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
|
||||
}
|
||||
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
|
||||
fprintf(stream, "mlp_adapter: %s\n", params.mlp_adapter.c_str());
|
||||
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
|
||||
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
|
||||
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);
|
||||
|
|
|
@ -90,6 +90,8 @@ struct gpt_params {
|
|||
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
|
||||
std::string lora_base = ""; // base model path for the lora adapter
|
||||
|
||||
std::string mlp_adapter = ""; // sparse activation mlp adapter path
|
||||
|
||||
int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
|
||||
int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
|
||||
// (which is more convenient to use for plotting)
|
||||
|
|
601
convert-hf-to-powerinfer-gguf.py
Normal file
601
convert-hf-to-powerinfer-gguf.py
Normal file
|
@ -0,0 +1,601 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
import argparse
|
||||
import contextlib
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
import sys
|
||||
from enum import IntEnum
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, Optional, cast
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from torch import Tensor
|
||||
|
||||
if "NO_LOCAL_GGUF" not in os.environ:
|
||||
sys.path.insert(1, str(Path(__file__).parent / "gguf-py"))
|
||||
import gguf
|
||||
|
||||
|
||||
###### MODEL DEFINITIONS ######
|
||||
|
||||
|
||||
class SentencePieceTokenTypes(IntEnum):
|
||||
NORMAL = 1
|
||||
UNKNOWN = 2
|
||||
CONTROL = 3
|
||||
USER_DEFINED = 4
|
||||
UNUSED = 5
|
||||
BYTE = 6
|
||||
|
||||
|
||||
class ReluMLP(tnn.Module):
|
||||
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
|
||||
super(ReluMLP, self).__init__()
|
||||
self.fc1 = tnn.Linear(input_dim, hidden_dim, bias=False)
|
||||
self.relu = tnn.ReLU()
|
||||
self.fc2 = tnn.Linear(hidden_dim, output_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def from_file(model_file: Path):
|
||||
model = torch.load(model_file, map_location="cpu")
|
||||
hidden_size, input_size = model.get("fc1.weight").shape
|
||||
output_size, _ = model.get("fc2.weight").shape
|
||||
mlp = ReluMLP(input_size, hidden_size, output_size)
|
||||
mlp.load_state_dict(model)
|
||||
return mlp
|
||||
|
||||
|
||||
class Model(ABC):
|
||||
"""Base class for model conversion"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dir_model: Path,
|
||||
dir_mlp_pred: Path,
|
||||
ftype: int,
|
||||
fname_out: Path,
|
||||
is_big_endian: bool,
|
||||
):
|
||||
self.dir_model = dir_model
|
||||
self.dir_mlp_pred = dir_mlp_pred
|
||||
self.ftype = ftype
|
||||
self.fname_out = fname_out
|
||||
self.is_big_endian = is_big_endian
|
||||
self.endianess = (
|
||||
gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
|
||||
)
|
||||
self.is_safetensors = self._is_model_safetensors()
|
||||
self.num_parts = Model.count_model_parts(
|
||||
self.dir_model, ".safetensors" if self.is_safetensors else ".bin"
|
||||
)
|
||||
self.part_names = self._get_part_names()
|
||||
self.hparams = Model.load_hparams(self.dir_model)
|
||||
self.model_arch = self._get_model_architecture()
|
||||
self.gguf_writer = gguf.GGUFWriter(
|
||||
fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file = False
|
||||
)
|
||||
|
||||
def set_vocab(self):
|
||||
self._set_vocab_gpt2()
|
||||
|
||||
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
|
||||
for model_layer, part_name in self._get_mlp_part_layer_names():
|
||||
print(f"gguf: loading mlp part '{part_name}'")
|
||||
mlp_model = ReluMLP.from_file(self.dir_mlp_pred / part_name)
|
||||
for name, data in mlp_model.state_dict().items():
|
||||
yield f"blk.{model_layer}.{name}", data
|
||||
|
||||
for part_name in self.part_names:
|
||||
print(f"gguf: loading model part '{part_name}'")
|
||||
ctx: ContextManager[Any]
|
||||
if self.is_safetensors:
|
||||
from safetensors import safe_open
|
||||
|
||||
ctx = cast(
|
||||
ContextManager[Any],
|
||||
safe_open(self.dir_model / part_name, framework="pt", device="cpu"),
|
||||
)
|
||||
else:
|
||||
ctx = contextlib.nullcontext(
|
||||
torch.load(self.dir_model / part_name, map_location="cpu")
|
||||
)
|
||||
|
||||
with ctx as model_part:
|
||||
for name in model_part.keys():
|
||||
data = (
|
||||
model_part.get_tensor(name)
|
||||
if self.is_safetensors
|
||||
else model_part[name]
|
||||
)
|
||||
yield name, data
|
||||
|
||||
@abstractmethod
|
||||
def set_gguf_parameters(self):
|
||||
pass
|
||||
# self.gguf_writer.add_name(self.dir_model.name)
|
||||
# self.gguf_writer.add_block_count(
|
||||
# self.hparams.get(
|
||||
# "n_layers",
|
||||
# self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")),
|
||||
# )
|
||||
# )
|
||||
# if (n_ctx := self.hparams.get("max_position_embeddings")) is not None:
|
||||
# self.gguf_writer.add_context_length(n_ctx)
|
||||
# if (n_embd := self.hparams.get("hidden_size")) is not None:
|
||||
# self.gguf_writer.add_embedding_length(n_embd)
|
||||
# if (n_ff := self.hparams.get("intermediate_size")) is not None:
|
||||
# self.gguf_writer.add_feed_forward_length(n_ff)
|
||||
# if (n_head := self.hparams.get("num_attention_head")) is not None:
|
||||
# self.gguf_writer.add_head_count(n_head)
|
||||
# self.gguf_writer.add_parallel_residual(
|
||||
# self.hparams.get("use_parallel_residual", True)
|
||||
# )
|
||||
|
||||
@abstractmethod
|
||||
def write_tensors(self):
|
||||
pass
|
||||
|
||||
def write(self):
|
||||
self.write_tensors()
|
||||
self.gguf_writer.write_header_to_file()
|
||||
self.gguf_writer.write_kv_data_to_file()
|
||||
self.gguf_writer.write_tensors_to_file()
|
||||
self.gguf_writer.close()
|
||||
|
||||
def write_vocab(self):
|
||||
self.gguf_writer.write_header_to_file()
|
||||
self.gguf_writer.write_kv_data_to_file()
|
||||
self.gguf_writer.close()
|
||||
|
||||
@staticmethod
|
||||
def count_model_parts(dir_model: Path, prefix: str) -> int:
|
||||
num_parts = 0
|
||||
for filename in os.listdir(dir_model):
|
||||
if filename.endswith(prefix):
|
||||
num_parts += 1
|
||||
|
||||
return num_parts
|
||||
|
||||
@staticmethod
|
||||
def load_hparams(dir_model):
|
||||
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
|
||||
@staticmethod
|
||||
def from_model_architecture(model_architecture):
|
||||
if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
|
||||
return FalconModel
|
||||
if model_architecture == "LlamaForCausalLM":
|
||||
return LlamaModel
|
||||
|
||||
raise NotImplementedError(f'Architecture "{model_architecture}" not supported!')
|
||||
|
||||
def _is_model_safetensors(self) -> bool:
|
||||
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
|
||||
|
||||
def _get_mlp_part_layer_names(self):
|
||||
"""Returns a generator of (index, name) for MLP predictors of each model layer"""
|
||||
n_mlp_parts = Model.count_model_parts(self.dir_mlp_pred, ".pt")
|
||||
return ((n, f"model_{n}.pt") for n in range(n_mlp_parts))
|
||||
|
||||
def _get_part_names(self):
|
||||
if self.is_safetensors:
|
||||
if self.num_parts == 1: # there's only one .safetensors file
|
||||
return ("model.safetensors",)
|
||||
return (
|
||||
f"model-{n:05}-of-{self.num_parts:05}.safetensors"
|
||||
for n in range(1, self.num_parts + 1)
|
||||
)
|
||||
|
||||
if self.num_parts == 1: # there's only one .bin file
|
||||
return ("pytorch_model.bin",)
|
||||
return (
|
||||
f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin"
|
||||
for n in range(1, self.num_parts + 1)
|
||||
)
|
||||
|
||||
def _get_model_architecture(self) -> gguf.MODEL_ARCH:
|
||||
arch = self.hparams["architectures"][0]
|
||||
if arch == "FalconForCausalLM":
|
||||
return gguf.MODEL_ARCH.FALCON
|
||||
if arch == "RWForCausalLM" or arch == "LlamaForCausalLM":
|
||||
return gguf.MODEL_ARCH.LLAMA
|
||||
|
||||
raise NotImplementedError(f'Architecture "{arch}" not supported!')
|
||||
|
||||
def _translate_tensor_key(
|
||||
self, key: str, try_suffixes=(".weight", ".bias")
|
||||
) -> Optional[str]:
|
||||
block_count = self.hparams.get(
|
||||
"n_layers",
|
||||
self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")),
|
||||
)
|
||||
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
|
||||
arch_tensor_key = tensor_map.get_name(key, try_suffixes=try_suffixes)
|
||||
if arch_tensor_key is not None:
|
||||
return arch_tensor_key
|
||||
# check and handle ReluMLP layers
|
||||
mlp_match = re.match(r"^blk\.\d+\.fc\d\.weight$", key)
|
||||
if mlp_match:
|
||||
return mlp_match.group(0)
|
||||
return None
|
||||
|
||||
def _set_vocab_gpt2(self):
|
||||
dir_model = self.dir_model
|
||||
hparams = self.hparams
|
||||
tokens: list[bytearray] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
from transformers import AutoTokenizer # type: ignore[attr-defined]
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(dir_model)
|
||||
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
|
||||
assert max(tokenizer.vocab.values()) < vocab_size
|
||||
|
||||
reverse_vocab = {
|
||||
id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()
|
||||
}
|
||||
added_vocab = tokenizer.get_added_vocab()
|
||||
|
||||
for i in range(vocab_size):
|
||||
if i not in reverse_vocab:
|
||||
pad_token = f"[PAD{i}]".encode("utf-8")
|
||||
tokens.append(bytearray(pad_token))
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
elif reverse_vocab[i] in added_vocab:
|
||||
tokens.append(reverse_vocab[i])
|
||||
if tokenizer.added_tokens_decoder[i].special:
|
||||
toktypes.append(gguf.TokenType.CONTROL)
|
||||
else:
|
||||
toktypes.append(gguf.TokenType.USER_DEFINED)
|
||||
else:
|
||||
tokens.append(reverse_vocab[i])
|
||||
toktypes.append(gguf.TokenType.NORMAL)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("gpt2")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
def _set_vocab_sentencepiece(self):
|
||||
from sentencepiece import SentencePieceProcessor
|
||||
|
||||
tokenizer_path = self.dir_model / "tokenizer.model"
|
||||
|
||||
tokens: list[bytes] = []
|
||||
scores: list[float] = []
|
||||
toktypes: list[int] = []
|
||||
|
||||
if not tokenizer_path.is_file():
|
||||
print(f"Error: Missing {tokenizer_path}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
tokenizer = SentencePieceProcessor(str(tokenizer_path))
|
||||
vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size())
|
||||
|
||||
for token_id in range(vocab_size):
|
||||
piece = tokenizer.id_to_piece(token_id)
|
||||
text = piece.encode("utf-8")
|
||||
score = tokenizer.get_score(token_id)
|
||||
|
||||
toktype = SentencePieceTokenTypes.NORMAL
|
||||
if tokenizer.is_unknown(token_id):
|
||||
toktype = SentencePieceTokenTypes.UNKNOWN
|
||||
elif tokenizer.is_control(token_id):
|
||||
toktype = SentencePieceTokenTypes.CONTROL
|
||||
elif tokenizer.is_unused(token_id):
|
||||
toktype = SentencePieceTokenTypes.UNUSED
|
||||
elif tokenizer.is_byte(token_id):
|
||||
toktype = SentencePieceTokenTypes.BYTE
|
||||
|
||||
tokens.append(text)
|
||||
scores.append(score)
|
||||
toktypes.append(toktype)
|
||||
|
||||
added_tokens_file = self.dir_model / "added_tokens.json"
|
||||
if added_tokens_file.is_file():
|
||||
with open(added_tokens_file, "r", encoding="utf-8") as f:
|
||||
added_tokens_json = json.load(f)
|
||||
|
||||
for key in added_tokens_json:
|
||||
tokens.append(key.encode("utf-8"))
|
||||
scores.append(-1000.0)
|
||||
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
|
||||
|
||||
self.gguf_writer.add_tokenizer_model("llama")
|
||||
self.gguf_writer.add_token_list(tokens)
|
||||
self.gguf_writer.add_token_scores(scores)
|
||||
self.gguf_writer.add_token_types(toktypes)
|
||||
|
||||
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
|
||||
special_vocab.add_to_gguf(self.gguf_writer)
|
||||
|
||||
|
||||
class LlamaModel(Model):
|
||||
def set_vocab(self):
|
||||
self._set_vocab_sentencepiece()
|
||||
|
||||
def set_gguf_parameters(self):
|
||||
self.gguf_writer.add_name("Llama")
|
||||
self.gguf_writer.add_context_length(2048) # not in config.json
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
|
||||
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
|
||||
self.gguf_writer.add_rope_dimension_count(
|
||||
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
|
||||
)
|
||||
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
|
||||
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
|
||||
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
|
||||
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def write_tensors(self):
|
||||
for name, data_torch in self.get_tensors():
|
||||
# we don't need these
|
||||
if name.endswith(
|
||||
(
|
||||
".attention.masked_bias",
|
||||
".attention.bias",
|
||||
".attention.rotary_emb.inv_freq",
|
||||
)
|
||||
):
|
||||
continue
|
||||
|
||||
old_dtype = data_torch.dtype
|
||||
|
||||
# convert any unsupported data types to float32
|
||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||
data_torch = data_torch.to(torch.float32)
|
||||
|
||||
data = data_torch.squeeze().numpy()
|
||||
|
||||
# map tensor names
|
||||
new_name = self._translate_tensor_key(name)
|
||||
if new_name is None:
|
||||
print(f"Can not map tensor {name!r}")
|
||||
sys.exit()
|
||||
|
||||
# We need to transpose the weight matrices for the FFN Down layers to support the
|
||||
# Axpy operation in PowerInfer. So we don't need to transpose them at runtime.
|
||||
if "ffn_down" in new_name:
|
||||
new_name = new_name.replace("ffn_down", "ffn_down_t")
|
||||
data = data.T
|
||||
|
||||
n_dims = len(data.shape)
|
||||
data_dtype = data.dtype
|
||||
|
||||
# if f32 desired, convert any float16 to float32
|
||||
if self.ftype == 0 and data_dtype == np.float16:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
||||
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||
if (
|
||||
self.ftype == 1
|
||||
and data_dtype == np.float32
|
||||
and name.endswith(".weight")
|
||||
and n_dims == 2
|
||||
):
|
||||
data = data.astype(np.float16)
|
||||
|
||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
class FalconModel(Model):
|
||||
def set_gguf_parameters(self):
|
||||
block_count = self.hparams.get("num_hidden_layers")
|
||||
if block_count is None:
|
||||
block_count = self.hparams["n_layer"] # old name
|
||||
|
||||
n_head = self.hparams.get("num_attention_heads")
|
||||
if n_head is None:
|
||||
n_head = self.hparams["n_head"] # old name
|
||||
|
||||
n_head_kv = self.hparams.get("num_kv_heads")
|
||||
if n_head_kv is None:
|
||||
n_head_kv = self.hparams.get("n_head_kv", 1) # old name
|
||||
|
||||
self.gguf_writer.add_name("Falcon")
|
||||
self.gguf_writer.add_context_length(2048) # not in config.json
|
||||
self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
|
||||
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
|
||||
self.gguf_writer.add_block_count(block_count)
|
||||
self.gguf_writer.add_head_count(n_head)
|
||||
self.gguf_writer.add_head_count_kv(n_head_kv)
|
||||
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
|
||||
self.gguf_writer.add_file_type(self.ftype)
|
||||
|
||||
def write_tensors(self):
|
||||
n_head = self.hparams.get("num_attention_heads")
|
||||
if n_head is None:
|
||||
n_head = self.hparams["n_head"] # old name
|
||||
|
||||
n_head_kv = self.hparams.get("num_kv_heads")
|
||||
if n_head_kv is None:
|
||||
n_head_kv = self.hparams.get("n_head_kv", 1) # old name
|
||||
|
||||
head_dim = self.hparams["hidden_size"] // n_head
|
||||
|
||||
for name, data_torch in self.get_tensors():
|
||||
old_dtype = data_torch.dtype
|
||||
|
||||
# convert any unsupported data types to float32
|
||||
if data_torch.dtype not in (torch.float16, torch.float32):
|
||||
data_torch = data_torch.to(torch.float32)
|
||||
|
||||
# QKV tensor transform
|
||||
# The original query_key_value tensor contains n_head_kv "kv groups",
|
||||
# each consisting of n_head/n_head_kv query weights followed by one key
|
||||
# and one value weight (shared by all query heads in the kv group).
|
||||
# This layout makes it a big pain to work with in GGML.
|
||||
# So we rearrange them here,, so that we have n_head query weights
|
||||
# followed by n_head_kv key weights followed by n_head_kv value weights,
|
||||
# in contiguous fashion.
|
||||
# ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
|
||||
|
||||
if "query_key_value" in name:
|
||||
qkv = data_torch.view(
|
||||
n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head
|
||||
)
|
||||
q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head)
|
||||
k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
|
||||
v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
|
||||
data_torch = torch.cat((q, k, v)).reshape_as(data_torch)
|
||||
|
||||
data = data_torch.squeeze().numpy()
|
||||
|
||||
# map tensor names
|
||||
new_name = self._translate_tensor_key(name)
|
||||
if new_name is None:
|
||||
print(f"Can not map tensor {name!r}")
|
||||
sys.exit()
|
||||
|
||||
# We need to transpose the weight matrices for the FFN Down layers to support the
|
||||
# Axpy operation in PowerInfer. So we don't need to transpose them at runtime.
|
||||
if "ffn_down" in new_name:
|
||||
new_name = new_name.replace("ffn_down", "ffn_down_t")
|
||||
data = data.T
|
||||
|
||||
n_dims = len(data.shape)
|
||||
data_dtype = data.dtype
|
||||
|
||||
# if f32 desired, convert any float16 to float32
|
||||
if self.ftype == 0 and data_dtype == np.float16:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
|
||||
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
|
||||
data = data.astype(np.float32)
|
||||
|
||||
# if f16 desired, convert any float32 2-dim weight tensors to float16
|
||||
if (
|
||||
self.ftype == 1
|
||||
and data_dtype == np.float32
|
||||
and name.endswith(".weight")
|
||||
and n_dims == 2
|
||||
):
|
||||
data = data.astype(np.float16)
|
||||
|
||||
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
|
||||
|
||||
self.gguf_writer.add_tensor(new_name, data)
|
||||
|
||||
|
||||
###### CONVERSION LOGIC ######
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Convert a huggingface model to a GGML compatible file"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--vocab-only",
|
||||
action="store_true",
|
||||
help="extract only the vocab",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outfile",
|
||||
type=Path,
|
||||
help="path to write to; default: based on input",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--outtype",
|
||||
type=str,
|
||||
choices=["f32", "f16"],
|
||||
default="f16",
|
||||
help="output format - use f32 for float32, f16 for float16",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--bigendian",
|
||||
action="store_true",
|
||||
help="model is executed on big endian machine",
|
||||
)
|
||||
parser.add_argument(
|
||||
"model",
|
||||
type=Path,
|
||||
help="directory containing model file",
|
||||
)
|
||||
parser.add_argument(
|
||||
"mlp_predictors",
|
||||
type=Path,
|
||||
help="directory containing MLP predictors for model",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
args = parse_args()
|
||||
|
||||
dir_model = args.model
|
||||
dir_mlp_pred = args.mlp_predictors
|
||||
if not dir_model.is_dir():
|
||||
print(f"Error: {args.model} is not a directory", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if not dir_mlp_pred.is_dir():
|
||||
print(f"Error: {args.mlp_predictors} is not a directory", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
ftype_map = {
|
||||
"f32": gguf.GGMLQuantizationType.F32,
|
||||
"f16": gguf.GGMLQuantizationType.F16,
|
||||
}
|
||||
|
||||
if args.outfile is not None:
|
||||
fname_out = args.outfile
|
||||
else:
|
||||
# output in the same directory as the model by default
|
||||
fname_out = dir_model / f"ggml-model-{args.outtype}.gguf"
|
||||
|
||||
print(f"Loading model: {dir_model.name}")
|
||||
|
||||
hparams = Model.load_hparams(dir_model)
|
||||
|
||||
model_class = Model.from_model_architecture(hparams["architectures"][0])
|
||||
model_instance = model_class(
|
||||
dir_model, dir_mlp_pred, ftype_map[args.outtype], fname_out, args.bigendian
|
||||
)
|
||||
|
||||
print("Set model parameters")
|
||||
model_instance.set_gguf_parameters()
|
||||
|
||||
print("Set model tokenizer")
|
||||
model_instance.set_vocab()
|
||||
|
||||
if args.vocab_only:
|
||||
print(f"Exporting model vocab to '{fname_out}'")
|
||||
model_instance.write_vocab()
|
||||
else:
|
||||
print(f"Exporting model to '{fname_out}'")
|
||||
model_instance.write()
|
||||
|
||||
# post-process: write another unique file header to distinguish from the origianl GGUF file
|
||||
with open(fname_out, "r+b") as fout:
|
||||
POWERINFER_MAGIC = int.from_bytes(b"PWRI", "little")
|
||||
fout.write(struct.pack("<I", POWERINFER_MAGIC))
|
||||
|
||||
print(f"Model successfully exported to '{fname_out}'")
|
55
convert.py
55
convert.py
|
@ -509,6 +509,14 @@ class LazyTensor:
|
|||
def load() -> Tensor:
|
||||
return self.load().astype(data_type)
|
||||
return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
|
||||
|
||||
def transposed(self) -> LazyTensor:
|
||||
def load() -> Tensor:
|
||||
loaded = self.load()
|
||||
assert isinstance(loaded, UnquantizedTensor), f'Cannot transpose {loaded}'
|
||||
loaded.ndarray = loaded.ndarray.T
|
||||
return loaded
|
||||
return LazyTensor(load, self.shape[::-1], self.data_type, f'transpose {self.description}')
|
||||
|
||||
def validate_conversion_to(self, data_type: DataType) -> None:
|
||||
if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions:
|
||||
|
@ -571,7 +579,8 @@ def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
|
|||
except StopIteration:
|
||||
vocab = None
|
||||
|
||||
if any("model.embed_tokens.weight" in mp.model for mp in models_plus):
|
||||
if any("model.embed_tokens.weight" in mp.model for mp in models_plus) or \
|
||||
any("model.layers.0.fc1.weight" in mp.model for mp in models_plus):
|
||||
# Transformers models put different tensors in different files, but
|
||||
# don't split indivdual tensors between files.
|
||||
model: LazyModel = {}
|
||||
|
@ -992,6 +1001,18 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
|
|||
|
||||
return out
|
||||
|
||||
def postprocess_transpose(model: LazyModel) -> LazyModel:
|
||||
"""Transpose ffn_down matrices for Axpy ops."""
|
||||
out: LazyModel = {}
|
||||
|
||||
for name, lazy_tensor in model.items():
|
||||
if name.endswith(".ffn_down.weight"):
|
||||
out[name.replace("ffn_down", "ffn_down_t")] = lazy_tensor.transposed()
|
||||
else:
|
||||
out[name] = lazy_tensor
|
||||
|
||||
return out
|
||||
|
||||
def nth_multifile_path(path: Path, n: int) -> Path | None:
|
||||
'''Given any path belonging to a multi-file model (e.g. foo.bin.1), return
|
||||
the nth path in the model.
|
||||
|
@ -1003,7 +1024,9 @@ def nth_multifile_path(path: Path, n: int) -> Path | None:
|
|||
# - x-00001-of-00002.bin, x-00002-of-00002.bin, etc.
|
||||
(r'-[0-9]{5}-of-(.*)$', fr'-{n:05}-of-\1'),
|
||||
# x.bin, x.bin.1, etc.
|
||||
(r'(\.[0-9]+)?$', r'\1' if n == 0 else fr'\1.{n}')
|
||||
(r'(\.[0-9]+)?$', r'\1' if n == 0 else fr'\1.{n}'),
|
||||
# x_0.pt, x_1.pt, etc.
|
||||
(r'(_[0-9]+)?\.pt$', fr'_{n}.pt'),
|
||||
]
|
||||
for regex, replacement in patterns:
|
||||
if re.search(regex, path.name):
|
||||
|
@ -1057,6 +1080,25 @@ def load_some_model(path: Path) -> ModelPlus:
|
|||
model_plus = merge_multifile_models(models_plus)
|
||||
return model_plus
|
||||
|
||||
def load_mlp_model(path: Path) -> ModelPlus:
|
||||
'''Load MLP models for sparse attention from directory.'''
|
||||
assert path.is_dir(), f"MLP model path {path} is not a directory"
|
||||
|
||||
first_model_path = path / "model_0.pt"
|
||||
assert first_model_path.resolve(), f"MLP model path {path} does not contain model_0.pt"
|
||||
|
||||
model_paths = find_multifile_paths(first_model_path)
|
||||
models_plus: list[ModelPlus] = []
|
||||
for model_path in model_paths:
|
||||
# find number in model_path
|
||||
model_layer = int(re.search(r'model_(\d+).pt', str(model_path)).group(1))
|
||||
print(f"Loading MLP model file {model_path}")
|
||||
mlp_model = lazy_load_file(model_path)
|
||||
mlp_model.model = {f"model.layers.{model_layer}.{name}": tensor for name, tensor in mlp_model.model.items()}
|
||||
models_plus.append(mlp_model)
|
||||
|
||||
return merge_multifile_models(models_plus)
|
||||
|
||||
|
||||
def load_vocab(path: Path, vocabtype: str | None) -> Vocab:
|
||||
# Be extra-friendly and accept either a file or a directory. Also, if it's
|
||||
|
@ -1125,6 +1167,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
|
||||
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
|
||||
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin, *.safetensors)")
|
||||
parser.add_argument("mlp_model", type=Path, help="MLP model for sparse attention")
|
||||
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
|
||||
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
|
||||
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
|
||||
|
@ -1138,6 +1181,8 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
|
||||
if not args.vocab_only:
|
||||
model_plus = load_some_model(args.model)
|
||||
mlp_predictor_plus = load_mlp_model(args.mlp_model)
|
||||
model_plus = merge_multifile_models([model_plus, mlp_predictor_plus])
|
||||
else:
|
||||
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
|
||||
|
||||
|
@ -1192,6 +1237,7 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
|
||||
model = model_plus.model
|
||||
model = convert_model_names(model, params)
|
||||
model = postprocess_transpose(model)
|
||||
ftype = pick_output_type(model, args.outtype)
|
||||
model = convert_to_output_type(model, ftype)
|
||||
outfile = args.outfile or default_outfile(model_plus.paths, ftype)
|
||||
|
@ -1202,6 +1248,11 @@ def main(args_in: list[str] | None = None) -> None:
|
|||
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency, endianess=endianess)
|
||||
print(f"Wrote {outfile}")
|
||||
|
||||
# post-process: write another unique file header to distinguish from the origianl GGUF file
|
||||
with open(outfile, "r+b") as fout:
|
||||
POWERINFER_MAGIC = int.from_bytes(b"PWRI", "little")
|
||||
fout.write(struct.pack("<I", POWERINFER_MAGIC))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
|
|
@ -11,7 +11,7 @@ int main(int argc, char ** argv) {
|
|||
gpt_params params;
|
||||
|
||||
if (argc == 1 || argv[1][0] == '-') {
|
||||
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN] [NGL]\n" , argv[0]);
|
||||
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN] [NGL] [N_THREAD] [MLP_PATH]\n" , argv[0]);
|
||||
return 1 ;
|
||||
}
|
||||
|
||||
|
@ -44,6 +44,17 @@ int main(int argc, char ** argv) {
|
|||
n_gpu_layers = std::atoi(argv[5]);
|
||||
}
|
||||
|
||||
if (argc >= 7) {
|
||||
params.n_threads = std::atoi(argv[6]);
|
||||
}
|
||||
|
||||
if (argc >= 8) {
|
||||
params.mlp_adapter = argv[7];
|
||||
}
|
||||
|
||||
printf("params: model = %s, prompt = %s, n_parallel = %d, n_len = %d, n_gpu_layers = %d, n_threads = %d, mlp_adapter = %s\n",
|
||||
params.model.c_str(), params.prompt.c_str(), n_parallel, n_len, n_gpu_layers, params.n_threads, params.mlp_adapter.c_str());
|
||||
|
||||
if (params.prompt.empty()) {
|
||||
params.prompt = "Hello my name is";
|
||||
}
|
||||
|
@ -65,6 +76,21 @@ int main(int argc, char ** argv) {
|
|||
return 1;
|
||||
}
|
||||
|
||||
if (!params.mlp_adapter.empty()) {
|
||||
int err = llama_model_apply_mlp_from_file(model, params.mlp_adapter.c_str(), true);
|
||||
if (err != 0) {
|
||||
fprintf(stderr, "%s: error: failed to apply mlp adapter\n", __func__);
|
||||
llama_free_model(model);
|
||||
return 1;
|
||||
}
|
||||
}
|
||||
|
||||
if (llama_model_apply_augmentation(model) != 0) {
|
||||
fprintf(stderr, "%s: error: failed to apply model augmentation\n", __func__);
|
||||
llama_free_model(model);
|
||||
return 1;
|
||||
}
|
||||
|
||||
// tokenize the prompt
|
||||
|
||||
std::vector<llama_token> tokens_list;
|
||||
|
|
1244
ggml-cuda.cu
1244
ggml-cuda.cu
File diff suppressed because it is too large
Load diff
|
@ -29,7 +29,11 @@ GGML_API void ggml_cuda_host_free(void * ptr);
|
|||
GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
|
||||
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
|
||||
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_cuda_alloc_tensor(struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_cuda_cpy_1d(struct ggml_tensor * dst, const struct ggml_tensor * src);
|
||||
GGML_API bool debug_equal(short *a, short *b);
|
||||
GGML_API void **ggml_cuda_get_data_pp(struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
|
||||
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);
|
||||
|
|
|
@ -2423,6 +2423,78 @@ static inline __m128i get_scale_shuffle(int i) {
|
|||
}
|
||||
#endif
|
||||
|
||||
void ggml_axpy_q4_0_q8_0(const int n, const void * restrict vx, const void * restrict vy, const void * restrict vz, int8_t alpha, ggml_fp16_t scale) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
assert(n % qk == 0);
|
||||
assert(nb % 2 == 0);
|
||||
|
||||
const block_q4_0 * restrict x = vx;
|
||||
|
||||
// Initialize accumulator with zeros
|
||||
__m256 acc = _mm256_setzero_ps();
|
||||
|
||||
__m256i alpha_v = _mm256_set1_epi16((short)alpha);
|
||||
// Main loop
|
||||
for (int i = 0; i < nb; ++i) {
|
||||
/* Compute combined scale for the block */
|
||||
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(scale) );
|
||||
__m256i bx = bytes_from_nibbles_32(x[i].qs);
|
||||
|
||||
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
|
||||
const __m256i off = _mm256_set1_epi8( 8 );
|
||||
bx = _mm256_sub_epi8( bx, off );
|
||||
//16个数计算
|
||||
__m128i m_a = _mm256_extracti128_si256(bx, 0);
|
||||
__m256i m_x = _mm256_cvtepi8_epi16(m_a); //16 elements
|
||||
m_x = _mm256_mullo_epi16(m_x, alpha_v);
|
||||
__m128i x_0 = _mm256_extracti128_si256(m_x, 0);
|
||||
__m256i x0_32 = _mm256_cvtepi16_epi32(x_0);
|
||||
__m256 fx0 = _mm256_cvtepi32_ps(x0_32);
|
||||
fx0 = _mm256_mul_ps(fx0, d);
|
||||
|
||||
|
||||
__m256 by = _mm256_loadu_ps((const __m256 *)(vy+i*128));
|
||||
|
||||
by = _mm256_add_ps(by, fx0);
|
||||
_mm256_storeu_ps((__m256*)(vz + i*128), by);
|
||||
//second phase
|
||||
|
||||
x_0 = _mm256_extracti128_si256(m_x, 1);
|
||||
x0_32 = _mm256_cvtepi16_epi32(x_0);
|
||||
fx0 = _mm256_cvtepi32_ps(x0_32);
|
||||
fx0 = _mm256_mul_ps(fx0, d);
|
||||
by = _mm256_loadu_ps((const __m256 *)(vy+i*128+32));
|
||||
by = _mm256_add_ps(by, fx0);
|
||||
_mm256_storeu_ps((__m256*)(vz + i*128+32), by);
|
||||
|
||||
//third phase
|
||||
m_a = _mm256_extracti128_si256(bx, 1);
|
||||
m_x = _mm256_cvtepi8_epi16(m_a);
|
||||
m_x = _mm256_mullo_epi16(m_x, alpha_v);
|
||||
x_0 = _mm256_extracti128_si256(m_x, 0);
|
||||
x0_32 = _mm256_cvtepi16_epi32(x_0);
|
||||
fx0 = _mm256_cvtepi32_ps(x0_32);
|
||||
fx0 = _mm256_mul_ps(fx0, d);
|
||||
by = _mm256_loadu_ps((const __m256 *)(vy+i*128+64));
|
||||
|
||||
by = _mm256_add_ps(by, fx0);
|
||||
_mm256_storeu_ps((__m256*)(vz + i*128+64), by);
|
||||
|
||||
//fourth phase
|
||||
x_0 = _mm256_extracti128_si256(m_x, 1);
|
||||
x0_32 = _mm256_cvtepi16_epi32(x_0);
|
||||
fx0 = _mm256_cvtepi32_ps(x0_32);
|
||||
fx0 = _mm256_mul_ps(fx0, d);
|
||||
by = _mm256_loadu_ps((const __m256 *)(vy+i*128+96));
|
||||
by = _mm256_add_ps(by, fx0);
|
||||
_mm256_storeu_ps((__m256*)(vz + i*128+96), by);
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
|
||||
const int qk = QK8_0;
|
||||
const int nb = n / qk;
|
||||
|
|
|
@ -210,6 +210,8 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
|
|||
void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
|
||||
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
|
||||
|
||||
void ggml_axpy_q4_0_q8_0(const int n, const void * restrict vx, const void * restrict vy, const void * restrict vz, int8_t alpha, ggml_fp16_t scale);
|
||||
|
||||
// Dot product
|
||||
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
|
||||
|
|
69
ggml.h
69
ggml.h
|
@ -207,6 +207,14 @@
|
|||
#include <stdint.h>
|
||||
#include <stddef.h>
|
||||
#include <stdbool.h>
|
||||
#ifdef __cplusplus
|
||||
#include <atomic>
|
||||
using std::atomic_int;
|
||||
using std::memory_order;
|
||||
using std::memory_order_acquire;
|
||||
#else /* not __cplusplus */
|
||||
#include <stdatomic.h>
|
||||
#endif /* __cplusplus */
|
||||
|
||||
#define GGML_FILE_MAGIC 0x67676d6c // "ggml"
|
||||
#define GGML_FILE_VERSION 1
|
||||
|
@ -232,6 +240,7 @@
|
|||
#define GGML_EXIT_ABORTED 1
|
||||
|
||||
#define GGUF_MAGIC "GGUF"
|
||||
#define GGUF_POWERINFER_MAGIC "PWRI"
|
||||
|
||||
#define GGUF_VERSION 3
|
||||
|
||||
|
@ -336,6 +345,11 @@ extern "C" {
|
|||
GGML_BACKEND_GPU_SPLIT = 20,
|
||||
};
|
||||
|
||||
enum ggml_sparse_deriv {
|
||||
GGML_DENSE_INFERENCE = 0,
|
||||
GGML_SPARSE_INFERENCE = 1,
|
||||
};
|
||||
|
||||
// model file types
|
||||
enum ggml_ftype {
|
||||
GGML_FTYPE_UNKNOWN = -1,
|
||||
|
@ -382,6 +396,7 @@ extern "C" {
|
|||
GGML_OP_GROUP_NORM,
|
||||
|
||||
GGML_OP_MUL_MAT,
|
||||
GGML_OP_AXPY,
|
||||
GGML_OP_OUT_PROD,
|
||||
|
||||
GGML_OP_SCALE,
|
||||
|
@ -504,6 +519,7 @@ extern "C" {
|
|||
struct ggml_tensor * src[GGML_MAX_SRC];
|
||||
|
||||
// performance
|
||||
atomic_int is_finish;
|
||||
int perf_runs;
|
||||
int64_t perf_cycles;
|
||||
int64_t perf_time_us;
|
||||
|
@ -520,6 +536,9 @@ extern "C" {
|
|||
char padding[12];
|
||||
};
|
||||
|
||||
|
||||
static const int64_t GGML_NE_WILDCARD = -1;
|
||||
|
||||
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
|
||||
|
||||
// the compute plan that needs to be prepared for ggml_graph_compute()
|
||||
|
@ -573,6 +592,22 @@ extern "C" {
|
|||
void * data;
|
||||
};
|
||||
|
||||
struct ggml_context {
|
||||
size_t mem_size;
|
||||
void * mem_buffer;
|
||||
bool mem_buffer_owned;
|
||||
bool no_alloc;
|
||||
bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
|
||||
|
||||
int n_objects;
|
||||
|
||||
struct ggml_object * objects_begin;
|
||||
struct ggml_object * objects_end;
|
||||
|
||||
struct ggml_scratch scratch;
|
||||
struct ggml_scratch scratch_save;
|
||||
};
|
||||
|
||||
struct ggml_init_params {
|
||||
// memory pool
|
||||
size_t mem_size; // bytes
|
||||
|
@ -600,6 +635,7 @@ extern "C" {
|
|||
// work buffer for all threads
|
||||
size_t wsize;
|
||||
void * wdata;
|
||||
atomic_int *aic;
|
||||
};
|
||||
|
||||
// misc
|
||||
|
@ -618,6 +654,8 @@ extern "C" {
|
|||
GGML_API void ggml_print_object (const struct ggml_object * obj);
|
||||
GGML_API void ggml_print_objects(const struct ggml_context * ctx);
|
||||
|
||||
GGML_API
|
||||
|
||||
GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor);
|
||||
GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
|
||||
GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
|
||||
|
@ -727,6 +765,7 @@ extern "C" {
|
|||
|
||||
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
|
||||
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
|
||||
GGML_API int32_t * ggml_get_data_i32(const struct ggml_tensor * tensor);
|
||||
|
||||
GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
|
||||
|
||||
|
@ -735,6 +774,9 @@ extern "C" {
|
|||
GGML_ATTRIBUTE_FORMAT(2, 3)
|
||||
GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);
|
||||
|
||||
GGML_API void ggml_set_backend(struct ggml_tensor * tensor, enum ggml_backend_type backend);
|
||||
|
||||
|
||||
//
|
||||
// operations on tensors with backpropagation
|
||||
//
|
||||
|
@ -753,6 +795,12 @@ extern "C" {
|
|||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
|
||||
GGML_API struct ggml_tensor *ggml_add_idx(
|
||||
struct ggml_context *ctx,
|
||||
struct ggml_tensor *a,
|
||||
struct ggml_tensor *b,
|
||||
struct ggml_tensor *idx);
|
||||
|
||||
GGML_API struct ggml_tensor * ggml_add_inplace(
|
||||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
|
@ -1027,6 +1075,25 @@ extern "C" {
|
|||
struct ggml_context * ctx,
|
||||
struct ggml_tensor * a,
|
||||
struct ggml_tensor * b);
|
||||
GGML_API struct ggml_tensor *ggml_mul_mat_idx(
|
||||
struct ggml_context *ctx,
|
||||
struct ggml_tensor *a,
|
||||
struct ggml_tensor *b,
|
||||
struct ggml_tensor *idx,
|
||||
struct ggml_tensor *d);
|
||||
GGML_API struct ggml_tensor *ggml_mul_mat_special(
|
||||
struct ggml_context *ctx,
|
||||
struct ggml_tensor *a,
|
||||
struct ggml_tensor *b,
|
||||
struct ggml_tensor *idx,
|
||||
struct ggml_tensor *d,
|
||||
struct ggml_tensor *ref);
|
||||
GGML_API struct ggml_tensor *ggml_axpy(
|
||||
struct ggml_context *ctx,
|
||||
struct ggml_tensor *a,
|
||||
struct ggml_tensor *b,
|
||||
struct ggml_tensor *c,
|
||||
struct ggml_tensor *d);
|
||||
|
||||
// A: m columns, n rows,
|
||||
// B: p columns, n rows,
|
||||
|
@ -2013,6 +2080,7 @@ extern "C" {
|
|||
};
|
||||
|
||||
GGML_API struct gguf_context * gguf_init_empty(void);
|
||||
GGML_API struct gguf_context * gguf_init_empty_sparse(void);
|
||||
GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
|
||||
//GGML_API struct gguf_context * gguf_init_from_buffer(..);
|
||||
|
||||
|
@ -2049,6 +2117,7 @@ extern "C" {
|
|||
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
|
||||
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
|
||||
|
||||
GGML_API enum ggml_sparse_deriv gguf_get_sparse_deriv(const struct gguf_context * ctx);
|
||||
GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
|
||||
GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
|
||||
GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);
|
||||
|
|
|
@ -115,6 +115,10 @@ class MODEL_TENSOR(IntEnum):
|
|||
FFN_NORM = auto()
|
||||
ATTN_Q_NORM = auto()
|
||||
ATTN_K_NORM = auto()
|
||||
FFN_DOWN_T = auto()
|
||||
FC_1 = auto()
|
||||
FC_2 = auto()
|
||||
|
||||
|
||||
|
||||
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
|
||||
|
@ -155,6 +159,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
|
|||
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
|
||||
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
|
||||
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
|
||||
MODEL_TENSOR.FFN_DOWN_T: "blk.{bid}.ffn_down_t",
|
||||
MODEL_TENSOR.FC_1: "blk.{bid}.fc1",
|
||||
MODEL_TENSOR.FC_2: "blk.{bid}.fc2",
|
||||
}
|
||||
|
||||
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
||||
|
@ -173,6 +180,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
|
|||
MODEL_TENSOR.FFN_GATE,
|
||||
MODEL_TENSOR.FFN_DOWN,
|
||||
MODEL_TENSOR.FFN_UP,
|
||||
MODEL_TENSOR.FFN_DOWN_T,
|
||||
MODEL_TENSOR.FC_1,
|
||||
MODEL_TENSOR.FC_2,
|
||||
],
|
||||
MODEL_ARCH.GPTNEOX: [
|
||||
MODEL_TENSOR.TOKEN_EMBD,
|
||||
|
|
|
@ -194,6 +194,14 @@ class TensorNameMap:
|
|||
MODEL_TENSOR.ROPE_FREQS: (
|
||||
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FC_1: (
|
||||
"model.layers.{bid}.fc1",
|
||||
),
|
||||
|
||||
MODEL_TENSOR.FC_2: (
|
||||
"model.layers.{bid}.fc2",
|
||||
),
|
||||
}
|
||||
|
||||
mapping: dict[str, tuple[MODEL_TENSOR, str]]
|
||||
|
|
8
llama.h
8
llama.h
|
@ -293,6 +293,7 @@ extern "C" {
|
|||
LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
|
||||
|
||||
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
|
||||
LLAMA_API bool llama_use_sparse_inference(const struct llama_model * model);
|
||||
|
||||
LLAMA_API int llama_n_vocab (const struct llama_model * model);
|
||||
LLAMA_API int llama_n_ctx_train(const struct llama_model * model);
|
||||
|
@ -340,6 +341,13 @@ extern "C" {
|
|||
const char * path_base_model,
|
||||
int n_threads);
|
||||
|
||||
LLAMA_API int llama_model_apply_mlp_from_file(
|
||||
struct llama_model * model,
|
||||
const char * path_mlp,
|
||||
bool use_mmap);
|
||||
|
||||
LLAMA_API int llama_model_apply_augmentation(struct llama_model * model);
|
||||
|
||||
//
|
||||
// KV cache
|
||||
//
|
||||
|
|
142
scripts/export-gpu-split.py
Normal file
142
scripts/export-gpu-split.py
Normal file
|
@ -0,0 +1,142 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import torch
|
||||
import torch.nn as tnn
|
||||
from pathlib import Path
|
||||
import os
|
||||
import re
|
||||
import struct
|
||||
from typing import Any, BinaryIO
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
||||
class ReluMLP(tnn.Module):
|
||||
def __init__(self, input_dim, hidden_dim, output_dim):
|
||||
super(ReluMLP, self).__init__()
|
||||
self.fc1 = tnn.Linear(input_dim, hidden_dim, bias=False)
|
||||
self.relu = tnn.ReLU()
|
||||
self.fc2 = tnn.Linear(hidden_dim, output_dim, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
def _load_mlp_model(model_file: Path):
|
||||
model = torch.load(model_file)
|
||||
# hidden_size, input_size = model.get("fc1.weight").shape
|
||||
# output_size, _ = model.get("fc2.weight").shape
|
||||
# mlp = ReluMLP(input_size, hidden_size, output_size)
|
||||
# mlp.load_state_dict(model)
|
||||
return model
|
||||
|
||||
|
||||
def load_mlp_predictors(models_base: Path):
|
||||
# TODO: might need a specification file to indicate which models to load.
|
||||
# But for now, let's assume it is a plain directory of models_{0, ... , n_layers - 1}.pt
|
||||
*_, files = next(os.walk(models_base))
|
||||
return [_load_mlp_model(models_base / f"activation_{i}.pt") for i in range(len(files))]
|
||||
|
||||
|
||||
def write_file_header(fout: BinaryIO, n_tensors: int) -> None:
|
||||
fout.write(b"gglp"[::-1]) # magic (GGml mLP)
|
||||
fout.write(struct.pack("i", 1)) # file version
|
||||
# TODO: If we found we need more common parameters, we can add them here.
|
||||
fout.write(struct.pack("i", n_tensors))
|
||||
|
||||
|
||||
def write_tensor_header(
|
||||
fout: BinaryIO, key: str, shape: tuple[int, ...], dtype: np.dtype
|
||||
) -> None:
|
||||
_NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1, "int32": 18}
|
||||
bkey = key.encode("utf-8")
|
||||
fout.write(
|
||||
struct.pack("iii", len(shape), len(bkey), _NUMPY_TYPE_TO_FTYPE[dtype.name])
|
||||
)
|
||||
fout.write(struct.pack("i" * len(shape), *shape))
|
||||
fout.write(bkey)
|
||||
# Aligns to 32 bytes
|
||||
fout.seek((fout.tell() + 31) & -32)
|
||||
|
||||
|
||||
# TODO: need to add more details in key name to indicate the network, layer number, etc.
|
||||
def _translate_mlp_key(key: str) -> str:
|
||||
match = re.match(r"^(fc\d+).weight$", key)
|
||||
if not match or len(match.groups()) != 1:
|
||||
raise ValueError(f"Unexpected key: {key}")
|
||||
return f"{match.group(1)}.weight.mlp"
|
||||
|
||||
|
||||
def append_mlp_model(fout: BinaryIO, model: ReluMLP) -> None:
|
||||
model_dict = model.state_dict()
|
||||
for k, v in model_dict.items():
|
||||
key = _translate_mlp_key(k)
|
||||
# torch.nn.Linear stores the weight matrix as (output_dim, input_dim), so does GGML.
|
||||
weights = v.half().detach().numpy()
|
||||
# GGML stores the weight matrix as (input_dim, output_dim)
|
||||
dims = weights.shape[::-1]
|
||||
print(
|
||||
f"{k} => {key} {weights.shape} {weights.dtype} {weights.nbytes/1024/1024} MiB"
|
||||
)
|
||||
# TODO: add option to write in float32
|
||||
write_tensor_header(fout, key, dims, np.dtype("float16"))
|
||||
weights.tofile(fout)
|
||||
|
||||
def append_gpu_idx(fout: BinaryIO, activation, select_count) -> None:
|
||||
values, indices = torch.topk(activation, k=int(select_count))
|
||||
gpu_idx = torch.zeros_like(activation)
|
||||
gpu_idx[indices] = 1
|
||||
gpu_idx = gpu_idx.numpy().astype(np.int32)
|
||||
weights = gpu_idx
|
||||
dims = gpu_idx.shape[::-1]
|
||||
key = "gpu_idx"
|
||||
print(
|
||||
f"{key} => {key} {weights.shape} {weights.dtype} {weights.nbytes/1024/1024} MiB"
|
||||
)
|
||||
write_tensor_header(fout, key, dims, np.dtype("int32"))
|
||||
weights.tofile(fout)
|
||||
|
||||
indices = indices.numpy().astype(np.int32)
|
||||
weights = indices
|
||||
dims = weights.shape[::-1]
|
||||
key = "gpu_bucket"
|
||||
print(
|
||||
f"{key} => {key} {weights.shape} {weights.dtype} {weights.nbytes/1024/1024} MiB"
|
||||
)
|
||||
write_tensor_header(fout, key, dims, np.dtype("int32"))
|
||||
weights = np.sort(weights)
|
||||
weights.tofile(fout)
|
||||
|
||||
def main(predictors_path: str, output_path: str, solver_path: str):
|
||||
predictors = load_mlp_predictors(Path(predictors_path)) # predictor => activation acount
|
||||
n_tensors = len(predictors) * 2 # gpu_idx and gpu_bucket
|
||||
print(f"found {len(predictors)} MLP adapters with {n_tensors} tensors")
|
||||
with open(solver_path, "rb") as f:
|
||||
loaded_lst = pickle.load(f)
|
||||
# print(f"check solver {loaded_lst}")
|
||||
with open(output_path, "wb") as fout:
|
||||
fout.truncate()
|
||||
write_file_header(fout, n_tensors=n_tensors)
|
||||
for i, activation in enumerate(predictors):
|
||||
print(f"appending gpu idx layer-{i}")
|
||||
append_gpu_idx(fout, activation, loaded_lst[i])
|
||||
# append_gpu_idx(fout, activation, (32768*0.0))
|
||||
|
||||
print(f"converted MLP adapters from {predictors_path} to {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("predictors_path", help="path to the MLP predictors")
|
||||
parser.add_argument(
|
||||
"output_path",
|
||||
help="path to the output GGML adapter",
|
||||
default="./ggml-mlp-adapters.bin",
|
||||
)
|
||||
parser.add_argument("solver", help="path to the solver")
|
||||
|
||||
args = parser.parse_args()
|
||||
main(args.predictors_path, args.output_path, args.solver)
|
Loading…
Add table
Add a link
Reference in a new issue