reshinthadith commited on
Commit
f304736
·
verified ·
1 Parent(s): c70bb0f

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +32 -2
README.md CHANGED
@@ -22,5 +22,35 @@ Generates and Edits minimal multi-file python code. Right now consistently gener
22
  <!-- Provide the basic links for the model. -->
23
 
24
  - **Repository:** https://github.com/reshinthadithyan/repo-level-code/tree/main
25
- - **Paper [optional]:** [More Information Needed]
26
- - **Demo [optional]:** [More Information Needed]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  <!-- Provide the basic links for the model. -->
23
 
24
  - **Repository:** https://github.com/reshinthadithyan/repo-level-code/tree/main
25
+ ### Usage
26
+ ```python
27
+ import torch
28
+ from transformers import AutoModelForCausalLM, AutoTokenizer
29
+ import fire
30
+
31
+
32
+ def main(model_path:str="./models_dir/repo_coder_v1"):
33
+ input_prompt = "###Instruction: {prompt}".format(prompt="Generate a small python repo for matplotlib to visualize timeseries data to read from timeseries.csv file using pandas.")
34
+
35
+ def load_model(model_path):
36
+ """
37
+ Load the model and tokenizer from the specified path.
38
+ """
39
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
40
+ model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype="auto")
41
+ return model, tokenizer
42
+
43
+
44
+ model, tokenizer = load_model(model_path)
45
+ print(f"Loaded model from {model_path}.")
46
+
47
+ input = tokenizer(input_prompt, return_tensors="pt").to(model.device)
48
+ print(input)
49
+ with torch.no_grad():
50
+ output = model.generate(**input, max_length=1024, do_sample=True, temperature=0.9, top_p=0.95, top_k=50)
51
+ output_text = tokenizer.decode(output[0], skip_special_tokens=True)
52
+ print(f"Generated text: {output_text}")
53
+
54
+ if __name__ == "__main__":
55
+ fire.Fire(main)
56
+ ```