cwhuh commited on
Commit
889c270
ยท
1 Parent(s): 52ee639
Files changed (2) hide show
  1. llm_wrapper.py +99 -99
  2. requirements.txt +1 -2
llm_wrapper.py CHANGED
@@ -1,101 +1,101 @@
1
- import logging
2
- from PIL import Image
3
- from io import BytesIO
4
- import requests, os, json, time
5
-
6
- from google import genai
7
-
8
- prompt_base_path = "src/llm_wrapper/prompt"
9
- client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
10
-
11
-
12
- def encode_image(image_source):
13
- """
14
- ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ๊ฐ€ URL์ด๋“  ๋กœ์ปฌ ํŒŒ์ผ์ด๋“  Pillow Image ๊ฐ์ฒด์ด๋“  ๋™์ผํ•˜๊ฒŒ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜.
15
- ์ด๋ฏธ์ง€๋ฅผ ์—ด์–ด google.genai.types.Part ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
16
- Pillow์—์„œ ์ง€์›๋˜์ง€ ์•Š๋Š” ํฌ๋งท์— ๋Œ€ํ•ด์„œ๋Š” ์˜ˆ์™ธ๋ฅผ ๋ฐœ์ƒ์‹œํ‚ต๋‹ˆ๋‹ค.
17
- """
18
- try:
19
- # ์ด๋ฏธ Pillow ์ด๋ฏธ์ง€ ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
20
- if isinstance(image_source, Image.Image):
21
- image = image_source
22
- else:
23
- # URL์—์„œ ์ด๋ฏธ์ง€ ๋‹ค์šด๋กœ๋“œ
24
- if isinstance(image_source, str) and (
25
- image_source.startswith("http://")
26
- or image_source.startswith("https://")
27
- ):
28
- response = requests.get(image_source)
29
- image = Image.open(BytesIO(response.content))
30
- # ๋กœ์ปฌ ํŒŒ์ผ์—์„œ ์ด๋ฏธ์ง€ ์—ด๊ธฐ
31
- else:
32
- image = Image.open(image_source)
33
-
34
- # ์ด๋ฏธ์ง€ ํฌ๋งท์ด None์ธ ๊ฒฝ์šฐ (๋ฉ”๋ชจ๋ฆฌ์—์„œ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€ ๋“ฑ)
35
- if image.format is None:
36
- image_format = "JPEG"
37
- else:
38
- image_format = image.format
39
-
40
- # ์ด๋ฏธ์ง€ ํฌ๋งท์ด ์ง€์›๋˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ ์˜ˆ์™ธ ๋ฐœ์ƒ
41
- if image_format not in Image.registered_extensions().values():
42
- raise ValueError(f"Unsupported image format: {image_format}.")
43
-
44
- buffered = BytesIO()
45
- # PIL์—์„œ ์ง€์›๋˜์ง€ ์•Š๋Š” ํฌ๋งท์ด๋‚˜ ๋‹ค์–‘ํ•œ ์ฑ„๋„์„ RGB๋กœ ๋ณ€ํ™˜ ํ›„ ์ €์žฅ
46
- if image.mode in ("RGBA", "P", "CMYK"): # RGBA, ํŒ”๋ ˆํŠธ, CMYK ๋“ฑ์€ RGB๋กœ ๋ณ€ํ™˜
47
- image = image.convert("RGB")
48
- image.save(buffered, format="JPEG")
49
 
50
- return genai.types.Part.from_bytes(data=buffered.getvalue(), mime_type="image/jpeg")
51
-
52
- except requests.exceptions.RequestException as e:
53
- raise ValueError(f"Failed to download the image from URL: {e}")
54
- except IOError as e:
55
- raise ValueError(f"Failed to process the image file: {e}")
56
- except ValueError as e:
57
- raise ValueError(e)
58
-
59
-
60
- def run_gemini(
61
- target_prompt: str,
62
- prompt_in_path: str,
63
- img_in_data: str = None,
64
- model: str = "gemini-2.0-flash",
65
- ) -> str:
66
- """
67
- GEMINI API๋ฅผ ๋™๊ธฐ ๋ฐฉ์‹์œผ๋กœ ํ˜ธ์ถœํ•˜์—ฌ ๋ฌธ์ž์—ด ์‘๋‹ต์„ ๋ฐ›์Šต๋‹ˆ๋‹ค.
68
- retry ๋…ผ๋ฆฌ๋Š” ์ œ๊ฑฐ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
69
- """
70
- with open(os.path.join(prompt_base_path, prompt_in_path), "r", encoding="utf-8") as file:
71
- prompt_dict = json.load(file)
72
-
73
- system_prompt = prompt_dict["system_prompt"]
74
- user_prompt_head = prompt_dict["user_prompt"]["head"]
75
- user_prompt_tail = prompt_dict["user_prompt"]["tail"]
76
-
77
- user_prompt_text = "\n".join([user_prompt_head, target_prompt, user_prompt_tail])
78
- input_content = [user_prompt_text]
79
-
80
- if img_in_data is not None:
81
- encoded_image = encode_image(img_in_data)
82
- input_content.append(encoded_image)
83
 
84
- logging.info("Requested API for chat completion response (sync call)...")
85
- start_time = time.time()
86
-
87
- # ๋™๊ธฐ ๋ฐฉ์‹: client.models.generate_content(...)
88
- chat_completion = client.models.generate_content(
89
- model=model,
90
- contents=input_content,
91
- )
92
-
93
- chat_output = chat_completion.parsed
94
- input_token = chat_completion.usage_metadata.prompt_token_count
95
- output_token = chat_completion.usage_metadata.candidates_token_count
96
- pricing = input_token / 1000000 * 0.1 * 1500 + output_token / 1000000 * 0.7 * 1500
97
-
98
- logging.info(
99
- f"[GEMINI] Request completed (sync). Time taken: {time.time()-start_time:.2f}s / Pricing(KRW): {pricing:.2f}"
100
- )
101
- return chat_output, chat_completion
 
1
+ # import logging
2
+ # from PIL import Image
3
+ # from io import BytesIO
4
+ # import requests, os, json, time
5
+
6
+ # from google import genai
7
+
8
+ # prompt_base_path = "src/llm_wrapper/prompt"
9
+ # client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
10
+
11
+
12
+ # def encode_image(image_source):
13
+ # """
14
+ # ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ๊ฐ€ URL์ด๋“  ๋กœ์ปฌ ํŒŒ์ผ์ด๋“  Pillow Image ๊ฐ์ฒด์ด๋“  ๋™์ผํ•˜๊ฒŒ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜.
15
+ # ์ด๋ฏธ์ง€๋ฅผ ์—ด์–ด google.genai.types.Part ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
16
+ # Pillow์—์„œ ์ง€์›๋˜์ง€ ์•Š๋Š” ํฌ๋งท์— ๋Œ€ํ•ด์„œ๋Š” ์˜ˆ์™ธ๋ฅผ ๋ฐœ์ƒ์‹œํ‚ต๋‹ˆ๋‹ค.
17
+ # """
18
+ # try:
19
+ # # ์ด๋ฏธ Pillow ์ด๋ฏธ์ง€ ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
20
+ # if isinstance(image_source, Image.Image):
21
+ # image = image_source
22
+ # else:
23
+ # # URL์—์„œ ์ด๋ฏธ์ง€ ๋‹ค์šด๋กœ๋“œ
24
+ # if isinstance(image_source, str) and (
25
+ # image_source.startswith("http://")
26
+ # or image_source.startswith("https://")
27
+ # ):
28
+ # response = requests.get(image_source)
29
+ # image = Image.open(BytesIO(response.content))
30
+ # # ๋กœ์ปฌ ํŒŒ์ผ์—์„œ ์ด๋ฏธ์ง€ ์—ด๊ธฐ
31
+ # else:
32
+ # image = Image.open(image_source)
33
+
34
+ # # ์ด๋ฏธ์ง€ ํฌ๋งท์ด None์ธ ๊ฒฝ์šฐ (๋ฉ”๋ชจ๋ฆฌ์—์„œ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€ ๋“ฑ)
35
+ # if image.format is None:
36
+ # image_format = "JPEG"
37
+ # else:
38
+ # image_format = image.format
39
+
40
+ # # ์ด๋ฏธ์ง€ ํฌ๋งท์ด ์ง€์›๋˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ ์˜ˆ์™ธ ๋ฐœ์ƒ
41
+ # if image_format not in Image.registered_extensions().values():
42
+ # raise ValueError(f"Unsupported image format: {image_format}.")
43
+
44
+ # buffered = BytesIO()
45
+ # # PIL์—์„œ ์ง€์›๋˜์ง€ ์•Š๋Š” ํฌ๋งท์ด๋‚˜ ๋‹ค์–‘ํ•œ ์ฑ„๋„์„ RGB๋กœ ๋ณ€ํ™˜ ํ›„ ์ €์žฅ
46
+ # if image.mode in ("RGBA", "P", "CMYK"): # RGBA, ํŒ”๋ ˆํŠธ, CMYK ๋“ฑ์€ RGB๋กœ ๋ณ€ํ™˜
47
+ # image = image.convert("RGB")
48
+ # image.save(buffered, format="JPEG")
49
 
50
+ # return genai.types.Part.from_bytes(data=buffered.getvalue(), mime_type="image/jpeg")
51
+
52
+ # except requests.exceptions.RequestException as e:
53
+ # raise ValueError(f"Failed to download the image from URL: {e}")
54
+ # except IOError as e:
55
+ # raise ValueError(f"Failed to process the image file: {e}")
56
+ # except ValueError as e:
57
+ # raise ValueError(e)
58
+
59
+
60
+ # def run_gemini(
61
+ # target_prompt: str,
62
+ # prompt_in_path: str,
63
+ # img_in_data: str = None,
64
+ # model: str = "gemini-2.0-flash",
65
+ # ) -> str:
66
+ # """
67
+ # GEMINI API๋ฅผ ๋™๊ธฐ ๋ฐฉ์‹์œผ๋กœ ํ˜ธ์ถœํ•˜์—ฌ ๋ฌธ์ž์—ด ์‘๋‹ต์„ ๋ฐ›์Šต๋‹ˆ๋‹ค.
68
+ # retry ๋…ผ๋ฆฌ๋Š” ์ œ๊ฑฐ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
69
+ # """
70
+ # with open(os.path.join(prompt_base_path, prompt_in_path), "r", encoding="utf-8") as file:
71
+ # prompt_dict = json.load(file)
72
+
73
+ # system_prompt = prompt_dict["system_prompt"]
74
+ # user_prompt_head = prompt_dict["user_prompt"]["head"]
75
+ # user_prompt_tail = prompt_dict["user_prompt"]["tail"]
76
+
77
+ # user_prompt_text = "\n".join([user_prompt_head, target_prompt, user_prompt_tail])
78
+ # input_content = [user_prompt_text]
79
+
80
+ # if img_in_data is not None:
81
+ # encoded_image = encode_image(img_in_data)
82
+ # input_content.append(encoded_image)
83
 
84
+ # logging.info("Requested API for chat completion response (sync call)...")
85
+ # start_time = time.time()
86
+
87
+ # # ๋™๊ธฐ ๋ฐฉ์‹: client.models.generate_content(...)
88
+ # chat_completion = client.models.generate_content(
89
+ # model=model,
90
+ # contents=input_content,
91
+ # )
92
+
93
+ # chat_output = chat_completion.parsed
94
+ # input_token = chat_completion.usage_metadata.prompt_token_count
95
+ # output_token = chat_completion.usage_metadata.candidates_token_count
96
+ # pricing = input_token / 1000000 * 0.1 * 1500 + output_token / 1000000 * 0.7 * 1500
97
+
98
+ # logging.info(
99
+ # f"[GEMINI] Request completed (sync). Time taken: {time.time()-start_time:.2f}s / Pricing(KRW): {pricing:.2f}"
100
+ # )
101
+ # return chat_output, chat_completion
requirements.txt CHANGED
@@ -4,5 +4,4 @@ torch
4
  transformers==4.42.4
5
  xformers
6
  sentencepiece
7
- peft==0.12.0
8
- google-genai
 
4
  transformers==4.42.4
5
  xformers
6
  sentencepiece
7
+ peft==0.12.0