GuhanAein commited on
Commit
2e97b24
·
verified ·
1 Parent(s): c3e7791

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +70 -0
main.py CHANGED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3
+ import torch
4
+ from datasets import load_dataset
5
+ from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
6
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
7
+ from llama_index.vector_stores.faiss import FaissVectorStore
8
+ import faiss
9
+ import os
10
+ from huggingface_hub import login
11
+
12
+ app = FastAPI()
13
+
14
+ # Log in to Hugging Face using environment variable
15
+ hf_token = os.getenv("HF_TOKEN")
16
+ if not hf_token:
17
+ raise ValueError("HF_TOKEN environment variable not set")
18
+ login(hf_token)
19
+
20
+ # Load Dataset and Prepare Knowledge Base
21
+ ds = load_dataset("codeparrot/apps", "all", split="train")
22
+ os.makedirs("knowledge_base", exist_ok=True)
23
+ for i, example in enumerate(ds.select(range(100))): # Reduced to 100 for free tier
24
+ solution = example['solutions'][0] if example['solutions'] else "No solution available"
25
+ with open(f"knowledge_base/doc_{i}.txt", "w", encoding="utf-8") as f:
26
+ f.write(f"### Problem\n{example['question']}\n\n### Solution\n{solution}")
27
+ documents = SimpleDirectoryReader("knowledge_base").load_data()
28
+
29
+ # Setup RAG
30
+ embed_model = HuggingFaceEmbedding(model_name="sentence-transformers/all-MiniLM-L6-v2")
31
+ Settings.embed_model = embed_model
32
+ d = 384
33
+ faiss_index = faiss.IndexFlatL2(d)
34
+ vector_store = FaissVectorStore(faiss_index=faiss_index)
35
+ index = VectorStoreIndex.from_documents(documents, vector_store=vector_store)
36
+
37
+ # Load LLaMA Model
38
+ model_name = "meta-llama/Llama-3.2-1B-Instruct"
39
+ quant_config = BitsAndBytesConfig(
40
+ load_in_4bit=True,
41
+ bnb_4bit_compute_dtype=torch.float16,
42
+ bnb_4bit_quant_type="nf4",
43
+ bnb_4bit_use_double_quant=True
44
+ )
45
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
46
+ device = "cuda" if torch.cuda.is_available() else "cpu"
47
+ model = AutoModelForCausalLM.from_pretrained(
48
+ model_name,
49
+ quantization_config=quant_config,
50
+ device_map="auto" if device == "cuda" else None
51
+ )
52
+ if tokenizer.pad_token is None:
53
+ tokenizer.pad_token = tokenizer.eos_token
54
+
55
+ @app.get("/solve")
56
+ async def solve_problem(problem: str, top_k: int = 1):
57
+ retriever = index.as_retriever(similarity_top_k=top_k)
58
+ retrieved_nodes = retriever.retrieve(problem)
59
+ context = retrieved_nodes[0].text if retrieved_nodes else "No relevant context found."
60
+ prompt = f"Given the following competitive programming problem:\n\n{problem}\n\nRelevant context:\n{context}\n\nGenerate a solution in Python:"
61
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
62
+ outputs = model.generate(
63
+ **inputs,
64
+ max_new_tokens=200,
65
+ temperature=0.7,
66
+ top_p=0.9,
67
+ do_sample=True
68
+ )
69
+ solution = tokenizer.decode(outputs[0], skip_special_tokens=True)
70
+ return {"solution": solution, "context": context}