Spaces:
Sleeping
Sleeping
import gradio as gr | |
import spaces | |
from sentence_transformers import SentenceTransformer | |
from sentence_transformers.util import cos_sim | |
from sentence_transformers.quantization import quantize_embeddings | |
import pymssql | |
import os | |
import pandas as pd | |
from openai import OpenAI | |
from pydantic import BaseModel, Field | |
import json | |
from sentence_transformers import CrossEncoder | |
from torch import nn | |
import time | |
SqlServer = os.environ['SQL_SERVER'] | |
SqlDatabase = os.environ['SQL_DB'] | |
SqlUser = os.environ['SQL_USER'] | |
SqlPass = os.environ['SQL_PASS'] | |
OpenaiApiKey = os.environ.get("OPENAI_API_KEY") | |
OpenaiBaseUrl = os.environ.get("OPENAI_BASE_URL","https://generativelanguage.googleapis.com/v1beta/openai") | |
def sql(query,db=SqlDatabase, login_timeout = 120,onConnectionError = None): | |
start_time = time.time() | |
while True: | |
try: | |
cnxn = pymssql.connect(SqlServer,SqlUser,SqlPass,db, login_timeout = 5) | |
break; | |
except Exception as e: | |
if onConnectionError: | |
onConnectionError(e) | |
if time.time() - start_time > login_timeout: | |
raise TimeoutError("SQL Connection Timeout"); | |
time.sleep(1) # Espera 1 segundo antes de tentar novamente | |
cursor = cnxn.cursor() | |
cursor.execute(query) | |
columns = [column[0] for column in cursor.description] | |
results = [dict(zip(columns, row)) for row in cursor.fetchall()] | |
return results; | |
def embed(text): | |
query_embedding = Embedder.encode(text) | |
return query_embedding.tolist(); | |
def rerank(query,documents, **kwargs): | |
return Reranker.rank(query, documents, **kwargs) | |
ClientOpenai = OpenAI( | |
api_key=OpenaiApiKey | |
,base_url=OpenaiBaseUrl | |
) | |
def llm(messages, ResponseFormat = None, **kwargs): | |
fn = ClientOpenai.chat.completions.create | |
if ResponseFormat: | |
fn = ClientOpenai.beta.chat.completions.parse | |
params = { | |
'model':"gemini-2.0-flash" | |
,'n':1 | |
,'messages':messages | |
,'response_format':ResponseFormat | |
} | |
params.update(kwargs); | |
response = fn(**params) | |
if params.get('stream'): | |
return response | |
return response.choices[0]; | |
def ai(system,user, schema, **kwargs): | |
msg = [ | |
{'role':"system",'content':system} | |
,{'role':"user",'content':user} | |
] | |
return llm(msg, schema, **kwargs); | |
def search(text, top = 10, onConnectionError = None): | |
EnglishText = text | |
embeddings = embed(text); | |
query = f""" | |
declare @search vector(1024) = '{embeddings}' | |
select top {top} | |
* | |
from ( | |
select | |
RelPath | |
,Similaridade = 1-CosDistance | |
,ScriptContent = ChunkContent | |
,ContentLength = LEN(ChunkContent) | |
,CosDistance | |
from | |
( | |
select | |
* | |
,CosDistance = vector_distance('cosine',embeddings,@search) | |
from | |
Scripts | |
) C | |
) v | |
order by | |
CosDistance | |
""" | |
queryResults = sql(query, onConnectionError = onConnectionError); | |
return queryResults | |
print("Loading embedding model"); | |
Embedder = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | |
print("Loading reranker"); | |
Reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-large-v1", activation_fn=nn.Sigmoid()) | |
class rfTranslatedText(BaseModel): | |
text: str = Field(description='Translated text') | |
lang: str = Field(description='source language') | |
class rfGenericText(BaseModel): | |
text: str = Field(description='The text result') | |
def ChatFunc(message, history, LangMode, ChooseLang): | |
# Determinar se o user quer fazer uma nova pesquisa! | |
IsNewSearch = True; | |
messages = [] | |
CurrentTable = None; | |
def ChatBotOutput(): | |
return [messages,CurrentTable] | |
class BotMessage(): | |
def __init__(self, *args, **kwargs): | |
self.Message = gr.ChatMessage(*args, **kwargs) | |
self.LastContent = None | |
messages.append(self.Message); | |
def __call__(self, content, noNewLine = False): | |
if not content: | |
return; | |
self.Message.content += content; | |
self.LastContent = None; | |
if not noNewLine: | |
self.Message.content += "\n"; | |
return ChatBotOutput(); | |
def update(self,content): | |
if not self.LastContent: | |
self.LastContent = self.Message.content | |
self.Message.content = self.LastContent +" "+content+"\n"; | |
return ChatBotOutput(); | |
def done(self): | |
self.Message.metadata['status'] = "done"; | |
return ChatBotOutput(); | |
def Reply(msg): | |
m = BotMessage(msg); | |
return ChatBotOutput(); | |
m = BotMessage("",metadata={"title":"Searching scripts...","status":"pending"}); | |
def OnConnError(err): | |
print("Sql connection error:", err) | |
try: | |
# Responder algo sobre o historico! | |
if IsNewSearch: | |
yield m("Enhancing the prompt...") | |
LLMResult = ai(""" | |
Translate the user's message to English. | |
The message is a question related to a SQL Server T-SQL script that the user is searching for. | |
You must do following actions: | |
- Identify the language of user text, using BCP 47 code (example: pt-BR, en-US, ja-JP, etc.) | |
- Generate translated user text to english | |
Return both source language and translated text. | |
""",message, rfTranslatedText) | |
Question = LLMResult.message.parsed.text; | |
if LangMode == "auto": | |
SourceLang = LLMResult.message.parsed.lang; | |
else: | |
SourceLang = ChooseLang | |
yield m(f"Lang:{SourceLang}({LangMode}), English Prompt: {Question}") | |
yield m("searching...") | |
try: | |
FoundScripts = search(Question, onConnectionError = OnConnError) | |
except Exception as e: | |
print('Search Error:') | |
print(e) | |
yield m("Houve alguma falha ao fazer a pesquisa. Tente novamente. Se persistir, veja orientações na aba Help!") | |
return; | |
yield m("Doing rerank"); | |
doclist = [doc['ScriptContent'] for doc in FoundScripts] | |
# Faz o reranker! | |
for score in rerank(Question, doclist): | |
i = score['corpus_id']; | |
FoundScripts[i]['rank'] = str(score['score']) | |
RankedScripts = sorted(FoundScripts, key=lambda item: float(item['rank']), reverse=True) | |
ScriptTable = [] | |
for script in RankedScripts: | |
link = "https://github.com/rrg92/sqlserver-lib/tree/main/" + script['RelPath'] | |
script['link'] = link; | |
ScriptTable.append({ | |
'Link': f'<a title="{link}" href="{link}" target="_blank">{script["RelPath"]}</a>' | |
,'Length': script['ContentLength'] | |
,'Cosine Similarity': script['Similaridade'] | |
,'Rank': script['rank'] | |
}) | |
CurrentTable = pd.DataFrame(ScriptTable) | |
yield m("Found scripts, check Rank tab for details!") | |
WaitMessage = ai(f""" | |
You will analyze some T-SQL scripts in order to check which is best for the user. | |
You found scripts, presented them to the user, and now will do some work that takes time. | |
Generate a message to tell the user to wait while you work, in the same language as the user. | |
You will receive the question the user sent that triggered this process. | |
Use the user’s original question to customize the message. | |
Answer in lang: {SourceLang} | |
""",message,rfGenericText).message.parsed.text | |
yield Reply(WaitMessage); | |
yield m(f"Analyzing scripts...") | |
ResultJson = json.dumps(RankedScripts); | |
SystemPrompt = f""" | |
You are an assistant that helps users find the best T-SQL scripts for their specific needs. | |
These scripts were created by Rodrigo Ribeiro Gomes and are publicly available for users to query and use. | |
The user will provide a short description of what they are looking for, and your task is to present the most relevant scripts. | |
To assist you, here is a JSON object with the top matches based on the current user query: | |
{ResultJson} | |
--- | |
This JSON contains all the scripts that matched the user's input. | |
Analyze each script's name and content, and create a ranked summary of the best recommendations according to the user's need. | |
Only use the information available in the provided JSON. Do not reference or mention anything outside of this list. | |
You can include parts of the scripts in your answer to illustrate or give usage examples based on the user's request. | |
Re-rank the results if necessary, presenting them from the most to the least relevant. | |
You may filter out scripts that appear unrelated to the user query. | |
--- | |
### Output Rules | |
- Review each script and evaluate how well it matches the user’s request. | |
- Summarize each script, ordering from the most relevant to the least relevant. | |
- Write personalized and informative review text for each recommendation. | |
- If applicable, explain how the user should run the script, including parameters or sections (like `WHERE` clauses) they might need to customize. | |
- When referencing a script, include the link provided in the JSON — all scripts are hosted on GitHub | |
- YOU MUST ANSWER THAT LANGUAGE: {SourceLang} | |
""" | |
ScriptPrompt = [ | |
{ 'role':'system', 'content':SystemPrompt } | |
,{ 'role':'user', 'content':message } | |
] | |
llmanswer = llm(ScriptPrompt, stream = True) | |
yield m.done() | |
answer = BotMessage(""); | |
for chunk in llmanswer: | |
content = chunk.choices[0].delta.content | |
yield answer(content, noNewLine = True) | |
finally: | |
yield m.done() | |
def SearchFiles(message): | |
Question = message; | |
try: | |
FoundScripts = search(Question) | |
except: | |
return m("Houve alguma falha ao executar a consulta no banco. Tente novamente. Se persistir, veja orientações na aba Help!") | |
return; | |
doclist = [doc['ScriptContent'] for doc in FoundScripts] | |
# Faz o reranker! | |
ScriptTable = []; | |
for score in rerank(Question, doclist): | |
i = score['corpus_id']; | |
script = FoundScripts[i]; | |
script['rank'] = str(score['score']) | |
link = "https://github.com/rrg92/sqlserver-lib/tree/main/" + script['RelPath'] | |
script['link'] = link; | |
if not AsJson: | |
ScriptTable.append({ | |
'Link': f'<a title="{link}" href="{link}" target="_blank">{script["RelPath"]}</a>' | |
,'Length': script['ContentLength'] | |
,'Cosine Similarity': script['Similaridade'] | |
,'Rank': script['rank'] | |
}) | |
RankedScripts = sorted(FoundScripts, key=lambda item: float(item['rank']), reverse=True) | |
#result = pd.DataFrame(ScriptTable) | |
jsonresult = json.dumps(RankedScripts) | |
return jsonresult; | |
resultTable = gr.Dataframe(datatype = ['html','number','number'], interactive = False, show_search = "search"); | |
TextResults = gr.Textbox() | |
with gr.Blocks(fill_height=True) as demo: | |
with gr.Column(): | |
tabSettings = gr.Tab("Settings", render = False) | |
with tabSettings: | |
LangOpts = gr.Radio([("Auto Detect from text","auto"), ("Use browser language","browser")], value="auto", label="Language", info="Choose lang used by AI to answer you!") | |
LangChoose = gr.Textbox(info = "This will be filled with detect browser language, but you can change") | |
LangOpts.change(None, [LangOpts],[LangChoose], js = """ | |
function(opt){ | |
if(opt == "browser"){ | |
return navigator ? navigator.language : "en-US"; | |
} | |
} | |
""") | |
with gr.Tab("Chat", scale = 1): | |
ChatTextBox = gr.Textbox(max_length = 500, info = "Which script are you looking for?", submit_btn = True); | |
gr.ChatInterface( | |
ChatFunc | |
,additional_outputs=[resultTable] | |
,additional_inputs=[LangOpts,LangChoose] | |
,type="messages" | |
,textbox = ChatTextBox | |
) | |
tabSettings.render() | |
with gr.Tab("Rank"): | |
txtSearchTable = gr.Textbox(label="Search script files",info="Description of what you want", visible = False) | |
AsJson = gr.Checkbox(visible = False) | |
resultTable.render(); | |
txtSearchTable.submit(SearchFiles, [txtSearchTable],[TextResults]) | |
with gr.Tab("Help"): | |
gr.Markdown(""" | |
Bem-vindo ao Space SQL Server Lib | |
Este space permite que você encontre scripts SQL do https://github.com/rrg92/sqlserver-lib com base nas suas necessidades | |
## Instruções de Uso | |
Apenas descreva o que você precisa no campo de chat e aguarde a IA analisar os melhores scripts do repositório para você. | |
Além de uma explicação feita pela IA, a aba "Rank", contém uma tabela com os scripts encontrados e seus respectictos rank. | |
A coluna Cosine Similarity é o nível de similaridades da sua pergunta com o script (calculado baseado nos embeddings do seu texto e do script). | |
A coluna Rank é um score onde quanto maior o valor mais relacionado ao seu texto o script é (calculado usando rerank/cross encoders). A tabela vem ordenada por essa coluna. | |
## Fluxo básico | |
- Quando você digita o texto, iremos fazer uma busca usando embeddings em um banco Azure SQL Database | |
- Os embeddings são calculados usando um modelo carregado no proprio script, via ZeroGPU. | |
- Os top 20 resultados mais similares são retornados e então um rerank é feito | |
- O rerank também é feito por um modelo que roda no próprio script, em ZeroGPU | |
- Estes resultados ordenados por reran, são então enviados ao LLM para que analise e monte uma resposta para você. | |
## Sobre o uso e eventuais erros | |
Eu tento usar o máximo de recursos FREE e open possíveis, e portanto, eventualmente, o Space pode falhar por alguma limitação. | |
Alguns possíveis pontos de falha: | |
- Créditos free do google ou rate limit | |
- Azure SQL database offline devido a crédito ou ao auto-pause (devido ao free tier) | |
- Limites de uso do ZeroGPU do Hugging Face. | |
Você pode me procurar no [linkedin](https://www.linkedin.com/in/rodrigoribeirogomes/), caso receba erroslimit | |
""") | |
with gr.Tab("Other", visible = False): | |
txtEmbed = gr.Text(label="Text to embed", visible=False) | |
btnEmbed = gr.Button("embed"); | |
btnEmbed.click(embed, [txtEmbed], [txtEmbed]) | |
TextResults.render(); | |
if __name__ == "__main__": | |
demo.launch( | |
share=False, | |
debug=False, | |
server_port=7860, | |
server_name="0.0.0.0", | |
allowed_paths=[] | |
) | |