reshinthadith commited on
Commit
ecce996
·
verified ·
1 Parent(s): 4287a63

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +66 -8
README.md CHANGED
@@ -39,17 +39,62 @@ The model generates the repository in the following format, Code to parse it and
39
  import torch
40
  from transformers import AutoModelForCausalLM, AutoTokenizer
41
  import fire
42
-
43
-
44
- def main(model_path:str="./models_dir/repo_coder_v1"):
45
- input_prompt = "###Instruction: {prompt}".format(prompt="Generate a small python repo for matplotlib to visualize timeseries data to read from timeseries.csv file using pandas.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
 
47
  def load_model(model_path):
48
  """
49
  Load the model and tokenizer from the specified path.
50
  """
51
  tokenizer = AutoTokenizer.from_pretrained(model_path)
52
- model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto")
 
53
  return model, tokenizer
54
 
55
 
@@ -57,12 +102,25 @@ def main(model_path:str="./models_dir/repo_coder_v1"):
57
  print(f"Loaded model from {model_path}.")
58
 
59
  input = tokenizer(input_prompt, return_tensors="pt").to(model.device)
60
- print(input)
61
  with torch.no_grad():
62
  output = model.generate(**input, max_length=1024, do_sample=True, temperature=0.9, top_p=0.95, top_k=50)
63
- output_text = tokenizer.decode(output[0], skip_special_tokens=True)
64
- print(f"Generated text: {output_text}")
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  if __name__ == "__main__":
67
  fire.Fire(main)
 
68
  ```
 
39
  import torch
40
  from transformers import AutoModelForCausalLM, AutoTokenizer
41
  import fire
42
+ from pathlib import Path
43
+ import os
44
+ import re
45
+
46
+ def generate_repo_from_string(input_str: str, output_dir: str) -> None:
47
+ """
48
+ Parse <output> tags in the input string and write files (and bashfiles) to the specified output directory.
49
+
50
+ - Searches for <output>...</output> section.
51
+ - Within that, finds all <fileX> or <bashfile> tags:
52
+ <file1>path/to/file.ext<content>...file content...</content></file1>
53
+ <bashfile>script.sh<content>...script content...</content></bashfile>
54
+
55
+ Args:
56
+ input_str: The full string containing <output> markup.
57
+ output_dir: Directory where files will be created. Existing files will be overwritten.
58
+ """
59
+ # Extract the content inside <output>...</output>
60
+ out_match = re.search(r"<output>(.*?)</output>", input_str, re.DOTALL)
61
+ if not out_match:
62
+ raise ValueError("No <output> section found in input.")
63
+ output_section = out_match.group(1)
64
+
65
+ # Regex to find file tags: file1, file2, file3, ... and bashfile
66
+ pattern = re.compile(
67
+ r"<(file\d+|bashfile)>([^<]+?)<content>(.*?)</content></\1>",
68
+ re.DOTALL
69
+ )
70
+
71
+ for tag, filename, content in pattern.findall(output_section):
72
+ # Determine full path
73
+ file_path = os.path.join(output_dir, filename.strip())
74
+ # Ensure parent directory exists
75
+ parent = os.path.dirname(file_path)
76
+ if parent:
77
+ os.makedirs(parent, exist_ok=True)
78
+ # Write content to file
79
+ with open(file_path, 'w', encoding='utf-8') as f:
80
+ # Strip only one leading newline if present
81
+ f.write(content.lstrip('\n'))
82
+
83
+ print(f"Repository generated at: {output_dir}")
84
+
85
+
86
+ def main(model_path:str="./models_dir/repo_coder_v1",
87
+ prompt:str="Generate a small python repo for matplotlib to visualize timeseries data to read from timeseries.csv file using polars."
88
+ ,output_path="./output_dir/demo2"):
89
+ input_prompt = "###Instruction: {prompt}".format(prompt=prompt)
90
 
91
  def load_model(model_path):
92
  """
93
  Load the model and tokenizer from the specified path.
94
  """
95
  tokenizer = AutoTokenizer.from_pretrained(model_path)
96
+ model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto").to("cuda:0")
97
+ model.eval()
98
  return model, tokenizer
99
 
100
 
 
102
  print(f"Loaded model from {model_path}.")
103
 
104
  input = tokenizer(input_prompt, return_tensors="pt").to(model.device)
 
105
  with torch.no_grad():
106
  output = model.generate(**input, max_length=1024, do_sample=True, temperature=0.9, top_p=0.95, top_k=50)
107
+ generated_code_repo = tokenizer.decode(output[0], skip_special_tokens=True)
108
+ print(f"Generated code repo: {generated_code_repo}")
109
+ Path(output_path).mkdir(parents=True, exist_ok=True)
110
+ generate_repo_from_string(generated_code_repo, output_path)
111
+
112
+ def list_files(startpath):
113
+ for root, dirs, files in os.walk(startpath):
114
+ level = root.replace(startpath, '').count(os.sep)
115
+ indent = ' ' * 4 * (level)
116
+ print('{}{}/'.format(indent, os.path.basename(root)))
117
+ subindent = ' ' * 4 * (level + 1)
118
+ for f in files:
119
+ print('{}{}'.format(subindent, f))
120
+ list_files(output_path)
121
+
122
 
123
  if __name__ == "__main__":
124
  fire.Fire(main)
125
+
126
  ```