Spaces:
Runtime error
Runtime error
import os | |
import asyncio | |
import logging | |
from typing import Optional, List, Union, Literal | |
from pathlib import Path | |
from pydantic import BaseModel, Field | |
from gradio import Interface, Blocks, Component | |
from gradio.data_classes import FileData, GradioModel, GradioRootModel | |
from transformers import pipeline | |
from diffusers import DiffusionPipeline | |
import torch | |
import gradio as gr | |
# Load gated image model | |
image_model = DiffusionPipeline.from_pretrained( | |
"black-forest-labs/FLUX.1-dev", | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, | |
use_auth_token=os.getenv("HUGGINGFACE_TOKEN") | |
) | |
image_model.enable_model_cpu_offload() | |
# Define data models | |
class FileDataDict(BaseModel): | |
path: str | |
url: Optional[str] = None | |
size: Optional[int] = None | |
orig_name: Optional[str] = None | |
mime_type: Optional[str] = None | |
is_stream: Optional[bool] = False | |
class Config: | |
arbitrary_types_allowed = True | |
class MessageDict(BaseModel): | |
content: Union[str, FileDataDict, tuple, Component] | |
role: Literal["user", "assistant", "system"] | |
metadata: Optional[dict] = None | |
options: Optional[List[dict]] = None | |
class Config: | |
arbitrary_types_allowed = True | |
class ChatMessage(GradioModel): | |
role: Literal["user", "assistant", "system"] | |
content: Union[str, FileData, Component] | |
metadata: dict = Field(default_factory=dict) | |
options: Optional[List[dict]] = None | |
class Config: | |
arbitrary_types_allowed = True | |
class ChatbotDataMessages(GradioRootModel): | |
root: List[ChatMessage] | |
# Universal Reasoning Aggregator | |
class UniversalReasoning: | |
def __init__(self, config): | |
self.config = config | |
self.context_history = [] | |
self.sentiment_analyzer = pipeline("sentiment-analysis") | |
self.deepseek_model = pipeline( | |
"text-classification", | |
model="distilbert-base-uncased-finetuned-sst-2-english", | |
truncation=True | |
) | |
self.davinci_model = pipeline( | |
"text2text-generation", | |
model="t5-small", | |
truncation=True | |
) | |
self.additional_model = pipeline( | |
"text-generation", | |
model="EleutherAI/gpt-neo-125M", | |
truncation=True | |
) | |
self.image_model = image_model | |
async def generate_response(self, question: str) -> str: | |
self.context_history.append(question) | |
sentiment_score = self.analyze_sentiment(question) | |
deepseek_response = self.deepseek_model(question) | |
davinci_response = self.davinci_model(question, max_length=50) | |
additional_response = self.additional_model(question, max_length=100) | |
responses = [ | |
f"Sentiment score: {sentiment_score}", | |
f"DeepSeek Response: {deepseek_response}", | |
f"T5 Response: {davinci_response}", | |
f"Additional Model Response: {additional_response}" | |
] | |
return "\n\n".join(responses) | |
def generate_image(self, prompt: str): | |
image = self.image_model( | |
prompt, | |
height=1024, | |
width=1024, | |
guidance_scale=3.5, | |
num_inference_steps=50, | |
max_sequence_length=512, | |
generator=torch.Generator('cpu').manual_seed(0) | |
).images[0] | |
image.save("flux-dev.png") | |
return image | |
def analyze_sentiment(self, text: str) -> list: | |
sentiment_score = self.sentiment_analyzer(text) | |
logging.info(f"Sentiment analysis result: {sentiment_score}") | |
return sentiment_score | |
# Main Component | |
class MultimodalChatbot(Component): | |
def __init__( | |
self, | |
value: Optional[List[MessageDict]] = None, | |
label: Optional[str] = None, | |
render: bool = True, | |
log_file: Optional[Path] = None, | |
): | |
value = value or [] | |
super().__init__(label=label, value=value) | |
self.log_file = log_file | |
self.render = render | |
self.data_model = ChatbotDataMessages | |
self.universal_reasoning = UniversalReasoning({}) | |
def preprocess(self, payload: Optional[ChatbotDataMessages]) -> List[MessageDict]: | |
return payload.root if payload else [] | |
def postprocess(self, messages: Optional[List[MessageDict]]) -> ChatbotDataMessages: | |
messages = messages or [] | |
return ChatbotDataMessages(root=messages) | |
# Gradio Interface | |
class HuggingFaceChatbot: | |
def __init__(self): | |
self.chatbot = MultimodalChatbot(value=[]) | |
def setup_interface(self): | |
async def chatbot_logic(input_text: str) -> str: | |
return await self.chatbot.universal_reasoning.generate_response(input_text) | |
def image_logic(prompt: str): | |
return self.chatbot.universal_reasoning.generate_image(prompt) | |
interface = Interface( | |
fn=chatbot_logic, | |
inputs="text", | |
outputs="text", | |
title="Hugging Face Multimodal Chatbot", | |
) | |
image_interface = Interface( | |
fn=image_logic, | |
inputs="text", | |
outputs="image", | |
title="Image Generator", | |
) | |
return Blocks([interface, image_interface]) | |
def launch(self): | |
interface = self.setup_interface() | |
interface.launch() | |
# Standalone launch | |
if __name__ == "__main__": | |
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(levelname)s - %(message)s') | |
chatbot = HuggingFaceChatbot() | |
chatbot.launch() |