AshenBorn commited on
Commit
f05fc7c
·
verified ·
1 Parent(s): 2511f4c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -0
app.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
4
+ from peft import PeftModel
5
+
6
+ def main():
7
+ st.title("Math Meme Repair (LoRA-Fine-Tuned)")
8
+
9
+ st.markdown("""
10
+ **Instructions**:
11
+ 1. Enter your incorrect math meme in the format:
12
+ ```
13
+ Math Meme Correction:
14
+ Incorrect: 5-3-1 = 3?
15
+ Correct:
16
+ ```
17
+ 2. Click **Repair Math Meme** to generate a corrected explanation.
18
+
19
+ **Note**: This is running on CPU, so it may be slow and memory-intensive for a 7B model.
20
+ """)
21
+
22
+ # 1. Load the base model from Hugging Face
23
+ model_name = "deepseek-ai/deepseek-math-7b-base"
24
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
25
+
26
+ # If your CPU doesn't support float16, switch to float32.
27
+ # (float16 might not work well on certain CPUs)
28
+ base_model = AutoModelForCausalLM.from_pretrained(
29
+ model_name,
30
+ torch_dtype=torch.float32 # CPU-friendly dtype
31
+ )
32
+ base_model = base_model.to("cpu") # We'll run on CPU
33
+
34
+ # 2. Load your LoRA adapter (local directory with adapter_config.json & adapter_model.safetensors)
35
+ adapter_dir = "trained-math-meme-model"
36
+ model = PeftModel.from_pretrained(base_model, adapter_dir)
37
+ model = model.to("cpu")
38
+
39
+ # 3. Configure generation
40
+ generation_config = GenerationConfig(
41
+ max_new_tokens=100,
42
+ temperature=0.7,
43
+ top_p=0.7,
44
+ pad_token_id=tokenizer.eos_token_id
45
+ )
46
+
47
+ # 4. User input area
48
+ user_input = st.text_area(
49
+ "Enter your math meme input:",
50
+ value="Math Meme Correction:\nIncorrect: 5-3-1 = 3?\nCorrect:"
51
+ )
52
+
53
+ if st.button("Repair Math Meme"):
54
+ if user_input.strip() == "":
55
+ st.warning("Please enter a math meme input following the required format.")
56
+ else:
57
+ with torch.no_grad():
58
+ # Tokenize on CPU
59
+ encoding = tokenizer(user_input, return_tensors="pt").to("cpu")
60
+ outputs = model.generate(
61
+ input_ids=encoding.input_ids,
62
+ attention_mask=encoding.attention_mask,
63
+ max_new_tokens=generation_config.max_new_tokens,
64
+ temperature=generation_config.temperature,
65
+ top_p=generation_config.top_p,
66
+ pad_token_id=generation_config.pad_token_id
67
+ )
68
+
69
+ # Decode and display
70
+ result = tokenizer.decode(outputs[0], skip_special_tokens=True)
71
+ st.subheader("Repaired Math Meme")
72
+ st.write(result)
73
+
74
+ st.markdown("\n**Error Rating:** 90% sass, 10% patience (on CPU)")
75
+
76
+ if __name__ == "__main__":
77
+ main()