import pandas as pd
import mysql.connector

def fetch_products_from_database():
    # Connect to your MySQL database
    connection = mysql.connector.connect(
        host="localhost",
        user="almexx_test",
        password="!TpfNGIBU28G(TbW",
        database="almexx_test",
        unix_socket="/mnt/mysql/fatcow/mysql.sock"
    )

    query = "SELECT product_name, product_description FROM products"
    df = pd.read_sql(query, connection)

    # Close the database connection
    connection.close()

    return df

from transformers import BertTokenizer

# Initialize the BERT tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

def tokenize_product_pair(product1, product2, tokenizer, max_length=128):
    encoding = tokenizer(
        product1, product2,
        padding='max_length',  # Pad to the max length
        truncation=True,       # Truncate if longer than max_length
        max_length=max_length,
        return_tensors='pt'    # Return PyTorch tensors
    )
    
    # Ensure the input has a batch dimension (i.e., [batch_size, seq_length])
    encoding = {key: value.unsqueeze(0) if value.dim() == 1 else value for key, value in encoding.items()}
    
    return encoding
import torch
from transformers import BertForSequenceClassification

# Load the fine-tuned BERT model
model = BertForSequenceClassification.from_pretrained('./trained_model')
model.eval()  # Set the model to evaluation mode

def compute_similarity(product1, product2, model, tokenizer):
    # Tokenize the products
    encoding = tokenize_product_pair(product1, product2, tokenizer)

    # Pass the tokenized inputs to the model
    with torch.no_grad():  # Disable gradient computation for inference
        output = model(**encoding)

    # Get the logits (raw predictions) from the output
    logits = output.logits

    # Apply softmax to get probabilities
    probabilities = torch.nn.functional.softmax(logits, dim=-1)

    # The probability of them being a match (1 for match, 0 for no match)
    match_probability = probabilities[0, 1].item()

    return match_probability

def find_best_match(new_product, database_df, model, tokenizer):
    best_match = None
    best_match_score = -1  # A low score for initialization

    for index, row in database_df.iterrows():
        product_name = row['product_name']
        product_description = row['product_description']
        product_pair = f"{product_name} {product_description}"


        # Compute similarity score
        match_probability = compute_similarity(new_product, product_pair, model, tokenizer)
        print(f"i: {index} product: {product_pair} score: {match_probability}")

        # If the similarity score is higher than the best score so far, update
        if match_probability > best_match_score:
            best_match_score = match_probability
            best_match = row  # Keep the whole row for further details

    return best_match, best_match_score

# Example: New product you want to find a match for
new_product = "Jumbo Drinkyoghurt aardbei 0,5L"

# Fetch all products from the database
database_df = fetch_products_from_database()

# Find the best match
best_match, best_match_score = find_best_match(new_product, database_df, model, tokenizer)

# Display the result
if best_match is not None:
    print(f"Best match for {new_product}: {best_match['product_name']} with score: {best_match_score:.2f}")
else:
    print("No match found.")