CyranoB commited on
Commit
8c28786
Β·
1 Parent(s): df527c8

Added copywriter mode

Browse files
Files changed (5) hide show
  1. copywriter.py +77 -0
  2. requirements.txt +1 -0
  3. search_agent.py +40 -4
  4. search_agent_ui.py +1 -1
  5. web_rag.py +6 -5
copywriter.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from langchain.schema import SystemMessage, HumanMessage
2
+ from langchain.prompts.chat import (
3
+ HumanMessagePromptTemplate,
4
+ SystemMessagePromptTemplate,
5
+ ChatPromptTemplate
6
+ )
7
+ from langchain.prompts.prompt import PromptTemplate
8
+
9
+
10
+
11
+ def get_comments_prompt(query, draft):
12
+ system_message = SystemMessage(
13
+ content="""
14
+ You are an AI text reviewer with a keen eye for detail and a deep understanding of language, style, and grammar.
15
+ Your task is to refine and improve the draft content provided by the writers, offering advanced copyediting techniques and suggestions to enhance the overall quality of the text.
16
+ When a user submits a piece of writing, follow these steps:
17
+ 1. Read the orginal query from the user so you understand clearly the request that was given to the writer.
18
+ 2. Read through the draft text carefully, identifying areas that need improvement in terms of grammar, punctuation, spelling, syntax, and style.
19
+ 3. Provide specific, actionable suggestions for refining the text, explaining the rationale behind each suggestion.
20
+ 4. Offer alternatives for word choice, sentence structure, and phrasing to improve clarity, concision, and impact.
21
+ 5. Ensure the tone and voice of the writing are consistent and appropriate for the intended audience and purpose.
22
+ 6. Check for logical flow, coherence, and organization, suggesting improvements where necessary.
23
+ 7. Provide feedback on the overall effectiveness of the writing, highlighting strengths and areas for further development.
24
+
25
+ Your suggestions should be constructive, insightful, and designed to help the user elevate the quality of their writing.
26
+ You never generate the corrected text by itself. *Only* give the comment.
27
+ """
28
+ )
29
+ human_message = HumanMessage(
30
+ content=f"""
31
+ Original query: {query}
32
+ ------------------------
33
+ Draft text: {draft}
34
+ """
35
+ )
36
+ return [system_message, human_message]
37
+
38
+
39
+ def generate_comments(chat_llm, query, draft, callbacks=[]):
40
+ messages = get_comments_prompt(query, draft)
41
+ response = chat_llm.invoke(messages, config={"callbacks": callbacks})
42
+ return response.content
43
+
44
+
45
+
46
+ def get_final_text_prompt(query, draft, comments):
47
+ system_message = SystemMessage(
48
+ content="""
49
+ You are an AI copyeditor with a keen eye for detail and a deep understanding of language, style, and grammar.
50
+ Your role is to elevate the quality of the writing.
51
+ You are given:
52
+ 1. The orginal query from the user
53
+ 2. The draft text from the writer
54
+ 3. The comments from the reviewer
55
+ Your task is to refine and improve draft text taking into account the comments from the reviewer.
56
+ Output a fully edited version that takes into account the original query, the draft text, and the comments from the reviewer.
57
+ Keep the references of the draft untouched!
58
+ """
59
+ )
60
+ human_message = HumanMessage(
61
+ content=f"""
62
+ Original query: {query}
63
+ -------------------------------------
64
+ Draft text: {draft}
65
+ -------------------------------------
66
+ Comments from the reviewer: {comments}
67
+ -------------------------------------
68
+ Final text:
69
+ """
70
+ )
71
+ return [system_message, human_message]
72
+
73
+
74
+ def generate_final_text(chat_llm, query, draft, comments, callbacks=[]):
75
+ messages = get_final_text_prompt(query, draft, comments)
76
+ response = chat_llm.invoke(messages, config={"callbacks": callbacks})
77
+ return response.content
requirements.txt CHANGED
@@ -15,6 +15,7 @@ langchain_experimental
15
  langchain_openai
16
  langchain_groq
17
  langsmith
 
18
  streamlit
19
  selenium
20
  rich
 
15
  langchain_openai
16
  langchain_groq
17
  langsmith
18
+ schema
19
  streamlit
20
  selenium
21
  rich
search_agent.py CHANGED
@@ -6,6 +6,7 @@ Usage:
6
  [--provider=provider]
7
  [--model=model]
8
  [--temperature=temp]
 
9
  [--max_pages=num]
10
  [--output=text]
11
  SEARCH_QUERY
@@ -14,6 +15,7 @@ Usage:
14
  Options:
15
  -h --help Show this screen.
16
  --version Show version.
 
17
  -d domain --domain=domain Limit search to a specific domain
18
  -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
19
  -p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere,fireworks) [default: openai]
@@ -26,6 +28,7 @@ Options:
26
  import os
27
 
28
  from docopt import docopt
 
29
  import dotenv
30
 
31
  from langchain.callbacks import LangChainTracer
@@ -37,6 +40,7 @@ from rich.markdown import Markdown
37
 
38
  import web_rag as wr
39
  import web_crawler as wc
 
40
 
41
  console = Console()
42
  dotenv.load_dotenv()
@@ -69,7 +73,18 @@ if os.getenv("LANGCHAIN_API_KEY"):
69
 
70
  if __name__ == '__main__':
71
  arguments = docopt(__doc__, version='Search Agent 0.1')
72
-
 
 
 
 
 
 
 
 
 
 
 
73
  provider = arguments["--provider"]
74
  model = arguments["--model"]
75
  temperature = float(arguments["--temperature"])
@@ -101,11 +116,32 @@ if __name__ == '__main__':
101
  vector_store = wc.vectorize(contents, embedding_model)
102
 
103
  with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
104
- respomse = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = 5, callbacks=callbacks)
105
 
106
  console.rule(f"[bold green]Response from {provider}")
107
  if output == "text":
108
- console.print(respomse)
109
  else:
110
- console.print(Markdown(respomse))
111
  console.rule("[bold green]")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  [--provider=provider]
7
  [--model=model]
8
  [--temperature=temp]
9
+ [--copywrite]
10
  [--max_pages=num]
11
  [--output=text]
12
  SEARCH_QUERY
 
15
  Options:
16
  -h --help Show this screen.
17
  --version Show version.
18
+ -c --copywrite First produce a draft, review it and rewrite for a final text
19
  -d domain --domain=domain Limit search to a specific domain
20
  -t temp --temperature=temp Set the temperature of the LLM [default: 0.0]
21
  -p provider --provider=provider Use a specific LLM (choices: bedrock,openai,groq,ollama,cohere,fireworks) [default: openai]
 
28
  import os
29
 
30
  from docopt import docopt
31
+ #from schema import Schema, Use, SchemaError
32
  import dotenv
33
 
34
  from langchain.callbacks import LangChainTracer
 
40
 
41
  import web_rag as wr
42
  import web_crawler as wc
43
+ import copywriter as cw
44
 
45
  console = Console()
46
  dotenv.load_dotenv()
 
73
 
74
  if __name__ == '__main__':
75
  arguments = docopt(__doc__, version='Search Agent 0.1')
76
+
77
+ #schema = Schema({
78
+ # '--max_pages': Use(int, error='--max_pages must be an integer'),
79
+ # '--temperature': Use(float, error='--temperature must be an float'),
80
+ #})
81
+
82
+ #try:
83
+ # arguments = schema.validate(arguments)
84
+ #except SchemaError as e:
85
+ # exit(e)
86
+
87
+ copywrite_mode = arguments["--copywrite"]
88
  provider = arguments["--provider"]
89
  model = arguments["--model"]
90
  temperature = float(arguments["--temperature"])
 
116
  vector_store = wc.vectorize(contents, embedding_model)
117
 
118
  with console.status("[bold green]Querying LLM relevant context", spinner='dots8Bit'):
119
+ draft = wr.query_rag(chat, query, optimize_search_query, vector_store, top_k = 5, callbacks=callbacks)
120
 
121
  console.rule(f"[bold green]Response from {provider}")
122
  if output == "text":
123
+ console.print(draft)
124
  else:
125
+ console.print(Markdown(draft))
126
  console.rule("[bold green]")
127
+
128
+ if(copywrite_mode):
129
+ with console.status("[bold green]Getting comments from the reviewer", spinner="dots8Bit"):
130
+ comments = cw.generate_comments(chat, query, draft, callbacks=callbacks)
131
+
132
+ console.rule(f"[bold green]Response from reviewer")
133
+ if output == "text":
134
+ console.print(comments)
135
+ else:
136
+ console.print(Markdown(comments))
137
+ console.rule("[bold green]")
138
+
139
+ with console.status("[bold green]Writing the final text", spinner="dots8Bit"):
140
+ final_text = cw.generate_final_text(chat, query, draft, comments, callbacks=callbacks)
141
+
142
+ console.rule(f"[bold green]Final text")
143
+ if output == "text":
144
+ console.print(final_text)
145
+ else:
146
+ console.print(Markdown(final_text))
147
+ console.rule("[bold green]")
search_agent_ui.py CHANGED
@@ -52,7 +52,7 @@ with st.sidebar:
52
  model_provider = st.selectbox("🧠 Model provider 🧠", st.session_state["providers"])
53
  temperature = st.slider("🌑️ Model temperature 🌑️", 0.0, 1.0, 0.1, help="The higher the more creative")
54
  max_pages = st.slider("πŸ” Max pages to retrieve πŸ”", 1, 20, 15, help="How many web pages to retrive from the internet")
55
- top_k_documents = st.slider("πŸ“„ How many document extracts to consider πŸ“„", 1, 20, 5, help="How many of the top extracts to consider")
56
 
57
  if "messages" not in st.session_state:
58
  st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
 
52
  model_provider = st.selectbox("🧠 Model provider 🧠", st.session_state["providers"])
53
  temperature = st.slider("🌑️ Model temperature 🌑️", 0.0, 1.0, 0.1, help="The higher the more creative")
54
  max_pages = st.slider("πŸ” Max pages to retrieve πŸ”", 1, 20, 15, help="How many web pages to retrive from the internet")
55
+ top_k_documents = st.slider("πŸ“„ How many doc extracts to consider πŸ“„", 1, 20, 5, help="How many of the top extracts to consider")
56
 
57
  if "messages" not in st.session_state:
58
  st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
web_rag.py CHANGED
@@ -47,12 +47,13 @@ def get_models(provider, model=None, temperature=0.0):
47
  chat_llm = BedrockChat(
48
  credentials_profile_name=credentials_profile_name,
49
  model_id=model,
50
- model_kwargs={"temperature": temperature },
51
- )
52
- embedding_model = BedrockEmbeddings(
53
- model_id='cohere.embed-multilingual-v3',
54
- credentials_profile_name=credentials_profile_name
55
  )
 
 
 
 
 
56
  case 'openai':
57
  if model is None:
58
  model = "gpt-3.5-turbo"
 
47
  chat_llm = BedrockChat(
48
  credentials_profile_name=credentials_profile_name,
49
  model_id=model,
50
+ model_kwargs={"temperature": temperature, 'max_tokens': 8192 },
 
 
 
 
51
  )
52
+ #embedding_model = BedrockEmbeddings(
53
+ # model_id='cohere.embed-multilingual-v3',
54
+ # credentials_profile_name=credentials_profile_name
55
+ #)
56
+ embedding_model = OpenAIEmbeddings(model='text-embedding-3-small')
57
  case 'openai':
58
  if model is None:
59
  model = "gpt-3.5-turbo"