101 lines
3.4 KiB
Python
101 lines
3.4 KiB
Python
import pandas as pd
|
|
import mysql.connector
|
|
|
|
def fetch_products_from_database():
|
|
# Connect to your MySQL database
|
|
connection = mysql.connector.connect(
|
|
host="localhost",
|
|
user="almexx_test",
|
|
password="!TpfNGIBU28G(TbW",
|
|
database="almexx_test",
|
|
unix_socket="/mnt/mysql/fatcow/mysql.sock"
|
|
)
|
|
|
|
query = "SELECT product_name, product_description FROM products"
|
|
df = pd.read_sql(query, connection)
|
|
|
|
# Close the database connection
|
|
connection.close()
|
|
|
|
return df
|
|
|
|
from transformers import BertTokenizer
|
|
|
|
# Initialize the BERT tokenizer
|
|
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
|
|
|
|
def tokenize_product_pair(product1, product2, tokenizer, max_length=128):
|
|
encoding = tokenizer(
|
|
product1, product2,
|
|
padding='max_length', # Pad to the max length
|
|
truncation=True, # Truncate if longer than max_length
|
|
max_length=max_length,
|
|
return_tensors='pt' # Return PyTorch tensors
|
|
)
|
|
|
|
# Ensure the input has a batch dimension (i.e., [batch_size, seq_length])
|
|
encoding = {key: value.unsqueeze(0) if value.dim() == 1 else value for key, value in encoding.items()}
|
|
|
|
return encoding
|
|
import torch
|
|
from transformers import BertForSequenceClassification
|
|
|
|
# Load the fine-tuned BERT model
|
|
model = BertForSequenceClassification.from_pretrained('./trained_model')
|
|
model.eval() # Set the model to evaluation mode
|
|
|
|
def compute_similarity(product1, product2, model, tokenizer):
|
|
# Tokenize the products
|
|
encoding = tokenize_product_pair(product1, product2, tokenizer)
|
|
|
|
# Pass the tokenized inputs to the model
|
|
with torch.no_grad(): # Disable gradient computation for inference
|
|
output = model(**encoding)
|
|
|
|
# Get the logits (raw predictions) from the output
|
|
logits = output.logits
|
|
|
|
# Apply softmax to get probabilities
|
|
probabilities = torch.nn.functional.softmax(logits, dim=-1)
|
|
|
|
# The probability of them being a match (1 for match, 0 for no match)
|
|
match_probability = probabilities[0, 1].item()
|
|
|
|
return match_probability
|
|
|
|
def find_best_match(new_product, database_df, model, tokenizer):
|
|
best_match = None
|
|
best_match_score = -1 # A low score for initialization
|
|
|
|
for index, row in database_df.iterrows():
|
|
product_name = row['product_name']
|
|
product_description = row['product_description']
|
|
product_pair = f"{product_name} {product_description}"
|
|
|
|
|
|
# Compute similarity score
|
|
match_probability = compute_similarity(new_product, product_pair, model, tokenizer)
|
|
print(f"i: {index} product: {product_pair} score: {match_probability}")
|
|
|
|
# If the similarity score is higher than the best score so far, update
|
|
if match_probability > best_match_score:
|
|
best_match_score = match_probability
|
|
best_match = row # Keep the whole row for further details
|
|
|
|
return best_match, best_match_score
|
|
|
|
# Example: New product you want to find a match for
|
|
new_product = "Jumbo Drinkyoghurt aardbei 0,5L"
|
|
|
|
# Fetch all products from the database
|
|
database_df = fetch_products_from_database()
|
|
|
|
# Find the best match
|
|
best_match, best_match_score = find_best_match(new_product, database_df, model, tokenizer)
|
|
|
|
# Display the result
|
|
if best_match is not None:
|
|
print(f"Best match for {new_product}: {best_match['product_name']} with score: {best_match_score:.2f}")
|
|
else:
|
|
print("No match found.")
|