MasoodishWisdom / scripts /generate_quote_gradio.py
0xrushi
examples
2c5fb56
raw
history blame contribute delete
3.78 kB
from huggingface_hub import login
import os
token = os.environ.get("HUGGINGFACE_TOKEN")
login(token)
import gradio as gr
import spaces
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
from f5py import generate_tts
from stitch import create_music_speech_mix
import traceback
import warnings
# Suppress NVML initialization warning
warnings.filterwarnings("ignore", message="Can't initialize NVML")
@spaces.GPU()
def generate_quote(temperature, top_p, max_length):
try:
def initialize_model():
adapter_path = "./checkpoints/epoch-11"
base_model = "mistralai/Mistral-7B-Instruct-v0.3"
# Check CUDA availability more thoroughly
device = "cuda" if torch.cuda.is_available() and torch.cuda.device_count() > 0 else "cpu"
print(f"Using device: {device}")
tokenizer = AutoTokenizer.from_pretrained(base_model)
model = AutoModelForCausalLM.from_pretrained(
base_model,
torch_dtype=torch.float16 if device == "cuda" else torch.float32,
device_map="auto" if device == "cuda" else None
)
model = PeftModel.from_pretrained(model, adapter_path)
model.eval()
return pipeline("text-generation", model=model, tokenizer=tokenizer)
generator = initialize_model()
prompt = "Generate a Masood Boomgaard style quote:"
output = generator(
prompt,
max_new_tokens=max_length,
do_sample=True,
temperature=temperature,
top_p=top_p,
top_k=50,
num_return_sequences=1
)
text = output[0]["generated_text"].replace(prompt, "")
output_path = generate_tts(input_text=text)
final_audio_path = create_music_speech_mix(speech_path=output_path)
return text, final_audio_path, None
except Exception as e:
error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
return None, None, error_msg
with gr.Blocks() as demo:
gr.Markdown("# MasoodishWisdom")
with gr.Row():
with gr.Column():
temperature = gr.Slider(
minimum=0.1, maximum=1.0, step=0.1, value=0.9,
label="Temperature"
)
top_p = gr.Slider(
minimum=0.1, maximum=1.0, step=0.05, value=0.95,
label="Top-p"
)
max_length = gr.Slider(
minimum=50, maximum=200, step=10, value=100,
label="Max Length"
)
generate_btn = gr.Button("Generate Quote")
gr.Markdown("---")
gr.Markdown("## Examples")
for file in sorted(os.listdir("examples")):
if file.startswith("output") and file.endswith(".wav"):
gr.Audio(os.path.join("examples", file), label=file, type="filepath")
with gr.Column():
text_output = gr.Textbox(label="Generated Quote")
audio_output = gr.Audio(label="Generated Audio")
error_output = gr.Textbox(label="Error Log", visible=True)
def handle_generation(*args):
text, audio, error = generate_quote(*args)
if error:
return [None, None, error]
return [text, audio, None]
generate_btn.click(
handle_generation,
inputs=[temperature, top_p, max_length],
outputs=[text_output, audio_output, error_output]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
show_error=True,
share=False
)