File size: 3,659 Bytes
07e0298
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import json
import os
from pathlib import Path

import gradio as gr
import numpy as np
from dotenv import load_dotenv
from fastapi import FastAPI
from fastapi.responses import HTMLResponse, StreamingResponse
from fastrtc import (
    AdditionalOutputs,
    ReplyOnPause,
    Stream,
    audio_to_bytes,
    get_cloudflare_turn_credentials_async,
    get_current_context,
    get_tts_model,
)
from groq import Groq
from numpy.typing import NDArray

curr_dir = Path(__file__).parent
load_dotenv()

tts_model = get_tts_model()
groq = Groq(api_key=os.getenv("GROQ_API_KEY"))


conversations: dict[str, list[dict[str, str]]] = {}


def response(user_audio: tuple[int, NDArray[np.int16]]):
    context = get_current_context()
    if context.webrtc_id not in conversations:
        conversations[context.webrtc_id] = [
            {
                "role": "system",
                "content": (
                    "You are a helpful assistant that can answer questions and help with tasks."
                    'Please return a short (that will be converted to audio using a text-to-speech model) response and long response to this question. They can be the same if appropriate. Please return in JSON format\n\n{"short":, "long"}\n\n'
                ),
            }
        ]
    messages = conversations[context.webrtc_id]

    transcription = groq.audio.transcriptions.create(
        file=("audio.wav", audio_to_bytes(user_audio)),
        model="distil-whisper-large-v3-en",
        response_format="verbose_json",
    )
    print(transcription.text)

    messages.append({"role": "user", "content": transcription.text})

    completion = groq.chat.completions.create(  # type: ignore
        model="meta-llama/llama-4-scout-17b-16e-instruct",
        messages=messages,  # type: ignore
        temperature=1,
        max_completion_tokens=1024,
        top_p=1,
        stream=False,
        response_format={"type": "json_object"},
        stop=None,
    )
    response = completion.choices[0].message.content
    response = json.loads(response)
    short_response = response["short"]
    long_response = response["long"]
    messages.append({"role": "assistant", "content": long_response})
    conversations[context.webrtc_id] = messages
    yield from tts_model.stream_tts_sync(short_response)
    yield AdditionalOutputs(messages)


stream = Stream(
    ReplyOnPause(response),
    modality="audio",
    mode="send-receive",
    additional_outputs=[gr.Chatbot(type="messages")],
    additional_outputs_handler=lambda old, new: new,
    rtc_configuration=get_cloudflare_turn_credentials_async,
)

app = FastAPI()
stream.mount(app)


@app.get("/")
async def _():
    rtc_config = await get_cloudflare_turn_credentials_async()
    html_content = (curr_dir / "index.html").read_text()
    html_content = html_content.replace("__RTC_CONFIGURATION__", json.dumps(rtc_config))
    return HTMLResponse(content=html_content)


@app.get("/outputs")
async def _(webrtc_id: str):
    async def output_stream():
        async for output in stream.output_stream(webrtc_id):
            state = output.args[0]
            for msg in state[-2:]:
                data = {
                    "message": msg,
                }
                yield f"event: output\ndata: {json.dumps(data)}\n\n"

    return StreamingResponse(output_stream(), media_type="text/event-stream")


if __name__ == "__main__":
    import os

    if (mode := os.getenv("MODE")) == "UI":
        stream.ui.launch(server_port=7860)
    elif mode == "PHONE":
        raise ValueError("Phone mode not supported")
    else:
        import uvicorn

        uvicorn.run(app, host="0.0.0.0", port=7860)