merge PowerInfer impl from the internal codebase

This commit is contained in:
Holden 2023-12-12 11:05:32 +08:00
parent 6bb4908a17
commit a3c295a2ae
17 changed files with 4515 additions and 169 deletions

2
.gitignore vendored
View file

@ -98,3 +98,5 @@ tests/test-tokenizer-0-llama
tests/test-tokenizer-0-falcon tests/test-tokenizer-0-falcon
tests/test-tokenizer-1-llama tests/test-tokenizer-1-llama
tests/test-tokenizer-1-bpe tests/test-tokenizer-1-bpe
build-info.h

View file

@ -471,6 +471,12 @@ bool gpt_params_parse_ex(int argc, char ** argv, gpt_params & params) {
break; break;
} }
params.lora_base = argv[i]; params.lora_base = argv[i];
} else if (arg == "--mlp-adapter") {
if (++i >= argc) {
invalid_param = true;
break;
}
params.mlp_adapter = argv[i];
} else if (arg == "--mmproj") { } else if (arg == "--mmproj") {
if (++i >= argc) { if (++i >= argc) {
invalid_param = true; invalid_param = true;
@ -950,8 +956,26 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
return std::make_tuple(nullptr, nullptr); return std::make_tuple(nullptr, nullptr);
} }
auto cparams = llama_context_params_from_gpt_params(params); if (llama_use_sparse_inference(model)) {
fprintf(stderr, "%s: postprocessing PowerInfer model '%s'\n", __func__, params.model.c_str());
if (!params.mlp_adapter.empty()) {
fprintf(stderr, "%s: warning: --mlp-adapter is deprecated and has no effect\n", __func__);
int err = llama_model_apply_mlp_from_file(model, params.mlp_adapter.c_str(), true);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply mlp adapter\n", __func__);
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
}
}
if (llama_model_apply_augmentation(model) != 0) {
fprintf(stderr, "%s: error: failed to apply augmentation\n", __func__);
llama_free_model(model);
return std::make_tuple(nullptr, nullptr);
}
}
auto cparams = llama_context_params_from_gpt_params(params);
llama_context * lctx = llama_new_context_with_model(model, cparams); llama_context * lctx = llama_new_context_with_model(model, cparams);
if (lctx == NULL) { if (lctx == NULL) {
fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str()); fprintf(stderr, "%s: error: failed to create context with model '%s'\n", __func__, params.model.c_str());
@ -981,6 +1005,8 @@ std::tuple<struct llama_model *, struct llama_context *> llama_init_from_gpt_par
params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY; params.sparams.logit_bias[llama_token_eos(model)] = -INFINITY;
} }
{ {
LOG("warming up the model with an empty run\n"); LOG("warming up the model with an empty run\n");
@ -1320,6 +1346,7 @@ void dump_non_result_info_yaml(FILE * stream, const gpt_params & params, const l
fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la)); fprintf(stream, " - %s: %f\n", std::get<0>(la).c_str(), std::get<1>(la));
} }
fprintf(stream, "lora_base: %s\n", params.lora_base.c_str()); fprintf(stream, "lora_base: %s\n", params.lora_base.c_str());
fprintf(stream, "mlp_adapter: %s\n", params.mlp_adapter.c_str());
fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu); fprintf(stream, "main_gpu: %d # default: 0\n", params.main_gpu);
fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false"); fprintf(stream, "memory_f32: %s # default: false\n", !params.memory_f16 ? "true" : "false");
fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat); fprintf(stream, "mirostat: %d # default: 0 (disabled)\n", sparams.mirostat);

View file

@ -90,6 +90,8 @@ struct gpt_params {
std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale std::vector<std::tuple<std::string, float>> lora_adapter; // lora adapter path with user defined scale
std::string lora_base = ""; // base model path for the lora adapter std::string lora_base = ""; // base model path for the lora adapter
std::string mlp_adapter = ""; // sparse activation mlp adapter path
int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used. int ppl_stride = 0; // stride for perplexity calculations. If left at 0, the pre-existing approach will be used.
int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line int ppl_output_type = 0; // = 0 -> ppl output is as usual, = 1 -> ppl output is num_tokens, ppl, one per line
// (which is more convenient to use for plotting) // (which is more convenient to use for plotting)

View file

@ -0,0 +1,601 @@
#!/usr/bin/env python3
from __future__ import annotations
from abc import ABC, abstractmethod
import argparse
import contextlib
import json
import os
import re
import struct
import sys
from enum import IntEnum
from pathlib import Path
from typing import TYPE_CHECKING, Any, ContextManager, Iterator, Optional, cast
import numpy as np
import torch
import torch.nn as tnn
if TYPE_CHECKING:
from torch import Tensor
if "NO_LOCAL_GGUF" not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / "gguf-py"))
import gguf
###### MODEL DEFINITIONS ######
class SentencePieceTokenTypes(IntEnum):
NORMAL = 1
UNKNOWN = 2
CONTROL = 3
USER_DEFINED = 4
UNUSED = 5
BYTE = 6
class ReluMLP(tnn.Module):
def __init__(self, input_dim: int, hidden_dim: int, output_dim: int):
super(ReluMLP, self).__init__()
self.fc1 = tnn.Linear(input_dim, hidden_dim, bias=False)
self.relu = tnn.ReLU()
self.fc2 = tnn.Linear(hidden_dim, output_dim, bias=False)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
@staticmethod
def from_file(model_file: Path):
model = torch.load(model_file, map_location="cpu")
hidden_size, input_size = model.get("fc1.weight").shape
output_size, _ = model.get("fc2.weight").shape
mlp = ReluMLP(input_size, hidden_size, output_size)
mlp.load_state_dict(model)
return mlp
class Model(ABC):
"""Base class for model conversion"""
def __init__(
self,
dir_model: Path,
dir_mlp_pred: Path,
ftype: int,
fname_out: Path,
is_big_endian: bool,
):
self.dir_model = dir_model
self.dir_mlp_pred = dir_mlp_pred
self.ftype = ftype
self.fname_out = fname_out
self.is_big_endian = is_big_endian
self.endianess = (
gguf.GGUFEndian.BIG if is_big_endian else gguf.GGUFEndian.LITTLE
)
self.is_safetensors = self._is_model_safetensors()
self.num_parts = Model.count_model_parts(
self.dir_model, ".safetensors" if self.is_safetensors else ".bin"
)
self.part_names = self._get_part_names()
self.hparams = Model.load_hparams(self.dir_model)
self.model_arch = self._get_model_architecture()
self.gguf_writer = gguf.GGUFWriter(
fname_out, gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file = False
)
def set_vocab(self):
self._set_vocab_gpt2()
def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
for model_layer, part_name in self._get_mlp_part_layer_names():
print(f"gguf: loading mlp part '{part_name}'")
mlp_model = ReluMLP.from_file(self.dir_mlp_pred / part_name)
for name, data in mlp_model.state_dict().items():
yield f"blk.{model_layer}.{name}", data
for part_name in self.part_names:
print(f"gguf: loading model part '{part_name}'")
ctx: ContextManager[Any]
if self.is_safetensors:
from safetensors import safe_open
ctx = cast(
ContextManager[Any],
safe_open(self.dir_model / part_name, framework="pt", device="cpu"),
)
else:
ctx = contextlib.nullcontext(
torch.load(self.dir_model / part_name, map_location="cpu")
)
with ctx as model_part:
for name in model_part.keys():
data = (
model_part.get_tensor(name)
if self.is_safetensors
else model_part[name]
)
yield name, data
@abstractmethod
def set_gguf_parameters(self):
pass
# self.gguf_writer.add_name(self.dir_model.name)
# self.gguf_writer.add_block_count(
# self.hparams.get(
# "n_layers",
# self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")),
# )
# )
# if (n_ctx := self.hparams.get("max_position_embeddings")) is not None:
# self.gguf_writer.add_context_length(n_ctx)
# if (n_embd := self.hparams.get("hidden_size")) is not None:
# self.gguf_writer.add_embedding_length(n_embd)
# if (n_ff := self.hparams.get("intermediate_size")) is not None:
# self.gguf_writer.add_feed_forward_length(n_ff)
# if (n_head := self.hparams.get("num_attention_head")) is not None:
# self.gguf_writer.add_head_count(n_head)
# self.gguf_writer.add_parallel_residual(
# self.hparams.get("use_parallel_residual", True)
# )
@abstractmethod
def write_tensors(self):
pass
def write(self):
self.write_tensors()
self.gguf_writer.write_header_to_file()
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.write_tensors_to_file()
self.gguf_writer.close()
def write_vocab(self):
self.gguf_writer.write_header_to_file()
self.gguf_writer.write_kv_data_to_file()
self.gguf_writer.close()
@staticmethod
def count_model_parts(dir_model: Path, prefix: str) -> int:
num_parts = 0
for filename in os.listdir(dir_model):
if filename.endswith(prefix):
num_parts += 1
return num_parts
@staticmethod
def load_hparams(dir_model):
with open(dir_model / "config.json", "r", encoding="utf-8") as f:
return json.load(f)
@staticmethod
def from_model_architecture(model_architecture):
if model_architecture in ("FalconForCausalLM", "RWForCausalLM"):
return FalconModel
if model_architecture == "LlamaForCausalLM":
return LlamaModel
raise NotImplementedError(f'Architecture "{model_architecture}" not supported!')
def _is_model_safetensors(self) -> bool:
return Model.count_model_parts(self.dir_model, ".safetensors") > 0
def _get_mlp_part_layer_names(self):
"""Returns a generator of (index, name) for MLP predictors of each model layer"""
n_mlp_parts = Model.count_model_parts(self.dir_mlp_pred, ".pt")
return ((n, f"model_{n}.pt") for n in range(n_mlp_parts))
def _get_part_names(self):
if self.is_safetensors:
if self.num_parts == 1: # there's only one .safetensors file
return ("model.safetensors",)
return (
f"model-{n:05}-of-{self.num_parts:05}.safetensors"
for n in range(1, self.num_parts + 1)
)
if self.num_parts == 1: # there's only one .bin file
return ("pytorch_model.bin",)
return (
f"pytorch_model-{n:05}-of-{self.num_parts:05}.bin"
for n in range(1, self.num_parts + 1)
)
def _get_model_architecture(self) -> gguf.MODEL_ARCH:
arch = self.hparams["architectures"][0]
if arch == "FalconForCausalLM":
return gguf.MODEL_ARCH.FALCON
if arch == "RWForCausalLM" or arch == "LlamaForCausalLM":
return gguf.MODEL_ARCH.LLAMA
raise NotImplementedError(f'Architecture "{arch}" not supported!')
def _translate_tensor_key(
self, key: str, try_suffixes=(".weight", ".bias")
) -> Optional[str]:
block_count = self.hparams.get(
"n_layers",
self.hparams.get("num_hidden_layers", self.hparams.get("n_layer")),
)
tensor_map = gguf.get_tensor_name_map(self.model_arch, block_count)
arch_tensor_key = tensor_map.get_name(key, try_suffixes=try_suffixes)
if arch_tensor_key is not None:
return arch_tensor_key
# check and handle ReluMLP layers
mlp_match = re.match(r"^blk\.\d+\.fc\d\.weight$", key)
if mlp_match:
return mlp_match.group(0)
return None
def _set_vocab_gpt2(self):
dir_model = self.dir_model
hparams = self.hparams
tokens: list[bytearray] = []
toktypes: list[int] = []
from transformers import AutoTokenizer # type: ignore[attr-defined]
tokenizer = AutoTokenizer.from_pretrained(dir_model)
vocab_size = hparams.get("vocab_size", len(tokenizer.vocab))
assert max(tokenizer.vocab.values()) < vocab_size
reverse_vocab = {
id_: encoded_tok for encoded_tok, id_ in tokenizer.vocab.items()
}
added_vocab = tokenizer.get_added_vocab()
for i in range(vocab_size):
if i not in reverse_vocab:
pad_token = f"[PAD{i}]".encode("utf-8")
tokens.append(bytearray(pad_token))
toktypes.append(gguf.TokenType.USER_DEFINED)
elif reverse_vocab[i] in added_vocab:
tokens.append(reverse_vocab[i])
if tokenizer.added_tokens_decoder[i].special:
toktypes.append(gguf.TokenType.CONTROL)
else:
toktypes.append(gguf.TokenType.USER_DEFINED)
else:
tokens.append(reverse_vocab[i])
toktypes.append(gguf.TokenType.NORMAL)
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(dir_model, load_merges=True)
special_vocab.add_to_gguf(self.gguf_writer)
def _set_vocab_sentencepiece(self):
from sentencepiece import SentencePieceProcessor
tokenizer_path = self.dir_model / "tokenizer.model"
tokens: list[bytes] = []
scores: list[float] = []
toktypes: list[int] = []
if not tokenizer_path.is_file():
print(f"Error: Missing {tokenizer_path}", file=sys.stderr)
sys.exit(1)
tokenizer = SentencePieceProcessor(str(tokenizer_path))
vocab_size = self.hparams.get("vocab_size", tokenizer.vocab_size())
for token_id in range(vocab_size):
piece = tokenizer.id_to_piece(token_id)
text = piece.encode("utf-8")
score = tokenizer.get_score(token_id)
toktype = SentencePieceTokenTypes.NORMAL
if tokenizer.is_unknown(token_id):
toktype = SentencePieceTokenTypes.UNKNOWN
elif tokenizer.is_control(token_id):
toktype = SentencePieceTokenTypes.CONTROL
elif tokenizer.is_unused(token_id):
toktype = SentencePieceTokenTypes.UNUSED
elif tokenizer.is_byte(token_id):
toktype = SentencePieceTokenTypes.BYTE
tokens.append(text)
scores.append(score)
toktypes.append(toktype)
added_tokens_file = self.dir_model / "added_tokens.json"
if added_tokens_file.is_file():
with open(added_tokens_file, "r", encoding="utf-8") as f:
added_tokens_json = json.load(f)
for key in added_tokens_json:
tokens.append(key.encode("utf-8"))
scores.append(-1000.0)
toktypes.append(SentencePieceTokenTypes.USER_DEFINED)
self.gguf_writer.add_tokenizer_model("llama")
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_scores(scores)
self.gguf_writer.add_token_types(toktypes)
special_vocab = gguf.SpecialVocab(self.dir_model, n_vocab=len(tokens))
special_vocab.add_to_gguf(self.gguf_writer)
class LlamaModel(Model):
def set_vocab(self):
self._set_vocab_sentencepiece()
def set_gguf_parameters(self):
self.gguf_writer.add_name("Llama")
self.gguf_writer.add_context_length(2048) # not in config.json
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_block_count(self.hparams["num_hidden_layers"])
self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(
self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
)
self.gguf_writer.add_head_count(self.hparams["num_attention_heads"])
self.gguf_writer.add_head_count_kv(self.hparams["num_key_value_heads"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
self.gguf_writer.add_file_type(self.ftype)
def write_tensors(self):
for name, data_torch in self.get_tensors():
# we don't need these
if name.endswith(
(
".attention.masked_bias",
".attention.bias",
".attention.rotary_emb.inv_freq",
)
):
continue
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
data = data_torch.squeeze().numpy()
# map tensor names
new_name = self._translate_tensor_key(name)
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
# We need to transpose the weight matrices for the FFN Down layers to support the
# Axpy operation in PowerInfer. So we don't need to transpose them at runtime.
if "ffn_down" in new_name:
new_name = new_name.replace("ffn_down", "ffn_down_t")
data = data.T
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if (
self.ftype == 1
and data_dtype == np.float32
and name.endswith(".weight")
and n_dims == 2
):
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
class FalconModel(Model):
def set_gguf_parameters(self):
block_count = self.hparams.get("num_hidden_layers")
if block_count is None:
block_count = self.hparams["n_layer"] # old name
n_head = self.hparams.get("num_attention_heads")
if n_head is None:
n_head = self.hparams["n_head"] # old name
n_head_kv = self.hparams.get("num_kv_heads")
if n_head_kv is None:
n_head_kv = self.hparams.get("n_head_kv", 1) # old name
self.gguf_writer.add_name("Falcon")
self.gguf_writer.add_context_length(2048) # not in config.json
self.gguf_writer.add_tensor_data_layout("jploski") # qkv tensor transform
self.gguf_writer.add_embedding_length(self.hparams["hidden_size"])
self.gguf_writer.add_feed_forward_length(4 * self.hparams["hidden_size"])
self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_head_count(n_head)
self.gguf_writer.add_head_count_kv(n_head_kv)
self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_epsilon"])
self.gguf_writer.add_file_type(self.ftype)
def write_tensors(self):
n_head = self.hparams.get("num_attention_heads")
if n_head is None:
n_head = self.hparams["n_head"] # old name
n_head_kv = self.hparams.get("num_kv_heads")
if n_head_kv is None:
n_head_kv = self.hparams.get("n_head_kv", 1) # old name
head_dim = self.hparams["hidden_size"] // n_head
for name, data_torch in self.get_tensors():
old_dtype = data_torch.dtype
# convert any unsupported data types to float32
if data_torch.dtype not in (torch.float16, torch.float32):
data_torch = data_torch.to(torch.float32)
# QKV tensor transform
# The original query_key_value tensor contains n_head_kv "kv groups",
# each consisting of n_head/n_head_kv query weights followed by one key
# and one value weight (shared by all query heads in the kv group).
# This layout makes it a big pain to work with in GGML.
# So we rearrange them here,, so that we have n_head query weights
# followed by n_head_kv key weights followed by n_head_kv value weights,
# in contiguous fashion.
# ref: https://github.com/jploski/ggml/blob/falcon40b/examples/falcon/convert-hf-to-ggml.py
if "query_key_value" in name:
qkv = data_torch.view(
n_head_kv, n_head // n_head_kv + 2, head_dim, head_dim * n_head
)
q = qkv[:, :-2].reshape(n_head * head_dim, head_dim * n_head)
k = qkv[:, [-2]].reshape(n_head_kv * head_dim, head_dim * n_head)
v = qkv[:, [-1]].reshape(n_head_kv * head_dim, head_dim * n_head)
data_torch = torch.cat((q, k, v)).reshape_as(data_torch)
data = data_torch.squeeze().numpy()
# map tensor names
new_name = self._translate_tensor_key(name)
if new_name is None:
print(f"Can not map tensor {name!r}")
sys.exit()
# We need to transpose the weight matrices for the FFN Down layers to support the
# Axpy operation in PowerInfer. So we don't need to transpose them at runtime.
if "ffn_down" in new_name:
new_name = new_name.replace("ffn_down", "ffn_down_t")
data = data.T
n_dims = len(data.shape)
data_dtype = data.dtype
# if f32 desired, convert any float16 to float32
if self.ftype == 0 and data_dtype == np.float16:
data = data.astype(np.float32)
# TODO: Why cant we use these float16 as-is? There should be not reason to store float16 as float32
if self.ftype == 1 and data_dtype == np.float16 and n_dims == 1:
data = data.astype(np.float32)
# if f16 desired, convert any float32 2-dim weight tensors to float16
if (
self.ftype == 1
and data_dtype == np.float32
and name.endswith(".weight")
and n_dims == 2
):
data = data.astype(np.float16)
print(f"{new_name}, n_dims = {n_dims}, {old_dtype} --> {data.dtype}")
self.gguf_writer.add_tensor(new_name, data)
###### CONVERSION LOGIC ######
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Convert a huggingface model to a GGML compatible file"
)
parser.add_argument(
"--vocab-only",
action="store_true",
help="extract only the vocab",
)
parser.add_argument(
"--outfile",
type=Path,
help="path to write to; default: based on input",
)
parser.add_argument(
"--outtype",
type=str,
choices=["f32", "f16"],
default="f16",
help="output format - use f32 for float32, f16 for float16",
)
parser.add_argument(
"--bigendian",
action="store_true",
help="model is executed on big endian machine",
)
parser.add_argument(
"model",
type=Path,
help="directory containing model file",
)
parser.add_argument(
"mlp_predictors",
type=Path,
help="directory containing MLP predictors for model",
)
return parser.parse_args()
args = parse_args()
dir_model = args.model
dir_mlp_pred = args.mlp_predictors
if not dir_model.is_dir():
print(f"Error: {args.model} is not a directory", file=sys.stderr)
sys.exit(1)
if not dir_mlp_pred.is_dir():
print(f"Error: {args.mlp_predictors} is not a directory", file=sys.stderr)
sys.exit(1)
ftype_map = {
"f32": gguf.GGMLQuantizationType.F32,
"f16": gguf.GGMLQuantizationType.F16,
}
if args.outfile is not None:
fname_out = args.outfile
else:
# output in the same directory as the model by default
fname_out = dir_model / f"ggml-model-{args.outtype}.gguf"
print(f"Loading model: {dir_model.name}")
hparams = Model.load_hparams(dir_model)
model_class = Model.from_model_architecture(hparams["architectures"][0])
model_instance = model_class(
dir_model, dir_mlp_pred, ftype_map[args.outtype], fname_out, args.bigendian
)
print("Set model parameters")
model_instance.set_gguf_parameters()
print("Set model tokenizer")
model_instance.set_vocab()
if args.vocab_only:
print(f"Exporting model vocab to '{fname_out}'")
model_instance.write_vocab()
else:
print(f"Exporting model to '{fname_out}'")
model_instance.write()
# post-process: write another unique file header to distinguish from the origianl GGUF file
with open(fname_out, "r+b") as fout:
POWERINFER_MAGIC = int.from_bytes(b"PWRI", "little")
fout.write(struct.pack("<I", POWERINFER_MAGIC))
print(f"Model successfully exported to '{fname_out}'")

View file

@ -509,6 +509,14 @@ class LazyTensor:
def load() -> Tensor: def load() -> Tensor:
return self.load().astype(data_type) return self.load().astype(data_type)
return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}') return LazyTensor(load, self.shape, data_type, f'convert({data_type}) {self.description}')
def transposed(self) -> LazyTensor:
def load() -> Tensor:
loaded = self.load()
assert isinstance(loaded, UnquantizedTensor), f'Cannot transpose {loaded}'
loaded.ndarray = loaded.ndarray.T
return loaded
return LazyTensor(load, self.shape[::-1], self.data_type, f'transpose {self.description}')
def validate_conversion_to(self, data_type: DataType) -> None: def validate_conversion_to(self, data_type: DataType) -> None:
if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions: if data_type != self.data_type and data_type.name not in self.data_type.valid_conversions:
@ -571,7 +579,8 @@ def merge_multifile_models(models_plus: list[ModelPlus]) -> ModelPlus:
except StopIteration: except StopIteration:
vocab = None vocab = None
if any("model.embed_tokens.weight" in mp.model for mp in models_plus): if any("model.embed_tokens.weight" in mp.model for mp in models_plus) or \
any("model.layers.0.fc1.weight" in mp.model for mp in models_plus):
# Transformers models put different tensors in different files, but # Transformers models put different tensors in different files, but
# don't split indivdual tensors between files. # don't split indivdual tensors between files.
model: LazyModel = {} model: LazyModel = {}
@ -992,6 +1001,18 @@ def convert_model_names(model: LazyModel, params: Params) -> LazyModel:
return out return out
def postprocess_transpose(model: LazyModel) -> LazyModel:
"""Transpose ffn_down matrices for Axpy ops."""
out: LazyModel = {}
for name, lazy_tensor in model.items():
if name.endswith(".ffn_down.weight"):
out[name.replace("ffn_down", "ffn_down_t")] = lazy_tensor.transposed()
else:
out[name] = lazy_tensor
return out
def nth_multifile_path(path: Path, n: int) -> Path | None: def nth_multifile_path(path: Path, n: int) -> Path | None:
'''Given any path belonging to a multi-file model (e.g. foo.bin.1), return '''Given any path belonging to a multi-file model (e.g. foo.bin.1), return
the nth path in the model. the nth path in the model.
@ -1003,7 +1024,9 @@ def nth_multifile_path(path: Path, n: int) -> Path | None:
# - x-00001-of-00002.bin, x-00002-of-00002.bin, etc. # - x-00001-of-00002.bin, x-00002-of-00002.bin, etc.
(r'-[0-9]{5}-of-(.*)$', fr'-{n:05}-of-\1'), (r'-[0-9]{5}-of-(.*)$', fr'-{n:05}-of-\1'),
# x.bin, x.bin.1, etc. # x.bin, x.bin.1, etc.
(r'(\.[0-9]+)?$', r'\1' if n == 0 else fr'\1.{n}') (r'(\.[0-9]+)?$', r'\1' if n == 0 else fr'\1.{n}'),
# x_0.pt, x_1.pt, etc.
(r'(_[0-9]+)?\.pt$', fr'_{n}.pt'),
] ]
for regex, replacement in patterns: for regex, replacement in patterns:
if re.search(regex, path.name): if re.search(regex, path.name):
@ -1057,6 +1080,25 @@ def load_some_model(path: Path) -> ModelPlus:
model_plus = merge_multifile_models(models_plus) model_plus = merge_multifile_models(models_plus)
return model_plus return model_plus
def load_mlp_model(path: Path) -> ModelPlus:
'''Load MLP models for sparse attention from directory.'''
assert path.is_dir(), f"MLP model path {path} is not a directory"
first_model_path = path / "model_0.pt"
assert first_model_path.resolve(), f"MLP model path {path} does not contain model_0.pt"
model_paths = find_multifile_paths(first_model_path)
models_plus: list[ModelPlus] = []
for model_path in model_paths:
# find number in model_path
model_layer = int(re.search(r'model_(\d+).pt', str(model_path)).group(1))
print(f"Loading MLP model file {model_path}")
mlp_model = lazy_load_file(model_path)
mlp_model.model = {f"model.layers.{model_layer}.{name}": tensor for name, tensor in mlp_model.model.items()}
models_plus.append(mlp_model)
return merge_multifile_models(models_plus)
def load_vocab(path: Path, vocabtype: str | None) -> Vocab: def load_vocab(path: Path, vocabtype: str | None) -> Vocab:
# Be extra-friendly and accept either a file or a directory. Also, if it's # Be extra-friendly and accept either a file or a directory. Also, if it's
@ -1125,6 +1167,7 @@ def main(args_in: list[str] | None = None) -> None:
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file") parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input") parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin, *.safetensors)") parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin, *.safetensors)")
parser.add_argument("mlp_model", type=Path, help="MLP model for sparse attention")
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm") parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)") parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY) parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
@ -1138,6 +1181,8 @@ def main(args_in: list[str] | None = None) -> None:
if not args.vocab_only: if not args.vocab_only:
model_plus = load_some_model(args.model) model_plus = load_some_model(args.model)
mlp_predictor_plus = load_mlp_model(args.mlp_model)
model_plus = merge_multifile_models([model_plus, mlp_predictor_plus])
else: else:
model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None) model_plus = ModelPlus(model = {}, paths = [args.model / 'dummy'], format = 'none', vocab = None)
@ -1192,6 +1237,7 @@ def main(args_in: list[str] | None = None) -> None:
model = model_plus.model model = model_plus.model
model = convert_model_names(model, params) model = convert_model_names(model, params)
model = postprocess_transpose(model)
ftype = pick_output_type(model, args.outtype) ftype = pick_output_type(model, args.outtype)
model = convert_to_output_type(model, ftype) model = convert_to_output_type(model, ftype)
outfile = args.outfile or default_outfile(model_plus.paths, ftype) outfile = args.outfile or default_outfile(model_plus.paths, ftype)
@ -1202,6 +1248,11 @@ def main(args_in: list[str] | None = None) -> None:
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency, endianess=endianess) OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency, endianess=endianess)
print(f"Wrote {outfile}") print(f"Wrote {outfile}")
# post-process: write another unique file header to distinguish from the origianl GGUF file
with open(outfile, "r+b") as fout:
POWERINFER_MAGIC = int.from_bytes(b"PWRI", "little")
fout.write(struct.pack("<I", POWERINFER_MAGIC))
if __name__ == '__main__': if __name__ == '__main__':
main() main()

View file

@ -11,7 +11,7 @@ int main(int argc, char ** argv) {
gpt_params params; gpt_params params;
if (argc == 1 || argv[1][0] == '-') { if (argc == 1 || argv[1][0] == '-') {
printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN] [NGL]\n" , argv[0]); printf("usage: %s MODEL_PATH [PROMPT] [PARALLEL] [LEN] [NGL] [N_THREAD] [MLP_PATH]\n" , argv[0]);
return 1 ; return 1 ;
} }
@ -44,6 +44,17 @@ int main(int argc, char ** argv) {
n_gpu_layers = std::atoi(argv[5]); n_gpu_layers = std::atoi(argv[5]);
} }
if (argc >= 7) {
params.n_threads = std::atoi(argv[6]);
}
if (argc >= 8) {
params.mlp_adapter = argv[7];
}
printf("params: model = %s, prompt = %s, n_parallel = %d, n_len = %d, n_gpu_layers = %d, n_threads = %d, mlp_adapter = %s\n",
params.model.c_str(), params.prompt.c_str(), n_parallel, n_len, n_gpu_layers, params.n_threads, params.mlp_adapter.c_str());
if (params.prompt.empty()) { if (params.prompt.empty()) {
params.prompt = "Hello my name is"; params.prompt = "Hello my name is";
} }
@ -65,6 +76,21 @@ int main(int argc, char ** argv) {
return 1; return 1;
} }
if (!params.mlp_adapter.empty()) {
int err = llama_model_apply_mlp_from_file(model, params.mlp_adapter.c_str(), true);
if (err != 0) {
fprintf(stderr, "%s: error: failed to apply mlp adapter\n", __func__);
llama_free_model(model);
return 1;
}
}
if (llama_model_apply_augmentation(model) != 0) {
fprintf(stderr, "%s: error: failed to apply model augmentation\n", __func__);
llama_free_model(model);
return 1;
}
// tokenize the prompt // tokenize the prompt
std::vector<llama_token> tokens_list; std::vector<llama_token> tokens_list;

File diff suppressed because it is too large Load diff

View file

@ -29,7 +29,11 @@ GGML_API void ggml_cuda_host_free(void * ptr);
GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst); GGML_API bool ggml_cuda_can_mul_mat(const struct ggml_tensor * src0, const struct ggml_tensor * src1, struct ggml_tensor * dst);
GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split); GGML_API void ggml_cuda_set_tensor_split(const float * tensor_split);
GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor); GGML_API void ggml_cuda_transform_tensor(void * data, struct ggml_tensor * tensor);
GGML_API void ggml_cuda_alloc_tensor(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor); GGML_API void ggml_cuda_free_data(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_cpy_1d(struct ggml_tensor * dst, const struct ggml_tensor * src);
GGML_API bool debug_equal(short *a, short *b);
GGML_API void **ggml_cuda_get_data_pp(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor); GGML_API void ggml_cuda_assign_buffers(struct ggml_tensor * tensor);
GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor); GGML_API void ggml_cuda_assign_buffers_no_scratch(struct ggml_tensor * tensor);

View file

@ -2423,6 +2423,78 @@ static inline __m128i get_scale_shuffle(int i) {
} }
#endif #endif
void ggml_axpy_q4_0_q8_0(const int n, const void * restrict vx, const void * restrict vy, const void * restrict vz, int8_t alpha, ggml_fp16_t scale) {
const int qk = QK8_0;
const int nb = n / qk;
assert(n % qk == 0);
assert(nb % 2 == 0);
const block_q4_0 * restrict x = vx;
// Initialize accumulator with zeros
__m256 acc = _mm256_setzero_ps();
__m256i alpha_v = _mm256_set1_epi16((short)alpha);
// Main loop
for (int i = 0; i < nb; ++i) {
/* Compute combined scale for the block */
const __m256 d = _mm256_set1_ps( GGML_FP16_TO_FP32(x[i].d) * GGML_FP16_TO_FP32(scale) );
__m256i bx = bytes_from_nibbles_32(x[i].qs);
// Now we have a vector with bytes in [ 0 .. 15 ] interval. Offset them into [ -8 .. +7 ] interval.
const __m256i off = _mm256_set1_epi8( 8 );
bx = _mm256_sub_epi8( bx, off );
//16个数计算
__m128i m_a = _mm256_extracti128_si256(bx, 0);
__m256i m_x = _mm256_cvtepi8_epi16(m_a); //16 elements
m_x = _mm256_mullo_epi16(m_x, alpha_v);
__m128i x_0 = _mm256_extracti128_si256(m_x, 0);
__m256i x0_32 = _mm256_cvtepi16_epi32(x_0);
__m256 fx0 = _mm256_cvtepi32_ps(x0_32);
fx0 = _mm256_mul_ps(fx0, d);
__m256 by = _mm256_loadu_ps((const __m256 *)(vy+i*128));
by = _mm256_add_ps(by, fx0);
_mm256_storeu_ps((__m256*)(vz + i*128), by);
//second phase
x_0 = _mm256_extracti128_si256(m_x, 1);
x0_32 = _mm256_cvtepi16_epi32(x_0);
fx0 = _mm256_cvtepi32_ps(x0_32);
fx0 = _mm256_mul_ps(fx0, d);
by = _mm256_loadu_ps((const __m256 *)(vy+i*128+32));
by = _mm256_add_ps(by, fx0);
_mm256_storeu_ps((__m256*)(vz + i*128+32), by);
//third phase
m_a = _mm256_extracti128_si256(bx, 1);
m_x = _mm256_cvtepi8_epi16(m_a);
m_x = _mm256_mullo_epi16(m_x, alpha_v);
x_0 = _mm256_extracti128_si256(m_x, 0);
x0_32 = _mm256_cvtepi16_epi32(x_0);
fx0 = _mm256_cvtepi32_ps(x0_32);
fx0 = _mm256_mul_ps(fx0, d);
by = _mm256_loadu_ps((const __m256 *)(vy+i*128+64));
by = _mm256_add_ps(by, fx0);
_mm256_storeu_ps((__m256*)(vz + i*128+64), by);
//fourth phase
x_0 = _mm256_extracti128_si256(m_x, 1);
x0_32 = _mm256_cvtepi16_epi32(x_0);
fx0 = _mm256_cvtepi32_ps(x0_32);
fx0 = _mm256_mul_ps(fx0, d);
by = _mm256_loadu_ps((const __m256 *)(vy+i*128+96));
by = _mm256_add_ps(by, fx0);
_mm256_storeu_ps((__m256*)(vz + i*128+96), by);
}
}
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) { void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
const int qk = QK8_0; const int qk = QK8_0;
const int nb = n / qk; const int nb = n / qk;

View file

@ -210,6 +210,8 @@ void dequantize_row_q5_K(const block_q5_K * restrict x, float * restrict y, int
void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k); void dequantize_row_q6_K(const block_q6_K * restrict x, float * restrict y, int k);
void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k); void dequantize_row_q8_K(const block_q8_K * restrict x, float * restrict y, int k);
void ggml_axpy_q4_0_q8_0(const int n, const void * restrict vx, const void * restrict vy, const void * restrict vz, int8_t alpha, ggml_fp16_t scale);
// Dot product // Dot product
void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy); void ggml_vec_dot_q4_0_q8_0(int n, float * restrict s, const void * restrict vx, const void * restrict vy);
void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy); void ggml_vec_dot_q4_1_q8_1(int n, float * restrict s, const void * restrict vx, const void * restrict vy);

1492
ggml.c

File diff suppressed because it is too large Load diff

69
ggml.h
View file

@ -207,6 +207,14 @@
#include <stdint.h> #include <stdint.h>
#include <stddef.h> #include <stddef.h>
#include <stdbool.h> #include <stdbool.h>
#ifdef __cplusplus
#include <atomic>
using std::atomic_int;
using std::memory_order;
using std::memory_order_acquire;
#else /* not __cplusplus */
#include <stdatomic.h>
#endif /* __cplusplus */
#define GGML_FILE_MAGIC 0x67676d6c // "ggml" #define GGML_FILE_MAGIC 0x67676d6c // "ggml"
#define GGML_FILE_VERSION 1 #define GGML_FILE_VERSION 1
@ -232,6 +240,7 @@
#define GGML_EXIT_ABORTED 1 #define GGML_EXIT_ABORTED 1
#define GGUF_MAGIC "GGUF" #define GGUF_MAGIC "GGUF"
#define GGUF_POWERINFER_MAGIC "PWRI"
#define GGUF_VERSION 3 #define GGUF_VERSION 3
@ -336,6 +345,11 @@ extern "C" {
GGML_BACKEND_GPU_SPLIT = 20, GGML_BACKEND_GPU_SPLIT = 20,
}; };
enum ggml_sparse_deriv {
GGML_DENSE_INFERENCE = 0,
GGML_SPARSE_INFERENCE = 1,
};
// model file types // model file types
enum ggml_ftype { enum ggml_ftype {
GGML_FTYPE_UNKNOWN = -1, GGML_FTYPE_UNKNOWN = -1,
@ -382,6 +396,7 @@ extern "C" {
GGML_OP_GROUP_NORM, GGML_OP_GROUP_NORM,
GGML_OP_MUL_MAT, GGML_OP_MUL_MAT,
GGML_OP_AXPY,
GGML_OP_OUT_PROD, GGML_OP_OUT_PROD,
GGML_OP_SCALE, GGML_OP_SCALE,
@ -504,6 +519,7 @@ extern "C" {
struct ggml_tensor * src[GGML_MAX_SRC]; struct ggml_tensor * src[GGML_MAX_SRC];
// performance // performance
atomic_int is_finish;
int perf_runs; int perf_runs;
int64_t perf_cycles; int64_t perf_cycles;
int64_t perf_time_us; int64_t perf_time_us;
@ -520,6 +536,9 @@ extern "C" {
char padding[12]; char padding[12];
}; };
static const int64_t GGML_NE_WILDCARD = -1;
static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor); static const size_t GGML_TENSOR_SIZE = sizeof(struct ggml_tensor);
// the compute plan that needs to be prepared for ggml_graph_compute() // the compute plan that needs to be prepared for ggml_graph_compute()
@ -573,6 +592,22 @@ extern "C" {
void * data; void * data;
}; };
struct ggml_context {
size_t mem_size;
void * mem_buffer;
bool mem_buffer_owned;
bool no_alloc;
bool no_alloc_save; // this is used to save the no_alloc state when using scratch buffers
int n_objects;
struct ggml_object * objects_begin;
struct ggml_object * objects_end;
struct ggml_scratch scratch;
struct ggml_scratch scratch_save;
};
struct ggml_init_params { struct ggml_init_params {
// memory pool // memory pool
size_t mem_size; // bytes size_t mem_size; // bytes
@ -600,6 +635,7 @@ extern "C" {
// work buffer for all threads // work buffer for all threads
size_t wsize; size_t wsize;
void * wdata; void * wdata;
atomic_int *aic;
}; };
// misc // misc
@ -618,6 +654,8 @@ extern "C" {
GGML_API void ggml_print_object (const struct ggml_object * obj); GGML_API void ggml_print_object (const struct ggml_object * obj);
GGML_API void ggml_print_objects(const struct ggml_context * ctx); GGML_API void ggml_print_objects(const struct ggml_context * ctx);
GGML_API
GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor); GGML_API int64_t ggml_nelements (const struct ggml_tensor * tensor);
GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor); GGML_API int64_t ggml_nrows (const struct ggml_tensor * tensor);
GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor); GGML_API size_t ggml_nbytes (const struct ggml_tensor * tensor);
@ -727,6 +765,7 @@ extern "C" {
GGML_API void * ggml_get_data (const struct ggml_tensor * tensor); GGML_API void * ggml_get_data (const struct ggml_tensor * tensor);
GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor); GGML_API float * ggml_get_data_f32(const struct ggml_tensor * tensor);
GGML_API int32_t * ggml_get_data_i32(const struct ggml_tensor * tensor);
GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor); GGML_API enum ggml_unary_op ggml_get_unary_op(const struct ggml_tensor * tensor);
@ -735,6 +774,9 @@ extern "C" {
GGML_ATTRIBUTE_FORMAT(2, 3) GGML_ATTRIBUTE_FORMAT(2, 3)
GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...); GGML_API struct ggml_tensor * ggml_format_name( struct ggml_tensor * tensor, const char * fmt, ...);
GGML_API void ggml_set_backend(struct ggml_tensor * tensor, enum ggml_backend_type backend);
// //
// operations on tensors with backpropagation // operations on tensors with backpropagation
// //
@ -753,6 +795,12 @@ extern "C" {
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
GGML_API struct ggml_tensor *ggml_add_idx(
struct ggml_context *ctx,
struct ggml_tensor *a,
struct ggml_tensor *b,
struct ggml_tensor *idx);
GGML_API struct ggml_tensor * ggml_add_inplace( GGML_API struct ggml_tensor * ggml_add_inplace(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
@ -1027,6 +1075,25 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
GGML_API struct ggml_tensor *ggml_mul_mat_idx(
struct ggml_context *ctx,
struct ggml_tensor *a,
struct ggml_tensor *b,
struct ggml_tensor *idx,
struct ggml_tensor *d);
GGML_API struct ggml_tensor *ggml_mul_mat_special(
struct ggml_context *ctx,
struct ggml_tensor *a,
struct ggml_tensor *b,
struct ggml_tensor *idx,
struct ggml_tensor *d,
struct ggml_tensor *ref);
GGML_API struct ggml_tensor *ggml_axpy(
struct ggml_context *ctx,
struct ggml_tensor *a,
struct ggml_tensor *b,
struct ggml_tensor *c,
struct ggml_tensor *d);
// A: m columns, n rows, // A: m columns, n rows,
// B: p columns, n rows, // B: p columns, n rows,
@ -2013,6 +2080,7 @@ extern "C" {
}; };
GGML_API struct gguf_context * gguf_init_empty(void); GGML_API struct gguf_context * gguf_init_empty(void);
GGML_API struct gguf_context * gguf_init_empty_sparse(void);
GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params); GGML_API struct gguf_context * gguf_init_from_file(const char * fname, struct gguf_init_params params);
//GGML_API struct gguf_context * gguf_init_from_buffer(..); //GGML_API struct gguf_context * gguf_init_from_buffer(..);
@ -2049,6 +2117,7 @@ extern "C" {
GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id); GGML_API const void * gguf_get_arr_data(const struct gguf_context * ctx, int key_id);
GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i); GGML_API const char * gguf_get_arr_str (const struct gguf_context * ctx, int key_id, int i);
GGML_API enum ggml_sparse_deriv gguf_get_sparse_deriv(const struct gguf_context * ctx);
GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx); GGML_API int gguf_get_n_tensors (const struct gguf_context * ctx);
GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name); GGML_API int gguf_find_tensor (const struct gguf_context * ctx, const char * name);
GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i); GGML_API size_t gguf_get_tensor_offset(const struct gguf_context * ctx, int i);

View file

@ -115,6 +115,10 @@ class MODEL_TENSOR(IntEnum):
FFN_NORM = auto() FFN_NORM = auto()
ATTN_Q_NORM = auto() ATTN_Q_NORM = auto()
ATTN_K_NORM = auto() ATTN_K_NORM = auto()
FFN_DOWN_T = auto()
FC_1 = auto()
FC_2 = auto()
MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
@ -155,6 +159,9 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate", MODEL_TENSOR.FFN_GATE: "blk.{bid}.ffn_gate",
MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down", MODEL_TENSOR.FFN_DOWN: "blk.{bid}.ffn_down",
MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up", MODEL_TENSOR.FFN_UP: "blk.{bid}.ffn_up",
MODEL_TENSOR.FFN_DOWN_T: "blk.{bid}.ffn_down_t",
MODEL_TENSOR.FC_1: "blk.{bid}.fc1",
MODEL_TENSOR.FC_2: "blk.{bid}.fc2",
} }
MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
@ -173,6 +180,9 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN, MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP, MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_DOWN_T,
MODEL_TENSOR.FC_1,
MODEL_TENSOR.FC_2,
], ],
MODEL_ARCH.GPTNEOX: [ MODEL_ARCH.GPTNEOX: [
MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.TOKEN_EMBD,

View file

@ -194,6 +194,14 @@ class TensorNameMap:
MODEL_TENSOR.ROPE_FREQS: ( MODEL_TENSOR.ROPE_FREQS: (
"language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon "language_model.encoder.layers.{bid}.self_attention.rotary_emb.inv_freq", # persimmon
), ),
MODEL_TENSOR.FC_1: (
"model.layers.{bid}.fc1",
),
MODEL_TENSOR.FC_2: (
"model.layers.{bid}.fc2",
),
} }
mapping: dict[str, tuple[MODEL_TENSOR, str]] mapping: dict[str, tuple[MODEL_TENSOR, str]]

916
llama.cpp

File diff suppressed because it is too large Load diff

View file

@ -293,6 +293,7 @@ extern "C" {
LLAMA_API int llama_n_ctx (const struct llama_context * ctx); LLAMA_API int llama_n_ctx (const struct llama_context * ctx);
LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model); LLAMA_API enum llama_vocab_type llama_vocab_type(const struct llama_model * model);
LLAMA_API bool llama_use_sparse_inference(const struct llama_model * model);
LLAMA_API int llama_n_vocab (const struct llama_model * model); LLAMA_API int llama_n_vocab (const struct llama_model * model);
LLAMA_API int llama_n_ctx_train(const struct llama_model * model); LLAMA_API int llama_n_ctx_train(const struct llama_model * model);
@ -340,6 +341,13 @@ extern "C" {
const char * path_base_model, const char * path_base_model,
int n_threads); int n_threads);
LLAMA_API int llama_model_apply_mlp_from_file(
struct llama_model * model,
const char * path_mlp,
bool use_mmap);
LLAMA_API int llama_model_apply_augmentation(struct llama_model * model);
// //
// KV cache // KV cache
// //

142
scripts/export-gpu-split.py Normal file
View file

@ -0,0 +1,142 @@
#!/usr/bin/env python3
import argparse
import torch
import torch.nn as tnn
from pathlib import Path
import os
import re
import struct
from typing import Any, BinaryIO
import numpy as np
import pickle
class ReluMLP(tnn.Module):
def __init__(self, input_dim, hidden_dim, output_dim):
super(ReluMLP, self).__init__()
self.fc1 = tnn.Linear(input_dim, hidden_dim, bias=False)
self.relu = tnn.ReLU()
self.fc2 = tnn.Linear(hidden_dim, output_dim, bias=False)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
def _load_mlp_model(model_file: Path):
model = torch.load(model_file)
# hidden_size, input_size = model.get("fc1.weight").shape
# output_size, _ = model.get("fc2.weight").shape
# mlp = ReluMLP(input_size, hidden_size, output_size)
# mlp.load_state_dict(model)
return model
def load_mlp_predictors(models_base: Path):
# TODO: might need a specification file to indicate which models to load.
# But for now, let's assume it is a plain directory of models_{0, ... , n_layers - 1}.pt
*_, files = next(os.walk(models_base))
return [_load_mlp_model(models_base / f"activation_{i}.pt") for i in range(len(files))]
def write_file_header(fout: BinaryIO, n_tensors: int) -> None:
fout.write(b"gglp"[::-1]) # magic (GGml mLP)
fout.write(struct.pack("i", 1)) # file version
# TODO: If we found we need more common parameters, we can add them here.
fout.write(struct.pack("i", n_tensors))
def write_tensor_header(
fout: BinaryIO, key: str, shape: tuple[int, ...], dtype: np.dtype
) -> None:
_NUMPY_TYPE_TO_FTYPE: dict[str, int] = {"float32": 0, "float16": 1, "int32": 18}
bkey = key.encode("utf-8")
fout.write(
struct.pack("iii", len(shape), len(bkey), _NUMPY_TYPE_TO_FTYPE[dtype.name])
)
fout.write(struct.pack("i" * len(shape), *shape))
fout.write(bkey)
# Aligns to 32 bytes
fout.seek((fout.tell() + 31) & -32)
# TODO: need to add more details in key name to indicate the network, layer number, etc.
def _translate_mlp_key(key: str) -> str:
match = re.match(r"^(fc\d+).weight$", key)
if not match or len(match.groups()) != 1:
raise ValueError(f"Unexpected key: {key}")
return f"{match.group(1)}.weight.mlp"
def append_mlp_model(fout: BinaryIO, model: ReluMLP) -> None:
model_dict = model.state_dict()
for k, v in model_dict.items():
key = _translate_mlp_key(k)
# torch.nn.Linear stores the weight matrix as (output_dim, input_dim), so does GGML.
weights = v.half().detach().numpy()
# GGML stores the weight matrix as (input_dim, output_dim)
dims = weights.shape[::-1]
print(
f"{k} => {key} {weights.shape} {weights.dtype} {weights.nbytes/1024/1024} MiB"
)
# TODO: add option to write in float32
write_tensor_header(fout, key, dims, np.dtype("float16"))
weights.tofile(fout)
def append_gpu_idx(fout: BinaryIO, activation, select_count) -> None:
values, indices = torch.topk(activation, k=int(select_count))
gpu_idx = torch.zeros_like(activation)
gpu_idx[indices] = 1
gpu_idx = gpu_idx.numpy().astype(np.int32)
weights = gpu_idx
dims = gpu_idx.shape[::-1]
key = "gpu_idx"
print(
f"{key} => {key} {weights.shape} {weights.dtype} {weights.nbytes/1024/1024} MiB"
)
write_tensor_header(fout, key, dims, np.dtype("int32"))
weights.tofile(fout)
indices = indices.numpy().astype(np.int32)
weights = indices
dims = weights.shape[::-1]
key = "gpu_bucket"
print(
f"{key} => {key} {weights.shape} {weights.dtype} {weights.nbytes/1024/1024} MiB"
)
write_tensor_header(fout, key, dims, np.dtype("int32"))
weights = np.sort(weights)
weights.tofile(fout)
def main(predictors_path: str, output_path: str, solver_path: str):
predictors = load_mlp_predictors(Path(predictors_path)) # predictor => activation acount
n_tensors = len(predictors) * 2 # gpu_idx and gpu_bucket
print(f"found {len(predictors)} MLP adapters with {n_tensors} tensors")
with open(solver_path, "rb") as f:
loaded_lst = pickle.load(f)
# print(f"check solver {loaded_lst}")
with open(output_path, "wb") as fout:
fout.truncate()
write_file_header(fout, n_tensors=n_tensors)
for i, activation in enumerate(predictors):
print(f"appending gpu idx layer-{i}")
append_gpu_idx(fout, activation, loaded_lst[i])
# append_gpu_idx(fout, activation, (32768*0.0))
print(f"converted MLP adapters from {predictors_path} to {output_path}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("predictors_path", help="path to the MLP predictors")
parser.add_argument(
"output_path",
help="path to the output GGML adapter",
default="./ggml-mlp-adapters.bin",
)
parser.add_argument("solver", help="path to the solver")
args = parser.parse_args()
main(args.predictors_path, args.output_path, args.solver)