adds test file for gpt2

This commit is contained in:
EC2 Default User 2023-12-13 13:54:49 +00:00
parent 8a7b2fa528
commit d6a9242df3

9
tests/test_gpt2.py Normal file
View file

@ -0,0 +1,9 @@
from transformers import set_seed
from transformers import AutoModelForCausalLM, pipeline, AutoTokenizer
model = AutoModelForCausalLM.from_pretrained("gpt2", cache_dir="models")
tokenizer = AutoTokenizer.from_pretrained("gpt2", cache_dir="models")
generator = pipeline('text-generation', model=model, tokenizer=tokenizer)
set_seed(42)
print(generator("Hello, I'm a language model,", max_length=30, num_return_sequences=5))