NikhilJoson commited on
Commit
61e1955
·
verified ·
1 Parent(s): e200100

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +84 -34
app.py CHANGED
@@ -1,45 +1,95 @@
1
- import torch
2
- import argparse
3
  import gradio as gr
4
- from janus import JanusProcessor, JanusForConditionalGeneration
5
- from transformers import AutoTokenizer
 
 
 
 
 
 
 
 
 
6
 
7
- # Load Model and Processor
8
- model_id = "allenai/janus-pro-7b"
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
 
 
 
 
10
 
11
- tokenizer = AutoTokenizer.from_pretrained(model_id)
12
- processor = JanusProcessor.from_pretrained(model_id)
13
- model = JanusForConditionalGeneration.from_pretrained(
14
- model_id, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32
15
- ).to(device)
16
 
17
- def chat_with_model(history, user_input, image=None):
18
- if image is not None:
19
- inputs = processor(text=user_input, images=image, return_tensors="pt").to(device)
20
- else:
21
- inputs = processor(text=user_input, return_tensors="pt").to(device)
 
 
 
 
 
 
 
 
 
22
 
23
- generated_ids = model.generate(**inputs, max_new_tokens=100)
24
- response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
 
 
 
25
 
26
- history.append((user_input, response))
27
- return history, ""
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- with gr.Blocks() as demo:
30
- gr.Markdown("# Chat with Janus Pro 7B (Multimodal AI)")
 
 
 
 
31
 
32
- chat_history = gr.State([])
33
- chatbot = gr.Chatbot()
34
- user_input = gr.Textbox(label="Your message")
35
- image_input = gr.Image(label="Upload an image (optional)", type="pil", optional=True)
36
- send_btn = gr.Button("Send")
37
 
38
- send_btn.click(chat_with_model, inputs=[chat_history, user_input, image_input], outputs=[chatbot, user_input])
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # gr.Examples([
41
- # ["Describe this image", "example_image.jpg"],
42
- # ["Generate an image of a futuristic city"],
43
- # ], inputs=[user_input, image_input])
44
 
45
- demo.launch()
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoConfig, AutoModelForCausalLM
4
+ from janus.models import MultiModalityCausalLM, VLChatProcessor
5
+ from janus.utils.io import load_pil_images
6
+ from PIL import Image
7
+
8
+ import numpy as np
9
+ import os
10
+ import time
11
+ from Upsample import RealESRGAN
12
+ import spaces # Import spaces for ZeroGPU compatibility
13
 
14
+ # Load model and processor
15
+ model_path = "deepseek-ai/Janus-Pro-7B"
16
+ config = AutoConfig.from_pretrained(model_path)
17
+ language_config = config.language_config
18
+ language_config._attn_implementation = 'eager'
19
+ vl_gpt = AutoModelForCausalLM.from_pretrained(model_path, language_config=language_config, trust_remote_code=True)
20
+ if torch.cuda.is_available():
21
+ vl_gpt = vl_gpt.to(torch.bfloat16).cuda()
22
+ else:
23
+ vl_gpt = vl_gpt.to(torch.float16)
24
 
25
+ vl_chat_processor = VLChatProcessor.from_pretrained(model_path)
26
+ tokenizer = vl_chat_processor.tokenizer
27
+ cuda_device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
28
 
29
+ # SR model
30
+ sr_model = RealESRGAN(torch.device('cuda' if torch.cuda.is_available() else 'cpu'), scale=2)
31
+ sr_model.load_weights(f'weights/RealESRGAN_x2.pth', download=False)
32
+
33
+ @torch.inference_mode()
34
+ @spaces.GPU(duration=120)
35
+ def multimodal_understanding(image, question, seed, top_p, temperature, progress=gr.Progress(track_tqdm=True)):
36
+ # Clear CUDA cache before generating
37
+ torch.cuda.empty_cache()
38
+
39
+ # set seed
40
+ torch.manual_seed(seed)
41
+ np.random.seed(seed)
42
+ torch.cuda.manual_seed(seed)
43
 
44
+ conversation = [
45
+ {
46
+ "role": "<|User |>",
47
+ "content": f"<image_placeholder>\n{question}",
48
+ "images": [image],
49
+ },
50
+ {"role": "<|Assistant|>", "content": ""},
51
+ ]
52
 
53
+ pil_images = [Image.fromarray(image)]
54
+ prepare_inputs = vl_chat_processor(conversations=conversation, images=pil_images, force_batchify=True
55
+ ).to(cuda_device, dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float16)
56
+
57
+ inputs_embeds = vl_gpt.prepare_inputs_embeds(**prepare_inputs)
58
+
59
+ outputs = vl_gpt.language_model.generate(inputs_embeds=inputs_embeds, attention_mask=prepare_inputs.attention_mask,
60
+ pad_token_id=tokenizer.eos_token_id, bos_token_id=tokenizer.bos_token_id,
61
+ eos_token_id=tokenizer.eos_token_id, max_new_tokens=512, temperature=temperature, top_p=top_p,
62
+ do_sample=False if temperature == 0 else True, use_cache=True,)
63
+
64
+ answer = tokenizer.decode(outputs[0].cpu().tolist(), skip_special_tokens=True)
65
+ return answer
66
 
67
+ # Gradio interface
68
+ css = '''
69
+ .gradio-container {max-width: 960px !important}
70
+ '''
71
+ with gr.Blocks(css=css) as demo:
72
+ gr.Markdown("# Janus Pro 7B Chat Interface")
73
 
74
+ chat_history = gr.Chatbot(label="Chat History")
75
+ message_input = gr.Textbox(label="Type your message here...")
76
+ image_input = gr.Image(label="Upload an image (optional)", type="numpy", tool="editor")
 
 
77
 
78
+ def respond(message, image):
79
+ # Here you can add logic to handle the image if provided
80
+ if image is not None:
81
+ # Call multimodal understanding with the image and message
82
+ response = multimodal_understanding(image, message, seed=42, top_p=0.95, temperature=0.1)
83
+ else:
84
+ # If no image is provided, just respond with a text-based answer
85
+ response = "Please provide an image for multimodal understanding."
86
+
87
+ return response
88
+
89
+ def submit_message(message, image):
90
+ response = respond(message, image)
91
+ return message, response
92
 
93
+ message_input.submit(submit_message, inputs=[message_input, image_input], outputs=[message_input, chat_history])
 
 
 
94
 
95
+ demo.launch(share=True)