Viewing File: /home/ubuntu/codegamaai-test/broker_bot/similarity_search/similarity_check.py

import json
from fuzzywuzzy import process
import pandas as pd
import yahooquery as yq
import os
from sentence_transformers import SentenceTransformer, util
import json


def read_stock_json():
    file_path = 'transformed_stock.json'
    file_path = os.path.join(os.environ['SIMILARITY_SEARCH_DIR'], file_path)

    try:
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
            return data
    except FileNotFoundError:
        print(f"File not found at location: {file_path}")
        return []
    except json.JSONDecodeError:
        print(f"Error decoding JSON from file at location: {file_path}")
        return None

def read_unity_data():
    file_path = os.path.join(os.environ['UNITY_DIR'], "instruments_id_info.json")
    try:
        with open(file_path, 'r') as json_file:
            data = json.load(json_file)
            return data
    except FileNotFoundError:
        print(f"File not found at location: {file_path}")
        return []
    except json.JSONDecodeError:
        print(f"Error decoding JSON from file at location: {file_path}")
        return None



# Initialize the model
model = SentenceTransformer('all-MiniLM-L6-v2')
# stocks_data = read_stock_json()
unity_data = read_unity_data()
stocks_data = unity_data['items']
stock_names = [stock['name'] for stock in stocks_data]
stock_tickers = [stock['ticker'] for stock in stocks_data]
stocks_code = [stock['code'] for stock in stocks_data]
stock_list = stock_names + stock_tickers + stocks_code
stock_embeddings = model.encode(stock_list, convert_to_tensor=True)
user_session = {}

    
def update_json(data):
    # Read existing JSON data from the file
    file_path = 'transformed_stock.json'
    file_path = os.path.join(os.environ['SIMILARITY_SEARCH_DIR'], file_path)
    try:
        with open(file_path, 'r') as json_file:
            existing_data = json.load(json_file)

    except FileNotFoundError:
        # If the file doesn't exist, initialize with an empty list
        existing_data = []

    # Append new values to the existing data
    existing_data.extend(data)

    # Write the updated data back to the JSON file
    with open(file_path, 'w') as json_file:
        json.dump(existing_data, json_file)

def find_closest_stock(query):
    global user_session, stock_embeddings, stocks_data, stock_names, stock_tickers, stocks_code
    # If user_session is empty
    if not user_session:
        user_session['interaction_count'] = 0
    
    user_session['interaction_count'] += 1
    if user_session['interaction_count'] % 100 == 0:
        # Write code to update the model and stock_embeddings after 100 new entries
        unity_data = read_unity_data()
        stocks_data = unity_data['items']
        stock_names = [stock['name'] for stock in stocks_data]
        stock_tickers = [stock['ticker'] for stock in stocks_data]
        stocks_code = [stock['code'] for stock in stocks_data]
        stock_list = stock_names + stock_tickers + stocks_code
        stock_embeddings = model.encode(stock_list, convert_to_tensor=True)


    # Write code to update the model and stock_embeddings after 100 new entries

    query_embedding = model.encode(query, convert_to_tensor=True)
    cosine_scores = util.pytorch_cos_sim(query_embedding, stock_embeddings)[0]
    highest_score_index = cosine_scores.argmax()
    if highest_score_index < len(stock_names):
        stock_entry = stocks_data[highest_score_index]
        score = cosine_scores[highest_score_index].item()
        # Add score in stock_entry
        stock_entry['score'] = score
        response = stock_entry
        if response['score'] > 0.85:
            return response
        else:
            return "No Symbol Found"
        
    elif highest_score_index < len(stock_names) + len(stock_tickers):
        stock_entry = stocks_data[highest_score_index - len(stock_names)]
        score = cosine_scores[highest_score_index].item()
        # Add score in stock_entry
        stock_entry['score'] = score
        response = stock_entry
        if response['score'] > 0.98:
            return response
        else:
            return "No Symbol Found"
    else:
        stock_entry = stocks_data[highest_score_index - len(stock_names) - len(stock_tickers)]
        score = cosine_scores[highest_score_index].item()
        # Add score in stock_entry
        stock_entry['score'] = score
        response = stock_entry
        if response['score'] > 0.95:
            return response
        else:
            return "No Symbol Found"

    


    


def get_symbol(query):
    try:
        data = yq.search(query)
    except ValueError: # Will catch JSONDecodeError
        print(query)
        return 'No Symbol Found'
    
    if data is None:
        return 'No Symbol Found'
    else:
        quotes = data['quotes']
        if len(quotes) == 0:
            return 'No Symbol Found'

        symbol = quotes[0]['symbol']
        longname = quotes[0]['longname']
        industry = quotes[0]['industry']
        exchange = quotes[0]['exchDisp']

        response = {
            'code': symbol,
            'name': longname,
            'industry': industry,
            'exchange': exchange
        }
        # for quote in quotes:
        #     if quote['exchange'] == preferred_exchange:
        #         symbol = quote['symbol']
        #         break
        return response


def get_stock_ticker(stock_entity):
    # existing_data = read_stock_json()
    # stock_names = [stock['name'] for stock in existing_data]
    closest_match = find_closest_stock(stock_entity)

    if closest_match == "No Symbol Found":
        symbol = get_symbol(stock_entity)
        if symbol == "No Symbol Found":
            return "No Symbol Found"
        else:
            update_json([symbol])
            return symbol
    else:
        return closest_match

# while True:
#     if query.lower() == 'exit':
#         break
#     else:
#         query = input("Enter the stock name: ")
#         result = get_stock_ticker(query)
#         print(result)
    
    
# def find_closest_stock(query, stocks_data, stock_names):
#     closest_match, score = process.extractOne(query, stock_names)
#     stock_entry = next((stock for stock in stocks_data if stock['name'] == closest_match), None)
#     if score > 80:  # Adjust the threshold score as needed
#         return {
#             "symbol": stock_entry['symbol'],
#             "name": stock_entry['name'],
#             "industry": stock_entry['industry'],
#             "score": score
#         }
#     else:
#         return "No Symbol Found"
    
Back to Directory File Manager