sa_ai_training/train.py
2025-04-01 11:06:12 +02:00

60 lines
1.8 KiB
Python

import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
# Check if CUDA (GPU support) is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Load dataset in streaming mode (no memory overload)
file_path = 'SnoepAll.csv'
dataset = load_dataset('csv', data_files=file_path, split='train', streaming=True)
# Tokenization function for Hugging Face dataset
def tokenize_function(example):
encoding = tokenizer(
example['product_name_1'], example['product_name_2'],
padding='max_length',
truncation=True,
max_length=128
)
encoding['labels'] = int(example['label']) # Convert label to integer
return encoding
# Apply tokenization to streamed dataset
dataset = dataset.map(tokenize_function, remove_columns=['product_name_1', 'product_name_2', 'label'])
# DataLoader automatically handled by Trainer with streaming dataset
# Load pre-trained BERT model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2).to(device)
# Define Training Arguments
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=16,
max_steps=200000, # Define a step-based stopping point
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
save_total_limit=5,
fp16=True # Enable Mixed Precision Training for GPU
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset
)
# Train Model
trainer.train()
# Save Model
model.save_pretrained('./trained_model')
tokenizer.save_pretrained('./trained_model')