Viewing File: /home/ubuntu/codegamaai-test/broker_bot/archive/intent_class_sub0.py
import torch
from transformers import BertTokenizer, BertForSequenceClassification
import joblib
model = BertForSequenceClassification.from_pretrained('./bert_intent_classifier')
tokenizer = BertTokenizer.from_pretrained('./bert_intent_classifier')
model.eval()
label_encoder = joblib.load('label_encoder.pkl')
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
def predict_intent_bert(query):
inputs = tokenizer.encode_plus(query, return_tensors="pt", max_length=256, truncation=True, padding="max_length")
input_ids = inputs["input_ids"].to(device)
attention_mask = inputs["attention_mask"].to(device)
with torch.no_grad():
outputs = model(input_ids, attention_mask=attention_mask)
logits = outputs.logits
predicted_index = torch.argmax(logits, dim=1).item()
predicted_label = label_encoder.inverse_transform([predicted_index])[0]
confidence = torch.nn.functional.softmax(logits, dim=1)[0][predicted_index].item()
return predicted_label, confidence
def classify_user_query(query):
if "sell" in query.lower():
return "sell_stocks", 0.96
elif "buy" in query.lower():
return "buy_stocks", 0.96
elif "stock" and "price" in query.lower():
return "stock_info", 0.96
elif "stock" and "history" in query.lower():
return "historical_stock_data", 0.97
elif "current" and "stocks" in query.lower():
return "view_stocks", 0.97
elif "stock" and "info" in query.lower():
return "stock_info", 0.97
elif "support" in query.lower():
return "human_support", 0.98
else:
return "general_query", 0.95
# Test loop
# while True:
# query = input("Enter a query ('exit' to quit): ")
# if query.lower() == 'exit':
# break
# label, confidence = predict_intent_bert(query)
# print(f"Intent: {label}, Confidence: {confidence:.4f}")
Back to Directory
File Manager