nuojohnchen commited on
Commit
ca06ed0
·
verified ·
1 Parent(s): e53ff8e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -52
app.py CHANGED
@@ -6,13 +6,13 @@ import PyPDF2
6
  from io import BytesIO
7
  import torch
8
 
9
- # 设置环境变量
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
 
12
  DESCRIPTION = '''
13
  <div>
14
  <h1 style="text-align: center;">Academic Paper Improver</h1>
15
- <p>This Space helps you improve sections of your academic paper using the <a href="https://huggingface.co/Xtra-Computing/XtraGPT-7B"><b>XtraGPT-7B</b></a> model.</p>
16
  <p>Upload your PDF paper, select a section of text you want to improve, and specify your requirements.</p>
17
  </div>
18
  '''
@@ -32,7 +32,7 @@ CITATION = """
32
  LICENSE = """
33
  <p/>
34
  ---
35
- Built with XtraGPT-7B
36
  """
37
 
38
  css = """
@@ -48,78 +48,102 @@ h1 {
48
  }
49
  """
50
 
51
- # 默认论文内容
52
  default_paper_content = """
53
  The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.
54
  """
55
 
56
- # 直接加载模型和分词器
57
- tokenizer = AutoTokenizer.from_pretrained("Xtra-Computing/XtraGPT-7B")
58
- model = AutoModelForCausalLM.from_pretrained("Xtra-Computing/XtraGPT-7B", device_map="auto")
 
 
 
 
 
 
 
 
 
59
 
60
  def extract_text_from_pdf(pdf_bytes):
61
- """从上传的PDF文件中提取文本"""
62
  if pdf_bytes is None:
63
  return default_paper_content
64
 
65
  try:
66
- # 确保pdf_bytes是字节类型
67
  if isinstance(pdf_bytes, str):
68
- return pdf_bytes # 如果已经是字符串,直接返回
69
 
70
- # 直接使用字节对象
71
  pdf_reader = PyPDF2.PdfReader(BytesIO(pdf_bytes))
72
 
73
- # 从所有页面提取文本
74
  text = ""
75
  for page_num in range(len(pdf_reader.pages)):
76
  page = pdf_reader.pages[page_num]
77
  text += page.extract_text() + "\n\n"
78
 
79
- # 限制文本长度,防止超出模型最大长度
80
- if len(text) > 10000: # 保守估计,留出足够空间给提示和生成
81
- text = text[:10000] + "...(文本已截断)"
82
-
83
  return text
84
  except Exception as e:
85
- print(f"PDF提取错误: {str(e)}")
86
  return default_paper_content
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
  @spaces.GPU(duration=200)
89
- def improve_paper_section(paper_content, selected_content, improvement_prompt, temperature=0.1, max_new_tokens=512):
90
  """
91
- 改进学术论文的一个部分 - 使用非流式生成
92
  """
93
- # 检查输入
94
  if not selected_content or not improvement_prompt:
95
- return "请同时提供要改进的文本和改进要求。"
96
 
97
  try:
98
- # 限制paper_content长度,防止超出模型最大长度
99
- if len(paper_content) > 20000: # 保守估计
100
- paper_content = paper_content[:20000] + "...(文本已截断)"
101
 
102
- # 构建提示
 
103
  content = f"""
104
  Please improve the selected content based on the following. Act as an expert model for improving articles **PAPER_CONTENT**.
105
-
106
  The output needs to answer the **QUESTION** on **SELECTED_CONTENT** in the input. Avoid adding unnecessary length, unrelated details, overclaims, or vague statements.
107
  Focus on clear, concise, and evidence-based improvements that align with the overall context of the paper.
108
-
109
  <PAPER_CONTENT>
110
  {paper_content}
111
  </PAPER_CONTENT>
112
-
113
  <SELECTED_CONTENT>
114
  {selected_content}
115
  </SELECTED_CONTENT>
116
-
117
  <QUESTION>
118
  {improvement_prompt}
119
  </QUESTION>
120
  """
121
 
122
- # 准备输入
123
  messages = [
124
  {"role": "user", "content": content}
125
  ]
@@ -130,14 +154,15 @@ Focus on clear, concise, and evidence-based improvements that align with the ove
130
  add_generation_prompt=True
131
  )
132
 
133
- # 检查输入长度并截断
134
  input_tokens = tokenizer.encode(text)
135
- if len(input_tokens) > 15000: # 为生成留出空间
136
- input_tokens = input_tokens[:15000]
137
  text = tokenizer.decode(input_tokens)
138
- print(f"输入已截断至15000个token")
139
 
140
- # 使用非流式方式生成
 
141
  input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
142
 
143
  with torch.no_grad():
@@ -149,38 +174,47 @@ Focus on clear, concise, and evidence-based improvements that align with the ove
149
  pad_token_id=tokenizer.eos_token_id
150
  )
151
 
152
- # 只保留新生成的部分
153
  generated_ids = output_ids[0, len(input_ids[0]):]
154
  response = tokenizer.decode(generated_ids, skip_special_tokens=True)
155
 
 
156
  return response
157
 
158
  except Exception as e:
159
  import traceback
160
  error_details = traceback.format_exc()
161
- print(f"生成错误: {str(e)}\n{error_details}")
162
- return f"生成文本时出错: {str(e)}\n\n请尝试使用不同的参数或输入。"
163
 
164
- # 创建Gradio界面
165
  with gr.Blocks(fill_height=True, css=css) as demo:
166
- # 存储提取的PDF文本
167
  extracted_pdf_text = gr.State(default_paper_content)
168
 
169
  gr.Markdown(DESCRIPTION)
170
- # gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
171
 
172
  with gr.Row():
173
  with gr.Column():
174
- # 步骤1:上传PDF
175
  with gr.Group():
176
  gr.Markdown("### Step 1: Upload your academic paper")
177
  pdf_file = gr.File(
178
  label="Upload PDF",
179
  file_types=[".pdf"],
180
- type="binary" # 直接获取二进制数据
 
 
 
 
 
 
 
 
 
181
  )
182
 
183
- # 步骤2:提取并选择文本
184
  with gr.Group():
185
  gr.Markdown("### Step 2: Enter the text section to improve")
186
  selected_content = gr.Textbox(
@@ -190,7 +224,7 @@ with gr.Blocks(fill_height=True, css=css) as demo:
190
  value="The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration."
191
  )
192
 
193
- # 步骤3:指定改进要求
194
  with gr.Group():
195
  gr.Markdown("### Step 3: Specify your improvement requirements")
196
  improvement_prompt = gr.Textbox(
@@ -207,10 +241,10 @@ with gr.Blocks(fill_height=True, css=css) as demo:
207
  submit_btn = gr.Button("Improve Text")
208
 
209
  with gr.Column():
210
- # 输出
211
  output = gr.Textbox(label="Improved Text", lines=20)
212
 
213
- # 显示提取的PDF文本(可折叠)
214
  with gr.Accordion("Extracted PDF Content (for reference)", open=False):
215
  pdf_content_display = gr.Textbox(
216
  label="Paper Content",
@@ -218,7 +252,7 @@ with gr.Blocks(fill_height=True, css=css) as demo:
218
  value=default_paper_content
219
  )
220
 
221
- # PDF上传时自动提取文本
222
  def update_pdf_content(pdf_bytes):
223
  if pdf_bytes is not None:
224
  content = extract_text_from_pdf(pdf_bytes)
@@ -231,15 +265,14 @@ with gr.Blocks(fill_height=True, css=css) as demo:
231
  outputs=[extracted_pdf_text, pdf_content_display]
232
  )
233
 
234
- # 处理文本改进
235
  submit_btn.click(
236
  fn=improve_paper_section,
237
- inputs=[extracted_pdf_text, selected_content, improvement_prompt, temperature, max_tokens],
238
  outputs=[output]
239
  )
240
 
241
- # gr.Markdown(LICENSE)
242
- gr.Markdown(CITATION)
243
 
244
  if __name__ == "__main__":
245
  demo.launch()
 
6
  from io import BytesIO
7
  import torch
8
 
9
+ # Set environment variables
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
11
 
12
  DESCRIPTION = '''
13
  <div>
14
  <h1 style="text-align: center;">Academic Paper Improver</h1>
15
+ <p>This Space helps you improve sections of your academic paper using the <a href="https://huggingface.co/Xtra-Computing/XtraGPT-7B"><b>XtraGPT</b></a> model series.</p>
16
  <p>Upload your PDF paper, select a section of text you want to improve, and specify your requirements.</p>
17
  </div>
18
  '''
 
32
  LICENSE = """
33
  <p/>
34
  ---
35
+ Built with XtraGPT models
36
  """
37
 
38
  css = """
 
48
  }
49
  """
50
 
51
+ # Default paper content
52
  default_paper_content = """
53
  The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration. The best performing models also connect the encoder and decoder through an attention mechanism. We propose a new simple network architecture, the Transformer, based solely on attention mechanisms, dispensing with recurrence and convolutions entirely. Experiments on two machine translation tasks show these models to be superior in quality while being more parallelizable and requiring significantly less time to train. Our model achieves 28.4 BLEU on the WMT 2014 English-to-German translation task, improving over the existing best results, including ensembles by over 2 BLEU. On the WMT 2014 English-to-French translation task, our model establishes a new single-model state-of-the-art BLEU score of 41.8 after training for 3.5 days on eight GPUs, a small fraction of the training costs of the best models from the literature. We show that the Transformer generalizes well to other tasks by applying it successfully to English constituency parsing both with large and limited training data.
54
  """
55
 
56
+ # Available models
57
+ AVAILABLE_MODELS = {
58
+ "XtraGPT-1.5B": "Xtra-Computing/XtraGPT-1.5B",
59
+ "XtraGPT-3B": "Xtra-Computing/XtraGPT-3B",
60
+ "XtraGPT-7B": "Xtra-Computing/XtraGPT-7B",
61
+ "XtraGPT-14B": "Xtra-Computing/XtraGPT-14B"
62
+ }
63
+
64
+ # Global variables for model and tokenizer
65
+ current_model = None
66
+ current_tokenizer = None
67
+ current_model_name = None
68
 
69
  def extract_text_from_pdf(pdf_bytes):
70
+ """Extract text from uploaded PDF file"""
71
  if pdf_bytes is None:
72
  return default_paper_content
73
 
74
  try:
75
+ # Ensure pdf_bytes is bytes type
76
  if isinstance(pdf_bytes, str):
77
+ return pdf_bytes # If already a string, return directly
78
 
79
+ # Use bytes object directly
80
  pdf_reader = PyPDF2.PdfReader(BytesIO(pdf_bytes))
81
 
82
+ # Extract text from all pages
83
  text = ""
84
  for page_num in range(len(pdf_reader.pages)):
85
  page = pdf_reader.pages[page_num]
86
  text += page.extract_text() + "\n\n"
87
 
 
 
 
 
88
  return text
89
  except Exception as e:
90
+ print(f"PDF extraction error: {str(e)}")
91
  return default_paper_content
92
 
93
+ def load_model(model_name):
94
+ """Load model and tokenizer on demand"""
95
+ global current_model, current_tokenizer, current_model_name
96
+
97
+ # If the requested model is already loaded, return it
98
+ if current_model_name == model_name and current_model is not None and current_tokenizer is not None:
99
+ return current_tokenizer, current_model
100
+
101
+ # Clear GPU memory if a model is already loaded
102
+ if current_model is not None:
103
+ del current_model
104
+ del current_tokenizer
105
+ torch.cuda.empty_cache()
106
+
107
+ # Load the requested model
108
+ model_path = AVAILABLE_MODELS[model_name]
109
+ current_tokenizer = AutoTokenizer.from_pretrained(model_path)
110
+ current_model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
111
+ current_model_name = model_name
112
+
113
+ return current_tokenizer, current_model
114
+
115
  @spaces.GPU(duration=200)
116
+ def improve_paper_section(model_name, paper_content, selected_content, improvement_prompt, temperature=0.1, max_new_tokens=512, progress=gr.Progress()):
117
  """
118
+ Improve a section of an academic paper - non-streaming generation
119
  """
120
+ # Check inputs
121
  if not selected_content or not improvement_prompt:
122
+ return "Please provide both text to improve and improvement requirements."
123
 
124
  try:
125
+ progress(0.1, desc="Loading model...")
126
+ # Load the selected model
127
+ tokenizer, model = load_model(model_name)
128
 
129
+ progress(0.3, desc="Processing input...")
130
+ # Build prompt
131
  content = f"""
132
  Please improve the selected content based on the following. Act as an expert model for improving articles **PAPER_CONTENT**.
 
133
  The output needs to answer the **QUESTION** on **SELECTED_CONTENT** in the input. Avoid adding unnecessary length, unrelated details, overclaims, or vague statements.
134
  Focus on clear, concise, and evidence-based improvements that align with the overall context of the paper.
 
135
  <PAPER_CONTENT>
136
  {paper_content}
137
  </PAPER_CONTENT>
 
138
  <SELECTED_CONTENT>
139
  {selected_content}
140
  </SELECTED_CONTENT>
 
141
  <QUESTION>
142
  {improvement_prompt}
143
  </QUESTION>
144
  """
145
 
146
+ # Prepare input
147
  messages = [
148
  {"role": "user", "content": content}
149
  ]
 
154
  add_generation_prompt=True
155
  )
156
 
157
+ # Check input length and truncate to first 10k tokens
158
  input_tokens = tokenizer.encode(text)
159
+ if len(input_tokens) > 10000: # Limit to 10k tokens as requested
160
+ input_tokens = input_tokens[:10000]
161
  text = tokenizer.decode(input_tokens)
162
+ print(f"Input truncated to 10000 tokens")
163
 
164
+ progress(0.5, desc="Generating improved text...")
165
+ # Generate non-streaming
166
  input_ids = tokenizer.encode(text, return_tensors="pt").to(model.device)
167
 
168
  with torch.no_grad():
 
174
  pad_token_id=tokenizer.eos_token_id
175
  )
176
 
177
+ # Only keep the newly generated part
178
  generated_ids = output_ids[0, len(input_ids[0]):]
179
  response = tokenizer.decode(generated_ids, skip_special_tokens=True)
180
 
181
+ progress(1.0, desc="Complete!")
182
  return response
183
 
184
  except Exception as e:
185
  import traceback
186
  error_details = traceback.format_exc()
187
+ print(f"Generation error: {str(e)}\n{error_details}")
188
+ return f"Error generating text: {str(e)}\n\nPlease try with different parameters or input."
189
 
190
+ # Create Gradio interface
191
  with gr.Blocks(fill_height=True, css=css) as demo:
192
+ # Store extracted PDF text
193
  extracted_pdf_text = gr.State(default_paper_content)
194
 
195
  gr.Markdown(DESCRIPTION)
 
196
 
197
  with gr.Row():
198
  with gr.Column():
199
+ # Step 1: Upload PDF
200
  with gr.Group():
201
  gr.Markdown("### Step 1: Upload your academic paper")
202
  pdf_file = gr.File(
203
  label="Upload PDF",
204
  file_types=[".pdf"],
205
+ type="binary" # Get binary data directly
206
+ )
207
+
208
+ # Model selection
209
+ with gr.Group():
210
+ gr.Markdown("### Select Model")
211
+ model_dropdown = gr.Dropdown(
212
+ choices=list(AVAILABLE_MODELS.keys()),
213
+ value="XtraGPT-7B", # Default selection
214
+ label="Select XtraGPT Model"
215
  )
216
 
217
+ # Step 2: Extract and select text
218
  with gr.Group():
219
  gr.Markdown("### Step 2: Enter the text section to improve")
220
  selected_content = gr.Textbox(
 
224
  value="The dominant sequence transduction models are based on complex recurrent or convolutional neural networks in an encoder-decoder configuration."
225
  )
226
 
227
+ # Step 3: Specify improvement requirements
228
  with gr.Group():
229
  gr.Markdown("### Step 3: Specify your improvement requirements")
230
  improvement_prompt = gr.Textbox(
 
241
  submit_btn = gr.Button("Improve Text")
242
 
243
  with gr.Column():
244
+ # Output
245
  output = gr.Textbox(label="Improved Text", lines=20)
246
 
247
+ # Display extracted PDF text (collapsible)
248
  with gr.Accordion("Extracted PDF Content (for reference)", open=False):
249
  pdf_content_display = gr.Textbox(
250
  label="Paper Content",
 
252
  value=default_paper_content
253
  )
254
 
255
+ # Automatically extract text when PDF is uploaded
256
  def update_pdf_content(pdf_bytes):
257
  if pdf_bytes is not None:
258
  content = extract_text_from_pdf(pdf_bytes)
 
265
  outputs=[extracted_pdf_text, pdf_content_display]
266
  )
267
 
268
+ # Process text improvement
269
  submit_btn.click(
270
  fn=improve_paper_section,
271
+ inputs=[model_dropdown, extracted_pdf_text, selected_content, improvement_prompt, temperature, max_tokens],
272
  outputs=[output]
273
  )
274
 
275
+ gr.HTML(CITATION)
 
276
 
277
  if __name__ == "__main__":
278
  demo.launch()