Update train.py
This commit is contained in:
parent
07e8ef7951
commit
f2959f61ae
4
train.py
4
train.py
@ -82,7 +82,11 @@ for batch in train_dataloader:
|
|||||||
batch = {key: value.to(device) for key, value in batch.items()} # Move batch to GPU
|
batch = {key: value.to(device) for key, value in batch.items()} # Move batch to GPU
|
||||||
|
|
||||||
# Train Model
|
# Train Model
|
||||||
|
import os
|
||||||
|
if (os.path.isdir('./results')):
|
||||||
trainer.train(resume_from_checkpoint=True)
|
trainer.train(resume_from_checkpoint=True)
|
||||||
|
else:
|
||||||
|
trainer.train()
|
||||||
|
|
||||||
# Save Model
|
# Save Model
|
||||||
model.save_pretrained('./trained_model')
|
model.save_pretrained('./trained_model')
|
||||||
|
Loading…
x
Reference in New Issue
Block a user