Tonic commited on
Commit
8ff5503
·
verified ·
1 Parent(s): 978d21a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -33
app.py CHANGED
@@ -25,39 +25,43 @@ os.system('python -m spacy download en_core_web_sm')
25
  nlp = spacy.load("en_core_web_sm")
26
 
27
  def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
28
- prompt = f"### Text ###\n{prompt}"
29
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
30
- input_ids = inputs["input_ids"].to(device)
31
- attention_mask = inputs["attention_mask"].to(device)
32
-
33
- output = model.generate(
34
- input_ids,
35
- attention_mask=attention_mask,
36
- max_new_tokens=max_new_tokens,
37
- pad_token_id=tokenizer.eos_token_id,
38
- top_k=top_k,
39
- temperature=temperature,
40
- top_p=top_p,
41
- do_sample=True,
42
- repetition_penalty=repetition_penalty,
43
- bos_token_id=tokenizer.bos_token_id,
44
- eos_token_id=tokenizer.eos_token_id
45
- )
46
-
47
- generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
48
-
49
- if "### Correction ###" in generated_text:
50
- generated_text = generated_text.split("### Correction ###")[1].strip()
51
-
52
- tokens = tokenizer.tokenize(generated_text)
53
-
54
- highlighted_text = []
55
- for token in tokens:
56
- clean_token = token.replace("Ġ", "")
57
- token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
58
- highlighted_text.append((clean_token, token_type))
59
-
60
- return highlighted_text, generated_text
 
 
 
 
61
 
62
  def text_analysis(text):
63
  doc = nlp(text)
 
25
  nlp = spacy.load("en_core_web_sm")
26
 
27
  def historical_generation(prompt, max_new_tokens=600, top_k=50, temperature=0.7, top_p=0.95, repetition_penalty=1.0):
28
+ with torch.no_grad():
29
+ prompt = f"### Text ###\n{prompt}"
30
+ inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=1024)
31
+ input_ids = inputs["input_ids"].to(device)
32
+ attention_mask = inputs["attention_mask"].to(device)
33
+
34
+ output = model.generate(
35
+ input_ids,
36
+ attention_mask=attention_mask,
37
+ max_new_tokens=max_new_tokens,
38
+ pad_token_id=tokenizer.eos_token_id,
39
+ top_k=top_k,
40
+ temperature=temperature,
41
+ top_p=top_p,
42
+ do_sample=True,
43
+ repetition_penalty=repetition_penalty,
44
+ bos_token_id=tokenizer.bos_token_id,
45
+ eos_token_id=tokenizer.eos_token_id
46
+ )
47
+
48
+ generated_text = tokenizer.decode(output[0], skip_special_tokens=True)
49
+
50
+ if "### Correction ###" in generated_text:
51
+ generated_text = generated_text.split("### Correction ###")[1].strip()
52
+
53
+ tokens = tokenizer.tokenize(generated_text)
54
+
55
+ highlighted_text = []
56
+ for token in tokens:
57
+ clean_token = token.replace("Ġ", "")
58
+ token_type = tokenizer.convert_ids_to_tokens([tokenizer.convert_tokens_to_ids(token)])[0].replace("Ġ", "")
59
+ highlighted_text.append((clean_token, token_type))
60
+
61
+ del inputs, input_ids, attention_mask, output, tokens
62
+ torch.cuda.empty_cache()
63
+
64
+ return highlighted_text, generated_text
65
 
66
  def text_analysis(text):
67
  doc = nlp(text)