cwhuh commited on
Commit
dc592b9
ยท
1 Parent(s): eb9cbe1

fix : google-genai -> openai

Browse files
__pycache__/live_preview_helpers.cpython-310.pyc CHANGED
Binary files a/__pycache__/live_preview_helpers.cpython-310.pyc and b/__pycache__/live_preview_helpers.cpython-310.pyc differ
 
__pycache__/llm_wrapper.cpython-310.pyc CHANGED
Binary files a/__pycache__/llm_wrapper.cpython-310.pyc and b/__pycache__/llm_wrapper.cpython-310.pyc differ
 
llm_wrapper.py CHANGED
@@ -1,105 +1,43 @@
1
- import logging
2
- from PIL import Image
3
- from io import BytesIO
4
- import requests, os, json, time
5
-
6
- from google import genai
7
 
8
  prompt_base_path = ""
9
- client = genai.Client(api_key=os.getenv("GEMINI_API_KEY"))
10
-
11
-
12
- def encode_image(image_source):
13
- """
14
- ์ด๋ฏธ์ง€ ๊ฒฝ๋กœ๊ฐ€ URL์ด๋“  ๋กœ์ปฌ ํŒŒ์ผ์ด๋“  Pillow Image ๊ฐ์ฒด์ด๋“  ๋™์ผํ•˜๊ฒŒ ์ฒ˜๋ฆฌํ•˜๋Š” ํ•จ์ˆ˜.
15
- ์ด๋ฏธ์ง€๋ฅผ ์—ด์–ด google.genai.types.Part ๊ฐ์ฒด๋กœ ๋ณ€ํ™˜ํ•ฉ๋‹ˆ๋‹ค.
16
- Pillow์—์„œ ์ง€์›๋˜์ง€ ์•Š๋Š” ํฌ๋งท์— ๋Œ€ํ•ด์„œ๋Š” ์˜ˆ์™ธ๋ฅผ ๋ฐœ์ƒ์‹œํ‚ต๋‹ˆ๋‹ค.
17
- """
18
- try:
19
- # ์ด๋ฏธ Pillow ์ด๋ฏธ์ง€ ๊ฐ์ฒด์ธ ๊ฒฝ์šฐ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉ
20
- if isinstance(image_source, Image.Image):
21
- image = image_source
22
- else:
23
- # URL์—์„œ ์ด๋ฏธ์ง€ ๋‹ค์šด๋กœ๋“œ
24
- if isinstance(image_source, str) and (
25
- image_source.startswith("http://")
26
- or image_source.startswith("https://")
27
- ):
28
- response = requests.get(image_source)
29
- image = Image.open(BytesIO(response.content))
30
- # ๋กœ์ปฌ ํŒŒ์ผ์—์„œ ์ด๋ฏธ์ง€ ์—ด๊ธฐ
31
- else:
32
- image = Image.open(image_source)
33
-
34
- # ์ด๋ฏธ์ง€ ํฌ๋งท์ด None์ธ ๊ฒฝ์šฐ (๋ฉ”๋ชจ๋ฆฌ์—์„œ ์ƒ์„ฑ๋œ ์ด๋ฏธ์ง€ ๋“ฑ)
35
- if image.format is None:
36
- image_format = "JPEG"
37
- else:
38
- image_format = image.format
39
-
40
- # ์ด๋ฏธ์ง€ ํฌ๋งท์ด ์ง€์›๋˜์ง€ ์•Š๋Š” ๊ฒฝ์šฐ ์˜ˆ์™ธ ๋ฐœ์ƒ
41
- if image_format not in Image.registered_extensions().values():
42
- raise ValueError(f"Unsupported image format: {image_format}.")
43
-
44
- buffered = BytesIO()
45
- # PIL์—์„œ ์ง€์›๋˜์ง€ ์•Š๋Š” ํฌ๋งท์ด๋‚˜ ๋‹ค์–‘ํ•œ ์ฑ„๋„์„ RGB๋กœ ๋ณ€ํ™˜ ํ›„ ์ €์žฅ
46
- if image.mode in ("RGBA", "P", "CMYK"): # RGBA, ํŒ”๋ ˆํŠธ, CMYK ๋“ฑ์€ RGB๋กœ ๋ณ€ํ™˜
47
- image = image.convert("RGB")
48
- image.save(buffered, format="JPEG")
49
-
50
- return genai.types.Part.from_bytes(data=buffered.getvalue(), mime_type="image/jpeg")
51
 
52
- except requests.exceptions.RequestException as e:
53
- raise ValueError(f"Failed to download the image from URL: {e}")
54
- except IOError as e:
55
- raise ValueError(f"Failed to process the image file: {e}")
56
- except ValueError as e:
57
- raise ValueError(e)
58
 
59
 
60
  def run_gemini(
61
  target_prompt: str,
62
  prompt_in_path: str,
63
- img_in_data: str = None,
64
- model: str = "gemini-2.0-flash",
65
  ) -> str:
66
  """
67
- GEMINI API๋ฅผ ๋™๊ธฐ ๋ฐฉ์‹์œผ๋กœ ํ˜ธ์ถœํ•˜์—ฌ ๋ฌธ์ž์—ด ์‘๋‹ต์„ ๋ฐ›์Šต๋‹ˆ๋‹ค.
68
- retry ๋…ผ๋ฆฌ๋Š” ์ œ๊ฑฐ๋˜์—ˆ์Šต๋‹ˆ๋‹ค.
69
  """
70
- with open(os.path.join(prompt_base_path, prompt_in_path), "r", encoding="utf-8") as file:
 
 
 
 
71
  prompt_dict = json.load(file)
72
 
73
  system_prompt = prompt_dict["system_prompt"]
74
- user_prompt_head = prompt_dict["user_prompt"]["head"]
75
- user_prompt_tail = prompt_dict["user_prompt"]["tail"]
76
-
77
- user_prompt_text = "\n".join([user_prompt_head, target_prompt, user_prompt_tail])
78
- input_content = [user_prompt_text]
79
-
80
- if img_in_data is not None:
81
- encoded_image = encode_image(img_in_data)
82
- input_content.append(encoded_image)
83
-
84
- logging.info("Requested API for chat completion response (sync call)...")
85
- start_time = time.time()
86
-
87
- # ๋™๊ธฐ ๋ฐฉ์‹: client.models.generate_content(...)
88
- chat_completion = client.models.generate_content(
89
- model=model,
90
- contents=input_content,
91
- config={
92
- "system_instruction": system_prompt,
93
- }
94
  )
95
- print(f"Chat Completion: {chat_completion}")
96
-
97
- chat_output = chat_completion.candidates[0].content.parts[0].text
98
- input_token = chat_completion.usage_metadata.prompt_token_count
99
- output_token = chat_completion.usage_metadata.candidates_token_count
100
- pricing = input_token / 1000000 * 0.1 * 1500 + output_token / 1000000 * 0.7 * 1500
101
 
102
- logging.info(
103
- f"[GEMINI] Request completed (sync). Time taken: {time.time()-start_time:.2f}s / Pricing(KRW): {pricing:.2f}"
 
 
 
 
 
 
 
104
  )
 
105
  return chat_output
 
1
+ import openai, os, json
 
 
 
 
 
2
 
3
  prompt_base_path = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ client = openai.OpenAI(
6
+ api_key=os.getenv("GEMINI_API_KEY"),
7
+ base_url="https://generativelanguage.googleapis.com/v1beta/openai/",
8
+ )
 
 
9
 
10
 
11
  def run_gemini(
12
  target_prompt: str,
13
  prompt_in_path: str,
14
+ llm_model: str = "gemini-2.0-flash-exp",
 
15
  ) -> str:
16
  """
17
+ gemini ๋ชจ๋ธ ์‚ฌ์šฉ ์ฝ”๋“œ
 
18
  """
19
+
20
+ # Load prompt
21
+ with open(
22
+ os.path.join(prompt_base_path, prompt_in_path), "r", encoding="utf-8"
23
+ ) as file:
24
  prompt_dict = json.load(file)
25
 
26
  system_prompt = prompt_dict["system_prompt"]
27
+ user_prompt_head, user_prompt_tail = (
28
+ prompt_dict["user_prompt"]["head"],
29
+ prompt_dict["user_prompt"]["tail"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
 
 
 
 
 
 
31
 
32
+ user_prompt_text = "\n".join([user_prompt_head, target_prompt, user_prompt_tail])
33
+ input_content = [{"type": "text", "text": user_prompt_text}]
34
+
35
+ chat_completion = client.beta.chat.completions.parse(
36
+ model=llm_model,
37
+ messages=[
38
+ {"role": "system", "content": system_prompt},
39
+ {"role": "user", "content": input_content},
40
+ ],
41
  )
42
+ chat_output = chat_completion.choices[0].message.content
43
  return chat_output
requirements.txt CHANGED
@@ -5,5 +5,5 @@ transformers==4.42.4
5
  xformers
6
  sentencepiece
7
  peft==0.12.0
8
- google-genai
9
- gradio
 
5
  xformers
6
  sentencepiece
7
  peft==0.12.0
8
+ openai
9
+ gradio==4.43.0