aaurelions commited on
Commit
fad4129
Β·
verified Β·
1 Parent(s): d7dc9a8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
4
+ from threading import Thread
5
+
6
+ # --- Configuration ---
7
+ MODEL_ID = "microsoft/bitnet-b1.58-2B-4T"
8
+ # Try 'cuda' if you have a GPU space, 'cpu' otherwise (will be slow)
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+ print(f"Using device: {DEVICE}")
11
+
12
+ # --- Load Model and Tokenizer ---
13
+ # Note: Loading might require specific trust_remote_code=True or other flags
14
+ # depending on the model implementation. Check the model card on Hugging Face.
15
+ # You might also need specific quantization configs if not handled automatically.
16
+ try:
17
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
18
+ # Adjust loading parameters as needed (e.g., torch_dtype, device_map)
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ MODEL_ID,
21
+ torch_dtype=torch.bfloat16, # Or float16, adjust based on hardware/model reqs
22
+ device_map="auto", # Automatically distribute across available devices (GPU/CPU)
23
+ # trust_remote_code=True # May be required for some custom model code
24
+ )
25
+ # model.to(DEVICE) # Usually handled by device_map="auto"
26
+ print("Model and tokenizer loaded successfully.")
27
+ except Exception as e:
28
+ print(f"Error loading model or tokenizer: {e}")
29
+ # Fallback or exit if loading fails
30
+ raise SystemExit("Failed to load model/tokenizer.")
31
+
32
+ # --- Chat Processing Function ---
33
+ def predict(message, history):
34
+ """
35
+ Generates a response to the user's message using the chat history.
36
+ """
37
+ history_transformer_format = []
38
+ for human, assistant in history:
39
+ # Basic alternating format - adjust if the model expects something different
40
+ history_transformer_format.append({"role": "user", "content": human})
41
+ history_transformer_format.append({"role": "assistant", "content": assistant})
42
+
43
+ # Add the current user message
44
+ history_transformer_format.append({"role": "user", "content": message})
45
+
46
+ # Use the tokenizer's chat template if available, otherwise manual formatting.
47
+ # Base models might not have a specific chat template.
48
+ try:
49
+ prompt = tokenizer.apply_chat_template(
50
+ history_transformer_format,
51
+ tokenize=False,
52
+ add_generation_prompt=True # Important for generation
53
+ )
54
+ except Exception:
55
+ # Manual fallback prompt formatting (Example - adjust as needed!)
56
+ print("Warning: Using basic manual prompt formatting.")
57
+ prompt_parts = ["Chat History:"]
58
+ for turn in history_transformer_format:
59
+ prompt_parts.append(f"{turn['role'].capitalize()}: {turn['content']}")
60
+ prompt = "\n".join(prompt_parts) + "\nAssistant:" # Ensure it ends ready for generation
61
+
62
+ print(f"\n--- Prompt Sent to Model ---\n{prompt}\n---------------------------\n")
63
+
64
+ # Use a streamer for interactive generation
65
+ streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
66
+
67
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
68
+
69
+ generation_kwargs = dict(
70
+ inputs,
71
+ streamer=streamer,
72
+ max_new_tokens=512,
73
+ do_sample=True,
74
+ top_p=0.9,
75
+ temperature=0.7,
76
+ # Add other generation parameters as needed
77
+ # eos_token_id=tokenizer.eos_token_id # Important if model needs it
78
+ pad_token_id=tokenizer.eos_token_id # Often set for open-end generation
79
+ )
80
+
81
+ # Run generation in a separate thread for streaming
82
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
83
+ thread.start()
84
+
85
+ # Yield tokens as they become available
86
+ partial_message = ""
87
+ for new_token in streamer:
88
+ partial_message += new_token
89
+ yield partial_message
90
+
91
+ # --- Gradio Interface ---
92
+ # Use gr.ChatInterface - it handles history management automatically
93
+ chatbot_interface = gr.ChatInterface(
94
+ fn=predict,
95
+ chatbot=gr.Chatbot(height=500),
96
+ textbox=gr.Textbox(placeholder="Ask me anything...", container=False, scale=7),
97
+ title="Chat with microsoft/bitnet-b1.58-2B-4T",
98
+ description="A basic chat interface for the BitNet 1.58-bit 2B parameter model. Remember it's a base model, so prompting matters!",
99
+ theme="soft",
100
+ examples=[["Hello!"], ["Explain the concept of 1.58-bit quantization."]],
101
+ cache_examples=False, # Set to True to cache example results
102
+ retry_btn=None,
103
+ undo_btn="Delete Previous Turn",
104
+ clear_btn="Clear Chat",
105
+ )
106
+
107
+ # --- Launch the Interface ---
108
+ if __name__ == "__main__":
109
+ chatbot_interface.launch() # Use share=True for public link if running locally