tosin2013 commited on
Commit
0647d83
·
verified ·
1 Parent(s): 8d4e900

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -0
app.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import openai
3
+ from openai import OpenAI
4
+ from langchain_community.embeddings import HuggingFaceEmbeddings
5
+ from datasets import load_dataset, Dataset
6
+ from sklearn.neighbors import NearestNeighbors
7
+ import numpy as np
8
+ from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM, TextStreamer
9
+ import torch
10
+ from typing import List
11
+ from langchain_core.output_parsers import StrOutputParser
12
+ from langchain_core.prompts import ChatPromptTemplate
13
+ import gradio as gr
14
+
15
+ # Configuration
16
+
17
+ # Sample questions:
18
+ # 1. What are the key components of an effective persona for prompt generation?
19
+ # 2. How can I create a persona that generates creative writing prompts?
20
+ # 3. What strategies can I use to make my persona-driven prompts more engaging?
21
+
22
+ DEFAULT_QUESTION = "Ask me anything in the context of persona-driven prompt generation..."
23
+
24
+ # Set API keys (make sure these are set in your environment)
25
+ os.environ['OPENAI_BASE'] = "https://api.openai.com/v1"
26
+ os.environ['OPENAI_MODEL'] = "gpt-4"
27
+ os.environ['MODEL_PROVIDER'] = "huggingface"
28
+ api_key = os.environ.get("OPENAI_API_KEY")
29
+ model_provider = os.environ.get("MODEL_PROVIDER")
30
+
31
+ # Instantiate the client for openai v1.x
32
+ if model_provider.lower() == "openai":
33
+ MODEL_NAME = os.environ['OPENAI_MODEL']
34
+ client = OpenAI(
35
+ base_url=os.environ.get("OPENAI_BASE"),
36
+ api_key=api_key
37
+ )
38
+ else:
39
+ MODEL_NAME = "meta-llama/Llama-3.3-70B-Instruct"
40
+ # Initialize Hugging Face client with streaming support
41
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, token=os.environ.get("HF_TOKEN"))
42
+ model = AutoModelForCausalLM.from_pretrained(
43
+ MODEL_NAME,
44
+ device_map='auto',
45
+ token=os.environ.get("HF_TOKEN"),
46
+ torch_dtype=torch.bfloat16,
47
+ )
48
+ streamer = TextStreamer(tokenizer, skip_prompt=True)
49
+ question_answerer = pipeline(
50
+ "text-generation",
51
+ model=model,
52
+ tokenizer=tokenizer,
53
+ device_map='auto',
54
+ streamer=streamer,
55
+ max_new_tokens=512,
56
+ return_full_text=False
57
+ )
58
+
59
+ # Load the Hugging Face dataset
60
+ dataset = load_dataset('tosin2013/persona-driven-prompt-generator', streaming=True)
61
+ dataset = Dataset.from_list(list(dataset['train']))
62
+
63
+ # Initialize embeddings
64
+ embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
65
+
66
+ # Extract texts from the dataset
67
+ texts = dataset['input']
68
+
69
+ # Create embeddings for the texts
70
+ text_embeddings = embeddings.embed_documents(texts)
71
+
72
+ # Fit a nearest neighbor model
73
+ nn = NearestNeighbors(n_neighbors=5, metric='cosine')
74
+ nn.fit(np.array(text_embeddings))
75
+
76
+ def get_relevant_documents(query, k=5):
77
+ """
78
+ Retrieves the k most relevant documents to the query.
79
+ """
80
+ query_embedding = embeddings.embed_query(query)
81
+ distances, indices = nn.kneighbors([query_embedding], n_neighbors=k)
82
+ relevant_docs = [texts[i] for i in indices[0]]
83
+ return relevant_docs
84
+
85
+ def generate_response(question, history):
86
+ try:
87
+ print(f"\n[LOG] Received question: {question}")
88
+
89
+ # Get relevant documents based on the query
90
+ relevant_docs = get_relevant_documents(question, k=3)
91
+ print(f"[LOG] Retrieved {len(relevant_docs)} relevant documents")
92
+
93
+ # Create the prompt for the LLM
94
+ context = "\n".join(relevant_docs)
95
+ prompt = f"Context: {context}\n\nQuestion: {question}\n\nAnswer:"
96
+ print(f"[LOG] Generated prompt: {prompt[:200]}...") # Log first 200 chars of prompt
97
+
98
+ if model_provider.lower() == "huggingface":
99
+ prompt_template = """
100
+ <s>[INST] <<SYS>>
101
+ You are a helpful AI assistant. Answer the question based on the provided context.
102
+ <</SYS>>
103
+
104
+ {prompt}[/INST]
105
+ """
106
+ chat_prompt = ChatPromptTemplate.from_template(prompt_template)
107
+ result = question_answerer(chat_prompt.format(prompt=prompt))
108
+ response = result[0]['generated_text'] if isinstance(result, list) else result
109
+ print(f"[LOG] Using Hugging Face model: {MODEL_NAME}")
110
+ print(f"[LOG] Hugging Face response: {response[:200]}...") # Log first 200 chars of response
111
+ elif model_provider.lower() == "openai":
112
+ response = client.chat.completions.create(
113
+ model=os.environ.get("OPENAI_MODEL"),
114
+ messages=[
115
+ {"role": "system", "content": "You are a helpful assistant. Answer the question based on the provided context."},
116
+ {"role": "user", "content": prompt},
117
+ ]
118
+ )
119
+ response = response.choices[0].message.content
120
+ print(f"[LOG] Using OpenAI model: {os.environ.get('OPENAI_MODEL')}")
121
+ print(f"[LOG] OpenAI response: {response[:200]}...") # Log first 200 chars of response
122
+
123
+ # Update chat history with new message pair
124
+ history.append((question, response))
125
+ return history
126
+ except Exception as e:
127
+ error_msg = f"Error generating response: {str(e)}"
128
+ print(f"[ERROR] {error_msg}")
129
+ history.append((question, error_msg))
130
+ return history
131
+
132
+ # Create Gradio interface
133
+ with gr.Blocks() as demo:
134
+ gr.Markdown(f"""
135
+ ## Persona-Driven Prompt Generator QA Agent
136
+ **Current Model:** {MODEL_NAME}
137
+
138
+ The Custom Prompt Generator is a Python application that leverages Large Language Models (LLMs) and the LiteLLM library to dynamically generate personas, fetch knowledge sources, resolve conflicts, and produce tailored prompts. This application is designed to assist in various software development tasks by providing context-aware prompts based on user input and predefined personas.
139
+
140
+ Sample questions:
141
+ 1. What are the key components of an effective persona for prompt generation?
142
+ 2. How can I create a persona that generates creative writing prompts?
143
+ 3. What are the main features of the persona generator?
144
+
145
+ Related repository: [persona-driven-prompt-generator](https://github.com/tosin2013/persona-driven-prompt-generator)
146
+ """)
147
+
148
+ with gr.Row():
149
+ chatbot = gr.Chatbot(label="Chat History")
150
+
151
+ with gr.Row():
152
+ question = gr.Textbox(
153
+ value=DEFAULT_QUESTION,
154
+ label="Your Question",
155
+ placeholder=DEFAULT_QUESTION
156
+ )
157
+
158
+ with gr.Row():
159
+ submit_btn = gr.Button("Submit")
160
+ clear_btn = gr.Button("Clear")
161
+
162
+ # Event handlers
163
+ submit_btn.click(
164
+ generate_response,
165
+ inputs=[question, chatbot],
166
+ outputs=[chatbot]
167
+ )
168
+
169
+ clear_btn.click(
170
+ lambda: (None, ""),
171
+ inputs=[],
172
+ outputs=[chatbot, question]
173
+ )
174
+
175
+ if __name__ == "__main__":
176
+ demo.launch()