support --bigendian option for s390x
1. verified with baichuan7b-chat with float 16 on s390x 2. verified with baichuan7b-chat 3. verified with chinese-alpaca-2-13b-f16
This commit is contained in:
parent
fa62c8c73a
commit
1ce890a7c0
3 changed files with 65 additions and 38 deletions
|
@ -73,6 +73,7 @@ def parse_args() -> argparse.Namespace:
|
||||||
"ftype", type=int, choices=[0, 1], default=1, nargs='?',
|
"ftype", type=int, choices=[0, 1], default=1, nargs='?',
|
||||||
help="output format - use 0 for float32, 1 for float16",
|
help="output format - use 0 for float32, 1 for float16",
|
||||||
)
|
)
|
||||||
|
parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
@ -83,6 +84,10 @@ if not dir_model.is_dir():
|
||||||
print(f'Error: {args.model} is not a directory', file = sys.stderr)
|
print(f'Error: {args.model} is not a directory', file = sys.stderr)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
endianess =gguf.GGUFEndian.LITTLE
|
||||||
|
if args.bigendian:
|
||||||
|
endianess = gguf.GGUFEndian.BIG
|
||||||
|
print(f"gguf: Conversion Endianess {endianess}")
|
||||||
# possible tensor data types
|
# possible tensor data types
|
||||||
# ftype == 0 -> float32
|
# ftype == 0 -> float32
|
||||||
# ftype == 1 -> float16
|
# ftype == 1 -> float16
|
||||||
|
@ -110,7 +115,7 @@ if hparams["architectures"][0] != "BaichuanForCausalLM":
|
||||||
num_parts = count_model_parts(dir_model)
|
num_parts = count_model_parts(dir_model)
|
||||||
print(f"num_parts:{num_parts}\n")
|
print(f"num_parts:{num_parts}\n")
|
||||||
ARCH=gguf.MODEL_ARCH.BAICHUAN
|
ARCH=gguf.MODEL_ARCH.BAICHUAN
|
||||||
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
|
gguf_writer = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
|
||||||
|
|
||||||
print("gguf: get model metadata")
|
print("gguf: get model metadata")
|
||||||
|
|
||||||
|
|
23
convert.py
23
convert.py
|
@ -818,8 +818,8 @@ def check_vocab_size(params: Params, vocab: Vocab) -> None:
|
||||||
|
|
||||||
|
|
||||||
class OutputFile:
|
class OutputFile:
|
||||||
def __init__(self, fname_out: Path) -> None:
|
def __init__(self, fname_out: Path, endianess:gguf.GGUFEndian=gguf.GGUFEndian.LITTLE) -> None:
|
||||||
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH])
|
self.gguf = gguf.GGUFWriter(fname_out, gguf.MODEL_ARCH_NAMES[ARCH], endianess=endianess)
|
||||||
|
|
||||||
def add_meta_arch(self, params: Params) -> None:
|
def add_meta_arch(self, params: Params) -> None:
|
||||||
name = "LLaMA"
|
name = "LLaMA"
|
||||||
|
@ -890,10 +890,10 @@ class OutputFile:
|
||||||
self.gguf.close()
|
self.gguf.close()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab) -> None:
|
def write_vocab_only(fname_out: Path, params: Params, vocab: Vocab, svocab: gguf.SpecialVocab, endianess:gguf.GGUFEndian=gguf.GGUFEndian.LITTLE) -> None:
|
||||||
check_vocab_size(params, vocab)
|
check_vocab_size(params, vocab)
|
||||||
|
|
||||||
of = OutputFile(fname_out)
|
of = OutputFile(fname_out, endianess=endianess)
|
||||||
|
|
||||||
# meta data
|
# meta data
|
||||||
of.add_meta_arch(params)
|
of.add_meta_arch(params)
|
||||||
|
@ -918,10 +918,10 @@ class OutputFile:
|
||||||
return dt.quantize(arr)
|
return dt.quantize(arr)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
|
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, svocab: gguf.SpecialVocab, concurrency: int = DEFAULT_CONCURRENCY, endianess=gguf.GGUFEndian.LITTLE) -> None:
|
||||||
check_vocab_size(params, vocab)
|
check_vocab_size(params, vocab)
|
||||||
|
|
||||||
of = OutputFile(fname_out)
|
of = OutputFile(fname_out, endianess=endianess)
|
||||||
|
|
||||||
# meta data
|
# meta data
|
||||||
of.add_meta_arch(params)
|
of.add_meta_arch(params)
|
||||||
|
@ -947,7 +947,8 @@ class OutputFile:
|
||||||
elapsed = time.time() - start
|
elapsed = time.time() - start
|
||||||
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
|
size = ' x '.join(f"{dim:6d}" for dim in lazy_tensor.shape)
|
||||||
padi = len(str(len(model)))
|
padi = len(str(len(model)))
|
||||||
ndarray.byteswap(inplace=True)
|
if endianess==gguf.GGUFEndian.BIG:
|
||||||
|
ndarray.byteswap(inplace=True)
|
||||||
print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}")
|
print(f"[{i+1:{padi}d}/{len(model)}] Writing tensor {name:38s} | size {size:16} | type {lazy_tensor.data_type.name:4} | T+{int(elapsed):4}")
|
||||||
of.gguf.write_tensor_data(ndarray)
|
of.gguf.write_tensor_data(ndarray)
|
||||||
|
|
||||||
|
@ -1139,8 +1140,9 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
|
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
|
||||||
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
|
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
|
||||||
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
|
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
|
||||||
args = parser.parse_args(args_in)
|
parser.add_argument("--bigendian", action="store_true", help="model is executed on big endian machine")
|
||||||
|
|
||||||
|
args = parser.parse_args(args_in)
|
||||||
if args.dump_single:
|
if args.dump_single:
|
||||||
model_plus = lazy_load_file(args.model)
|
model_plus = lazy_load_file(args.model)
|
||||||
do_dump_model(model_plus)
|
do_dump_model(model_plus)
|
||||||
|
@ -1154,6 +1156,9 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
if args.dump:
|
if args.dump:
|
||||||
do_dump_model(model_plus)
|
do_dump_model(model_plus)
|
||||||
return
|
return
|
||||||
|
endianess = gguf.GGUFEndian.LITTLE
|
||||||
|
if args.bigendian:
|
||||||
|
endianess = gguf.GGUFEndian.BIG
|
||||||
|
|
||||||
params = Params.load(model_plus)
|
params = Params.load(model_plus)
|
||||||
if params.n_ctx == -1:
|
if params.n_ctx == -1:
|
||||||
|
@ -1201,7 +1206,7 @@ def main(args_in: list[str] | None = None) -> None:
|
||||||
params.ftype = ftype
|
params.ftype = ftype
|
||||||
print(f"Writing {outfile}, format {ftype}")
|
print(f"Writing {outfile}, format {ftype}")
|
||||||
|
|
||||||
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency)
|
OutputFile.write_all(outfile, ftype, params, model, vocab, special_vocab, concurrency = args.concurrency, endianess=endianess)
|
||||||
print(f"Wrote {outfile}")
|
print(f"Wrote {outfile}")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -429,6 +429,11 @@ class GGMLQuantizationType(IntEnum):
|
||||||
Q6_K = 14
|
Q6_K = 14
|
||||||
Q8_K = 15
|
Q8_K = 15
|
||||||
|
|
||||||
|
class GGUFEndian(IntEnum):
|
||||||
|
LITTLE = 0
|
||||||
|
BIG = 1
|
||||||
|
|
||||||
|
|
||||||
class GGUFValueType(IntEnum):
|
class GGUFValueType(IntEnum):
|
||||||
UINT8 = 0
|
UINT8 = 0
|
||||||
INT8 = 1
|
INT8 = 1
|
||||||
|
@ -475,18 +480,41 @@ class GGUFWriter:
|
||||||
temp_file: tempfile.SpooledTemporaryFile[bytes] | None = None
|
temp_file: tempfile.SpooledTemporaryFile[bytes] | None = None
|
||||||
tensors: list[tuple[np.ndarray[Any, Any], int]]
|
tensors: list[tuple[np.ndarray[Any, Any], int]]
|
||||||
|
|
||||||
def __init__(self, path: os.PathLike[str] | str, arch: str, use_temp_file = True):
|
def get_pack_prefix(self):
|
||||||
|
if self.endianess==GGUFEndian.LITTLE:
|
||||||
|
return "<"
|
||||||
|
else:
|
||||||
|
return ">"
|
||||||
|
|
||||||
|
def __init__(self, path: os.PathLike[str] | str, arch: str, use_temp_file = True, endianess=GGUFEndian.LITTLE):
|
||||||
self.fout = open(path, "wb")
|
self.fout = open(path, "wb")
|
||||||
self.arch = arch
|
self.arch = arch
|
||||||
|
self.endianess = endianess
|
||||||
|
self._simple_value_packing = {
|
||||||
|
GGUFValueType.UINT8: f"{self.get_pack_prefix()}B",
|
||||||
|
GGUFValueType.INT8: f"{self.get_pack_prefix()}b",
|
||||||
|
GGUFValueType.UINT16: f"{self.get_pack_prefix()}H",
|
||||||
|
GGUFValueType.INT16: f"{self.get_pack_prefix()}h",
|
||||||
|
GGUFValueType.UINT32: f"{self.get_pack_prefix()}I",
|
||||||
|
GGUFValueType.INT32: f"{self.get_pack_prefix()}i",
|
||||||
|
GGUFValueType.FLOAT32: f"{self.get_pack_prefix()}f",
|
||||||
|
GGUFValueType.UINT64: f"{self.get_pack_prefix()}Q",
|
||||||
|
GGUFValueType.INT64: f"{self.get_pack_prefix()}q",
|
||||||
|
GGUFValueType.FLOAT64: f"{self.get_pack_prefix()}d",
|
||||||
|
GGUFValueType.BOOL: "?" ,
|
||||||
|
}
|
||||||
self.add_architecture()
|
self.add_architecture()
|
||||||
self.use_temp_file = use_temp_file
|
self.use_temp_file = use_temp_file
|
||||||
self.tensors = []
|
self.tensors = []
|
||||||
|
|
||||||
|
|
||||||
|
print(f"This gguf file is for {self.endianess} only")
|
||||||
|
|
||||||
def write_header_to_file(self):
|
def write_header_to_file(self):
|
||||||
self.fout.write(struct.pack(">I", GGUF_MAGIC))
|
self.fout.write(struct.pack(f"{self.get_pack_prefix()}I", GGUF_MAGIC))
|
||||||
self.fout.write(struct.pack(">I", GGUF_VERSION))
|
self.fout.write(struct.pack(f"{self.get_pack_prefix()}I", GGUF_VERSION))
|
||||||
self.fout.write(struct.pack(">Q", self.ti_data_count))
|
self.fout.write(struct.pack(f"{self.get_pack_prefix()}Q", self.ti_data_count))
|
||||||
self.fout.write(struct.pack(">Q", self.kv_data_count))
|
self.fout.write(struct.pack(f"{self.get_pack_prefix()}Q", self.kv_data_count))
|
||||||
self.flush()
|
self.flush()
|
||||||
# print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
|
# print("tensors " + str(self.ti_data_count) + " kv " + str(self.kv_data_count))
|
||||||
|
|
||||||
|
@ -558,25 +586,13 @@ class GGUFWriter:
|
||||||
self.add_key(key)
|
self.add_key(key)
|
||||||
self.add_val(val, GGUFValueType.ARRAY)
|
self.add_val(val, GGUFValueType.ARRAY)
|
||||||
|
|
||||||
_simple_value_packing = {
|
|
||||||
GGUFValueType.UINT8: f"{GGUF_ENDIANESS}B",
|
|
||||||
GGUFValueType.INT8: f"{GGUF_ENDIANESS.}b",
|
|
||||||
GGUFValueType.UINT16: f"{GGUF_ENDIANESS.get}H",
|
|
||||||
GGUFValueType.INT16: ">h",
|
|
||||||
GGUFValueType.UINT32: ">I",
|
|
||||||
GGUFValueType.INT32: ">i",
|
|
||||||
GGUFValueType.FLOAT32: ">f",
|
|
||||||
GGUFValueType.UINT64: ">Q",
|
|
||||||
GGUFValueType.INT64: ">q",
|
|
||||||
GGUFValueType.FLOAT64: ">d",
|
|
||||||
GGUFValueType.BOOL: "?" ,
|
|
||||||
}
|
|
||||||
def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True):
|
def add_val(self, val: Any, vtype: GGUFValueType | None = None, add_vtype: bool = True):
|
||||||
if vtype is None:
|
if vtype is None:
|
||||||
vtype = GGUFValueType.get_type(val)
|
vtype = GGUFValueType.get_type(val)
|
||||||
|
|
||||||
if add_vtype:
|
if add_vtype:
|
||||||
self.kv_data += struct.pack(">I", vtype)
|
self.kv_data += struct.pack(f"{self.get_pack_prefix()}I", vtype)
|
||||||
self.kv_data_count += 1
|
self.kv_data_count += 1
|
||||||
|
|
||||||
pack_fmt = self._simple_value_packing.get(vtype)
|
pack_fmt = self._simple_value_packing.get(vtype)
|
||||||
|
@ -584,14 +600,14 @@ class GGUFWriter:
|
||||||
self.kv_data += struct.pack(pack_fmt, val)
|
self.kv_data += struct.pack(pack_fmt, val)
|
||||||
elif vtype == GGUFValueType.STRING:
|
elif vtype == GGUFValueType.STRING:
|
||||||
encoded_val = val.encode("utf8") if isinstance(val, str) else val
|
encoded_val = val.encode("utf8") if isinstance(val, str) else val
|
||||||
self.kv_data += struct.pack(">Q", len(encoded_val))
|
self.kv_data += struct.pack(f"{self.get_pack_prefix()}Q", len(encoded_val))
|
||||||
self.kv_data += encoded_val
|
self.kv_data += encoded_val
|
||||||
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and len(val) > 0:
|
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and len(val) > 0:
|
||||||
ltype = GGUFValueType.get_type(val[0])
|
ltype = GGUFValueType.get_type(val[0])
|
||||||
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
|
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
|
||||||
raise ValueError("All items in a GGUF array should be of the same type")
|
raise ValueError("All items in a GGUF array should be of the same type")
|
||||||
self.kv_data += struct.pack(">I", ltype)
|
self.kv_data += struct.pack(f"{self.get_pack_prefix()}I", ltype)
|
||||||
self.kv_data += struct.pack(">Q", len(val))
|
self.kv_data += struct.pack(f"{self.get_pack_prefix()}Q", len(val))
|
||||||
for item in val:
|
for item in val:
|
||||||
self.add_val(item, add_vtype=False)
|
self.add_val(item, add_vtype=False)
|
||||||
else:
|
else:
|
||||||
|
@ -605,23 +621,24 @@ class GGUFWriter:
|
||||||
assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"
|
assert raw_dtype is not None or tensor_dtype in (np.float32, np.float16), "Only F32 and F16 tensors are supported for now"
|
||||||
|
|
||||||
encoded_name = name.encode("utf8")
|
encoded_name = name.encode("utf8")
|
||||||
self.ti_data += struct.pack(">Q", len(encoded_name))
|
self.ti_data += struct.pack(f"{self.get_pack_prefix()}Q", len(encoded_name))
|
||||||
self.ti_data += encoded_name
|
self.ti_data += encoded_name
|
||||||
n_dims = len(tensor_shape)
|
n_dims = len(tensor_shape)
|
||||||
self.ti_data += struct.pack(">I", n_dims)
|
self.ti_data += struct.pack(f"{self.get_pack_prefix()}I", n_dims)
|
||||||
for i in range(n_dims):
|
for i in range(n_dims):
|
||||||
self.ti_data += struct.pack(">Q", tensor_shape[n_dims - 1 - i])
|
self.ti_data += struct.pack(f"{self.get_pack_prefix()}Q", tensor_shape[n_dims - 1 - i])
|
||||||
if raw_dtype is None:
|
if raw_dtype is None:
|
||||||
dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
|
dtype = GGMLQuantizationType.F32 if tensor_dtype == np.float32 else GGMLQuantizationType.F16
|
||||||
else:
|
else:
|
||||||
dtype = raw_dtype
|
dtype = raw_dtype
|
||||||
self.ti_data += struct.pack(">I", dtype)
|
self.ti_data += struct.pack(f"{self.get_pack_prefix()}I", dtype)
|
||||||
self.ti_data += struct.pack(">Q", self.offset_tensor)
|
self.ti_data += struct.pack(f"{self.get_pack_prefix()}Q", self.offset_tensor)
|
||||||
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
|
self.offset_tensor += GGUFWriter.ggml_pad(tensor_nbytes, self.data_alignment)
|
||||||
self.ti_data_count += 1
|
self.ti_data_count += 1
|
||||||
|
|
||||||
def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None):
|
def add_tensor(self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, raw_dtype: GGMLQuantizationType | None = None):
|
||||||
tensor.byteswap(inplace=True)
|
if self.endianess == GGUFEndian.BIG:
|
||||||
|
tensor.byteswap(inplace=True)
|
||||||
if self.use_temp_file and self.temp_file is None:
|
if self.use_temp_file and self.temp_file is None:
|
||||||
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
|
fp = tempfile.SpooledTemporaryFile(mode="w+b", max_size=256*1024*1024)
|
||||||
fp.seek(0)
|
fp.seek(0)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue