OjciecTadeusz commited on
Commit
404e508
·
verified ·
1 Parent(s): ab2de94

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -21
app.py CHANGED
@@ -1,6 +1,6 @@
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
- from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification
4
  import torch
5
 
6
  app = FastAPI()
@@ -9,18 +9,31 @@ app = FastAPI()
9
  MODEL_NAME = "nlptown/bert-base-multilingual-uncased-sentiment"
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
- # Initialize model and tokenizer
13
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME)
15
- classifier = pipeline("sentiment-analysis", model=model, tokenizer=tokenizer, device=DEVICE)
 
 
 
 
 
 
 
 
 
16
 
17
  class TextInput(BaseModel):
18
  text: str
19
 
 
 
 
 
20
  @app.post("/analyze-sentiment")
21
  async def analyze_sentiment(input_data: TextInput):
22
  try:
23
- result = classifier(input_data.text)
24
  return {
25
  "sentiment": result[0]['label'],
26
  "score": float(result[0]['score'])
@@ -28,31 +41,41 @@ async def analyze_sentiment(input_data: TextInput):
28
  except Exception as e:
29
  raise HTTPException(status_code=500, detail=str(e))
30
 
31
- # Przykład dla większego modelu (np. GPT-2)
32
- MODEL_NAME_LARGE = "gpt2-large"
33
- tokenizer_large = AutoTokenizer.from_pretrained(MODEL_NAME_LARGE)
34
- model_large = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME_LARGE)
35
-
36
- class GenerationInput(BaseModel):
37
- prompt: str
38
- max_length: int = 100
39
-
40
  @app.post("/generate-text")
41
  async def generate_text(input_data: GenerationInput):
42
  try:
43
- inputs = tokenizer_large(input_data.prompt, return_tensors="pt")
44
- outputs = model_large.generate(
 
 
 
 
45
  inputs["input_ids"],
46
  max_length=input_data.max_length,
47
  num_return_sequences=1,
48
- no_repeat_ngram_size=2
 
 
 
 
 
 
49
  )
50
- generated_text = tokenizer_large.decode(outputs[0], skip_special_tokens=True)
51
  return {"generated_text": generated_text}
52
  except Exception as e:
53
  raise HTTPException(status_code=500, detail=str(e))
54
 
55
- # Dodanie podstawowego health checka
56
  @app.get("/health")
57
  async def health_check():
58
- return {"status": "healthy"}
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
4
  import torch
5
 
6
  app = FastAPI()
 
9
  MODEL_NAME = "nlptown/bert-base-multilingual-uncased-sentiment"
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
 
12
+ # Initialize sentiment analysis model
13
+ sentiment_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
14
+ sentiment_classifier = pipeline(
15
+ "sentiment-analysis",
16
+ model=MODEL_NAME,
17
+ tokenizer=sentiment_tokenizer,
18
+ device=DEVICE
19
+ )
20
+
21
+ # Initialize GPT-2 for text generation
22
+ MODEL_NAME_LARGE = "gpt2-large"
23
+ generation_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_LARGE)
24
+ generation_model = AutoModelForCausalLM.from_pretrained(MODEL_NAME_LARGE).to(DEVICE)
25
 
26
  class TextInput(BaseModel):
27
  text: str
28
 
29
+ class GenerationInput(BaseModel):
30
+ prompt: str
31
+ max_length: int = 100
32
+
33
  @app.post("/analyze-sentiment")
34
  async def analyze_sentiment(input_data: TextInput):
35
  try:
36
+ result = sentiment_classifier(input_data.text)
37
  return {
38
  "sentiment": result[0]['label'],
39
  "score": float(result[0]['score'])
 
41
  except Exception as e:
42
  raise HTTPException(status_code=500, detail=str(e))
43
 
 
 
 
 
 
 
 
 
 
44
  @app.post("/generate-text")
45
  async def generate_text(input_data: GenerationInput):
46
  try:
47
+ inputs = generation_tokenizer(
48
+ input_data.prompt,
49
+ return_tensors="pt"
50
+ ).to(DEVICE)
51
+
52
+ outputs = generation_model.generate(
53
  inputs["input_ids"],
54
  max_length=input_data.max_length,
55
  num_return_sequences=1,
56
+ no_repeat_ngram_size=2,
57
+ pad_token_id=generation_tokenizer.eos_token_id
58
+ )
59
+
60
+ generated_text = generation_tokenizer.decode(
61
+ outputs[0],
62
+ skip_special_tokens=True
63
  )
64
+
65
  return {"generated_text": generated_text}
66
  except Exception as e:
67
  raise HTTPException(status_code=500, detail=str(e))
68
 
 
69
  @app.get("/health")
70
  async def health_check():
71
+ return {
72
+ "status": "healthy",
73
+ "sentiment_model": MODEL_NAME,
74
+ "generation_model": MODEL_NAME_LARGE,
75
+ "device": str(DEVICE)
76
+ }
77
+
78
+ # Dodaj to na końcu pliku
79
+ if __name__ == "__main__":
80
+ import uvicorn
81
+ uvicorn.run(app, host="0.0.0.0", port=8000)