Sanchit2207 commited on
Commit
c960062
·
verified ·
1 Parent(s): 49c9a25

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -58
app.py CHANGED
@@ -1,60 +1,43 @@
1
- import gradio as gr
2
- from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
3
  import torch
4
 
5
- # ---------------- Agent 1: Intent Classifier ----------------
6
- intent_classifier = pipeline("zero-shot-classification", model="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli")
7
-
8
-
9
- def detect_intent(text):
10
- labels = {
11
- "weather": "The user wants to know the weather.",
12
- "faq": "The user is asking for help.",
13
- "smalltalk": "The user is making casual conversation."
14
- }
15
- best_intent = "smalltalk"
16
- best_score = 0
17
- for label, hypothesis in labels.items():
18
- result = intent_classifier(text=text, text_pair=hypothesis)[0]
19
- if result['label'] == 'ENTAILMENT' and result['score'] > best_score:
20
- best_score = result['score']
21
- best_intent = label
22
- return best_intent
23
-
24
- # ---------------- Agent 2: Domain Logic ----------------
25
- def handle_logic(intent):
26
- if intent == "weather":
27
- return "It's sunny and 26°C today."
28
- elif intent == "faq":
29
- return "To reset your password, use the 'Forgot Password' option."
30
- else:
31
- return "That's great! Anything else you'd like to talk about?"
32
-
33
- # ---------------- Agent 3: Natural Language Generation ----------------
34
- tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
35
- model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
36
-
37
- def generate_reply(prompt):
38
- input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors='pt')
39
- output_ids = model.generate(input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id)
40
- response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
41
- return response
42
-
43
- # ---------------- Chatbot Pipeline ----------------
44
- def chatbot(user_input):
45
- intent = detect_intent(user_input)
46
- logic = handle_logic(intent)
47
- response = generate_reply(logic)
48
- return response
49
-
50
- # ---------------- Gradio UI ----------------
51
- gr.Interface(
52
- fn=chatbot,
53
- inputs=gr.Textbox(label="User Input"),
54
- outputs=gr.Textbox(label="Chatbot Response"),
55
- title="3-Agent Chatbot",
56
- description="Intent Detection → Domain Logic → Natural Language Generation"
57
- ).launch()
58
-
59
-
60
-
 
1
+ from transformers import AutoTokenizer, AutoModelForCausalLM
 
2
  import torch
3
 
4
+ # Load same or different models for each agent
5
+ tokenizer1 = AutoTokenizer.from_pretrained("gpt2")
6
+ model1 = AutoModelForCausalLM.from_pretrained("gpt2")
7
+
8
+ tokenizer2 = AutoTokenizer.from_pretrained("gpt2-medium")
9
+ model2 = AutoModelForCausalLM.from_pretrained("gpt2-medium")
10
+
11
+ tokenizer3 = AutoTokenizer.from_pretrained("gpt2-large")
12
+ model3 = AutoModelForCausalLM.from_pretrained("gpt2-large")
13
+
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ model1, model2, model3 = model1.to(device), model2.to(device), model3.to(device)
16
+
17
+ def generate_response(model, tokenizer, prompt):
18
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
19
+ outputs = model.generate(inputs["input_ids"], max_length=100, pad_token_id=tokenizer.eos_token_id)
20
+ return tokenizer.decode(outputs[0], skip_special_tokens=True)
21
+
22
+ import gradio as gr
23
+
24
+ def multi_agent_chat(user_input):
25
+ res1 = generate_response(model1, tokenizer1, user_input)
26
+ res2 = generate_response(model2, tokenizer2, user_input)
27
+ res3 = generate_response(model3, tokenizer3, user_input)
28
+
29
+ return res1, res2, res3
30
+
31
+ interface = gr.Interface(
32
+ fn=multi_agent_chat,
33
+ inputs=gr.Textbox(lines=2, placeholder="Ask something..."),
34
+ outputs=[
35
+ gr.Textbox(label="Agent 1 (GPT-2)"),
36
+ gr.Textbox(label="Agent 2 (GPT-2 Medium)"),
37
+ gr.Textbox(label="Agent 3 (GPT-2 Large)")
38
+ ],
39
+ title="3-Agent AI Chatbot"
40
+ )
41
+
42
+ interface.launch()
43
+