Upload files to "/"

This commit is contained in:
thijn 2025-03-31 21:40:15 +02:00
commit 67f9ee48a8
5 changed files with 150684 additions and 0 deletions

149197
Snoep.csv Normal file

File diff suppressed because it is too large Load Diff

1293
Zuivel-200.csv Normal file

File diff suppressed because it is too large Load Diff

5
requirements.txt Normal file
View File

@ -0,0 +1,5 @@
mysql_connector_repackaged
pandas
torch
transformers
transformers[torch]

89
train.py Normal file
View File

@ -0,0 +1,89 @@
import pandas as pd
import torch
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
from torch.utils.data import DataLoader
# Check if CUDA (GPU support) is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
# Load CSV data into a pandas DataFrame
def load_data_from_csv(file_path):
return pd.read_csv(file_path)
# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# Dataset Class
class ProductMatchDataset(torch.utils.data.Dataset):
def __init__(self, product_pairs, labels, tokenizer, max_length=128):
self.product_pairs = product_pairs
self.labels = labels
self.tokenizer = tokenizer
self.max_length = max_length
def __len__(self):
return len(self.product_pairs)
def __getitem__(self, idx):
product1, product2 = self.product_pairs[idx]
# Tokenize (keep tensors on CPU)
encoding = self.tokenizer(
product1, product2,
padding='max_length',
truncation=True,
max_length=self.max_length,
return_tensors='pt'
)
encoding = {key: value.squeeze(0) for key, value in encoding.items()}
encoding['labels'] = torch.tensor(self.labels[idx], dtype=torch.long)
return encoding
# Function to prepare data
def prepare_data_for_finetuning(file_path):
df = load_data_from_csv(file_path)
product_pairs = list(zip(df['product_name_1'], df['product_name_2']))
labels = df['label'].tolist()
return ProductMatchDataset(product_pairs, labels, tokenizer)
# Load dataset
file_path = 'Snoep.csv'
dataset = prepare_data_for_finetuning(file_path)
# Create DataLoader
train_dataloader = DataLoader(dataset, batch_size=16, shuffle=True, num_workers=4, pin_memory=True)
# Load pre-trained BERT model
model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2).to(device)
# Define Training Arguments
training_args = TrainingArguments(
output_dir='./results',
num_train_epochs=3,
per_device_train_batch_size=16,
warmup_steps=500,
weight_decay=0.01,
logging_dir='./logs',
fp16=True # Enable Mixed Precision Training for GPU
)
# Initialize Trainer
trainer = Trainer(
model=model,
args=training_args,
train_dataset=dataset
)
# Move batches to GPU inside training loop
for batch in train_dataloader:
batch = {key: value.to(device) for key, value in batch.items()} # Move batch to GPU
# Train Model
trainer.train(resume_from_checkpoint=True)
# Save Model
model.save_pretrained('./trained_model')
tokenizer.save_pretrained('./trained_model')

100
use.py Normal file
View File

@ -0,0 +1,100 @@
import pandas as pd
import mysql.connector
def fetch_products_from_database():
# Connect to your MySQL database
connection = mysql.connector.connect(
host="localhost",
user="almexx_test",
password="!TpfNGIBU28G(TbW",
database="almexx_test",
unix_socket="/mnt/mysql/fatcow/mysql.sock"
)
query = "SELECT product_name, product_description FROM products"
df = pd.read_sql(query, connection)
# Close the database connection
connection.close()
return df
from transformers import BertTokenizer
# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
def tokenize_product_pair(product1, product2, tokenizer, max_length=128):
encoding = tokenizer(
product1, product2,
padding='max_length', # Pad to the max length
truncation=True, # Truncate if longer than max_length
max_length=max_length,
return_tensors='pt' # Return PyTorch tensors
)
# Ensure the input has a batch dimension (i.e., [batch_size, seq_length])
encoding = {key: value.unsqueeze(0) if value.dim() == 1 else value for key, value in encoding.items()}
return encoding
import torch
from transformers import BertForSequenceClassification
# Load the fine-tuned BERT model
model = BertForSequenceClassification.from_pretrained('./trained_model')
model.eval() # Set the model to evaluation mode
def compute_similarity(product1, product2, model, tokenizer):
# Tokenize the products
encoding = tokenize_product_pair(product1, product2, tokenizer)
# Pass the tokenized inputs to the model
with torch.no_grad(): # Disable gradient computation for inference
output = model(**encoding)
# Get the logits (raw predictions) from the output
logits = output.logits
# Apply softmax to get probabilities
probabilities = torch.nn.functional.softmax(logits, dim=-1)
# The probability of them being a match (1 for match, 0 for no match)
match_probability = probabilities[0, 1].item()
return match_probability
def find_best_match(new_product, database_df, model, tokenizer):
best_match = None
best_match_score = -1 # A low score for initialization
for index, row in database_df.iterrows():
product_name = row['product_name']
product_description = row['product_description']
product_pair = f"{product_name} {product_description}"
# Compute similarity score
match_probability = compute_similarity(new_product, product_pair, model, tokenizer)
print(f"i: {index} product: {product_pair} score: {match_probability}")
# If the similarity score is higher than the best score so far, update
if match_probability > best_match_score:
best_match_score = match_probability
best_match = row # Keep the whole row for further details
return best_match, best_match_score
# Example: New product you want to find a match for
new_product = "Jumbo Drinkyoghurt aardbei 0,5L"
# Fetch all products from the database
database_df = fetch_products_from_database()
# Find the best match
best_match, best_match_score = find_best_match(new_product, database_df, model, tokenizer)
# Display the result
if best_match is not None:
print(f"Best match for {new_product}: {best_match['product_name']} with score: {best_match_score:.2f}")
else:
print("No match found.")