File size: 5,883 Bytes
1795a1a
26435ba
9af5fdb
 
26435ba
 
5a7d5c7
26435ba
5a7d5c7
 
9af5fdb
 
26435ba
9af5fdb
 
 
 
 
 
 
 
 
 
1795a1a
9af5fdb
 
 
 
 
 
 
 
 
 
 
 
 
26435ba
1795a1a
6d798ab
26435ba
5a7d5c7
 
26435ba
 
5a7d5c7
26435ba
5a7d5c7
 
26435ba
 
 
5a7d5c7
 
 
 
26435ba
5a7d5c7
 
26435ba
5a7d5c7
 
 
26435ba
 
5a7d5c7
 
26435ba
5a7d5c7
 
 
 
9af5fdb
1795a1a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
"""import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import torch
from fastapi import FastAPI
from fastapi.responses import RedirectResponse

# Initialize FastAPI
app = FastAPI()

# Load models - Using microsoft/git-large-coco
try:
    # Load the better model
    processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
    git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
    print("Successfully loaded microsoft/git-large-coco model")
    USE_GIT = True
except Exception as e:
    print(f"Failed to load GIT model: {e}. Falling back to smaller model")
    captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
    USE_GIT = False

def generate_caption(image_path):
    "Generate caption using the best available model""
    try:
        if USE_GIT:
            image = Image.open(image_path)
            inputs = processor(images=image, return_tensors="pt")
            outputs = git_model.generate(**inputs, max_length=50)
            return processor.batch_decode(outputs, skip_special_tokens=True)[0]
        else:
            result = captioner(image_path)
            return result[0]['generated_text']
    except Exception as e:
        print(f"Caption generation error: {e}")
        return "Could not generate caption"

def process_image(file_path: str):
    "Handle image processing for Gradio interface"
    if not file_path:
        return "Please upload an image first"
    
    try:
        caption = generate_caption(file_path)
        return f"πŸ“· Image Caption:\n{caption}"
    except Exception as e:
        return f"Error processing image: {str(e)}"

# Gradio Interface
with gr.Blocks(title="Image Captioning Service", theme=gr.themes.Soft()) as demo:
    gr.Markdown("# πŸ–ΌοΈ Image Captioning Service")
    gr.Markdown("Upload an image to get automatic captioning")
    
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="Upload Image", type="filepath")
            analyze_btn = gr.Button("Generate Caption", variant="primary")
        
        with gr.Column():
            output = gr.Textbox(label="Caption Result", lines=5)
    
    analyze_btn.click(
        fn=process_image,
        inputs=[image_input],
        outputs=[output]
    )

# Mount Gradio app to FastAPI
app = gr.mount_gradio_app(app, demo, path="/")

@app.get("/")
def redirect_to_interface():
    return RedirectResponse(url="/")
"""
import gradio as gr
from transformers import AutoProcessor, AutoModelForCausalLM, pipeline
from PIL import Image
import torch
from fastapi import FastAPI, UploadFile, Form
from fastapi.responses import RedirectResponse, JSONResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
import os
import tempfile

# βœ… Initialize FastAPI
app = FastAPI()

# βœ… Enable CORS (so frontend JS can call backend)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# βœ… Load caption model
USE_GIT = False
try:
    processor = AutoProcessor.from_pretrained("microsoft/git-large-coco")
    git_model = AutoModelForCausalLM.from_pretrained("microsoft/git-large-coco")
    git_model.eval()
    USE_GIT = True
except Exception as e:
    print(f"[INFO] Falling back to ViT: {e}")
    captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")

# βœ… Image captioning logic
def generate_caption(image_path: str) -> str:
    try:
        if USE_GIT:
            image = Image.open(image_path).convert("RGB")
            inputs = processor(images=image, return_tensors="pt")
            outputs = git_model.generate(**inputs, max_length=50)
            caption = processor.batch_decode(outputs, skip_special_tokens=True)[0]
        else:
            result = captioner(image_path)
            caption = result[0]['generated_text']
        return caption
    except Exception as e:
        return f"Error: {str(e)}"

# βœ… For Gradio demo
def process_image(file_path: str):
    if not file_path:
        return "Please upload an image."
    return f"πŸ“· Image Caption:\n{generate_caption(file_path)}"

# βœ… FastAPI endpoint for frontend POSTs
@app.post("/imagecaption/")
async def caption_from_frontend(file: UploadFile, question: str = Form("")):
    try:
        # Save temp image
        contents = await file.read()
        tmp_path = os.path.join(tempfile.gettempdir(), file.filename)
        with open(tmp_path, "wb") as f:
            f.write(contents)

        caption = generate_caption(tmp_path)

        # Optionally generate audio
        from gtts import gTTS
        audio_path = os.path.join(tempfile.gettempdir(), file.filename + ".mp3")
        tts = gTTS(text=caption)
        tts.save(audio_path)

        return {
            "answer": caption,
            "audio": f"/files/{os.path.basename(audio_path)}"
        }

    except Exception as e:
        return JSONResponse({"error": str(e)}, status_code=500)

# βœ… Serve static files
@app.get("/files/{file_name}")
async def serve_file(file_name: str):
    path = os.path.join(tempfile.gettempdir(), file_name)
    if os.path.exists(path):
        return FileResponse(path)
    return JSONResponse({"error": "File not found"}, status_code=404)

# βœ… Mount Gradio demo for test
with gr.Blocks(title="πŸ–ΌοΈ Image Captioning") as demo:
    gr.Markdown("# πŸ–ΌοΈ Image Captioning Demo")
    image_input = gr.Image(type="filepath", label="Upload Image")
    result_box = gr.Textbox(label="Caption")
    btn = gr.Button("Generate Caption")
    btn.click(fn=process_image, inputs=[image_input], outputs=[result_box])

app = gr.mount_gradio_app(app, demo, path="/")

# βœ… Optional root redirect to frontend
@app.get("/")
def redirect_to_frontend():
    return RedirectResponse(url="/templates/home.html")