Viewing File: /home/ubuntu/codegamaai-test/tax_bot/help_functions/intent_class.py

import torch
import torch.nn.functional as F
import pandas as pd
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.model_selection import train_test_split
from torch.utils.data import Dataset
import os

output_dir = os.environ['INTENT_MODEL_DIR']
checkpoint_dir = os.path.join(os.environ['INTENT_MODEL_DIR'], "checkpoint-186")
# output_dir = "/home/codegama-ai/codegama_bot/tax_bot/resources/intent_classification"
# checkpoint_dir = "/home/codegama-ai/codegama_bot/tax_bot/resources/intent_classification/checkpoint-183"

# Load the model and tokenizer from the checkpoint
model = BertForSequenceClassification.from_pretrained(checkpoint_dir)
tokenizer = BertTokenizer.from_pretrained(output_dir)


# Prediction function with probabilities
# Prediction function with probabilities
def predict_with_probabilities(text, model, tokenizer):
    model.eval()  # Set model to evaluation mode
    predictions = []
    probabilities = []

    with torch.no_grad():  # Disable gradient calculation
        
        inputs = tokenizer(text, padding=True, truncation=True, max_length=128, return_tensors="pt")
        outputs = model(**inputs)
        logits = outputs.logits
        probs = F.softmax(logits, dim=1)
        _, predicted_class = torch.max(probs, dim=1)
        predicted_prob = torch.max(probs, dim=1).values

        predictions.append(predicted_class.item())
        probabilities.append(predicted_prob.item())

    return predictions, probabilities

# Assuming the label dictionary used during training
label_dict = {"Tax-related": 0, "Non-tax-related": 1}
label_dict_inv = {v: k for k, v in label_dict.items()}  # Inverse mapping

def get_intent(text):
    predictions, probabilities = predict_with_probabilities(text, model, tokenizer)
    threshold = 0.7
    filtered_predictions = []
    for text, pred, prob in zip(text, predictions, probabilities):
        if pred == label_dict["Non-tax-related"] and prob > threshold:
            return "Non-tax-related"
        else:
            return "Tax-related"
Back to Directory File Manager