seawolf2357 commited on
Commit
d87e79b
·
verified ·
1 Parent(s): 655b8e8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -10
app.py CHANGED
@@ -8,6 +8,7 @@ import re
8
  import requests
9
  from PIL import Image
10
  import io
 
11
 
12
  # 로깅 설정
13
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()])
@@ -36,21 +37,22 @@ def modify_caption(caption: str) -> str:
36
 
37
  return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
38
 
39
- def create_captions_rich(image: Image.Image) -> str:
40
  prompt = "caption en"
41
- # 이미지 데이터를 전처리하여 processor에 전달
42
  image_tensor = processor(images=image, return_tensors="pt").pixel_values.to("cpu")
43
- # 이미지 범위 조정 [0, 1]에서 [0, 255]로
44
  image_tensor = (image_tensor * 255).type(torch.uint8)
45
  model_inputs = processor(text=prompt, images=image_tensor, return_tensors="pt").to("cpu")
46
  input_len = model_inputs["input_ids"].shape[-1]
47
 
48
- with torch.no_grad():
49
- generation = model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
50
- generation = generation[0][input_len:]
51
- decoded = processor.decode(generation, skip_special_tokens=True)
 
 
 
52
 
53
- modified_caption = modify_caption(decoded)
54
  return modified_caption
55
 
56
  # 특정 채널 ID 설정
@@ -64,9 +66,12 @@ class MyClient(discord.Client):
64
 
65
  async def on_ready(self):
66
  logging.info(f'{self.user}로 로그인되었습니다!')
67
- subprocess.Popen(["python", "web.py"])
68
  logging.info("Web.py 서버가 시작되었습니다.")
69
 
 
 
 
70
  async def on_message(self, message):
71
  if message.author == self.user:
72
  return
@@ -90,7 +95,7 @@ class MyClient(discord.Client):
90
 
91
  async def process_image(image_url, message):
92
  image = await download_image(image_url)
93
- caption = create_captions_rich(image)
94
  return f"{message.author.mention}, 인식된 이미지 설명: {caption}"
95
 
96
  async def download_image(url):
 
8
  import requests
9
  from PIL import Image
10
  import io
11
+ import asyncio
12
 
13
  # 로깅 설정
14
  logging.basicConfig(level=logging.DEBUG, format='%(asctime)s:%(levelname)s:%(name)s: %(message)s', handlers=[logging.StreamHandler()])
 
37
 
38
  return re.sub(pattern, replace_fn, caption, count=1, flags=re.IGNORECASE)
39
 
40
+ async def create_captions_rich(image: Image.Image) -> str:
41
  prompt = "caption en"
 
42
  image_tensor = processor(images=image, return_tensors="pt").pixel_values.to("cpu")
 
43
  image_tensor = (image_tensor * 255).type(torch.uint8)
44
  model_inputs = processor(text=prompt, images=image_tensor, return_tensors="pt").to("cpu")
45
  input_len = model_inputs["input_ids"].shape[-1]
46
 
47
+ loop = asyncio.get_event_loop()
48
+ generation = await loop.run_in_executor(
49
+ None,
50
+ lambda: model.generate(**model_inputs, max_new_tokens=256, do_sample=False)
51
+ )
52
+ generation = generation[0][input_len:]
53
+ decoded = processor.decode(generation, skip_special_tokens=True)
54
 
55
+ modified_caption = modify_caption(decoded)
56
  return modified_caption
57
 
58
  # 특정 채널 ID 설정
 
66
 
67
  async def on_ready(self):
68
  logging.info(f'{self.user}로 로그인되었습니다!')
69
+ asyncio.create_task(self.start_gradio_server())
70
  logging.info("Web.py 서버가 시작되었습니다.")
71
 
72
+ async def start_gradio_server(self):
73
+ subprocess.run(["python", "web.py"], check=True)
74
+
75
  async def on_message(self, message):
76
  if message.author == self.user:
77
  return
 
95
 
96
  async def process_image(image_url, message):
97
  image = await download_image(image_url)
98
+ caption = await create_captions_rich(image)
99
  return f"{message.author.mention}, 인식된 이미지 설명: {caption}"
100
 
101
  async def download_image(url):