benkada commited on
Commit
6991b14
·
verified ·
1 Parent(s): f7e9534

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +87 -124
main.py CHANGED
@@ -1,126 +1,89 @@
1
- import os
2
- from fastapi import FastAPI, UploadFile, File, HTTPException
3
- from fastapi.middleware.cors import CORSMiddleware
4
  from fastapi.responses import JSONResponse
5
- from pydantic import BaseModel
6
- from typing import Optional
 
 
7
  from PIL import Image
8
- import pytesseract
9
- from transformers import pipeline
10
- from langchain.chains import LLMChain
11
- from langchain.prompts import PromptTemplate
12
- from langchain_community.llms import HuggingFaceHub
13
-
14
- # Ensure HF cache directory is set before any HF import uses it
15
- os.environ.setdefault("HF_HOME", os.getenv("HF_HOME", "/app/cache"))
16
-
17
- # FastAPI application
18
- app = FastAPI(
19
- title="AI-Powered Web Application API",
20
- description="API for document summarization, image captioning, and question answering",
21
- version="1.0.0"
22
- )
23
-
24
- # CORS middleware
25
- app.add_middleware(
26
- CORSMiddleware,
27
- allow_origins=["*"],
28
- allow_credentials=True,
29
- allow_methods=["*"],
30
- allow_headers=["*"],
31
- )
32
-
33
- # ----------------
34
- # Schemas
35
- # ----------------
36
- class SummarizeRequest(BaseModel):
37
- text: str
38
- max_length: Optional[int] = 150
39
- min_length: Optional[int] = 40
40
-
41
- class QARequest(BaseModel):
42
- question: str
43
- context: Optional[str] = None
44
-
45
- # ----------------
46
- # Model loaders (lazy)
47
- # ----------------
48
- _cache_dir = os.getenv("HF_HOME", "/app/cache")
49
- _summarizer = None
50
- _captioner = None
51
- _qa_chain = None
52
-
53
-
54
- def get_summarizer():
55
- global _summarizer
56
- if _summarizer is None:
57
- _summarizer = pipeline(
58
- "summarization",
59
- model="facebook/bart-large-cnn",
60
- cache_dir=_cache_dir
61
- )
62
- return _summarizer
63
-
64
-
65
- def get_image_captioner():
66
- global _captioner
67
- if _captioner is None:
68
- _captioner = pipeline(
69
- "image-to-text",
70
- model="nlpconnect/vit-gpt2-image-captioning",
71
- cache_dir=_cache_dir
72
- )
73
- return _captioner
74
-
75
-
76
- def get_qa_chain():
77
- global _qa_chain
78
- if _qa_chain is None:
79
- llm = HuggingFaceHub(
80
- repo_id="google/flan-t5-large",
81
- model_kwargs={"cache_dir": _cache_dir},
82
- huggingfacehub_api_token=os.getenv("HUGGINGFACEHUB_API_TOKEN", None)
83
- )
84
- prompt = PromptTemplate(
85
- input_variables=["context", "question"],
86
- template="""
87
- Use the following context to answer the question:
88
-
89
- {context}
90
-
91
- Question: {question}
92
- Answer:"""
93
- )
94
- _qa_chain = LLMChain(llm=llm, prompt=prompt)
95
- return _qa_chain
96
-
97
- # ----------------
98
- # Routes
99
- # ----------------
100
- @app.post("/summarize")
101
- def summarize(req: SummarizeRequest):
102
- summarizer = get_summarizer()
103
- result = summarizer(
104
- req.text,
105
- max_length=req.max_length,
106
- min_length=req.min_length,
107
- clean_up_tokenization_spaces=True
108
- )
109
- return JSONResponse(content={"summary": result[0]["summary_text"]})
110
-
111
- @app.post("/caption")
112
- async def caption_image(file: UploadFile = File(...)):
113
- try:
114
- img = Image.open(file.file).convert("RGB")
115
- captioner = get_image_captioner()
116
- result = captioner(img)
117
- return JSONResponse(content={"caption": result[0]["generated_text"]})
118
- except Exception as e:
119
- raise HTTPException(status_code=400, detail=str(e))
120
-
121
- @app.post("/qa")
122
- def question_answer(req: QARequest):
123
- chain = get_qa_chain()
124
- context = req.context or ""
125
- answer = chain.run({"context": context, "question": req.question})
126
- return JSONResponse(content={"answer": answer})
 
1
+ from fastapi import FastAPI, UploadFile, File, Form
 
 
2
  from fastapi.responses import JSONResponse
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
4
+ import uvicorn
5
+ import tempfile
6
+ import os
7
  from PIL import Image
8
+ import torch
9
+
10
+ app = FastAPI()
11
+
12
+ # Load tokenizers fast but not full models immediately
13
+ tokenizers = {
14
+ "qwen": AutoTokenizer.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct", trust_remote_code=True),
15
+ "deepseek": AutoTokenizer.from_pretrained("deepseek-ai/DeepSeek-V2-Chat", trust_remote_code=True),
16
+ "llama": AutoTokenizer.from_pretrained("meta-llama/Llama-2-70b-chat-hf", trust_remote_code=True),
17
+ }
18
+
19
+ models = {}
20
+
21
+ def load_model(name):
22
+ if name not in models:
23
+ if name == "qwen":
24
+ models[name] = AutoModelForCausalLM.from_pretrained(
25
+ "Qwen/Qwen2.5-VL-7B-Instruct",
26
+ device_map="auto",
27
+ trust_remote_code=True,
28
+ torch_dtype=torch.float16
29
+ )
30
+ elif name == "deepseek":
31
+ models[name] = AutoModelForCausalLM.from_pretrained(
32
+ "deepseek-ai/DeepSeek-V2-Chat",
33
+ device_map="auto",
34
+ trust_remote_code=True,
35
+ torch_dtype=torch.float16
36
+ )
37
+ elif name == "llama":
38
+ models[name] = AutoModelForCausalLM.from_pretrained(
39
+ "meta-llama/Llama-2-70b-chat-hf",
40
+ device_map="auto",
41
+ trust_remote_code=True,
42
+ torch_dtype=torch.float16
43
+ )
44
+ return models[name]
45
+
46
+ @app.post("/api/summarize")
47
+ async def summarize(file: UploadFile = File(...)):
48
+ ext = os.path.splitext(file.filename)[1].lower()
49
+ temp_path = os.path.join(tempfile.gettempdir(), file.filename)
50
+ with open(temp_path, "wb") as f:
51
+ f.write(await file.read())
52
+
53
+ # For now: Just simulate basic summarization
54
+ summarizer = pipeline("summarization", model="facebook/bart-large-cnn")
55
+ with open(temp_path, 'r', errors='ignore') as f:
56
+ text = f.read()
57
+
58
+ if len(text) > 1024:
59
+ text = text[:1024]
60
+
61
+ summary = summarizer(text, max_length=150, min_length=40, do_sample=False)[0]['summary_text']
62
+ return JSONResponse({"result": summary})
63
+
64
+ @app.post("/api/caption")
65
+ async def caption(file: UploadFile = File(...)):
66
+ image = Image.open(await file.read())
67
+ # For now: Use a simple vision model, because Qwen2.5 VL loading takes a lot of time
68
+ captioner = pipeline("image-to-text", model="nlpconnect/vit-gpt2-image-captioning")
69
+ caption = captioner(image)[0]['generated_text']
70
+ return JSONResponse({"result": caption})
71
+
72
+ @app.post("/api/qa")
73
+ async def question_answer(file: UploadFile = File(...), question: str = Form(...)):
74
+ temp_path = os.path.join(tempfile.gettempdir(), file.filename)
75
+ with open(temp_path, "wb") as f:
76
+ f.write(await file.read())
77
+
78
+ # For now: pick deepseek model for QA
79
+ tokenizer = tokenizers["deepseek"]
80
+ model = load_model("deepseek")
81
+
82
+ inputs = tokenizer(question, return_tensors="pt").to(model.device)
83
+ outputs = model.generate(**inputs, max_new_tokens=100)
84
+ answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
85
+
86
+ return JSONResponse({"result": answer})
87
+
88
+ if __name__ == "__main__":
89
+ uvicorn.run(app, host="0.0.0.0", port=7860)