Viewing File: /home/ubuntu/codegamaai-test/broker_bot/similarity_search/similarity_search_v2.py
from sentence_transformers import SentenceTransformer, util
import json
with open(r'transformed_stock.json', 'r') as file:
stocks_data = json.load(file)
# Initialize the model
model = SentenceTransformer('all-MiniLM-L6-v2')
stock_names = [stock['name'] for stock in stocks_data]
stock_tickers = [stock['symbol'] for stock in stocks_data]
stock_list = stock_names + stock_tickers
stock_embeddings = model.encode(stock_list, convert_to_tensor=True)
print("len ", len(stock_embeddings))
print(stock_embeddings)
def find_closest_stock(query, stocks_data=stocks_data, stock_names=stock_names, model=model, stock_embeddings=stock_embeddings):
query_embedding = model.encode(query, convert_to_tensor=True)
cosine_scores = util.pytorch_cos_sim(query_embedding, stock_embeddings)[0]
print("cosine_scores", cosine_scores)
highest_score_index = cosine_scores.argmax()
print("highest_score_index", highest_score_index)
if highest_score_index < len(stock_names):
stock_entry = stocks_data[highest_score_index]
else:
stock_entry = stocks_data[highest_score_index - len(stock_names)]
# stock_entry = stocks_data[highest_score_index]
print("stock_entry", stock_entry)
return {
"symbol": stock_entry['symbol'],
"name": stock_entry['name'],
"industry": stock_entry['industry'],
"score": cosine_scores[highest_score_index].item()
}
# query = "Tencent"
# result = find_closest_stock(query)
# print(result)
while True:
query = input("Enter a stock name: ")
if query == "exit":
break
query = input("Enter a stock name: ")
result = find_closest_stock(query)
print(result)
Back to Directory
File Manager