Spaces:
Running
Running
Added copywriter mode
Browse files- copywriter.py +77 -0
- requirements.txt +1 -0
- search_agent.py +40 -4
- search_agent_ui.py +1 -1
- web_rag.py +6 -5
copywriter.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from langchain.schema import SystemMessage, HumanMessage
|
2 |
+
from langchain.prompts.chat import (
|
3 |
+
HumanMessagePromptTemplate,
|
4 |
+
SystemMessagePromptTemplate,
|
5 |
+
ChatPromptTemplate
|
6 |
+
)
|
7 |
+
from langchain.prompts.prompt import PromptTemplate
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def get_comments_prompt(query, draft):
|
12 |
+
system_message = SystemMessage(
|
13 |
+
content="""
|
14 |
+
You are an AI text reviewer with a keen eye for detail and a deep understanding of language, style, and grammar.
|
15 |
+
Your task is to refine and improve the draft content provided by the writers, offering advanced copyediting techniques and suggestions to enhance the overall quality of the text.
|
16 |
+
When a user submits a piece of writing, follow these steps:
|
17 |
+
1. Read the orginal query from the user so you understand clearly the request that was given to the writer.
|
18 |
+
2. Read through the draft text carefully, identifying areas that need improvement in terms of grammar, punctuation, spelling, syntax, and style.
|
19 |
+
3. Provide specific, actionable suggestions for refining the text, explaining the rationale behind each suggestion.
|
20 |
+
4. Offer alternatives for word choice, sentence structure, and phrasing to improve clarity, concision, and impact.
|
21 |
+
5. Ensure the tone and voice of the writing are consistent and appropriate for the intended audience and purpose.
|
22 |
+
6. Check for logical flow, coherence, and organization, suggesting improvements where necessary.
|
23 |
+
7. Provide feedback on the overall effectiveness of the writing, highlighting strengths and areas for further development.
|
24 |
+
|
25 |
+
Your suggestions should be constructive, insightful, and designed to help the user elevate the quality of their writing.
|
26 |
+
You never generate the corrected text by itself. *Only* give the comment.
|
27 |
+
"""
|
28 |
+
)
|
29 |
+
human_message = HumanMessage(
|
30 |
+
content=f"""
|
31 |
+
Original query: {query}
|
32 |
+
------------------------
|
33 |
+
Draft text: {draft}
|
34 |
+
"""
|
35 |
+
)
|
36 |
+
return [system_message, human_message]
|
37 |
+
|
38 |
+
|
39 |
+
def generate_comments(chat_llm, query, draft, callbacks=[]):
|
40 |
+
messages = get_comments_prompt(query, draft)
|
41 |
+
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
42 |
+
return response.content
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
def get_final_text_prompt(query, draft, comments):
|
47 |
+
system_message = SystemMessage(
|
48 |
+
content="""
|
49 |
+
You are an AI copyeditor with a keen eye for detail and a deep understanding of language, style, and grammar.
|
50 |
+
Your role is to elevate the quality of the writing.
|
51 |
+
You are given:
|
52 |
+
1. The orginal query from the user
|
53 |
+
2. The draft text from the writer
|
54 |
+
3. The comments from the reviewer
|
55 |
+
Your task is to refine and improve draft text taking into account the comments from the reviewer.
|
56 |
+
Output a fully edited version that takes into account the original query, the draft text, and the comments from the reviewer.
|
57 |
+
Keep the references of the draft untouched!
|
58 |
+
"""
|
59 |
+
)
|
60 |
+
human_message = HumanMessage(
|
61 |
+
content=f"""
|
62 |
+
Original query: {query}
|
63 |
+
-------------------------------------
|
64 |
+
Draft text: {draft}
|
65 |
+
-------------------------------------
|
66 |
+
Comments from the reviewer: {comments}
|
67 |
+
-------------------------------------
|
68 |
+
Final text:
|
69 |
+
"""
|
70 |
+
)
|
71 |
+
return [system_message, human_message]
|
72 |
+
|
73 |
+
|
74 |
+
def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
|
75 |
+
messages = get_final_text_prompt(query, draft, comments)
|
76 |
+
response = chat_llm.invoke(messages, config={"callbacks": callbacks})
|
77 |
+
return response.content
|
requirements.txt
CHANGED
@@ -15,6 +15,7 @@ langchain_experimental
|
|
15 |
langchain_openai
|
16 |
langchain_groq
|
17 |
langsmith
|
|
|
18 |
streamlit
|
19 |
selenium
|
20 |
rich
|
|
|
15 |
langchain_openai
|
16 |
langchain_groq
|
17 |
langsmith
|
18 |
+
schema
|
19 |
streamlit
|
20 |
selenium
|
21 |
rich
|
search_agent.py
CHANGED
@@ -6,6 +6,7 @@ Usage:
|
|
6 |
[--provider=provider]
|
7 |
[--model=model]
|
8 |
[--temperature=temp]
|
|
|
9 |
[--max_pages=num]
|
10 |
[--output=text]
|
11 |
SEARCH_QUERY
|
@@ -14,6 +15,7 @@ Usage:
|
|
14 |
Options:
|
15 |
-h --help Show this screen.
|
16 |
--version Show version.
|
|
|
17 |
-d domain --domain=domain Limit search to a specific domain
|
18 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
19 |
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere,fireworks) [default: openai]
|
@@ -26,6 +28,7 @@ Options:
|
|
26 |
import os
|
27 |
|
28 |
from docopt import docopt
|
|
|
29 |
import dotenv
|
30 |
|
31 |
from langchain.callbacks import LangChainTracer
|
@@ -37,6 +40,7 @@ from rich.markdown import Markdown
|
|
37 |
|
38 |
import web_rag as wr
|
39 |
import web_crawler as wc
|
|
|
40 |
|
41 |
console = Console()
|
42 |
dotenv.load_dotenv()
|
@@ -69,7 +73,18 @@ if os.getenv("LANGCHAIN_API_KEY"):
|
|
69 |
|
70 |
if __name__ == '__main__':
|
71 |
arguments = docopt(__doc__, version='Search Agent 0.1')
|
72 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
provider = arguments["--provider"]
|
74 |
model = arguments["--model"]
|
75 |
temperature = float(arguments["--temperature"])
|
@@ -101,11 +116,32 @@ if __name__ == '__main__':
|
|
101 |
vector_store = wc.vectorize(contents, embedding_model)
|
102 |
|
103 |
with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
|
104 |
-
|
105 |
|
106 |
console.rule(f"[bold green]Response from {provider}")
|
107 |
if output == "text":
|
108 |
-
console.print(
|
109 |
else:
|
110 |
-
console.print(Markdown(
|
111 |
console.rule("[bold green]")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
[--provider=provider]
|
7 |
[--model=model]
|
8 |
[--temperature=temp]
|
9 |
+
[--copywrite]
|
10 |
[--max_pages=num]
|
11 |
[--output=text]
|
12 |
SEARCH_QUERY
|
|
|
15 |
Options:
|
16 |
-h --help Show this screen.
|
17 |
--version Show version.
|
18 |
+
-c --copywrite First produce a draft, review it and rewrite for a final text
|
19 |
-d domain --domain=domain Limit search to a specific domain
|
20 |
-t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
|
21 |
-p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere,fireworks) [default: openai]
|
|
|
28 |
import os
|
29 |
|
30 |
from docopt import docopt
|
31 |
+
#from schema import Schema, Use, SchemaError
|
32 |
import dotenv
|
33 |
|
34 |
from langchain.callbacks import LangChainTracer
|
|
|
40 |
|
41 |
import web_rag as wr
|
42 |
import web_crawler as wc
|
43 |
+
import copywriter as cw
|
44 |
|
45 |
console = Console()
|
46 |
dotenv.load_dotenv()
|
|
|
73 |
|
74 |
if __name__ == '__main__':
|
75 |
arguments = docopt(__doc__, version='Search Agent 0.1')
|
76 |
+
|
77 |
+
#schema = Schema({
|
78 |
+
# '--max_pages': Use(int, error='--max_pages must be an integer'),
|
79 |
+
# '--temperature': Use(float, error='--temperature must be an float'),
|
80 |
+
#})
|
81 |
+
|
82 |
+
#try:
|
83 |
+
# arguments = schema.validate(arguments)
|
84 |
+
#except SchemaError as e:
|
85 |
+
# exit(e)
|
86 |
+
|
87 |
+
copywrite_mode = arguments["--copywrite"]
|
88 |
provider = arguments["--provider"]
|
89 |
model = arguments["--model"]
|
90 |
temperature = float(arguments["--temperature"])
|
|
|
116 |
vector_store = wc.vectorize(contents, embedding_model)
|
117 |
|
118 |
with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
|
119 |
+
draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = 5, callbacks=callbacks)
|
120 |
|
121 |
console.rule(f"[bold green]Response from {provider}")
|
122 |
if output == "text":
|
123 |
+
console.print(draft)
|
124 |
else:
|
125 |
+
console.print(Markdown(draft))
|
126 |
console.rule("[bold green]")
|
127 |
+
|
128 |
+
if(copywrite_mode):
|
129 |
+
with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
|
130 |
+
comments = cw.generate_comments(chat, query, draft, callbacks=callbacks)
|
131 |
+
|
132 |
+
console.rule(f"[bold green]Response from reviewer")
|
133 |
+
if output == "text":
|
134 |
+
console.print(comments)
|
135 |
+
else:
|
136 |
+
console.print(Markdown(comments))
|
137 |
+
console.rule("[bold green]")
|
138 |
+
|
139 |
+
with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
|
140 |
+
final_text = cw.generate_final_text(chat, query, draft, comments, callbacks=callbacks)
|
141 |
+
|
142 |
+
console.rule(f"[bold green]Final text")
|
143 |
+
if output == "text":
|
144 |
+
console.print(final_text)
|
145 |
+
else:
|
146 |
+
console.print(Markdown(final_text))
|
147 |
+
console.rule("[bold green]")
|
search_agent_ui.py
CHANGED
@@ -52,7 +52,7 @@ with st.sidebar:
|
|
52 |
model_provider = st.selectbox("π§ Model provider π§ ", st.session_state["providers"])
|
53 |
temperature = st.slider("π‘οΈ Model temperature π‘οΈ", 0.0, 1.0, 0.1, help="The higher the more creative")
|
54 |
max_pages = st.slider("π Max pages to retrieve π", 1, 20, 15, help="How many web pages to retrive from the internet")
|
55 |
-
top_k_documents = st.slider("π How many
|
56 |
|
57 |
if "messages" not in st.session_state:
|
58 |
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
|
|
52 |
model_provider = st.selectbox("π§ Model provider π§ ", st.session_state["providers"])
|
53 |
temperature = st.slider("π‘οΈ Model temperature π‘οΈ", 0.0, 1.0, 0.1, help="The higher the more creative")
|
54 |
max_pages = st.slider("π Max pages to retrieve π", 1, 20, 15, help="How many web pages to retrive from the internet")
|
55 |
+
top_k_documents = st.slider("π How many doc extracts to consider π", 1, 20, 5, help="How many of the top extracts to consider")
|
56 |
|
57 |
if "messages" not in st.session_state:
|
58 |
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
web_rag.py
CHANGED
@@ -47,12 +47,13 @@ def get_models(provider, model=None, temperature=0.0):
|
|
47 |
chat_llm = BedrockChat(
|
48 |
credentials_profile_name=credentials_profile_name,
|
49 |
model_id=model,
|
50 |
-
model_kwargs={"temperature": temperature },
|
51 |
-
)
|
52 |
-
embedding_model = BedrockEmbeddings(
|
53 |
-
model_id='cohere.embed-multilingual-v3',
|
54 |
-
credentials_profile_name=credentials_profile_name
|
55 |
)
|
|
|
|
|
|
|
|
|
|
|
56 |
case 'openai':
|
57 |
if model is None:
|
58 |
model = "gpt-3.5-turbo"
|
|
|
47 |
chat_llm = BedrockChat(
|
48 |
credentials_profile_name=credentials_profile_name,
|
49 |
model_id=model,
|
50 |
+
model_kwargs={"temperature": temperature, 'max_tokens': 8192 },
|
|
|
|
|
|
|
|
|
51 |
)
|
52 |
+
#embedding_model = BedrockEmbeddings(
|
53 |
+
# model_id='cohere.embed-multilingual-v3',
|
54 |
+
# credentials_profile_name=credentials_profile_name
|
55 |
+
#)
|
56 |
+
embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
|
57 |
case 'openai':
|
58 |
if model is None:
|
59 |
model = "gpt-3.5-turbo"
|