File size: 10,526 Bytes
9874cd6
c3dfa5f
 
 
 
 
 
 
 
9874cd6
 
c3dfa5f
9874cd6
 
c3dfa5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9874cd6
 
 
 
 
 
 
 
 
 
 
c3dfa5f
9874cd6
 
 
 
 
 
 
 
 
 
c3dfa5f
9874cd6
 
 
 
 
 
 
 
 
 
 
 
c3dfa5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9874cd6
 
c3dfa5f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9874cd6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
import gradio as gr
import os
import json
import time
import subprocess
import threading
import uuid
from pathlib import Path
from huggingface_hub import InferenceClient, HfFolder

"""
Shedify app - Using fine-tuned Llama 3.3 49B for document assistance
"""

# Model settings
DEFAULT_MODEL = "Borislav18/Shedify"  # Your Hugging Face username/model name
LOCAL_MODEL = os.environ.get("LOCAL_MODEL", None)  # Set this if testing locally

# Get Hugging Face token
HF_TOKEN = os.environ.get("HF_TOKEN", None)

# App title and description
title = "Shedify - Document Assistant powered by Llama 3.3"
description = """
This app uses a fine-tuned version of Llama 3.3 49B model trained on your documents.
Ask questions about the documents, generate insights, or request summaries!
"""

# Initialize inference client with your model
client = InferenceClient(
    DEFAULT_MODEL,
    token=HF_TOKEN,
)

# Training status tracking
class TrainingState:
    def __init__(self):
        self.status = "idle"  # idle, running, success, failed
        self.progress = 0.0   # 0.0 to 1.0
        self.message = ""
        self.id = str(uuid.uuid4())[:8]  # Generate a unique ID for this session
        
        # Check if state file exists and load it
        self.state_file = Path("training_state.json")
        self.load_state()
    
    def load_state(self):
        """Load state from file if it exists"""
        if self.state_file.exists():
            try:
                with open(self.state_file, "r") as f:
                    state = json.load(f)
                self.status = state.get("status", "idle")
                self.progress = state.get("progress", 0.0)
                self.message = state.get("message", "")
                self.id = state.get("id", self.id)
            except Exception as e:
                print(f"Error loading state: {e}")
    
    def save_state(self):
        """Save current state to file"""
        try:
            with open(self.state_file, "w") as f:
                json.dump({
                    "status": self.status,
                    "progress": self.progress,
                    "message": self.message,
                    "id": self.id
                }, f)
        except Exception as e:
            print(f"Error saving state: {e}")
    
    def update(self, status=None, progress=None, message=None):
        """Update state and save it"""
        if status is not None:
            self.status = status
        if progress is not None:
            self.progress = progress
        if message is not None:
            self.message = message
        self.save_state()
        return self.status, self.progress, self.message

# Initialize the training state
training_state = TrainingState()

def respond(
    message,
    history: list[tuple[str, str]],
    system_message,
    max_tokens,
    temperature,
    top_p,
):
    messages = [{"role": "system", "content": system_message}]

    # Format history to match chat completion format
    for val in history:
        if val[0]:
            messages.append({"role": "user", "content": val[0]})
        if val[1]:
            messages.append({"role": "assistant", "content": val[1]})

    messages.append({"role": "user", "content": message})

    response = ""

    # Use streaming to get real-time responses
    for message in client.chat_completion(
        messages,
        max_tokens=max_tokens,
        stream=True,
        temperature=temperature,
        top_p=top_p,
    ):
        token = message.choices[0].delta.content

        response += token
        yield response

def run_training_process(pdf_dir, output_name, progress_callback):
    """Run the PDF processing and fine-tuning process"""
    try:
        # Create processed_data directory if it doesn't exist
        os.makedirs("processed_data", exist_ok=True)
        
        # Update state
        progress_callback("running", 0.05, "Processing PDFs...")
        
        # Process PDFs
        pdf_process = subprocess.run(
            ["python", "pdf_processor.py", "--pdf_dir", pdf_dir, "--output_dir", "processed_data"],
            capture_output=True,
            text=True
        )
        
        if pdf_process.returncode != 0:
            progress_callback("failed", 0.0, f"PDF processing failed: {pdf_process.stderr}")
            return False
        
        # Update state
        progress_callback("running", 0.3, "PDFs processed. Starting fine-tuning...")
        
        # Get Hugging Face token
        hf_token = HF_TOKEN or HfFolder.get_token()
        if not hf_token:
            progress_callback("failed", 0.0, "No Hugging Face token found. Please set the HF_TOKEN environment variable.")
            return False
        
        # Run fine-tuning
        finetune_process = subprocess.run(
            [
                "python", "finetune_llama3.py",
                "--dataset_path", "processed_data/training_data",
                "--hub_model_id", f"Borislav18/{output_name}",
                "--epochs", "1",  # Starting with 1 epoch for quicker feedback
                "--gradient_accumulation_steps", "4"
            ],
            env={**os.environ, "HF_TOKEN": hf_token},
            capture_output=True,
            text=True
        )
        
        if finetune_process.returncode != 0:
            progress_callback("failed", 0.0, f"Fine-tuning failed: {finetune_process.stderr}")
            return False
        
        # Update state
        progress_callback("success", 1.0, f"Training complete! Model pushed to Hugging Face as Borislav18/{output_name}")
        return True
    
    except Exception as e:
        progress_callback("failed", 0.0, f"Training process failed with error: {str(e)}")
        return False

def training_thread(pdf_dir, output_name):
    """Background thread for running training"""
    def progress_callback(status, progress, message):
        training_state.update(status, progress, message)
    
    # Simulate progress updates for UI feedback
    progress_callback("running", 0.01, "Starting training process...")
    
    # Run the actual training process
    run_training_process(pdf_dir, output_name, progress_callback)

def start_training(pdf_dir, output_name):
    """Start the training process in a background thread"""
    if not pdf_dir or not output_name:
        return "Please provide both a PDF directory and output model name", 0.0, "idle"
    
    # Check if already running
    if training_state.status == "running":
        return f"Training already in progress: {training_state.message}", training_state.progress, training_state.status
    
    # Start background thread
    thread = threading.Thread(
        target=training_thread,
        args=(pdf_dir, output_name),
        daemon=True
    )
    thread.start()
    
    return "Training started...", 0.0, "running"

def get_training_status():
    """Get the current training status for UI updates"""
    return training_state.message, training_state.progress, training_state.status


# Create the main application
with gr.Blocks(title="Shedify - Document Assistant") as demo:
    with gr.Row():
        with gr.Column(scale=2):
            gr.Markdown(f"# {title}")
            gr.Markdown(description)
        
        with gr.Column(scale=1):
            # Training controls
            with gr.Group(visible=True):
                gr.Markdown("## Train New Model")
                pdf_dir = gr.Textbox(label="PDF Directory", placeholder="Path to directory containing PDFs")
                output_name = gr.Textbox(label="Model Name", placeholder="Name for your fine-tuned model", value="Shedify-v1")
                train_btn = gr.Button("Start Training")
                
                training_message = gr.Textbox(label="Training Status", interactive=False)
                training_progress = gr.Slider(
                    minimum=0, maximum=1, value=0, 
                    label="Progress", interactive=False
                )
                training_status = gr.Textbox(visible=False)
    
    # Chat interface
    chatbot = gr.ChatInterface(
        fn=respond,
        additional_inputs=[
            gr.Textbox(
                value="You are an AI assistant trained on specific documents. Answer questions based only on information from these documents. If you don't know the answer from the documents, say so clearly.",
                label="System message"
            ),
            gr.Slider(minimum=1, maximum=2048, value=1024, step=1, label="Max new tokens"),
            gr.Slider(minimum=0.1, maximum=2.0, value=0.7, step=0.1, label="Temperature"),
            gr.Slider(
                minimum=0.1,
                maximum=1.0,
                value=0.9,
                step=0.05,
                label="Top-p (nucleus sampling)",
            ),
        ],
        examples=[
            ["Summarize the key points from all documents you were trained on."],
            ["What are the main themes discussed in the documents?"],
            ["Extract the most important concepts mentioned in the documents."],
            ["Explain the relationship between the different topics in the documents."],
            ["What recommendations or conclusions can be drawn from the documents?"],
        ]
    )
    
    # Set up event handlers
    train_btn.click(
        fn=start_training, 
        inputs=[pdf_dir, output_name], 
        outputs=[training_message, training_progress, training_status]
    )
    
    # Setup periodic status checking
    demo.load(get_training_status, outputs=[training_message, training_progress, training_status])
    
    def update_ui(message, progress, status):
        is_running = status == "running"
        color = {
            "idle": "gray", 
            "running": "blue", 
            "success": "green", 
            "failed": "red"
        }.get(status, "gray")
        
        message_with_color = f"<span style='color: {color}'>{message}</span>"
        return message_with_color, progress, train_btn.update(interactive=not is_running)
    
    training_status.change(
        fn=update_ui, 
        inputs=[training_message, training_progress, training_status], 
        outputs=[training_message, training_progress, train_btn]
    )
    
    # Set interval to update the UI every few seconds
    demo.add_event_handler("load", None, None, None, None, interval=5.0, inputs=None, outputs=[training_message, training_progress, training_status], _js=None, fn=get_training_status)

if __name__ == "__main__":
    demo.launch()