Rajesh3338 commited on
Commit
cba278b
·
verified ·
1 Parent(s): 42ebb52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +144 -86
app.py CHANGED
@@ -1,92 +1,150 @@
1
- from transformers import MllamaForConditionalGeneration, AutoProcessor, TextIteratorStreamer
2
- from PIL import Image
3
- import requests
4
- import torch
5
- from threading import Thread
6
  import gradio as gr
7
- from gradio import FileData
8
- import time
9
- import spaces
10
- ckpt = "mrcuddle/llama3.2-11B-Vision_instruct-Coder"
11
- model = MllamaForConditionalGeneration.from_pretrained(ckpt,
12
- torch_dtype=torch.bfloat16).to("cuda")
13
- processor = AutoProcessor.from_pretrained(ckpt)
 
14
 
 
 
 
 
15
 
16
- @spaces.GPU
17
- def bot_streaming(message, history, max_new_tokens=250):
18
-
19
- txt = message["text"]
20
- ext_buffer = f"{txt}"
21
-
22
- messages= []
23
- images = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- for i, msg in enumerate(history):
27
- if isinstance(msg[0], tuple):
28
- messages.append({"role": "user", "content": [{"type": "text", "text": history[i+1][0]}, {"type": "image"}]})
29
- messages.append({"role": "assistant", "content": [{"type": "text", "text": history[i+1][1]}]})
30
- images.append(Image.open(msg[0][0]).convert("RGB"))
31
- elif isinstance(history[i-1], tuple) and isinstance(msg[0], str):
32
- # messages are already handled
33
- pass
34
- elif isinstance(history[i-1][0], str) and isinstance(msg[0], str): # text only turn
35
- messages.append({"role": "user", "content": [{"type": "text", "text": msg[0]}]})
36
- messages.append({"role": "assistant", "content": [{"type": "text", "text": msg[1]}]})
37
-
38
- # add current message
39
- if len(message["files"]) == 1:
40
 
41
- if isinstance(message["files"][0], str): # examples
42
- image = Image.open(message["files"][0]).convert("RGB")
43
- else: # regular input
44
- image = Image.open(message["files"][0]["path"]).convert("RGB")
45
- images.append(image)
46
- messages.append({"role": "user", "content": [{"type": "text", "text": txt}, {"type": "image"}]})
47
- else:
48
- messages.append({"role": "user", "content": [{"type": "text", "text": txt}]})
49
-
50
-
51
- texts = processor.apply_chat_template(messages, add_generation_prompt=True)
52
-
53
- if images == []:
54
- inputs = processor(text=texts, return_tensors="pt").to("cuda")
55
- else:
56
- inputs = processor(text=texts, images=images, return_tensors="pt").to("cuda")
57
- streamer = TextIteratorStreamer(processor, skip_special_tokens=True, skip_prompt=True)
58
-
59
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
60
- generated_text = ""
61
-
62
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
63
- thread.start()
64
- buffer = ""
65
-
66
- for new_text in streamer:
67
- buffer += new_text
68
- generated_text_without_prompt = buffer
69
- time.sleep(0.01)
70
- yield buffer
71
-
72
-
73
- demo = gr.ChatInterface(fn=bot_streaming, title="Multimodal Llama",examples=[
74
- [{"text": "Replicate this webpage using Tyepescript and ChakraUI.", "files":["./examples/Untitled.png"]},
75
- 2000],
76
- ],
77
- textbox=gr.MultimodalTextbox(),
78
- additional_inputs = [gr.Slider(
79
- minimum=10,
80
- maximum=2500,
81
- value=500,
82
- step=10,
83
- label="Maximum number of new tokens to generate",
84
- )
85
- ],
86
- cache_examples=False,
87
- description="Yes, this space can replicate (to the model's best ability) a webpage in your preferred language.",
88
- stop_btn="Stop Generation",
89
- fill_height=True,
90
- multimodal=True)
91
-
92
- demo.launch(debug=True)
 
1
+ import os
 
 
 
 
2
  import gradio as gr
3
+ import torch
4
+ from langchain_huggingface import HuggingFaceEmbeddings
5
+ from langchain_community.document_loaders import TextLoader
6
+ from langchain_community.vectorstores import FAISS
7
+ from langchain.text_splitter import RecursiveCharacterTextSplitter
8
+ from langchain.chains import RetrievalQA
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
10
+ from langchain_huggingface import HuggingFacePipeline
11
 
12
+ # Configure GPU settings
13
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
14
+ device = "cuda" if torch.cuda.is_available() else "cpu"
15
+ print(f"Using device: {device}")
16
 
17
+ class CPSLChatbot:
18
+ def __init__(self):
19
+ self.initialize_components()
20
+
21
+ def initialize_components(self):
22
+ try:
23
+ # Load and process document
24
+ doc_loader = TextLoader("dataset.txt")
25
+ docs = doc_loader.load()
26
+ text_splitter = RecursiveCharacterTextSplitter(
27
+ chunk_size=1000,
28
+ chunk_overlap=100
29
+ )
30
+ split_docs = text_splitter.split_documents(docs)
31
+
32
+ # Initialize embeddings and vector store
33
+ self.embeddings = HuggingFaceEmbeddings(
34
+ model_name="all-MiniLM-L6-v2",
35
+ model_kwargs={'device': device}
36
+ )
37
+ self.vectordb = FAISS.from_documents(split_docs, self.embeddings)
38
+
39
+ # Load model and tokenizer
40
+ model_name = "01-ai/Yi-Coder-9B-Chat"
41
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
42
+ self.model = AutoModelForCausalLM.from_pretrained(
43
+ model_name,
44
+ device_map="auto",
45
+ torch_dtype=torch.float16 if device == "cuda" else torch.float32,
46
+ trust_remote_code=True
47
+ )
48
+
49
+ # Set up QA pipeline
50
+ self.qa_pipeline = pipeline(
51
+ "text-generation",
52
+ model=self.model,
53
+ tokenizer=self.tokenizer,
54
+ max_new_tokens=750,
55
+ pad_token_id=self.tokenizer.eos_token_id,
56
+ device=0 if device == "cuda" else -1
57
+ )
58
+
59
+ # Initialize LangChain components
60
+ llm = HuggingFacePipeline(pipeline=self.qa_pipeline)
61
+ retriever = self.vectordb.as_retriever(search_kwargs={"k": 5})
62
+ self.qa_chain = RetrievalQA.from_chain_type(
63
+ retriever=retriever,
64
+ chain_type="stuff",
65
+ llm=llm,
66
+ return_source_documents=False
67
+ )
68
+ print("Initialization completed successfully")
69
+
70
+ except Exception as e:
71
+ print(f"Initialization error: {str(e)}")
72
+ raise
73
+
74
+ def preprocess_query(self, query):
75
+ if "script" in query.lower() or "code" in query.lower():
76
+ return f"Write a CPSL script: {query}"
77
+ return query
78
+
79
+ def clean_response(self, response):
80
+ result = response.get("result", "")
81
+ if "Answer:" in result:
82
+ return result.split("Answer:")[1].strip()
83
+ return result.strip()
84
+
85
+ def get_response(self, user_input):
86
+ try:
87
+ processed_query = self.preprocess_query(user_input)
88
+ raw_response = self.qa_chain.invoke({"query": processed_query})
89
+ return self.clean_response(raw_response)
90
+ except Exception as e:
91
+ return f"Error processing query: {str(e)}"
92
+
93
+ def create_gradio_interface():
94
+ chatbot = CPSLChatbot()
95
 
96
+ with gr.Blocks(title="CPSL Chatbot") as chat_interface:
97
+ gr.Markdown("# CPSL Chatbot with GPU Support")
98
+ gr.Markdown("Using Yi-Coder-9B-Chat model for CPSL script generation and queries")
99
+
100
+ chat_history = gr.Chatbot(
101
+ value=[],
102
+ elem_id="chatbot",
103
+ height=600
104
+ )
105
+
106
+ with gr.Row():
107
+ user_input = gr.Textbox(
108
+ label="Your Message:",
109
+ placeholder="Type your message here...",
110
+ show_label=True,
111
+ elem_id="user-input"
112
+ )
113
+ send_button = gr.Button("Send", variant="primary")
114
+
115
+ def chat_response(user_message, history):
116
+ if not user_message:
117
+ return history, history
118
+
119
+ bot_response = chatbot.get_response(user_message)
120
+ history.append((user_message, bot_response))
121
+ return history, history
122
 
123
+ send_button.click(
124
+ chat_response,
125
+ inputs=[user_input, chat_history],
126
+ outputs=[chat_history, chat_history],
127
+ api_name="chat"
128
+ )
 
 
 
 
 
 
 
 
129
 
130
+ # Clear the input textbox after sending
131
+ send_button.click(lambda: "", None, user_input)
132
+
133
+ # Also allow Enter key to send message
134
+ user_input.submit(
135
+ chat_response,
136
+ inputs=[user_input, chat_history],
137
+ outputs=[chat_history, chat_history],
138
+ )
139
+ user_input.submit(lambda: "", None, user_input)
140
+
141
+ return chat_interface
142
+
143
+ if __name__ == "__main__":
144
+ interface = create_gradio_interface()
145
+ interface.launch(
146
+ server_name="0.0.0.0",
147
+ server_port=7860,
148
+ share=True,
149
+ enable_queue=True
150
+ )