mgbam commited on
Commit
ac1cd8a
·
verified ·
1 Parent(s): c5ccf13

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -15
app.py CHANGED
@@ -5,26 +5,32 @@ from diffusers import AutoencoderKL
5
  import numpy as np
6
  import gradio as gr
7
 
8
- # Configure device and attention implementation
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
- attn_implementation = "flash_attention_2" if device == "cuda" else "eager"
11
- print(f"Using device: {device} with {attn_implementation}")
12
 
13
  # Initialize medical imaging components
14
  def load_medical_models():
15
  try:
16
- processor = VLChatProcessor.from_pretrained("deepseek-ai/Janus-1.3B")
 
 
 
 
17
 
 
18
  model = MultiModalityCausalLM.from_pretrained(
19
  "deepseek-ai/Janus-1.3B",
20
- torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32,
21
- attn_implementation=attn_implementation,
22
- use_flash_attention_2=(attn_implementation == "flash_attention_2")
23
  ).to(device).eval()
24
 
 
25
  vae = AutoencoderKL.from_pretrained(
26
  "stabilityai/sdxl-vae",
27
- torch_dtype=torch.bfloat16 if device == "cuda" else torch.float32
28
  ).to(device).eval()
29
 
30
  return processor, model, vae
@@ -34,31 +40,40 @@ def load_medical_models():
34
 
35
  processor, model, vae = load_medical_models()
36
 
37
- # Medical image analysis function with attention control
38
  def medical_analysis(image, question, seed=42):
39
  try:
 
40
  torch.manual_seed(seed)
41
  np.random.seed(seed)
42
 
 
43
  if isinstance(image, np.ndarray):
44
  image = Image.fromarray(image).convert("RGB")
45
 
 
46
  inputs = processor(
47
  text=f"<medical_query>{question}</medical_query>",
48
  images=[image],
49
- return_tensors="pt"
 
 
50
  ).to(device)
51
 
 
52
  outputs = model.generate(
53
  inputs.input_ids,
54
  attention_mask=inputs.attention_mask,
55
  max_new_tokens=512,
56
  temperature=0.1,
57
  top_p=0.95,
58
- pad_token_id=processor.tokenizer.eos_token_id
 
59
  )
60
 
61
- return processor.decode(outputs[0], skip_special_tokens=True)
 
 
62
  except Exception as e:
63
  return f"Radiology analysis error: {str(e)}"
64
 
@@ -70,11 +85,14 @@ with gr.Blocks(title="Medical Imaging Assistant", theme=gr.themes.Soft()) as dem
70
  with gr.Tab("Diagnostic Imaging"):
71
  with gr.Row():
72
  med_image = gr.Image(label="DICOM Image", type="pil")
73
- med_question = gr.Textbox(label="Clinical Query",
74
- placeholder="Describe findings in this CT scan...")
 
 
75
  analysis_btn = gr.Button("Analyze", variant="primary")
76
  report_output = gr.Textbox(label="Radiology Report", interactive=False)
77
 
 
78
  med_question.submit(
79
  medical_analysis,
80
  inputs=[med_image, med_question],
@@ -86,4 +104,10 @@ with gr.Blocks(title="Medical Imaging Assistant", theme=gr.themes.Soft()) as dem
86
  outputs=report_output
87
  )
88
 
89
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
 
 
 
 
 
 
5
  import numpy as np
6
  import gradio as gr
7
 
8
+ # Configure device and disable FlashAttention
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ torch_dtype = torch.bfloat16 if device == "cuda" else torch.float32
11
+ print(f"Using device: {device}")
12
 
13
  # Initialize medical imaging components
14
  def load_medical_models():
15
  try:
16
+ # Load processor with medical-specific configuration
17
+ processor = VLChatProcessor.from_pretrained(
18
+ "deepseek-ai/Janus-1.3B",
19
+ medical_mode=True
20
+ )
21
 
22
+ # Load model with CPU/GPU optimization
23
  model = MultiModalityCausalLM.from_pretrained(
24
  "deepseek-ai/Janus-1.3B",
25
+ torch_dtype=torch_dtype,
26
+ attn_implementation="eager", # Force standard attention
27
+ low_cpu_mem_usage=True
28
  ).to(device).eval()
29
 
30
+ # Load VAE with reduced precision
31
  vae = AutoencoderKL.from_pretrained(
32
  "stabilityai/sdxl-vae",
33
+ torch_dtype=torch_dtype
34
  ).to(device).eval()
35
 
36
  return processor, model, vae
 
40
 
41
  processor, model, vae = load_medical_models()
42
 
43
+ # Medical image analysis function
44
  def medical_analysis(image, question, seed=42):
45
  try:
46
+ # Set random seed for reproducibility
47
  torch.manual_seed(seed)
48
  np.random.seed(seed)
49
 
50
+ # Convert and validate input image
51
  if isinstance(image, np.ndarray):
52
  image = Image.fromarray(image).convert("RGB")
53
 
54
+ # Prepare medical-specific input
55
  inputs = processor(
56
  text=f"<medical_query>{question}</medical_query>",
57
  images=[image],
58
+ return_tensors="pt",
59
+ max_length=512,
60
+ truncation=True
61
  ).to(device)
62
 
63
+ # Generate medical analysis
64
  outputs = model.generate(
65
  inputs.input_ids,
66
  attention_mask=inputs.attention_mask,
67
  max_new_tokens=512,
68
  temperature=0.1,
69
  top_p=0.95,
70
+ pad_token_id=processor.tokenizer.eos_token_id,
71
+ do_sample=True
72
  )
73
 
74
+ # Clean and return medical report
75
+ report = processor.decode(outputs[0], skip_special_tokens=True)
76
+ return report.replace("##MEDICAL_REPORT##", "").strip()
77
  except Exception as e:
78
  return f"Radiology analysis error: {str(e)}"
79
 
 
85
  with gr.Tab("Diagnostic Imaging"):
86
  with gr.Row():
87
  med_image = gr.Image(label="DICOM Image", type="pil")
88
+ med_question = gr.Textbox(
89
+ label="Clinical Query",
90
+ placeholder="Describe findings in this CT scan..."
91
+ )
92
  analysis_btn = gr.Button("Analyze", variant="primary")
93
  report_output = gr.Textbox(label="Radiology Report", interactive=False)
94
 
95
+ # Connect components
96
  med_question.submit(
97
  medical_analysis,
98
  inputs=[med_image, med_question],
 
104
  outputs=report_output
105
  )
106
 
107
+ # Launch with CPU optimization
108
+ demo.launch(
109
+ server_name="0.0.0.0",
110
+ server_port=7860,
111
+ enable_queue=True,
112
+ max_threads=2
113
+ )