Update train.py
This commit is contained in:
parent
07e8ef7951
commit
f2959f61ae
6
train.py
6
train.py
@ -82,7 +82,11 @@ for batch in train_dataloader:
|
||||
batch = {key: value.to(device) for key, value in batch.items()} # Move batch to GPU
|
||||
|
||||
# Train Model
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
import os
|
||||
if (os.path.isdir('./results')):
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
else:
|
||||
trainer.train()
|
||||
|
||||
# Save Model
|
||||
model.save_pretrained('./trained_model')
|
||||
|
Loading…
x
Reference in New Issue
Block a user