Update train.py

This commit is contained in:
thijn 2025-04-01 10:18:39 +02:00
parent f2959f61ae
commit e3e5deb32f

View File

@ -1,60 +1,33 @@
import pandas as pd
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import DataLoader
from datasets import load_dataset
# Check if CUDA (GPU support) is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load CSV data into a pandas DataFrame
def load_data_from_csv(file_path):
return pd.read_csv(file_path)
# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Dataset Class
class ProductMatchDataset(torch.utils.data.Dataset):
def __init__(self, product_pairs, labels, tokenizer, max_length=128):
self.product_pairs = product_pairs
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.product_pairs)
def __getitem__(self, idx):
product1, product2 = self.product_pairs[idx]
# Tokenize (keep tensors on CPU)
encoding = self.tokenizer(
product1, product2,
padding='max_length',
truncation=True,
max_length=self.max_length,
return_tensors='pt'
)
encoding = {key: value.squeeze(0) for key, value in encoding.items()}
encoding['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
return encoding
# Load dataset in streaming mode (no memory overload)
file_path = 'SnoepAll.csv'
dataset = load_dataset('csv', data_files=file_path, split='train', streaming=True)
# Function to prepare data
def prepare_data_for_finetuning(file_path):
df = load_data_from_csv(file_path)
product_pairs = list(zip(df['product_name_1'], df['product_name_2']))
labels = df['label'].tolist()
return ProductMatchDataset(product_pairs, labels, tokenizer)
# Tokenization function for Hugging Face dataset
def tokenize_function(example):
encoding = tokenizer(
example['product_name_1'], example['product_name_2'],
padding='max_length',
truncation=True,
max_length=128
)
encoding['labels'] = int(example['label']) # Convert label to integer
return encoding
# Load dataset
file_path = 'Snoep.csv'
dataset = prepare_data_for_finetuning(file_path)
# Apply tokenization to streamed dataset
dataset = dataset.map(tokenize_function, remove_columns=['product_name_1', 'product_name_2', 'label'])
# Create DataLoader
train_dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
# DataLoader automatically handled by Trainer with streaming dataset
# Load pre-trained BERT model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2).to(device)
@ -64,6 +37,7 @@ training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=16,
max_steps=200000, # Define a step-based stopping point
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
@ -77,17 +51,9 @@ trainer = Trainer(
train_dataset=dataset
)
# Move batches to GPU inside training loop
for batch in train_dataloader:
batch = {key: value.to(device) for key, value in batch.items()} # Move batch to GPU
# Train Model
import os
if (os.path.isdir('./results')):
trainer.train(resume_from_checkpoint=True)
else:
trainer.train()
trainer.train()
# Save Model
model.save_pretrained('./trained_model')
tokenizer.save_pretrained('./trained_model')
tokenizer.save_pretrained('./trained_model')