diff --git a/train.py b/train.py index b62c493..ba4ce00 100644 --- a/train.py +++ b/train.py @@ -82,7 +82,11 @@ for batch in train_dataloader: batch = {key: value.to(device) for key, value in batch.items()} # Move batch to GPU # Train Model -trainer.train(resume_from_checkpoint=True) +import os +if (os.path.isdir('./results')): + trainer.train(resume_from_checkpoint=True) +else: + trainer.train() # Save Model model.save_pretrained('./trained_model')