From f2959f61ae7c5fbc142268a296c9c6f9f8ebd918 Mon Sep 17 00:00:00 2001 From: thijn Date: Tue, 1 Apr 2025 09:41:06 +0200 Subject: [PATCH] Update train.py --- train.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/train.py b/train.py index b62c493..ba4ce00 100644 --- a/train.py +++ b/train.py @@ -82,7 +82,11 @@ for batch in train_dataloader: batch = {key: value.to(device) for key, value in batch.items()} # Move batch to GPU # Train Model -trainer.train(resume_from_checkpoint=True) +import os +if (os.path.isdir('./results')): + trainer.train(resume_from_checkpoint=True) +else: + trainer.train() # Save Model model.save_pretrained('./trained_model')