Update train.py

This commit is contained in:
thijn 2025-04-01 09:41:06 +02:00
parent 07e8ef7951
commit f2959f61ae

View File

@ -82,7 +82,11 @@ for batch in train_dataloader:
batch = {key: value.to(device) for key, value in batch.items()} # Move batch to GPU
# Train Model
trainer.train(resume_from_checkpoint=True)
import os
if (os.path.isdir('./results')):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
# Save Model
model.save_pretrained('./trained_model')