Spaces:
Build error
Build error
import gradio as gr | |
from smolagents import HfApiModel, CodeAgent, Tool | |
from smolagents import CodeAgent, DuckDuckGoSearchTool, HfApiModel, load_tool, tool | |
from huggingface_hub import login | |
from llama_index.retrievers.bm25 import BM25Retriever | |
import spaces | |
import torch | |
from transformers.models.speecht5.number_normalizer import EnglishNumberNormalizer | |
from string import punctuation | |
import re | |
from parler_tts import ParlerTTSForConditionalGeneration | |
from transformers import AutoTokenizer, AutoFeatureExtractor, set_seed | |
device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
repo_id = "parler-tts/parler-tts-mini-v1" | |
# repo_id_large = "parler-tts/parler-tts-large-v1" | |
tts_model = ParlerTTSForConditionalGeneration.from_pretrained(repo_id).to(device) | |
tokenizer = AutoTokenizer.from_pretrained(repo_id) | |
feature_extractor = AutoFeatureExtractor.from_pretrained(repo_id) | |
SAMPLE_RATE = feature_extractor.sampling_rate | |
SEED = 42 | |
number_normalizer = EnglishNumberNormalizer() | |
def preprocess(text): | |
text = number_normalizer(text).strip() | |
text = text.replace("-", " ") | |
if text[-1] not in punctuation: | |
text = f"{text}." | |
abbreviations_pattern = r'\b[A-Z][A-Z\.]+\b' | |
def separate_abb(chunk): | |
chunk = chunk.replace(".","") | |
print(chunk) | |
return " ".join(chunk) | |
abbreviations = re.findall(abbreviations_pattern, text) | |
for abv in abbreviations: | |
if abv in text: | |
text = text.replace(abv, separate_abb(abv)) | |
return text | |
def gen_tts(text, description): | |
inputs = tokenizer(description.strip(), return_tensors="pt").to(device) | |
prompt = tokenizer(preprocess(text), return_tensors="pt").to(device) | |
set_seed(SEED) | |
generation = tts_model.generate( | |
input_ids=inputs.input_ids, prompt_input_ids=prompt.input_ids, attention_mask=inputs.attention_mask, prompt_attention_mask=prompt.attention_mask, do_sample=True, temperature=1.0 | |
) | |
audio_arr = generation.cpu().numpy().squeeze() | |
return SAMPLE_RATE, audio_arr | |
class RetrieverTool(Tool): | |
name = "retriever" | |
description = "Uses semantic search to retrieve the parts of transformers documentation that could be most relevant to answer your query." | |
inputs = { | |
"query": { | |
"type": "string", | |
"description": "The query to perform. Ask the question as an human would, with simple explanation. The underlying index is BM25.", | |
} | |
} | |
output_type = "string" | |
def __init__(self, path, **kwargs): | |
super().__init__(**kwargs) | |
self.retriever = BM25Retriever.from_persist_dir(path) | |
def forward(self, query: str) -> str: | |
assert isinstance(query, str), "Your search query must be a string" | |
docs = self.retriever.retrieve( | |
query, | |
) | |
return "\nRetrieved documents:\n" + "".join( | |
[ | |
f"\n\n===== Document {str(i)} =====\n" + doc.text | |
for i, doc in enumerate(docs) | |
] | |
) | |
path = "./ml_notes_index" | |
model = HfApiModel( | |
max_tokens=4086, | |
temperature=0.5, | |
model_id='Qwen/Qwen2.5-Coder-32B-Instruct', | |
custom_role_conversions=None | |
) | |
retriever_tool = RetrieverTool(path) | |
agent = CodeAgent( | |
tools=[retriever_tool], | |
model=model, | |
max_steps=4, | |
verbosity_level=2 | |
) | |
summarization_agent = CodeAgent( | |
tools=[], | |
model=model, | |
max_steps=1, | |
verbosity_level=2 | |
) | |
def greet(question): | |
agent_output = agent.run(question) | |
result = summarization_agent.run(f"Rephrase the following out since it will be passed to an Text-To-Speach Model: {agent_output}") | |
# Generate audio from the text | |
description = "Laura's voice is monotone yet slightly fast in delivery, with a very close recording that almost has no background noise." | |
sample_rate, audio = gen_tts(result, description) | |
return result, (sample_rate, audio) | |
# login() | |
css = """ | |
#share-btn-container { | |
display: flex; | |
padding-left: 0.5rem !important; | |
padding-right: 0.5rem !important; | |
background-color: #000000; | |
justify-content: center; | |
align-items: center; | |
border-radius: 9999px !important; | |
width: 13rem; | |
margin-top: 10px; | |
margin-left: auto; | |
flex: unset !important; | |
} | |
""" | |
with gr.Blocks(css=css) as block: | |
gr.HTML( | |
""" | |
<div style="text-align: center; max-width: 700px; margin: 0 auto;"> | |
<div style="display: inline-flex; align-items: center; gap: 0.8rem; font-size: 1.75rem;"> | |
<h1 style="font-weight: 900; margin-bottom: 7px; line-height: normal;"> | |
ML Professor with Voice 🗣️ | |
</h1> | |
</div> | |
</div> | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox(label="Your Question", lines=2) | |
run_button = gr.Button("Ask Question", variant="primary") | |
with gr.Column(): | |
text_output = gr.Textbox(label="Answer", lines=4) | |
audio_out = gr.Audio(label="Voice Answer", type="numpy") | |
run_button.click(fn=greet, inputs=[input_text], outputs=[text_output, audio_out]) | |
block.queue() | |
block.launch(share=True) | |