mbarak
Ml notes chatter
038b068
import gradio as gr
from smolagents import HfApiModel, CodeAgent, Tool
from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, load_tool, tool
from huggingface_hub import login
from llama_index.retrievers.bm25 import BM25Retriever
import spaces
import torch
from transformers.models.speecht5.number_normalizer import EnglishNumberNormalizer
from string import punctuation
import re
from parler_tts import ParlerTTSForConditionalGeneration
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed
device = "cuda:0" if torch.cuda.is_available() else "cpu"
repo_id = "parler-tts/parler-tts-mini-v1"
# repo_id_large = "parler-tts/parler-tts-large-v1"
tts_model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device)
tokenizer = AutoTokenizer.from_pretrained(repo_id)
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id)
SAMPLE_RATE = feature_extractor.sampling_rate
SEED = 42
number_normalizer = EnglishNumberNormalizer()
def preprocess(text):
text = number_normalizer(text).strip()
text = text.replace("-", " ")
if text[-1] not in punctuation:
text = f"{text}."
abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b'
def separate_abb(chunk):
chunk = chunk.replace(".","")
print(chunk)
return " ".join(chunk)
abbreviations = re.findall(abbreviations_pattern, text)
for abv in abbreviations:
if abv in text:
text = text.replace(abv, separate_abb(abv))
return text
@spaces.GPU
def gen_tts(text, description):
inputs = tokenizer(description.strip(), return_tensors="pt").to(device)
prompt = tokenizer(preprocess(text), return_tensors="pt").to(device)
set_seed(SEED)
generation = tts_model.generate(
input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0
)
audio_arr = generation.cpu().numpy().squeeze()
return SAMPLE_RATE, audio_arr
class RetrieverTool(Tool):
name = "retriever"
description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query."
inputs = {
"query": {
"type": "string",
"description": "The query to perform. Ask the question as an human would, with simple explanation. The underlying index is BM25.",
}
}
output_type = "string"
def __init__(self, path, **kwargs):
super().__init__(**kwargs)
self.retriever = BM25Retriever.from_persist_dir(path)
def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"
docs = self.retriever.retrieve(
query,
)
return "\nRetrieved documents:\n" + "".join(
[
f"\n\n===== Document {str(i)} =====\n" + doc.text
for i, doc in enumerate(docs)
]
)
path = "./ml_notes_index"
model = HfApiModel(
max_tokens=4086,
temperature=0.5,
model_id='Qwen/Qwen2.5-Coder-32B-Instruct',
custom_role_conversions=None
)
retriever_tool = RetrieverTool(path)
agent = CodeAgent(
tools=[retriever_tool],
model=model,
max_steps=4,
verbosity_level=2
)
summarization_agent = CodeAgent(
tools=[],
model=model,
max_steps=1,
verbosity_level=2
)
def greet(question):
agent_output = agent.run(question)
result = summarization_agent.run(f"Rephrase the following out since it will be passed to an Text-To-Speach Model: {agent_output}")
# Generate audio from the text
description = "Laura's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise."
sample_rate, audio = gen_tts(result, description)
return result, (sample_rate, audio)
# login()
css = """
#share-btn-container {
display: flex;
padding-left: 0.5rem !important;
padding-right: 0.5rem !important;
background-color: #000000;
justify-content: center;
align-items: center;
border-radius: 9999px !important;
width: 13rem;
margin-top: 10px;
margin-left: auto;
flex: unset !important;
}
"""
with gr.Blocks(css=css) as block:
gr.HTML(
"""
<div style="text-align: center; max-width: 700px; margin: 0 auto;">
<div style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;">
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;">
ML Professor with Voice 🗣️
</h1>
</div>
</div>
"""
)
with gr.Row():
with gr.Column():
input_text = gr.Textbox(label="Your Question", lines=2)
run_button = gr.Button("Ask Question", variant="primary")
with gr.Column():
text_output = gr.Textbox(label="Answer", lines=4)
audio_out = gr.Audio(label="Voice Answer", type="numpy")
run_button.click(fn=greet, inputs=[input_text], outputs=[text_output, audio_out])
block.queue()
block.launch(share=True)