NikhilJoson commited on
Commit
e200100
·
verified ·
1 Parent(s): e0402e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -59
app.py CHANGED
@@ -1,68 +1,45 @@
1
- import gradio as gr
2
  import torch
3
- from transformers import AutoConfig, AutoModelForCausalLM, AutoProcessor, AutoModelForVision2Seq, pipeline
4
- from janus.models import MultiModalityCausalLM, VLChatProcessor
5
- from janus.utils.io import load_pil_images
6
- from PIL import Image
7
- import numpy as np
8
- import os
9
- import time
10
- from Upsample import RealESRGAN
11
- import spaces # Import spaces for ZeroGPU compatibility
12
- import re
13
-
14
-
15
- # Load Janus Pro models for vision and text tasks
16
- vision_model = AutoModelForVision2Seq.from_pretrained("deepseek-ai/janus-pro-7b", torch_dtype=torch.float16, device_map="auto")
17
- text_model = AutoModelForCausalLM.from_pretrained("deepseek-ai/janus-pro-7b", torch_dtype=torch.float16, device_map="auto")
18
- processor = AutoProcessor.from_pretrained("deepseek-ai/janus-pro-7b")
19
- image_pipe = pipeline("text-to-image", model="deepseek-ai/janus-pro-7b")
20
-
21
- last_uploaded_image = None
22
-
23
- def detect_image_request(user_input):
24
- image_keywords = ["generate an image", "create an image", "show me a picture", "draw", "visualize","generate image",
25
- "image generation", "get me an image", "get an image", "need an image", "need image",]
26
- return any(re.search(keyword, user_input, re.IGNORECASE) for keyword in image_keywords)
27
-
28
- def chatbot(history, image=None, user_input=""):
29
- global last_uploaded_image
30
-
31
- if image:
32
- last_uploaded_image = image # Store the latest uploaded image
33
-
34
- if detect_image_request(user_input):
35
- image = image_pipe(user_input)
36
- history.append((user_input, "[Generated Image]"))
37
- return history, "", image[0]["image"]
38
-
39
- if last_uploaded_image:
40
- inputs = processor(images=last_uploaded_image, return_tensors="pt").to("cuda")
41
- output = vision_model.generate(**inputs)
42
- response = processor.decode(output[0], skip_special_tokens=True)
43
  else:
44
- response = text_model.generate(user_input)
45
- response = processor.decode(response[0], skip_special_tokens=True)
 
 
46
 
47
  history.append((user_input, response))
48
- return history, "", None
49
-
50
- def reset_chat():
51
- global last_uploaded_image
52
- last_uploaded_image = None
53
- return [], ""
54
 
55
  with gr.Blocks() as demo:
56
- gr.Markdown("# Janus Pro Chatbot with Vision & Image Generation")
57
- chatbot_interface = gr.Chatbot()
58
- with gr.Row():
59
- image_input = gr.Image(type="pil", label="Upload Image")
60
- text_input = gr.Textbox(label="Type your message")
61
- send_button = gr.Button("Send")
62
- reset_button = gr.Button("Reset Chat")
63
- image_output = gr.Image(label="Generated Image")
64
 
65
- send_button.click(chatbot, [chatbot_interface, image_input, text_input], [chatbot_interface, text_input, image_output])
66
- reset_button.click(reset_chat, [], [chatbot_interface, text_input])
 
 
 
 
67
 
68
  demo.launch()
 
 
1
  import torch
2
+ import argparse
3
+ import gradio as gr
4
+ from janus import JanusProcessor, JanusForConditionalGeneration
5
+ from transformers import AutoTokenizer
6
+
7
+ # Load Model and Processor
8
+ model_id = "allenai/janus-pro-7b"
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
12
+ processor = JanusProcessor.from_pretrained(model_id)
13
+ model = JanusForConditionalGeneration.from_pretrained(
14
+ model_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
15
+ ).to(device)
16
+
17
+ def chat_with_model(history, user_input, image=None):
18
+ if image is not None:
19
+ inputs = processor(text=user_input, images=image, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  else:
21
+ inputs = processor(text=user_input, return_tensors="pt").to(device)
22
+
23
+ generated_ids = model.generate(**inputs, max_new_tokens=100)
24
+ response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
25
 
26
  history.append((user_input, response))
27
+ return history, ""
 
 
 
 
 
28
 
29
  with gr.Blocks() as demo:
30
+ gr.Markdown("# Chat with Janus Pro 7B (Multimodal AI)")
31
+
32
+ chat_history = gr.State([])
33
+ chatbot = gr.Chatbot()
34
+ user_input = gr.Textbox(label="Your message")
35
+ image_input = gr.Image(label="Upload an image (optional)", type="pil", optional=True)
36
+ send_btn = gr.Button("Send")
 
37
 
38
+ send_btn.click(chat_with_model, inputs=[chat_history, user_input, image_input], outputs=[chatbot, user_input])
39
+
40
+ # gr.Examples([
41
+ # ["Describe this image", "example_image.jpg"],
42
+ # ["Generate an image of a futuristic city"],
43
+ # ], inputs=[user_input, image_input])
44
 
45
  demo.launch()