Add --concurrency option

Minor improvements to help text

Clean up bounded_parallel_map function a bit
This commit is contained in:
KerfuffleV2 2023-08-23 19:36:55 -06:00
parent f54daa0735
commit 3efcbb8f59

View file

@ -39,6 +39,7 @@ NDArray: 'TypeAlias' = 'np.ndarray[Any, Any]'
ARCH=gguf.MODEL_ARCH.LLAMA
NAMES=gguf.MODEL_TENSOR_NAMES[ARCH]
DEFAULT_CONCURRENCY = 8
#
# data types
#
@ -722,21 +723,21 @@ def bounded_parallel_map(func: Callable[[In], Out], iterable: Iterable[In], conc
with factory(max_workers = max_workers) as executor:
futures: List[concurrent.futures.Future[Out]] = []
done = False
for i in range(concurrency):
for _ in range(concurrency):
try:
nexti = next(iterable)
futures.append(executor.submit(func, next(iterable)))
except StopIteration:
done = True
break
futures.append(executor.submit(func, nexti))
while not done or futures:
while futures:
result = futures.pop(0).result()
while len(futures) < concurrency:
while not done and len(futures) < concurrency:
try:
nexti = next(iterable)
futures.append(executor.submit(func, next(iterable)))
except StopIteration:
done = True
break
futures.append(executor.submit(func, nexti))
yield result
def check_vocab_size(params: Params, vocab: Vocab) -> None:
@ -857,13 +858,13 @@ class OutputFile:
return (lazy_tensor.data_type, tensor.ndarray)
@staticmethod
def maybe_do_quant(item: Tuple[DataType, NDArray]) -> NDArray:
def maybe_do_quantize(item: Tuple[DataType, NDArray]) -> NDArray:
if item[0] == DT_Q8_0:
return quantize_array_q8_0(item[1])
return item[1]
@staticmethod
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab) -> None:
def write_all(fname_out: Path, ftype: GGMLFileType, params: Params, model: LazyModel, vocab: Vocab, concurrency: int = DEFAULT_CONCURRENCY) -> None:
check_vocab_size(params, vocab)
of = OutputFile(fname_out)
@ -880,11 +881,11 @@ class OutputFile:
of.write_tensor_info()
# tensor data
ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = 8)
ndarrays = bounded_parallel_map(OutputFile.do_item, model.items(), concurrency = concurrency)
if ftype == GGMLFileType.MostlyQ8_0:
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quant, ndarrays, concurrency = 8, max_workers = 8, factory = ProcessPoolExecutor)
ndarrays = bounded_parallel_map(OutputFile.maybe_do_quantize, ndarrays, concurrency = concurrency, max_workers = concurrency, factory = ProcessPoolExecutor)
else:
ndarrays = map(OutputFile.maybe_do_quant, ndarrays)
ndarrays = map(OutputFile.maybe_do_quantize, ndarrays)
start = time.time()
for i, ((name, lazy_tensor), ndarray) in enumerate(zip(model.items(), ndarrays)):
@ -1080,12 +1081,13 @@ def main(args_in: Optional[List[str]] = None) -> None:
parser.add_argument("--dump", action="store_true", help="don't convert, just show what's in the model")
parser.add_argument("--dump-single", action="store_true", help="don't convert, just show what's in a single model file")
parser.add_argument("--vocab-only", action="store_true", help="extract only the vocab")
parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format (default: based on input)")
parser.add_argument("--outtype", choices=["f32", "f16", "q8_0"], help="output format - note: q8_0 may be very slow (default: f16 or f32 based on input)")
parser.add_argument("--vocab-dir", type=Path, help="directory containing tokenizer.model, if separate from model file")
parser.add_argument("--outfile", type=Path, help="path to write to; default: based on input")
parser.add_argument("model", type=Path, help="directory containing model file, or model file itself (*.pth, *.pt, *.bin)")
parser.add_argument("--vocabtype", choices=["spm", "bpe"], help="vocab format (default: spm)", default="spm")
parser.add_argument("--ctx", type=int, help="model training context (default: based on input)")
parser.add_argument("--concurrency", type=int, help=f"concurrency used for conversion (default: {DEFAULT_CONCURRENCY})", default = DEFAULT_CONCURRENCY)
args = parser.parse_args(args_in)
if args.dump_single:
@ -1139,7 +1141,7 @@ def main(args_in: Optional[List[str]] = None) -> None:
params.ftype = ftype
print(f"Writing {outfile}, format {ftype}")
OutputFile.write_all(outfile, ftype, params, model, vocab)
OutputFile.write_all(outfile, ftype, params, model, vocab, concurrency = args.concurrency)
print(f"Wrote {outfile}")