yuiseki commited on
Commit
85a9901
ยท
1 Parent(s): 02793e0
Files changed (2) hide show
  1. app.py +17 -16
  2. requirements.txt +1 -0
app.py CHANGED
@@ -2,32 +2,33 @@ import gradio as gr
2
  import openai
3
  import os
4
  import json
 
 
 
 
 
 
 
 
5
 
6
  openai.organization = os.getenv("API_ORG")
7
  openai.api_key = os.getenv("API_KEY")
8
  app_password = os.getenv("APP_PASSWORD")
9
  app_username = os.getenv("APP_USERNAME")
10
 
11
-
12
- def generate(prompt):
13
- response = openai.Image.create(
14
- prompt=prompt,
15
- n=1,
16
- size="256x256"
17
- )
18
- return response['data'][0]['url']
19
-
20
- examples = [
21
- ["ใใฎใ“ใฎๅฑฑ"],
22
- ["ใŸใ‘ใฎใ“ใฎ้‡Œ"],
23
- ]
24
 
25
  demo = gr.Interface(
26
  fn=generate,
27
- inputs=gr.components.Textbox(lines=5, label="Prompt"),
28
- outputs=gr.components.Image(type="filepath", label="Generated Image"),
29
  flagging_options=[],
30
- examples=examples
31
  )
32
 
33
  demo.launch(share=False, auth=(app_username, app_password))
 
2
  import openai
3
  import os
4
  import json
5
+ import numpy as np
6
+ import torch
7
+
8
+ from transformers import AutoProcessor, AutoModelForCausalLM
9
+
10
+ checkpoint = "microsoft/git-base"
11
+ processor = AutoProcessor.from_pretrained(checkpoint)
12
+ model = AutoModelForCausalLM.from_pretrained(checkpoint)
13
 
14
  openai.organization = os.getenv("API_ORG")
15
  openai.api_key = os.getenv("API_KEY")
16
  app_password = os.getenv("APP_PASSWORD")
17
  app_username = os.getenv("APP_USERNAME")
18
 
19
+ def generate(input_image):
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+ inputs = processor(images=input_image, return_tensors="pt").to(device)
22
+ pixel_values = inputs.pixel_values
23
+ generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
24
+ generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
25
+ return generated_caption
 
 
 
 
 
 
26
 
27
  demo = gr.Interface(
28
  fn=generate,
29
+ inputs=gr.Image(label="Input", elem_id="input_image", type="pil"),
30
+ outputs=gr.Text(label="Generated Caption"),
31
  flagging_options=[],
 
32
  )
33
 
34
  demo.launch(share=False, auth=(app_username, app_password))
requirements.txt CHANGED
@@ -1,2 +1,3 @@
1
  gradio
2
  openai
 
 
1
  gradio
2
  openai
3
+ transformers