Spaces:
Sleeping
Sleeping
import spaces | |
import torch | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
def generate_answer(llm_name, texts, query, queries, response_lang, mode='validate'): | |
if llm_name == 'solar': | |
tokenizer = AutoTokenizer.from_pretrained("Upstage/SOLAR-10.7B-Instruct-v1.0", use_fast=True) | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
"Upstage/SOLAR-10.7B-Instruct-v1.0", | |
device_map="auto", #device_map="cuda" | |
#torch_dtype=torch.float16, | |
) | |
elif llm_name == 'mistral': | |
tokenizer = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2", use_fast=True) | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
"mistralai/Mistral-7B-Instruct-v0.2", | |
# device_map="auto", | |
device_map="cuda", | |
torch_dtype=torch.float16, | |
) | |
elif llm_name == 'phi3mini': | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct", use_fast=True) | |
llm_model = AutoModelForCausalLM.from_pretrained( | |
"microsoft/Phi-3-mini-128k-instruct", | |
device_map="auto", | |
torch_dtype="auto", | |
trust_remote_code=False, | |
) | |
template_texts ="" | |
for i, text in enumerate(texts): | |
template_texts += f'{i+1}. {text} \n' | |
if mode == 'validate': | |
conversation = [ {'role': 'user', 'content': f'Given the following query: "{query}"? \nIs the following document relevant to answer this query?\n{template_texts} \nResponse: Yes / No'} ] | |
elif mode == 'summarize': | |
conversation = [ {'role': 'user', 'content': f'For the following query and documents, try to answer the given query based on the documents.\nQuery: {query} \nDocuments: {template_texts}.'} ] | |
elif mode == 'h_summarize': | |
conversation = [ {'role': 'user', 'content': f'The documents below describe a developing disaster event. Based on these documents, write a brief summary in the form of a paragraph, highlighting the most crucial information. \nDocuments: {template_texts}'} ] | |
elif mode == "multi_summarize": | |
conversation = [ {'role': 'user', 'content': f"""For the following queries and documents, in a brief paragraph try to answer the given queries based on the documents. | |
Then, return the top 5 documents as provided that answer the queries.\nQueries: {queries} \nDocuments: {template_texts}. Give your response in {response_lang} language"""} ] | |
prompt = tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True) | |
inputs = tokenizer(prompt, return_tensors="pt").to(llm_model.device) | |
outputs = llm_model.generate(**inputs, use_cache=True, max_length=4096,do_sample=True,temperature=0.7,top_p=0.95,top_k=10,repetition_penalty=1.1) | |
output_text = tokenizer.decode(outputs[0]) | |
if llm_name == "solar": | |
assistant_respond = output_text.split("Assistant:")[1] | |
elif llm_name == "phi3mini": | |
assistant_respond = output_text.split("<|assistant|>")[1] | |
assistant_respond = assistant_respond[:-7] | |
else: | |
assistant_respond = output_text.split("[/INST]")[1] | |
if mode == 'validate': | |
if 'Yes' in assistant_respond: | |
return True | |
else: | |
return False | |
elif mode == 'summarize': | |
return assistant_respond | |
elif mode == 'h_summarize': | |
return assistant_respond | |
elif mode == 'multi_summarize': | |
return assistant_respond | |