Review fixes, persimmon fixes

This commit is contained in:
Galunid 2023-11-01 02:32:49 +01:00
parent 3ec89dcc69
commit 4fdd7cdf2b
2 changed files with 12 additions and 6 deletions

View file

@ -2,7 +2,6 @@ import os
import re import re
import sys import sys
import json import json
import gguf
import torch import torch
import contextlib import contextlib
import numpy as np import numpy as np
@ -11,6 +10,12 @@ from enum import IntEnum
from pathlib import Path from pathlib import Path
from typing import TypeAlias, Any, Generator from typing import TypeAlias, Any, Generator
if 'NO_LOCAL_GGUF' not in os.environ:
sys.path.insert(1, str(Path(__file__).parent / 'gguf-py' / 'gguf'))
import gguf
NDArray: TypeAlias = 'np.ndarray[Any, Any]' NDArray: TypeAlias = 'np.ndarray[Any, Any]'
@ -160,7 +165,7 @@ class Model:
def set_vocab(self): def set_vocab(self):
self._set_vocab_gpt2() self._set_vocab_gpt2()
def get_tensors(self) -> Generator[str, Any]: def get_tensors(self) -> Generator[str, Any, None]:
for part_name in self.part_names: for part_name in self.part_names:
print("gguf: loading model part '" + part_name + "'") print("gguf: loading model part '" + part_name + "'")
if self.is_safetensors: if self.is_safetensors:
@ -789,12 +794,13 @@ class PersimmonModel(Model):
self.gguf_writer.add_name('persimmon-8b-chat') self.gguf_writer.add_name('persimmon-8b-chat')
self.gguf_writer.add_embedding_length(hidden_size) self.gguf_writer.add_embedding_length(hidden_size)
self.gguf_writer.add_block_count(block_count) self.gguf_writer.add_block_count(block_count)
self.gguf_writer.add_feed_forward_length(self.hparams["ffn_hidden_size"]) self.gguf_writer.add_feed_forward_length(self.hparams["intermediate_size"])
self.gguf_writer.add_rope_dimension_count(hidden_size // head_count) self.gguf_writer.add_rope_dimension_count(hidden_size // head_count)
self.gguf_writer.add_head_count(head_count) self.gguf_writer.add_head_count(head_count)
self.gguf_writer.add_head_count_kv(head_count_kv) self.gguf_writer.add_head_count_kv(head_count_kv)
self.gguf_writer.add_rope_freq_base(self.hparams["rotary_emb_base"]) self.gguf_writer.add_rope_freq_base(self.hparams["rope_theta"])
self.gguf_writer.add_layer_norm_eps(self.hparams["layernorm_epsilon"]) self.gguf_writer.add_layer_norm_eps(self.hparams["layer_norm_eps"])
self.gguf_writer.add_layer_norm_rms_eps(self.hparams["rms_norm_eps"])
def set_vocab(self): def set_vocab(self):
self._set_vocab_sentencepiece() self._set_vocab_sentencepiece()

View file

@ -3,7 +3,7 @@ import argparse
from pathlib import Path from pathlib import Path
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Convert a stablelm model to a GGML compatible file") parser = argparse.ArgumentParser(description="Convert a huggingface model to a GGML compatible file")
parser.add_argument( parser.add_argument(
"--vocab-only", action="store_true", "--vocab-only", action="store_true",
help="extract only the vocab", help="extract only the vocab",