bearking58 commited on
Commit
fb8f2dc
·
1 Parent(s): d1d0c5f

chore: try to move hf model loading to dockerfile

Browse files
cloudbuild-model.yaml → cloudbuild.yaml RENAMED
@@ -4,10 +4,13 @@ steps:
4
  args:
5
  [
6
  "build",
 
 
7
  "-t",
8
  "us-central1-docker.pkg.dev/${PROJECT_ID}/interview-ai-detector/model-prediction:latest",
9
  ".",
10
  ]
 
11
 
12
  - name: "gcr.io/cloud-builders/docker"
13
  args:
@@ -18,3 +21,8 @@ steps:
18
 
19
  images:
20
  - "us-central1-docker.pkg.dev/${PROJECT_ID}/interview-ai-detector/model-prediction:latest"
 
 
 
 
 
 
4
  args:
5
  [
6
  "build",
7
+ "--build-arg",
8
+ "HF_TOKEN=${_HF_TOKEN}",
9
  "-t",
10
  "us-central1-docker.pkg.dev/${PROJECT_ID}/interview-ai-detector/model-prediction:latest",
11
  ".",
12
  ]
13
+ secretEnv: ["HF_TOKEN"]
14
 
15
  - name: "gcr.io/cloud-builders/docker"
16
  args:
 
21
 
22
  images:
23
  - "us-central1-docker.pkg.dev/${PROJECT_ID}/interview-ai-detector/model-prediction:latest"
24
+
25
+ availableSecrets:
26
+ secretManager:
27
+ - versionName: "projects/${PROJECT_ID}/secrets/HF_TOKEN/versions/1"
28
+ env: "HF_TOKEN"
core-model-prediction/Dockerfile CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  # Use an official Python runtime as a base image
2
  FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
3
 
@@ -17,6 +20,16 @@ RUN python -m nltk.downloader punkt wordnet averaged_perceptron_tagger
17
  # Unzip wordnet
18
  RUN unzip /root/nltk_data/corpora/wordnet.zip -d /root/nltk_data/corpora/
19
 
 
 
 
 
 
 
 
 
 
 
20
  # Make port 8080 available to the world outside this container
21
  EXPOSE 8080
22
 
 
1
+ # HF Token args
2
+ ARG HF_TOKEN
3
+
4
  # Use an official Python runtime as a base image
5
  FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
6
 
 
20
  # Unzip wordnet
21
  RUN unzip /root/nltk_data/corpora/wordnet.zip -d /root/nltk_data/corpora/
22
 
23
+ # Download HuggingFace model
24
+ RUN python -c "from transformers import AutoTokenizer, AutoModelForCausalLM; \
25
+ tokenizer = AutoTokenizer.from_pretrained('google/gemma-2b', token='$HF_TOKEN'); \
26
+ model = AutoModelForCausalLM.from_pretrained('google/gemma-2b', token='$HF_TOKEN'); \
27
+ tokenizer.save_pretrained('/app/gemma-2b'); \
28
+ model.save_pretrained('/app/gemma-2b')"
29
+
30
+ # Model env
31
+ ENV MODEL_DIR=/app/gemma-2b
32
+
33
  # Make port 8080 available to the world outside this container
34
  EXPOSE 8080
35
 
core-model-prediction/gemma2b_dependencies.py CHANGED
@@ -1,10 +1,10 @@
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM
2
  import torch
3
  from torch.nn.functional import cosine_similarity
4
  from collections import Counter
5
  import numpy as np
6
  from device_manager import DeviceManager
7
- from google.cloud import secretmanager
8
 
9
 
10
  class Gemma2BDependencies:
@@ -13,21 +13,13 @@ class Gemma2BDependencies:
13
  def __new__(cls):
14
  if cls._instance is None:
15
  cls._instance = super(Gemma2BDependencies, cls).__new__(cls)
16
- token = cls._instance.access_hf_token_secret()
17
- cls._instance.tokenizer = AutoTokenizer.from_pretrained(
18
- "google/gemma-2b", token=token)
19
- cls._instance.model = AutoModelForCausalLM.from_pretrained(
20
- "google/gemma-2b", token=token)
21
  cls._instance.device = DeviceManager()
22
  cls._instance.model.to(cls._instance.device)
23
  return cls._instance
24
 
25
- def access_hf_token_secret(self):
26
- client = secretmanager.SecretManagerServiceClient()
27
- name = "projects/steady-climate-416810/secrets/HF_TOKEN/versions/1"
28
- response = client.access_secret_version(request={"name": name})
29
- return response.payload.data.decode('UTF-8')
30
-
31
  def calculate_perplexity(self, text: str):
32
  inputs = self.tokenizer(text, return_tensors="pt",
33
  truncation=True, max_length=1024)
@@ -42,7 +34,6 @@ class Gemma2BDependencies:
42
  return perplexity.item()
43
 
44
  def calculate_burstiness(self, text: str):
45
- # Tokenize the text using GPT-2 tokenizer
46
  tokens = self.tokenizer.encode(text, add_special_tokens=False)
47
 
48
  # Count token frequencies
 
1
+ import os
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
  from torch.nn.functional import cosine_similarity
5
  from collections import Counter
6
  import numpy as np
7
  from device_manager import DeviceManager
 
8
 
9
 
10
  class Gemma2BDependencies:
 
13
  def __new__(cls):
14
  if cls._instance is None:
15
  cls._instance = super(Gemma2BDependencies, cls).__new__(cls)
16
+ model_dir = os.getenv("MODEL_DIR", "/app/gemma-2b")
17
+ cls._instance.tokenizer = AutoTokenizer.from_pretrained(model_dir)
18
+ cls._instance.model = AutoModelForCausalLM.from_pretrained(model_dir)
 
 
19
  cls._instance.device = DeviceManager()
20
  cls._instance.model.to(cls._instance.device)
21
  return cls._instance
22
 
 
 
 
 
 
 
23
  def calculate_perplexity(self, text: str):
24
  inputs = self.tokenizer(text, return_tensors="pt",
25
  truncation=True, max_length=1024)
 
34
  return perplexity.item()
35
 
36
  def calculate_burstiness(self, text: str):
 
37
  tokens = self.tokenizer.encode(text, add_special_tokens=False)
38
 
39
  # Count token frequencies
public-prediction/kafka_consumer.py CHANGED
@@ -52,7 +52,6 @@ def send_results_back(full_results: dict[str, any], job_application_id: str):
52
 
53
  response = requests.patch(url, json=body, headers=headers)
54
  print(f"Data sent with status code {response.status_code}")
55
- print(response.content)
56
 
57
 
58
  def consume_messages():
@@ -72,7 +71,7 @@ def consume_messages():
72
 
73
  for message in consumer:
74
  try:
75
- incoming_message = json.loads(message.value.decode("utf-8"))
76
  full_batch = incoming_message["data"]
77
  except json.JSONDecodeError:
78
  print("Failed to decode JSON from message:", message.value)
@@ -84,6 +83,7 @@ def consume_messages():
84
 
85
  full_results = []
86
  for i in range(0, len(full_batch), BATCH_SIZE):
 
87
  batch = full_batch[i:i+BATCH_SIZE]
88
  batch_results = process_batch(batch, BATCH_SIZE, gpt_helper)
89
  full_results.extend(batch_results)
 
52
 
53
  response = requests.patch(url, json=body, headers=headers)
54
  print(f"Data sent with status code {response.status_code}")
 
55
 
56
 
57
  def consume_messages():
 
71
 
72
  for message in consumer:
73
  try:
74
+ incoming_message = json.loads(json.loads(message.value.decode("utf-8")))
75
  full_batch = incoming_message["data"]
76
  except json.JSONDecodeError:
77
  print("Failed to decode JSON from message:", message.value)
 
83
 
84
  full_results = []
85
  for i in range(0, len(full_batch), BATCH_SIZE):
86
+ print(f"Processing batch {i} to {i+BATCH_SIZE}")
87
  batch = full_batch[i:i+BATCH_SIZE]
88
  batch_results = process_batch(batch, BATCH_SIZE, gpt_helper)
89
  full_results.extend(batch_results)
public-prediction/predict_custom_model.py CHANGED
@@ -20,13 +20,9 @@ def predict_custom_trained_model(
20
  # The AI Platform services require regional API endpoints.
21
  client_options = {"api_endpoint": api_endpoint}
22
 
23
- credentials = service_account.Credentials.from_service_account_file(
24
- os.getenv("GOOGLE_APPLICATION_CREDENTIALS"))
25
  # Initialize client that will be used to create and send requests.
26
  # This client only needs to be created once, and can be reused for multiple requests.
27
- client = aiplatform.gapic.PredictionServiceClient(
28
- credentials=credentials,
29
- client_options=client_options)
30
  # The format of each instance should conform to the deployed model's prediction input schema.
31
  instances = instances if isinstance(instances, list) else [instances]
32
  instances = [
 
20
  # The AI Platform services require regional API endpoints.
21
  client_options = {"api_endpoint": api_endpoint}
22
 
 
 
23
  # Initialize client that will be used to create and send requests.
24
  # This client only needs to be created once, and can be reused for multiple requests.
25
+ client = aiplatform.gapic.PredictionServiceClient(client_options=client_options)
 
 
26
  # The format of each instance should conform to the deployed model's prediction input schema.
27
  instances = instances if isinstance(instances, list) else [instances]
28
  instances = [