Guy24 commited on
Commit
d844e87
Β·
1 Parent(s): 682420f

adding application

Browse files
app.py CHANGED
@@ -12,10 +12,10 @@ import torch
12
  from tqdm import tqdm
13
  from abc import ABC, abstractmethod
14
 
15
- from .utils.enums import MultiTokenKind, RetrievalTechniques
16
  from .processor import RetrievalProcessor
17
- from .utils.logit_lens import ReverseLogitLens
18
- from .utils.model_utils import extract_token_i_hidden_states
19
 
20
 
21
  class WordRetrieverBase(ABC):
@@ -118,15 +118,15 @@ class PatchscopesRetriever(WordRetrieverBase):
118
  return patchscopes_description_by_layers, last_token_hidden_states
119
 
120
 
121
- class ReverseLogitLensRetriever(WordRetrieverBase):
122
- def __init__(self, model, tokenizer, device='cuda', dtype=torch.float16):
123
- super().__init__(model, tokenizer)
124
- self.reverse_logit_lens = ReverseLogitLens.from_model(model).to(device).to(dtype)
125
-
126
- def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=3):
127
- result = self.reverse_logit_lens(hidden_states, layer_idx)
128
- token = self.tokenizer.decode(torch.argmax(result, dim=-1).item())
129
- return token
130
 
131
 
132
  class AnalysisWordRetriever:
 
12
  from tqdm import tqdm
13
  from abc import ABC, abstractmethod
14
 
15
+ from enums import MultiTokenKind, RetrievalTechniques
16
  from .processor import RetrievalProcessor
17
+ # from .utils.logit_lens import ReverseLogitLens
18
+ from model_utils import extract_token_i_hidden_states
19
 
20
 
21
  class WordRetrieverBase(ABC):
 
118
  return patchscopes_description_by_layers, last_token_hidden_states
119
 
120
 
121
+ # class ReverseLogitLensRetriever(WordRetrieverBase):
122
+ # def __init__(self, model, tokenizer, device='cuda', dtype=torch.float16):
123
+ # super().__init__(model, tokenizer)
124
+ # self.reverse_logit_lens = ReverseLogitLens.from_model(model).to(device).to(dtype)
125
+ #
126
+ # def retrieve_word(self, hidden_states, layer_idx=None, num_tokens_to_generate=3):
127
+ # result = self.reverse_logit_lens(hidden_states, layer_idx)
128
+ # token = self.tokenizer.decode(torch.argmax(result, dim=-1).item())
129
+ # return token
130
 
131
 
132
  class AnalysisWordRetriever:
utils/calibration_utils.py β†’ calibration_utils.py RENAMED
File without changes
utils/data_utils.py β†’ data_utils.py RENAMED
File without changes
utils/enums.py β†’ enums.py RENAMED
File without changes
utils/eval_utils.py β†’ eval_utils.py RENAMED
File without changes
utils/file_utils.py β†’ file_utils.py RENAMED
File without changes
utils/logit_lens.py β†’ logit_lens.py RENAMED
File without changes
utils/model_utils.py β†’ model_utils.py RENAMED
File without changes
utils/__init__.py DELETED
File without changes