Second pass
This commit is contained in:
parent
f82aec99a4
commit
4a3d783d3e
5 changed files with 31 additions and 25 deletions
|
@ -34,8 +34,7 @@ def bytes_to_unicode():
|
||||||
bs.append(b)
|
bs.append(b)
|
||||||
cs.append(2**8+n)
|
cs.append(2**8+n)
|
||||||
n += 1
|
n += 1
|
||||||
cs = [chr(n) for n in cs]
|
return dict(zip(bs, (chr(n) for n in cs)))
|
||||||
return dict(zip(bs, cs))
|
|
||||||
|
|
||||||
|
|
||||||
def count_model_parts(dir_model: str) -> int:
|
def count_model_parts(dir_model: str) -> int:
|
||||||
|
|
|
@ -75,7 +75,7 @@ class Tensor:
|
||||||
self.dims = ()
|
self.dims = ()
|
||||||
self.dtype = None
|
self.dtype = None
|
||||||
self.start_offset = 0
|
self.start_offset = 0
|
||||||
self.len_bytes = 0
|
self.len_bytes = np.int64(0)
|
||||||
|
|
||||||
def load(self, data, offset):
|
def load(self, data, offset):
|
||||||
orig_offset = offset
|
orig_offset = offset
|
||||||
|
|
|
@ -60,7 +60,7 @@ def write_file_header(fout: BinaryIO, params: Dict[str, Any]) -> None:
|
||||||
|
|
||||||
|
|
||||||
def write_tensor_header(
|
def write_tensor_header(
|
||||||
self, name: str, shape: Sequence[int], data_type: np.dtype
|
self, name: str, shape: Sequence[int], data_type: np.dtype[Any]
|
||||||
) -> None:
|
) -> None:
|
||||||
sname = name.encode("utf-8")
|
sname = name.encode("utf-8")
|
||||||
fout.write(
|
fout.write(
|
||||||
|
|
27
convert.py
27
convert.py
|
@ -25,7 +25,7 @@ import numpy as np
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, TypeVar, Union)
|
from typing import (IO, TYPE_CHECKING, Any, Callable, Dict, Generator, Iterable, List, Literal, Optional, Sequence, Set, Tuple, Type, TypeVar, Union)
|
||||||
from sentencepiece import SentencePieceProcessor # type: ignore
|
from sentencepiece import SentencePieceProcessor # type: ignore
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -447,7 +447,7 @@ class Tensor(metaclass=ABCMeta):
|
||||||
def to_ggml(self) -> 'GGMLCompatibleTensor': ...
|
def to_ggml(self) -> 'GGMLCompatibleTensor': ...
|
||||||
|
|
||||||
|
|
||||||
def bf16_to_fp32(bf16_arr: np.ndarray) -> NDArray:
|
def bf16_to_fp32(bf16_arr: np.ndarray[Any, np.dtype[np.uint16]]) -> NDArray:
|
||||||
assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}"
|
assert bf16_arr.dtype == np.uint16, f"Input array should be of dtype uint16, but got {bf16_arr.dtype}"
|
||||||
fp32_arr = bf16_arr.astype(np.uint32) << 16
|
fp32_arr = bf16_arr.astype(np.uint32) << 16
|
||||||
return fp32_arr.view(np.float32)
|
return fp32_arr.view(np.float32)
|
||||||
|
@ -658,7 +658,7 @@ class LazyUnpickler(pickle.Unpickler):
|
||||||
description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}'
|
description = f'storage data_type={data_type} path-in-zip={filename} path={self.zip_file.filename}'
|
||||||
return LazyStorage(load=load, kind=pid[1], description=description)
|
return LazyStorage(load=load, kind=pid[1], description=description)
|
||||||
|
|
||||||
# @staticmethod
|
@staticmethod
|
||||||
def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any,
|
def lazy_rebuild_tensor_v2(storage: Any, storage_offset: Any, size: Any, stride: Any,
|
||||||
# pyright: ignore[reportSelfClsParameterName]
|
# pyright: ignore[reportSelfClsParameterName]
|
||||||
requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor:
|
requires_grad: Any, backward_hooks: Any, metadata: Any = None) -> LazyTensor:
|
||||||
|
@ -670,13 +670,15 @@ class LazyUnpickler(pickle.Unpickler):
|
||||||
description = f'pickled storage_offset={storage_offset} in {storage.description}'
|
description = f'pickled storage_offset={storage_offset} in {storage.description}'
|
||||||
return LazyTensor(load, list(size), storage.kind.data_type, description)
|
return LazyTensor(load, list(size), storage.kind.data_type, description)
|
||||||
|
|
||||||
# @staticmethod
|
@staticmethod
|
||||||
def rebuild_from_type_v2(func, new_type, args, state):
|
def rebuild_from_type_v2(func, new_type, args, state):
|
||||||
return func(*args)
|
return func(*args)
|
||||||
|
|
||||||
CLASSES: Dict[Any, Any] = {
|
CLASSES: Dict[Tuple[str, str], Any] = {
|
||||||
('torch._tensor', '_rebuild_from_type_v2'): rebuild_from_type_v2,
|
# getattr used here as a workaround for mypy not being smart enough to detrmine
|
||||||
('torch._utils', '_rebuild_tensor_v2'): lazy_rebuild_tensor_v2,
|
# the staticmethods have a __func__ attribute.
|
||||||
|
('torch._tensor', '_rebuild_from_type_v2'): getattr(rebuild_from_type_v2, '__func__'),
|
||||||
|
('torch._utils', '_rebuild_tensor_v2'): getattr(lazy_rebuild_tensor_v2, '__func__'),
|
||||||
('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16),
|
('torch', 'BFloat16Storage'): LazyStorageKind(DT_BF16),
|
||||||
('torch', 'HalfStorage'): LazyStorageKind(DT_F16),
|
('torch', 'HalfStorage'): LazyStorageKind(DT_F16),
|
||||||
('torch', 'FloatStorage'): LazyStorageKind(DT_F32),
|
('torch', 'FloatStorage'): LazyStorageKind(DT_F32),
|
||||||
|
@ -752,7 +754,7 @@ def lazy_load_file(path: Path) -> ModelPlus:
|
||||||
In = TypeVar('In')
|
In = TypeVar('In')
|
||||||
Out = TypeVar('Out')
|
Out = TypeVar('Out')
|
||||||
|
|
||||||
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, factory: Callable = ThreadPoolExecutor) -> Iterable[Out]:
|
def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], concurrency: int, max_workers: Optional[int] = None, use_processpool_executor: bool = False) -> Iterable[Out]:
|
||||||
'''Parallel map, but with backpressure. If the caller doesn't call `next`
|
'''Parallel map, but with backpressure. If the caller doesn't call `next`
|
||||||
fast enough, this will stop calling `func` at some point rather than
|
fast enough, this will stop calling `func` at some point rather than
|
||||||
letting results pile up in memory. Specifically, there is a max of one
|
letting results pile up in memory. Specifically, there is a max of one
|
||||||
|
@ -761,7 +763,12 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
|
||||||
yield from map(func, iterable)
|
yield from map(func, iterable)
|
||||||
# Not reached.
|
# Not reached.
|
||||||
iterable = iter(iterable)
|
iterable = iter(iterable)
|
||||||
with factory(max_workers = max_workers) as executor:
|
executor_class: Union[Type[ThreadPoolExecutor], Type[ProcessPoolExecutor]]
|
||||||
|
if use_processpool_executor:
|
||||||
|
executor_class = ProcessPoolExecutor
|
||||||
|
else:
|
||||||
|
executor_class = ThreadPoolExecutor
|
||||||
|
with executor_class(max_workers = max_workers) as executor:
|
||||||
futures: List[concurrent.futures.Future[Out]] = []
|
futures: List[concurrent.futures.Future[Out]] = []
|
||||||
done = False
|
done = False
|
||||||
for _ in range(concurrency):
|
for _ in range(concurrency):
|
||||||
|
@ -913,7 +920,7 @@ class OutputFile:
|
||||||
# tensor data
|
# tensor data
|
||||||
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
|
ndarrays_inner = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
|
||||||
if ftype == GGMLFileType.MostlyQ8_0:
|
if ftype == GGMLFileType.MostlyQ8_0:
|
||||||
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor)
|
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays_inner, concurrency = concurrency, max_workers = concurrency, use_processpool_executor = True)
|
||||||
else:
|
else:
|
||||||
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
|
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays_inner)
|
||||||
|
|
||||||
|
|
|
@ -8,7 +8,6 @@ import json
|
||||||
import os
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import collections.abc as collections_abc
|
|
||||||
from enum import IntEnum, auto
|
from enum import IntEnum, auto
|
||||||
from typing import Any, BinaryIO, Callable, IO, Dict, List, Optional, Sequence, Tuple, Union
|
from typing import Any, BinaryIO, Callable, IO, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
|
@ -512,7 +511,7 @@ class GGUFWriter:
|
||||||
self.add_val(val, GGUFValueType.STRING)
|
self.add_val(val, GGUFValueType.STRING)
|
||||||
|
|
||||||
def add_array(self, key: str, val: Sequence[Any]):
|
def add_array(self, key: str, val: Sequence[Any]):
|
||||||
if not isinstance(val, collections_abc.Sequence):
|
if not isinstance(val, Sequence):
|
||||||
raise ValueError("Value must be a sequence for array type")
|
raise ValueError("Value must be a sequence for array type")
|
||||||
|
|
||||||
self.add_key(key)
|
self.add_key(key)
|
||||||
|
@ -546,15 +545,16 @@ class GGUFWriter:
|
||||||
encoded_val = val.encode("utf8") if isinstance(val, str) else val
|
encoded_val = val.encode("utf8") if isinstance(val, str) else val
|
||||||
self.kv_data += struct.pack("<Q", len(encoded_val))
|
self.kv_data += struct.pack("<Q", len(encoded_val))
|
||||||
self.kv_data += encoded_val
|
self.kv_data += encoded_val
|
||||||
elif vtype == GGUFValueType.ARRAY:
|
elif vtype == GGUFValueType.ARRAY and isinstance(val, Sequence) and len(val) > 0:
|
||||||
ltype = set([GGUFValueType.get_type(item) for item in val])
|
ltype = GGUFValueType.get_type(val[0])
|
||||||
assert len(ltype) == 1, "All items in a GGUF array should be of the same type"
|
if not all(GGUFValueType.get_type(i) is ltype for i in val[1:]):
|
||||||
self.kv_data += struct.pack("<I", list(ltype)[0])
|
raise ValueError("All items in a GGUF array should be of the same type")
|
||||||
|
self.kv_data += struct.pack("<I", ltype)
|
||||||
self.kv_data += struct.pack("<Q", len(val))
|
self.kv_data += struct.pack("<Q", len(val))
|
||||||
for item in val:
|
for item in val:
|
||||||
self.add_val(item, add_vtype=False)
|
self.add_val(item, add_vtype=False)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Invalid GGUF metadata value type")
|
raise ValueError("Invalid GGUF metadata value type or value")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def ggml_pad(x: int, n: int) -> int:
|
def ggml_pad(x: int, n: int) -> int:
|
||||||
|
@ -614,7 +614,7 @@ class GGUFWriter:
|
||||||
|
|
||||||
self.write_padding(self.fout, self.fout.tell())
|
self.write_padding(self.fout, self.fout.tell())
|
||||||
|
|
||||||
if not self.use_temp_file:
|
if self.temp_file is None:
|
||||||
for (currtensor, currpad) in self.tensors:
|
for (currtensor, currpad) in self.tensors:
|
||||||
currtensor.tofile(self.fout)
|
currtensor.tofile(self.fout)
|
||||||
if currpad != 0:
|
if currpad != 0:
|
||||||
|
@ -808,14 +808,14 @@ class SpecialVocab:
|
||||||
def add_to_gguf(self, gw: GGUFWriter):
|
def add_to_gguf(self, gw: GGUFWriter):
|
||||||
# FIXME: Don't always include merges (possibly also don't even load them).
|
# FIXME: Don't always include merges (possibly also don't even load them).
|
||||||
if len(self.merges) > 0:
|
if len(self.merges) > 0:
|
||||||
print(f'SpecialVocab: Adding {len(self.merges)} merge(s).')
|
print(f'gguf: Adding {len(self.merges)} merge(s).')
|
||||||
gw.add_token_merges(self.merges)
|
gw.add_token_merges(self.merges)
|
||||||
for typ, tokid in self.special_token_ids.items():
|
for typ, tokid in self.special_token_ids.items():
|
||||||
handler: Optional[Callable[[int], None]] = getattr(gw, f'add_{typ}_token_id', None)
|
handler: Optional[Callable[[int], None]] = getattr(gw, f'add_{typ}_token_id', None)
|
||||||
if handler is None:
|
if handler is None:
|
||||||
print(f'SpecialVocab: WARNING: No handler for special token type {typ} with id {tokid} - skipping')
|
print(f'gguf: WARNING: No handler for special token type {typ} with id {tokid} - skipping')
|
||||||
continue
|
continue
|
||||||
print(f'SpecialVocab: Setting special token type {typ} to {tokid}')
|
print(f'gguf: Setting special token type {typ} to {tokid}')
|
||||||
handler(tokid)
|
handler(tokid)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue