refactor: Apply huggingface_hub api to CLI

This commit is contained in:
teleprint-me 2024-05-25 04:16:10 -04:00
parent 63c3410492
commit 6c1b0111a1
No known key found for this signature in database
GPG key ID: B0D11345E65C4D48

View file

@ -31,8 +31,7 @@ if (
): ):
sys.path.insert(0, str(Path(__file__).parent.parent)) sys.path.insert(0, str(Path(__file__).parent.parent))
from gguf.constants import MODEL_ARCH, MODEL_ARCH_NAMES from gguf.huggingface_hub import HFHubModel, HFHubTokenizer
from gguf.huggingface_hub import HFHub, HFTokenizer
logger = logging.getLogger(Path(__file__).stem) logger = logging.getLogger(Path(__file__).stem)
@ -47,17 +46,16 @@ def main():
"-v", "--verbose", action="store_true", help="Increase output verbosity." "-v", "--verbose", action="store_true", help="Increase output verbosity."
) )
parser.add_argument( parser.add_argument(
"-m",
"--model-path", "--model-path",
default="models", default="models",
help="The models storage path. Default is 'models/'.", help="The models storage path. Default is 'models/'.",
) )
parser.add_argument( parser.add_argument(
"--vocab-type", "--vocab-name",
const="BPE", const="BPE",
nargs="?", nargs="?",
choices=["SPM", "BPE", "WPM"], choices=["SPM", "BPE", "WPM"],
help="The type of vocab. Default is 'BPE'.", help="The name of the vocab type. Default is 'BPE'.",
) )
args = parser.parse_args() args = parser.parse_args()
@ -66,11 +64,25 @@ def main():
else: else:
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
vocab_request = HFModel(args.auth_token, args.model_path, logger) hub_model = HFHubModel(
vocab_type = HFTokenizer.get_vocab_enum(args.vocab_type) auth_token=args.auth_token,
tokenizer = vocab_request.tokenizer model_path=args.model_path,
vocab_request.get_all_vocab_files(args.model_repo, vocab_type) logger=logger,
tokenizer.log_tokenizer_json_info(args.model_repo) )
hub_tokenizer = HFHubTokenizer(
model_path=args.model_path,
logger=logger,
)
vocab_type = HFHubTokenizer.get_vocab_type(args.vocab_name)
hub_model.download_all_vocab_files(
model_repo=args.model_repo,
vocab_type=vocab_type,
)
hub_model.download_all_vocab_files(args.model_repo, vocab_type)
hub_tokenizer.log_tokenizer_json_info(args.model_repo)
if __name__ == "__main__": if __name__ == "__main__":