Update app.py
Browse files
app.py
CHANGED
@@ -29,58 +29,6 @@ from flask import Flask, request, render_template
|
|
29 |
from flask_cors import CORS
|
30 |
from flask_socketio import SocketIO, emit
|
31 |
|
32 |
-
|
33 |
-
# GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
34 |
-
GROQ_API_KEY = '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'
|
35 |
-
# load_dotenv(".env")
|
36 |
-
USER_AGENT = os.getenv("USER_AGENT")
|
37 |
-
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
38 |
-
SECRET_KEY = os.getenv("SECRET_KEY")
|
39 |
-
|
40 |
-
|
41 |
-
# Set environment variables
|
42 |
-
os.environ['USER_AGENT'] = USER_AGENT
|
43 |
-
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
|
44 |
-
os.environ["TOKENIZERS_PARALLELISM"] = 'true'
|
45 |
-
|
46 |
-
# Initialize Flask app and SocketIO with CORS
|
47 |
-
app = Flask(__name__)
|
48 |
-
CORS(app)
|
49 |
-
app.config['MAX_CONTENT_LENGTH'] = 1e7
|
50 |
-
app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS
|
51 |
-
app.config['SESSION_COOKIE_HTTPONLY'] = True
|
52 |
-
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
|
53 |
-
socketio = SocketIO(app, cors_allowed_origins="*", logger=True, max_http_buffer_size=1e7)
|
54 |
-
app.config['SECRET_KEY'] = SECRET_KEY
|
55 |
-
|
56 |
-
import pandas as pd
|
57 |
-
from PIL import Image
|
58 |
-
import numpy as np
|
59 |
-
import os
|
60 |
-
|
61 |
-
import torch
|
62 |
-
import torch.nn.functional as F
|
63 |
-
|
64 |
-
# from src.data.embs import ImageDataset
|
65 |
-
from src.model.blip_embs import blip_embs
|
66 |
-
from src.data.transforms import transform_test
|
67 |
-
|
68 |
-
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer
|
69 |
-
import gradio as gr
|
70 |
-
# import spaces
|
71 |
-
|
72 |
-
from langchain.chains import ConversationChain
|
73 |
-
from langchain_community.chat_message_histories import ChatMessageHistory
|
74 |
-
from langchain_core.runnables import RunnableWithMessageHistory
|
75 |
-
from langchain_core.output_parsers import StrOutputParser
|
76 |
-
from langchain_core.prompts import ChatPromptTemplate
|
77 |
-
from langchain_groq import ChatGroq
|
78 |
-
|
79 |
-
from dotenv import load_dotenv
|
80 |
-
from flask import Flask, request, render_template
|
81 |
-
from flask_cors import CORS
|
82 |
-
from flask_socketio import SocketIO, emit
|
83 |
-
|
84 |
import json
|
85 |
from openai import OpenAI
|
86 |
|
@@ -100,12 +48,14 @@ os.environ["TOKENIZERS_PARALLELISM"] = 'true'
|
|
100 |
# Initialize Flask app and SocketIO with CORS
|
101 |
app = Flask(__name__)
|
102 |
CORS(app)
|
103 |
-
|
104 |
app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS
|
105 |
app.config['SESSION_COOKIE_HTTPONLY'] = True
|
106 |
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
|
|
|
107 |
app.config['SECRET_KEY'] = SECRET_KEY
|
108 |
|
|
|
109 |
# Initialize LLM
|
110 |
llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0, max_tokens=1024, max_retries=2)
|
111 |
|
@@ -179,15 +129,11 @@ model = blip_embs(
|
|
179 |
|
180 |
model = model.to(device)
|
181 |
model.eval()
|
182 |
-
print("Model Loaded !")
|
183 |
-
print("="*50)
|
184 |
|
185 |
transform = transform_test(384)
|
186 |
|
187 |
-
print("Loading Data")
|
188 |
df = pd.read_json("my_recipes.json")
|
189 |
|
190 |
-
print("Loading Target Embedding")
|
191 |
tar_img_feats = []
|
192 |
for _id in df["id_"].tolist():
|
193 |
tar_img_feats.append(torch.load("./datasets/sidechef/blip-embs-large/{:07d}.pth".format(_id)).unsqueeze(0))
|
@@ -234,7 +180,6 @@ class Chat:
|
|
234 |
|
235 |
|
236 |
chat = Chat(model,transform,df,tar_img_feats, device)
|
237 |
-
print("Chat Initialized !")
|
238 |
|
239 |
|
240 |
def answer_generator(formated_input, session_id):
|
@@ -516,7 +461,7 @@ def handle_message(data):
|
|
516 |
context = "No data available"
|
517 |
session_id = request.sid
|
518 |
if session_id not in session_store:
|
519 |
-
session_store[session_id] = {'image_data':
|
520 |
|
521 |
if 'message' in data:
|
522 |
session_store[session_id]['message'] = data['message']
|
@@ -614,9 +559,6 @@ def handle_message(data):
|
|
614 |
emit('response', response, room=session_id)
|
615 |
return response
|
616 |
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
# Home route
|
621 |
@app.route("/")
|
622 |
def index_view():
|
|
|
29 |
from flask_cors import CORS
|
30 |
from flask_socketio import SocketIO, emit
|
31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
import json
|
33 |
from openai import OpenAI
|
34 |
|
|
|
48 |
# Initialize Flask app and SocketIO with CORS
|
49 |
app = Flask(__name__)
|
50 |
CORS(app)
|
51 |
+
app.config['MAX_CONTENT_LENGTH'] = 1e7
|
52 |
app.config['SESSION_COOKIE_SECURE'] = True # Use HTTPS
|
53 |
app.config['SESSION_COOKIE_HTTPONLY'] = True
|
54 |
app.config['SESSION_COOKIE_SAMESITE'] = 'Lax'
|
55 |
+
socketio = SocketIO(app, cors_allowed_origins="*", logger=True, max_http_buffer_size=1e7)
|
56 |
app.config['SECRET_KEY'] = SECRET_KEY
|
57 |
|
58 |
+
|
59 |
# Initialize LLM
|
60 |
llm = ChatGroq(model="llama-3.1-8b-instant", temperature=0, max_tokens=1024, max_retries=2)
|
61 |
|
|
|
129 |
|
130 |
model = model.to(device)
|
131 |
model.eval()
|
|
|
|
|
132 |
|
133 |
transform = transform_test(384)
|
134 |
|
|
|
135 |
df = pd.read_json("my_recipes.json")
|
136 |
|
|
|
137 |
tar_img_feats = []
|
138 |
for _id in df["id_"].tolist():
|
139 |
tar_img_feats.append(torch.load("./datasets/sidechef/blip-embs-large/{:07d}.pth".format(_id)).unsqueeze(0))
|
|
|
180 |
|
181 |
|
182 |
chat = Chat(model,transform,df,tar_img_feats, device)
|
|
|
183 |
|
184 |
|
185 |
def answer_generator(formated_input, session_id):
|
|
|
461 |
context = "No data available"
|
462 |
session_id = request.sid
|
463 |
if session_id not in session_store:
|
464 |
+
session_store[session_id] = {'image_data': "", 'message': None, 'image_received': False}
|
465 |
|
466 |
if 'message' in data:
|
467 |
session_store[session_id]['message'] = data['message']
|
|
|
559 |
emit('response', response, room=session_id)
|
560 |
return response
|
561 |
|
|
|
|
|
|
|
562 |
# Home route
|
563 |
@app.route("/")
|
564 |
def index_view():
|