Update train.py
This commit is contained in:
parent
f2959f61ae
commit
e3e5deb32f
74
train.py
74
train.py
@ -1,60 +1,33 @@
|
||||
import pandas as pd
|
||||
import torch
|
||||
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
|
||||
from torch.utils.data import DataLoader
|
||||
from datasets import load_dataset
|
||||
|
||||
# Check if CUDA (GPU support) is available
|
||||
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
print(f"Using device: {device}")
|
||||
|
||||
# Load CSV data into a pandas DataFrame
|
||||
def load_data_from_csv(file_path):
|
||||
return pd.read_csv(file_path)
|
||||
|
||||
# Initialize the BERT tokenizer
|
||||
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
||||
|
||||
# Dataset Class
|
||||
class ProductMatchDataset(torch.utils.data.Dataset):
|
||||
def __init__(self, product_pairs, labels, tokenizer, max_length=128):
|
||||
self.product_pairs = product_pairs
|
||||
self.labels = labels
|
||||
self.tokenizer = tokenizer
|
||||
self.max_length = max_length
|
||||
|
||||
def __len__(self):
|
||||
return len(self.product_pairs)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
product1, product2 = self.product_pairs[idx]
|
||||
|
||||
# Tokenize (keep tensors on CPU)
|
||||
encoding = self.tokenizer(
|
||||
product1, product2,
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=self.max_length,
|
||||
return_tensors='pt'
|
||||
)
|
||||
|
||||
encoding = {key: value.squeeze(0) for key, value in encoding.items()}
|
||||
encoding['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
|
||||
|
||||
return encoding
|
||||
# Load dataset in streaming mode (no memory overload)
|
||||
file_path = 'SnoepAll.csv'
|
||||
dataset = load_dataset('csv', data_files=file_path, split='train', streaming=True)
|
||||
|
||||
# Function to prepare data
|
||||
def prepare_data_for_finetuning(file_path):
|
||||
df = load_data_from_csv(file_path)
|
||||
product_pairs = list(zip(df['product_name_1'], df['product_name_2']))
|
||||
labels = df['label'].tolist()
|
||||
return ProductMatchDataset(product_pairs, labels, tokenizer)
|
||||
# Tokenization function for Hugging Face dataset
|
||||
def tokenize_function(example):
|
||||
encoding = tokenizer(
|
||||
example['product_name_1'], example['product_name_2'],
|
||||
padding='max_length',
|
||||
truncation=True,
|
||||
max_length=128
|
||||
)
|
||||
encoding['labels'] = int(example['label']) # Convert label to integer
|
||||
return encoding
|
||||
|
||||
# Load dataset
|
||||
file_path = 'Snoep.csv'
|
||||
dataset = prepare_data_for_finetuning(file_path)
|
||||
# Apply tokenization to streamed dataset
|
||||
dataset = dataset.map(tokenize_function, remove_columns=['product_name_1', 'product_name_2', 'label'])
|
||||
|
||||
# Create DataLoader
|
||||
train_dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
|
||||
# DataLoader automatically handled by Trainer with streaming dataset
|
||||
|
||||
# Load pre-trained BERT model
|
||||
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2).to(device)
|
||||
@ -64,6 +37,7 @@ training_args = TrainingArguments(
|
||||
output_dir='./results',
|
||||
num_train_epochs=3,
|
||||
per_device_train_batch_size=16,
|
||||
max_steps=200000, # Define a step-based stopping point
|
||||
warmup_steps=500,
|
||||
weight_decay=0.01,
|
||||
logging_dir='./logs',
|
||||
@ -77,17 +51,9 @@ trainer = Trainer(
|
||||
train_dataset=dataset
|
||||
)
|
||||
|
||||
# Move batches to GPU inside training loop
|
||||
for batch in train_dataloader:
|
||||
batch = {key: value.to(device) for key, value in batch.items()} # Move batch to GPU
|
||||
|
||||
# Train Model
|
||||
import os
|
||||
if (os.path.isdir('./results')):
|
||||
trainer.train(resume_from_checkpoint=True)
|
||||
else:
|
||||
trainer.train()
|
||||
trainer.train()
|
||||
|
||||
# Save Model
|
||||
model.save_pretrained('./trained_model')
|
||||
tokenizer.save_pretrained('./trained_model')
|
||||
tokenizer.save_pretrained('./trained_model')
|
Loading…
x
Reference in New Issue
Block a user