cwhuh commited on
Commit
2c61b33
ยท
1 Parent(s): ac860b2

add : llm wrapper

Browse files
Files changed (1) hide show
  1. llm_wrapper.py +99 -99
llm_wrapper.py CHANGED
@@ -1,101 +1,101 @@
1
- # import logging
2
- # from PIL import Image
3
- # from io import BytesIO
4
- # import requests, os, json, time
5
-
6
- # from google import genai
7
-
8
- # prompt_base_path = "src/llm_wrapper/prompt"
9
- # client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
10
-
11
-
12
- # def encode_image(image_source):
13
- # """
14
- # ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ๊ฐ€ URL์ด๋“  ๋กœ์ปฌ ํŒŒ์ผ์ด๋“  Pillow Image ๊ฐ์ฒด์ด๋“  ๋™์ผํ•˜๊ฒŒ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜.
15
- # ์ด๋ฏธ์ง€๋ฅผ ์—ด์–ด google.genai.types.Part ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
16
- # Pillow์—์„œ ์ง€์›๋˜์ง€ ์•Š๋Š” ํฌ๋งท์— ๋Œ€ํ•ด์„œ๋Š” ์˜ˆ์™ธ๋ฅผ ๋ฐœ์ƒ์‹œํ‚ต๋‹ˆ๋‹ค.
17
- # """
18
- # try:
19
- # # ์ด๋ฏธ Pillow ์ด๋ฏธ์ง€ ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
20
- # if isinstance(image_source, Image.Image):
21
- # image = image_source
22
- # else:
23
- # # URL์—์„œ ์ด๋ฏธ์ง€ ๋‹ค์šด๋กœ๋“œ
24
- # if isinstance(image_source, str) and (
25
- # image_source.startswith("http://")
26
- # or image_source.startswith("https://")
27
- # ):
28
- # response = requests.get(image_source)
29
- # image = Image.open(BytesIO(response.content))
30
- # # ๋กœ์ปฌ ํŒŒ์ผ์—์„œ ์ด๋ฏธ์ง€ ์—ด๊ธฐ
31
- # else:
32
- # image = Image.open(image_source)
33
-
34
- # # ์ด๋ฏธ์ง€ ํฌ๋งท์ด None์ธ ๊ฒฝ์šฐ (๋ฉ”๋ชจ๋ฆฌ์—์„œ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€ ๋“ฑ)
35
- # if image.format is None:
36
- # image_format = "JPEG"
37
- # else:
38
- # image_format = image.format
39
-
40
- # # ์ด๋ฏธ์ง€ ํฌ๋งท์ด ์ง€์›๋˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ ์˜ˆ์™ธ ๋ฐœ์ƒ
41
- # if image_format not in Image.registered_extensions().values():
42
- # raise ValueError(f"Unsupported image format: {image_format}.")
43
-
44
- # buffered = BytesIO()
45
- # # PIL์—์„œ ์ง€์›๋˜์ง€ ์•Š๋Š” ํฌ๋งท์ด๋‚˜ ๋‹ค์–‘ํ•œ ์ฑ„๋„์„ RGB๋กœ ๋ณ€ํ™˜ ํ›„ ์ €์žฅ
46
- # if image.mode in ("RGBA", "P", "CMYK"): # RGBA, ํŒ”๋ ˆํŠธ, CMYK ๋“ฑ์€ RGB๋กœ ๋ณ€ํ™˜
47
- # image = image.convert("RGB")
48
- # image.save(buffered, format="JPEG")
49
 
50
- # return genai.types.Part.from_bytes(data=buffered.getvalue(), mime_type="image/jpeg")
51
-
52
- # except requests.exceptions.RequestException as e:
53
- # raise ValueError(f"Failed to download the image from URL: {e}")
54
- # except IOError as e:
55
- # raise ValueError(f"Failed to process the image file: {e}")
56
- # except ValueError as e:
57
- # raise ValueError(e)
58
-
59
-
60
- # def run_gemini(
61
- # target_prompt: str,
62
- # prompt_in_path: str,
63
- # img_in_data: str = None,
64
- # model: str = "gemini-2.0-flash",
65
- # ) -> str:
66
- # """
67
- # GEMINI API๋ฅผ ๋™๊ธฐ ๋ฐฉ์‹์œผ๋กœ ํ˜ธ์ถœํ•˜์—ฌ ๋ฌธ์ž์—ด ์‘๋‹ต์„ ๋ฐ›์Šต๋‹ˆ๋‹ค.
68
- # retry ๋…ผ๋ฆฌ๋Š” ์ œ๊ฑฐ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
69
- # """
70
- # with open(os.path.join(prompt_base_path, prompt_in_path), "r", encoding="utf-8") as file:
71
- # prompt_dict = json.load(file)
72
-
73
- # system_prompt = prompt_dict["system_prompt"]
74
- # user_prompt_head = prompt_dict["user_prompt"]["head"]
75
- # user_prompt_tail = prompt_dict["user_prompt"]["tail"]
76
-
77
- # user_prompt_text = "\n".join([user_prompt_head, target_prompt, user_prompt_tail])
78
- # input_content = [user_prompt_text]
79
-
80
- # if img_in_data is not None:
81
- # encoded_image = encode_image(img_in_data)
82
- # input_content.append(encoded_image)
83
 
84
- # logging.info("Requested API for chat completion response (sync call)...")
85
- # start_time = time.time()
86
-
87
- # # ๋™๊ธฐ ๋ฐฉ์‹: client.models.generate_content(...)
88
- # chat_completion = client.models.generate_content(
89
- # model=model,
90
- # contents=input_content,
91
- # )
92
-
93
- # chat_output = chat_completion.parsed
94
- # input_token = chat_completion.usage_metadata.prompt_token_count
95
- # output_token = chat_completion.usage_metadata.candidates_token_count
96
- # pricing = input_token / 1000000 * 0.1 * 1500 + output_token / 1000000 * 0.7 * 1500
97
-
98
- # logging.info(
99
- # f"[GEMINI] Request completed (sync). Time taken: {time.time()-start_time:.2f}s / Pricing(KRW): {pricing:.2f}"
100
- # )
101
- # return chat_output, chat_completion
 
1
+ import logging
2
+ from PIL import Image
3
+ from io import BytesIO
4
+ import requests, os, json, time
5
+
6
+ from google import genai
7
+
8
+ prompt_base_path = "src/llm_wrapper/prompt"
9
+ client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
10
+
11
+
12
+ def encode_image(image_source):
13
+ """
14
+ ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ๊ฐ€ URL์ด๋“  ๋กœ์ปฌ ํŒŒ์ผ์ด๋“  Pillow Image ๊ฐ์ฒด์ด๋“  ๋™์ผํ•˜๊ฒŒ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜.
15
+ ์ด๋ฏธ์ง€๋ฅผ ์—ด์–ด google.genai.types.Part ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
16
+ Pillow์—์„œ ์ง€์›๋˜์ง€ ์•Š๋Š” ํฌ๋งท์— ๋Œ€ํ•ด์„œ๋Š” ์˜ˆ์™ธ๋ฅผ ๋ฐœ์ƒ์‹œํ‚ต๋‹ˆ๋‹ค.
17
+ """
18
+ try:
19
+ # ์ด๋ฏธ Pillow ์ด๋ฏธ์ง€ ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
20
+ if isinstance(image_source, Image.Image):
21
+ image = image_source
22
+ else:
23
+ # URL์—์„œ ์ด๋ฏธ์ง€ ๋‹ค์šด๋กœ๋“œ
24
+ if isinstance(image_source, str) and (
25
+ image_source.startswith("http://")
26
+ or image_source.startswith("https://")
27
+ ):
28
+ response = requests.get(image_source)
29
+ image = Image.open(BytesIO(response.content))
30
+ # ๋กœ์ปฌ ํŒŒ์ผ์—์„œ ์ด๋ฏธ์ง€ ์—ด๊ธฐ
31
+ else:
32
+ image = Image.open(image_source)
33
+
34
+ # ์ด๋ฏธ์ง€ ํฌ๋งท์ด None์ธ ๊ฒฝ์šฐ (๋ฉ”๋ชจ๋ฆฌ์—์„œ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€ ๋“ฑ)
35
+ if image.format is None:
36
+ image_format = "JPEG"
37
+ else:
38
+ image_format = image.format
39
+
40
+ # ์ด๋ฏธ์ง€ ํฌ๋งท์ด ์ง€์›๋˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ ์˜ˆ์™ธ ๋ฐœ์ƒ
41
+ if image_format not in Image.registered_extensions().values():
42
+ raise ValueError(f"Unsupported image format: {image_format}.")
43
+
44
+ buffered = BytesIO()
45
+ # PIL์—์„œ ์ง€์›๋˜์ง€ ์•Š๋Š” ํฌ๋งท์ด๋‚˜ ๋‹ค์–‘ํ•œ ์ฑ„๋„์„ RGB๋กœ ๋ณ€ํ™˜ ํ›„ ์ €์žฅ
46
+ if image.mode in ("RGBA", "P", "CMYK"): # RGBA, ํŒ”๋ ˆํŠธ, CMYK ๋“ฑ์€ RGB๋กœ ๋ณ€ํ™˜
47
+ image = image.convert("RGB")
48
+ image.save(buffered, format="JPEG")
49
 
50
+ return genai.types.Part.from_bytes(data=buffered.getvalue(), mime_type="image/jpeg")
51
+
52
+ except requests.exceptions.RequestException as e:
53
+ raise ValueError(f"Failed to download the image from URL: {e}")
54
+ except IOError as e:
55
+ raise ValueError(f"Failed to process the image file: {e}")
56
+ except ValueError as e:
57
+ raise ValueError(e)
58
+
59
+
60
+ def run_gemini(
61
+ target_prompt: str,
62
+ prompt_in_path: str,
63
+ img_in_data: str = None,
64
+ model: str = "gemini-2.0-flash",
65
+ ) -> str:
66
+ """
67
+ GEMINI API๋ฅผ ๋™๊ธฐ ๋ฐฉ์‹์œผ๋กœ ํ˜ธ์ถœํ•˜์—ฌ ๋ฌธ์ž์—ด ์‘๋‹ต์„ ๋ฐ›์Šต๋‹ˆ๋‹ค.
68
+ retry ๋…ผ๋ฆฌ๋Š” ์ œ๊ฑฐ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
69
+ """
70
+ with open(os.path.join(prompt_base_path, prompt_in_path), "r", encoding="utf-8") as file:
71
+ prompt_dict = json.load(file)
72
+
73
+ system_prompt = prompt_dict["system_prompt"]
74
+ user_prompt_head = prompt_dict["user_prompt"]["head"]
75
+ user_prompt_tail = prompt_dict["user_prompt"]["tail"]
76
+
77
+ user_prompt_text = "\n".join([user_prompt_head, target_prompt, user_prompt_tail])
78
+ input_content = [user_prompt_text]
79
+
80
+ if img_in_data is not None:
81
+ encoded_image = encode_image(img_in_data)
82
+ input_content.append(encoded_image)
83
 
84
+ logging.info("Requested API for chat completion response (sync call)...")
85
+ start_time = time.time()
86
+
87
+ # ๋™๊ธฐ ๋ฐฉ์‹: client.models.generate_content(...)
88
+ chat_completion = client.models.generate_content(
89
+ model=model,
90
+ contents=input_content,
91
+ )
92
+
93
+ chat_output = chat_completion.parsed
94
+ input_token = chat_completion.usage_metadata.prompt_token_count
95
+ output_token = chat_completion.usage_metadata.candidates_token_count
96
+ pricing = input_token / 1000000 * 0.1 * 1500 + output_token / 1000000 * 0.7 * 1500
97
+
98
+ logging.info(
99
+ f"[GEMINI] Request completed (sync). Time taken: {time.time()-start_time:.2f}s / Pricing(KRW): {pricing:.2f}"
100
+ )
101
+ return chat_output, chat_completion