benkada commited on
Commit
e1933c4
·
verified ·
1 Parent(s): aca5e64

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +102 -71
main.py CHANGED
@@ -1,89 +1,120 @@
1
- from fastapi import FastAPI, UploadFile, File, Form
2
- from fastapi.responses import JSONResponse
3
- from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
- import uvicorn
5
- import tempfile
6
  import os
 
 
 
 
 
 
 
7
  from PIL import Image
8
- import torch
 
 
 
9
 
10
  app = FastAPI()
11
 
12
- # Load tokenizers fast but not full models immediately
13
- tokenizers = {
14
- "qwen": AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True),
15
- "deepseek": AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V2-Chat", trust_remote_code=True),
16
- "llama": AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-chat-hf", trust_remote_code=True),
17
- }
18
-
19
- models = {}
20
-
21
- def load_model(name):
22
- if name not in models:
23
- if name == "qwen":
24
- models[name] = AutoModelForCausalLM.from_pretrained(
25
- "Qwen/Qwen2.5-VL-7B-Instruct",
26
- device_map="auto",
27
- trust_remote_code=True,
28
- torch_dtype=torch.float16
29
- )
30
- elif name == "deepseek":
31
- models[name] = AutoModelForCausalLM.from_pretrained(
32
- "deepseek-ai/DeepSeek-V2-Chat",
33
- device_map="auto",
34
- trust_remote_code=True,
35
- torch_dtype=torch.float16
36
- )
37
- elif name == "llama":
38
- models[name] = AutoModelForCausalLM.from_pretrained(
39
- "meta-llama/Llama-2-70b-chat-hf",
40
- device_map="auto",
41
- trust_remote_code=True,
42
- torch_dtype=torch.float16
43
- )
44
- return models[name]
45
 
46
- @app.post("/api/summarize")
47
- async def summarize(file: UploadFile = File(...)):
48
- ext = os.path.splitext(file.filename)[1].lower()
49
- temp_path = os.path.join(tempfile.gettempdir(), file.filename)
50
- with open(temp_path, "wb") as f:
51
- f.write(await file.read())
 
 
52
 
53
- # For now: Just simulate basic summarization
54
- summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
55
- with open(temp_path, 'r', errors='ignore') as f:
56
- text = f.read()
57
 
58
- if len(text) > 1024:
59
- text = text[:1024]
 
 
 
 
 
 
 
 
 
60
 
61
- summary = summarizer(text, max_length=150, min_length=40, do_sample=False)[0]['summary_text']
62
- return JSONResponse({"result": summary})
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  @app.post("/api/caption")
65
- async def caption(file: UploadFile = File(...)):
66
- image = Image.open(await file.read())
67
- # For now: Use a simple vision model, because Qwen2.5 VL loading takes a lot of time
68
- captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
69
- caption = captioner(image)[0]['generated_text']
70
- return JSONResponse({"result": caption})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
  @app.post("/api/qa")
73
- async def question_answer(file: UploadFile = File(...), question: str = Form(...)):
74
- temp_path = os.path.join(tempfile.gettempdir(), file.filename)
75
- with open(temp_path, "wb") as f:
76
- f.write(await file.read())
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- # For now: pick deepseek model for QA
79
- tokenizer = tokenizers["deepseek"]
80
- model = load_model("deepseek")
81
 
82
- inputs = tokenizer(question, return_tensors="pt").to(model.device)
83
- outputs = model.generate(**inputs, max_new_tokens=100)
84
- answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
85
 
86
- return JSONResponse({"result": answer})
 
87
 
88
  if __name__ == "__main__":
89
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
 
 
1
  import os
2
+ import io
3
+ from fastapi import FastAPI, UploadFile, File, Form
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from fastapi.responses import JSONResponse, HTMLResponse
6
+ from huggingface_hub import InferenceClient
7
+ from PyPDF2 import PdfReader
8
+ from docx import Document
9
  from PIL import Image
10
+ from io import BytesIO
11
+
12
+ # Load Hugging Face Token securely
13
+ HUGGINGFACE_TOKEN = os.getenv("HF_TOKEN")
14
 
15
  app = FastAPI()
16
 
17
+ app.add_middleware(
18
+ CORSMiddleware,
19
+ allow_origins=["*"],
20
+ allow_credentials=True,
21
+ allow_methods=["*"],
22
+ allow_headers=["*"],
23
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # Initialize Hugging Face clients
26
+ summary_client = InferenceClient(model="facebook/bart-large-cnn", token=HUGGINGFACE_TOKEN)
27
+ qa_client = InferenceClient(model="deepset/roberta-base-squad2", token=HUGGINGFACE_TOKEN)
28
+ image_caption_client = InferenceClient(model="nlpconnect/vit-gpt2-image-captioning", token=HUGGINGFACE_TOKEN)
29
+
30
+ def extract_text_from_pdf(content: bytes) -> str:
31
+ reader = PdfReader(io.BytesIO(content))
32
+ return "\n".join(page.extract_text() or "" for page in reader.pages).strip()
33
 
34
+ def extract_text_from_docx(content: bytes) -> str:
35
+ doc = Document(io.BytesIO(content))
36
+ return "\n".join(para.text for para in doc.paragraphs).strip()
 
37
 
38
+ def process_uploaded_file(file: UploadFile) -> str:
39
+ content = file.file.read()
40
+ extension = file.filename.split('.')[-1].lower()
41
+ if extension == "pdf":
42
+ return extract_text_from_pdf(content)
43
+ elif extension == "docx":
44
+ return extract_text_from_docx(content)
45
+ elif extension == "txt":
46
+ return content.decode("utf-8").strip()
47
+ else:
48
+ raise ValueError("Unsupported file type.")
49
 
50
+ @app.get("/", response_class=HTMLResponse)
51
+ async def serve_homepage():
52
+ with open("index.html", "r", encoding="utf-8") as f:
53
+ return HTMLResponse(content=f.read(), status_code=200)
54
+
55
+ @app.post("/api/summarize")
56
+ async def summarize_document(file: UploadFile = File(...)):
57
+ try:
58
+ text = process_uploaded_file(file)
59
+ if len(text) < 20:
60
+ return {"result": "Document too short to summarize."}
61
+ summary = summary_client.summarization(text[:3000])
62
+ return {"result": summary}
63
+ except Exception as e:
64
+ return JSONResponse(status_code=500, content={"error": str(e)})
65
 
66
  @app.post("/api/caption")
67
+ async def caption_image(file: UploadFile = File(...)):
68
+ try:
69
+ image_bytes = await file.read()
70
+ image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
71
+ image_pil.thumbnail((1024, 1024))
72
+ img_byte_arr = BytesIO()
73
+ image_pil.save(img_byte_arr, format='JPEG')
74
+ img_byte_arr = img_byte_arr.getvalue()
75
+ result = image_caption_client.image_to_text(img_byte_arr)
76
+
77
+ if isinstance(result, dict):
78
+ caption = result.get("generated_text") or result.get("caption") or "No caption found."
79
+ elif isinstance(result, list) and result:
80
+ caption = result[0].get("generated_text", "No caption found.")
81
+ elif isinstance(result, str):
82
+ caption = result
83
+ else:
84
+ caption = "No caption found."
85
+
86
+ return {"result": caption}
87
+ except Exception as e:
88
+ return JSONResponse(status_code=500, content={"error": str(e)})
89
 
90
  @app.post("/api/qa")
91
+ async def question_answering(file: UploadFile = File(...), question: str = Form(...)):
92
+ try:
93
+ content_type = file.content_type
94
+ if content_type.startswith("image/"):
95
+ image_bytes = await file.read()
96
+ image_pil = Image.open(io.BytesIO(image_bytes)).convert("RGB")
97
+ image_pil.thumbnail((1024, 1024))
98
+ img_byte_arr = BytesIO()
99
+ image_pil.save(img_byte_arr, format='JPEG')
100
+ img_byte_arr = img_byte_arr.getvalue()
101
+ result = image_caption_client.image_to_text(img_byte_arr)
102
+ context = result.get("generated_text") if isinstance(result, dict) else result
103
+ else:
104
+ text = process_uploaded_file(file)
105
+ if len(text) < 20:
106
+ return {"result": "Document too short to answer questions."}
107
+ context = text[:3000]
108
 
109
+ if not context:
110
+ return {"result": "No context available to answer."}
 
111
 
112
+ answer = qa_client.question_answering(question=question, context=context)
113
+ return {"result": answer.get("answer", "No answer found.")}
 
114
 
115
+ except Exception as e:
116
+ return JSONResponse(status_code=500, content={"error": str(e)})
117
 
118
  if __name__ == "__main__":
119
+ import uvicorn
120
+ uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)