ericjohnson97 commited on
Commit
6255f8a
·
1 Parent(s): be39a04

added gradio chat bot interface

Browse files
Files changed (5) hide show
  1. .gitignore +2 -1
  2. data/data.txt +0 -45
  3. llm/gptPlotCreator.py +103 -0
  4. llm_plot.py +41 -91
  5. plot.py +0 -42
.gitignore CHANGED
@@ -1 +1,2 @@
1
- .env
 
 
1
+ .env
2
+ *.pyc
data/data.txt DELETED
@@ -1,45 +0,0 @@
1
- LOCAL_POSITION_NED {time_boot_ms : 239230, x : 109.40231323242188, y : 37.670654296875, z : -9.969955444335938, vx : -0.018984168767929077, vy : -0.005424499046057463, vz : 5.3835941798752174e-05}
2
- LOCAL_POSITION_NED {time_boot_ms : 239480, x : 109.40264129638672, y : 37.67082977294922, z : -9.96972942352295, vx : -0.019311394542455673, vy : -0.005466990172863007, vz : -5.457072620629333e-05}
3
- LOCAL_POSITION_NED {time_boot_ms : 239729, x : 109.4027099609375, y : 37.67092514038086, z : -9.969511032104492, vx : -0.01970984973013401, vy : -0.005530376452952623, vz : -0.0002979390264954418}
4
- LOCAL_POSITION_NED {time_boot_ms : 239979, x : 109.40276336669922, y : 37.671024322509766, z : -9.969305992126465, vx : -0.020000549033284187, vy : -0.0055604190565645695, vz : -0.0005820284713990986}
5
- LOCAL_POSITION_NED {time_boot_ms : 240229, x : 109.40313720703125, y : 37.67121887207031, z : -9.969114303588867, vx : -0.02006402052938938, vy : -0.00547247938811779, vz : -0.000677842297591269}
6
- LOCAL_POSITION_NED {time_boot_ms : 240479, x : 109.4032211303711, y : 37.67136764526367, z : -9.968948364257812, vx : -0.0202533807605505, vy : -0.005525950342416763, vz : -0.0007673363434150815}
7
- LOCAL_POSITION_NED {time_boot_ms : 240729, x : 109.40300750732422, y : 37.67150115966797, z : -9.968831062316895, vx : -0.020623182877898216, vy : -0.0055752284824848175, vz : -0.0009132251143455505}
8
- LOCAL_POSITION_NED {time_boot_ms : 240980, x : 109.4027328491211, y : 37.67162322998047, z : -9.968841552734375, vx : -0.021015815436840057, vy : -0.0057647754438221455, vz : -0.0011039149248972535}
9
- LOCAL_POSITION_NED {time_boot_ms : 241230, x : 109.40284729003906, y : 37.67182540893555, z : -9.96899127960205, vx : -0.020866867154836655, vy : -0.005960384849458933, vz : -0.0008990211645141244}
10
- LOCAL_POSITION_NED {time_boot_ms : 241480, x : 109.40279388427734, y : 37.67191696166992, z : -9.96912670135498, vx : -0.020786112174391747, vy : -0.006111173890531063, vz : -0.000912059098482132}
11
- LOCAL_POSITION_NED {time_boot_ms : 241729, x : 109.40264129638672, y : 37.67238998413086, z : -9.969264030456543, vx : -0.020849552005529404, vy : -0.006145048886537552, vz : -0.0006911428063176572}
12
- LOCAL_POSITION_NED {time_boot_ms : 241979, x : 109.40258026123047, y : 37.673919677734375, z : -9.969388008117676, vx : -0.021021265536546707, vy : -0.006493883207440376, vz : -0.0007432479178532958}
13
- LOCAL_POSITION_NED {time_boot_ms : 242229, x : 109.40287017822266, y : 37.675846099853516, z : -9.96950626373291, vx : -0.021021228283643723, vy : -0.007321516517549753, vz : -0.0005821734084747732}
14
- LOCAL_POSITION_NED {time_boot_ms : 242479, x : 109.40294647216797, y : 37.67726135253906, z : -9.969613075256348, vx : -0.02099592424929142, vy : -0.008481874130666256, vz : -0.000457199988886714}
15
- LOCAL_POSITION_NED {time_boot_ms : 242729, x : 109.40283203125, y : 37.67806625366211, z : -9.969687461853027, vx : -0.02108212560415268, vy : -0.009890645742416382, vz : -0.000333481642883271}
16
- LOCAL_POSITION_NED {time_boot_ms : 242979, x : 109.40362548828125, y : 37.67841339111328, z : -9.96973705291748, vx : -0.02084990404546261, vy : -0.011260980740189552, vz : -0.00045963365118950605}
17
- LOCAL_POSITION_NED {time_boot_ms : 243230, x : 109.40609741210938, y : 37.678443908691406, z : -9.969796180725098, vx : -0.021079031750559807, vy : -0.012367655523121357, vz : -0.0001387261290801689}
18
- LOCAL_POSITION_NED {time_boot_ms : 243480, x : 109.40825653076172, y : 37.67808151245117, z : -9.96983528137207, vx : -0.02236475795507431, vy : -0.012887658551335335, vz : -0.00010744269820861518}
19
- LOCAL_POSITION_NED {time_boot_ms : 243729, x : 109.40962219238281, y : 37.677486419677734, z : -9.969886779785156, vx : -0.023954948410391808, vy : -0.013113063760101795, vz : -0.00017305012443102896}
20
- LOCAL_POSITION_NED {time_boot_ms : 243979, x : 109.40973663330078, y : 37.677276611328125, z : -9.969881057739258, vx : -0.04371815547347069, vy : -0.004350676201283932, vz : 0.0006270252051763237}
21
- LOCAL_POSITION_NED {time_boot_ms : 244229, x : 109.38226318359375, y : 37.69022750854492, z : -9.969354629516602, vx : -0.2798684239387512, vy : 0.10651721060276031, vz : 0.004445536062121391}
22
- LOCAL_POSITION_NED {time_boot_ms : 244479, x : 109.25566864013672, y : 37.74885940551758, z : -9.967907905578613, vx : -0.8151731491088867, vy : 0.35761335492134094, vz : 0.006137563847005367}
23
- LOCAL_POSITION_NED {time_boot_ms : 244729, x : 108.98152160644531, y : 37.87641525268555, z : -9.967415809631348, vx : -1.4317526817321777, vy : 0.6452293395996094, vz : -0.0019165349658578634}
24
- LOCAL_POSITION_NED {time_boot_ms : 244979, x : 108.56168365478516, y : 38.0730094909668, z : -9.968149185180664, vx : -2.0058188438415527, vy : 0.9171269536018372, vz : -0.0052565005607903}
25
- LOCAL_POSITION_NED {time_boot_ms : 245230, x : 108.00172424316406, y : 38.3388557434082, z : -9.969520568847656, vx : -2.558675765991211, vy : 1.1873902082443237, vz : -0.004303304478526115}
26
- LOCAL_POSITION_NED {time_boot_ms : 245480, x : 107.30584716796875, y : 38.671424865722656, z : -9.970559120178223, vx : -3.1081371307373047, vy : 1.4497607946395874, vz : -0.0013086843537166715}
27
- LOCAL_POSITION_NED {time_boot_ms : 245730, x : 106.47059631347656, y : 39.06772994995117, z : -9.970938682556152, vx : -3.657071590423584, vy : 1.6894506216049194, vz : 0.0002285348455188796}
28
- LOCAL_POSITION_NED {time_boot_ms : 245979, x : 105.5009536743164, y : 39.52318572998047, z : -9.971338272094727, vx : -4.182033538818359, vy : 1.926192045211792, vz : -0.0010198309319093823}
29
- LOCAL_POSITION_NED {time_boot_ms : 246229, x : 104.40209197998047, y : 40.04393768310547, z : -9.971746444702148, vx : -4.702603340148926, vy : 2.200885772705078, vz : 0.0009427944314666092}
30
- LOCAL_POSITION_NED {time_boot_ms : 246479, x : 103.16869354248047, y : 40.63428497314453, z : -9.971199035644531, vx : -5.25104284286499, vy : 2.47697114944458, vz : 0.00528548052534461}
31
- LOCAL_POSITION_NED {time_boot_ms : 246729, x : 101.79534912109375, y : 41.29187774658203, z : -9.969746589660645, vx : -5.810123443603516, vy : 2.742892265319824, vz : 0.00835045799612999}
32
- LOCAL_POSITION_NED {time_boot_ms : 246979, x : 100.2834243774414, y : 42.01643753051758, z : -9.968152046203613, vx : -6.356932163238525, vy : 3.00797700881958, vz : 0.007713631726801395}
33
- LOCAL_POSITION_NED {time_boot_ms : 247230, x : 98.63285064697266, y : 42.80854797363281, z : -9.967652320861816, vx : -6.869801044464111, vy : 3.2539689540863037, vz : 1.1250964234932326e-05}
34
- LOCAL_POSITION_NED {time_boot_ms : 247480, x : 96.86815643310547, y : 43.65073013305664, z : -9.969475746154785, vx : -7.309067249298096, vy : 3.4248688220977783, vz : -0.007807408459484577}
35
- LOCAL_POSITION_NED {time_boot_ms : 247730, x : 94.9985122680664, y : 44.52605056762695, z : -9.971796989440918, vx : -7.700419902801514, vy : 3.5252227783203125, vz : -0.005552171263843775}
36
- LOCAL_POSITION_NED {time_boot_ms : 247979, x : 93.03529357910156, y : 45.42596435546875, z : -9.972488403320312, vx : -8.041450500488281, vy : 3.617892265319824, vz : 0.0015491923550143838}
37
- LOCAL_POSITION_NED {time_boot_ms : 248229, x : 90.99347686767578, y : 46.35050582885742, z : -9.97187328338623, vx : -8.306546211242676, vy : 3.7119176387786865, vz : 0.0021588108502328396}
38
- LOCAL_POSITION_NED {time_boot_ms : 248479, x : 88.89547729492188, y : 47.29640197753906, z : -9.971541404724121, vx : -8.481266021728516, vy : 3.7974555492401123, vz : 0.00032809507683850825}
39
- LOCAL_POSITION_NED {time_boot_ms : 248729, x : 86.76085662841797, y : 48.263916015625, z : -9.971596717834473, vx : -8.600085258483887, vy : 3.8975508213043213, vz : 0.0003860758733935654}
40
- LOCAL_POSITION_NED {time_boot_ms : 248979, x : 84.5998306274414, y : 49.25948715209961, z : -9.972135543823242, vx : -8.690932273864746, vy : 4.027074337005615, vz : 0.0017167243640869856}
41
- LOCAL_POSITION_NED {time_boot_ms : 249229, x : 82.42630004882812, y : 50.28388977050781, z : -9.972501754760742, vx : -8.765235900878906, vy : 4.157995223999023, vz : 0.004778474103659391}
42
- LOCAL_POSITION_NED {time_boot_ms : 249480, x : 80.2116928100586, y : 51.346527099609375, z : -9.971891403198242, vx : -8.841904640197754, vy : 4.258383274078369, vz : 0.007347704842686653}
43
- LOCAL_POSITION_NED {time_boot_ms : 249730, x : 77.99267578125, y : 52.42091369628906, z : -9.970446586608887, vx : -8.91469955444336, vy : 4.3127851486206055, vz : 0.006476235575973988}
44
- LOCAL_POSITION_NED {time_boot_ms : 249979, x : 75.75646209716797, y : 53.50323486328125, z : -9.969244956970215, vx : -8.981115341186523, vy : 4.32005500793457, vz : 0.004069608170539141}
45
- LOCAL_POSITION_NED {time_boot_ms : 250229, x : 73.50287628173828, y : 54.58274459838867, z : -9.968433380126953, vx : -9.047001838684082, vy : 4.29080057144165, vz : 0.0032091387547552586}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
llm/gptPlotCreator.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import random
3
+ import linecache
4
+ import subprocess
5
+ from langchain.prompts import PromptTemplate
6
+ from langchain.chat_models import ChatOpenAI
7
+ from langchain.chains import LLMChain
8
+ from langchain.llms import OpenAI
9
+ from langchain.chains import ConversationChain
10
+ from langchain.memory import ConversationBufferMemory
11
+ from langchain.prompts.chat import (
12
+ ChatPromptTemplate,
13
+ HumanMessagePromptTemplate,
14
+ )
15
+ import os
16
+ from dotenv import load_dotenv
17
+ from PIL import Image
18
+
19
+
20
+
21
+ class PlotCreator:
22
+ def __init__(self):
23
+ load_dotenv()
24
+ llm = ChatOpenAI(model_name="gpt-3.5-turbo", max_tokens=2000, temperature=0)
25
+
26
+ mavlink_data_prompt = PromptTemplate(
27
+ input_variables=["human_input", "file"],
28
+ template="You are an AI conversation agent that will be used for generating python scripts to plot mavlink data provided by the user. Please create a python script using matplotlib and pymavlink's mavutil to plot the data provided by the user. Please do not explain the code just return the script. Please plot each independent variable over time in seconds. Please save the plot to file named plot.png with at least 400 dpi. \n\nHUMAN: {human_input} \n\nplease read this data from the file {file}.",
29
+ )
30
+ self.chain = LLMChain(verbose=True, llm=llm, prompt=mavlink_data_prompt)
31
+
32
+ @staticmethod
33
+ def sample_lines(filename, num_lines=5):
34
+ with open(filename) as f:
35
+ total_lines = sum(1 for _ in f)
36
+
37
+ if total_lines < num_lines:
38
+ raise ValueError("File has fewer lines than the number of lines requested.")
39
+
40
+ line_numbers = random.sample(range(1, total_lines + 1), num_lines)
41
+ lines = [linecache.getline(filename, line_number).rstrip() for line_number in line_numbers]
42
+
43
+ return '\n'.join(lines)
44
+
45
+ @staticmethod
46
+ def extract_code_snippets(text):
47
+ pattern = r'```(.*?)```'
48
+ snippets = re.findall(pattern, text, re.DOTALL)
49
+ if len(snippets) == 0:
50
+ snippets = [text]
51
+ return snippets
52
+
53
+ @staticmethod
54
+ def write_plot_script(filename, text):
55
+ with open(filename, 'w') as file:
56
+ file.write(text)
57
+
58
+ @staticmethod
59
+ def attempt_to_fix_sctript(filename, error_message):
60
+ llm = ChatOpenAI(model_name="gpt-3.5-turbo", max_tokens=2000, temperature=0)
61
+
62
+ fix_plot_script_template = PromptTemplate(
63
+ input_variables=["error", "script"],
64
+ template="You are an AI agent that is designed to debug scripts created to plot mavlink data using matplotlib and pymavlink's mavutil. the following script produced this error: \n\n{script}\n\nThe error is: \n\n{error}\n\nPlease fix the script so that it produces the correct plot.",
65
+ )
66
+
67
+ # read script from file
68
+ with open(filename, 'r') as file:
69
+ script = file.read()
70
+
71
+ chain = LLMChain(verbose=True, llm=llm, prompt=fix_plot_script_template)
72
+ response = chain.run({"error": error_message, "script": script})
73
+ print(response)
74
+ code = PlotCreator.extract_code_snippets(response)
75
+ PlotCreator.write_plot_script("plot.py", code[0])
76
+
77
+ # run the script
78
+ os.system("python plot.py")
79
+
80
+ def create_plot(self, human_input):
81
+ file = "data/2023-01-04 20-51-25.tlog"
82
+
83
+ # prompt the user for the what plot they would like to generate
84
+ # human_input = input("Please enter a description of the plot you would like to generate: ")
85
+
86
+ response = self.chain.run({"file": file, "human_input": human_input})
87
+ print(response)
88
+
89
+ # parse the code from the response
90
+ code = self.extract_code_snippets(response)
91
+ self.write_plot_script("plot.py", code[0])
92
+
93
+ # run the script if it doesn't work capture output and call attempt_to_fix_script
94
+ try:
95
+ subprocess.check_output(["python", "plot.py"], stderr=subprocess.STDOUT)
96
+ except subprocess.CalledProcessError as e:
97
+ print(e.output.decode())
98
+ self.attempt_to_fix_sctript("plot.py", e.output.decode())
99
+ except Exception as e:
100
+ print(e)
101
+ self.attempt_to_fix_sctript("plot.py", str(e))
102
+
103
+ return ("plot.png", None)
llm_plot.py CHANGED
@@ -1,99 +1,49 @@
1
- import re
2
- import random
3
- import linecache
4
- import subprocess
5
- from langchain.prompts import PromptTemplate
6
- from langchain.chat_models import ChatOpenAI
7
- from langchain.chains import LLMChain
8
- from langchain.llms import OpenAI
9
- from langchain.chains import ConversationChain
10
- from langchain.memory import ConversationBufferMemory
11
- from langchain.prompts.chat import (
12
- ChatPromptTemplate,
13
- HumanMessagePromptTemplate,
14
- )
15
- import os
16
- from dotenv import load_dotenv
17
 
18
- # Load environment variables from .env file
19
- load_dotenv()
20
 
21
- def sample_lines(filename, num_lines=5):
22
- with open(filename) as f:
23
- total_lines = sum(1 for _ in f)
24
-
25
- if total_lines < num_lines:
26
- raise ValueError("File has fewer lines than the number of lines requested.")
27
-
28
- line_numbers = random.sample(range(1, total_lines + 1), num_lines)
29
- lines = [linecache.getline(filename, line_number).rstrip() for line_number in line_numbers]
30
-
31
- return '\n'.join(lines)
32
-
33
-
34
- def extract_code_snippets(text):
35
- pattern = r'```(.*?)```'
36
- snippets = re.findall(pattern, text, re.DOTALL)
37
- if len(snippets) == 0:
38
- snippets = [text]
39
- return snippets
40
 
41
- def write_plot_script(filename, text):
42
- with open(filename, 'w') as file:
43
- file.write(text)
44
 
45
- def attempt_to_fix_sctript(filename, error_message):
46
- llm = ChatOpenAI(model_name="gpt-3.5-turbo", max_tokens=2000, temperature=0)
47
-
48
- fix_plot_script_template = PromptTemplate(
49
- input_variables=["error", "script"],
50
- template="You are an AI agent that is designed to debug scripts created to plot mavlink data using matplotlib and pymavlink's mavutil. the following script produced this error: \n\n{script}\n\nThe error is: \n\n{error}\n\nPlease fix the script so that it produces the correct plot.",
51
- )
52
-
53
- # read script from file
54
- with open(filename, 'r') as file:
55
- script = file.read()
56
 
57
- chain = LLMChain(verbose=True, llm=llm, prompt=fix_plot_script_template)
58
- response = chain.run({"error": error_message, "script": script})
59
- print(response)
60
- code = extract_code_snippets(response)
61
- write_plot_script("plot.py", code[0])
62
-
63
- # run the script
64
- os.system("python plot.py")
65
-
66
-
67
- if __name__ == "__main__":
68
-
69
-
70
- file = "data/2023-01-04 20-51-25.tlog"
71
- llm = ChatOpenAI(model_name="gpt-3.5-turbo", max_tokens=2000, temperature=0)
72
 
73
- mavlink_data_prompt = PromptTemplate(
74
- input_variables=["human_input", "file"],
75
- template="You are an AI conversation agent that will be used for generating python scripts to plot mavlink data provided by the user. Please create a python script using matplotlib and pymavlink's mavutil to plot the data provided by the user. Please do not explain the code just return the script. Please plot each independent variable over time. \n\nHUMAN: {human_input} \n\nplease read this data from the file {file}.",
 
 
 
 
 
 
 
 
 
 
 
76
  )
77
-
78
- chain = LLMChain(verbose=True, llm=llm, prompt=mavlink_data_prompt)
79
-
80
- # prompt the user for the what plot they would like to generate
81
- human_input = input("Please enter a description of the plot you would like to generate: ")
82
-
83
- # human_input = "Please create a script to plot x y and z from LOCAL_POSITION_NED from the following data."
84
- response = chain.run({"file": file, "human_input": human_input})
85
- print(response)
86
-
87
- # parse the code from the response
88
- code = extract_code_snippets(response)
89
- write_plot_script("plot.py", code[0])
90
 
91
- # run the script if it doesn't work capture output and call attempt_to_fix_script
92
- try:
93
- subprocess.check_output(["python", "plot.py"], stderr=subprocess.STDOUT)
94
- except subprocess.CalledProcessError as e:
95
- print(e.output.decode())
96
- attempt_to_fix_sctript("plot.py", e.output.decode())
97
- except Exception as e:
98
- print(e)
99
- attempt_to_fix_sctript("plot.py", str(e))
 
1
+ import gradio as gr
2
+ from llm.gptPlotCreator import PlotCreator
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
+ plot_creator = PlotCreator()
 
5
 
6
+ def add_text(history, text):
7
+ history = history + [(text, None)]
8
+ return history, ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
+ def add_file(history, file):
11
+ history = history + [((file.name,), None)]
12
+ return history
13
 
14
+ def bot(history):
15
+ # Get the last input from the user
16
+ user_input = history[-1][0]
 
 
 
 
 
 
 
 
17
 
18
+ # Check if it is a string
19
+ if isinstance(user_input, str):
20
+ # Generate the plot
21
+ img = plot_creator.create_plot(user_input)
22
+ response = img
23
+ else:
24
+ response = "**That's cool!**"
25
+
26
+ history[-1][1] = ('plot.png', None)
27
+ return history
28
+
29
+ with gr.Blocks() as demo:
30
+ chatbot = gr.Chatbot([], elem_id="chatbot").style(height=750)
 
 
31
 
32
+ with gr.Row():
33
+ with gr.Column(scale=0.85):
34
+ txt = gr.Textbox(
35
+ show_label=False,
36
+ placeholder="Enter text and press enter, or upload an image",
37
+ ).style(container=False)
38
+ with gr.Column(scale=0.15, min_width=0):
39
+ btn = gr.UploadButton("📁", file_types=["image", "video", "audio"])
40
+
41
+ txt.submit(add_text, [chatbot, txt], [chatbot, txt]).then(
42
+ bot, chatbot, chatbot
43
+ )
44
+ btn.upload(add_file, [chatbot, btn], [chatbot]).then(
45
+ bot, chatbot, chatbot
46
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ if __name__ == "__main__":
49
+ demo.launch()
 
 
 
 
 
 
 
plot.py DELETED
@@ -1,42 +0,0 @@
1
-
2
- import matplotlib.pyplot as plt
3
- from pymavlink import mavutil
4
-
5
- # Open the MAVLink log file
6
- mlog = mavutil.mavlink_connection('data/2023-01-04 20-51-25.tlog')
7
-
8
- # Initialize lists to store the data
9
- time_stamps = []
10
- latitudes = []
11
- longitudes = []
12
- altitudes = []
13
-
14
- # Loop through the log file and extract the data
15
- while True:
16
- msg = mlog.recv_match()
17
- if not msg:
18
- break
19
- if msg.get_type() == 'GLOBAL_POSITION_INT':
20
- time_stamps.append(msg.time_boot_ms / 1000.0)
21
- latitudes.append(msg.lat / 1e7)
22
- longitudes.append(msg.lon / 1e7)
23
- altitudes.append(msg.alt / 1000.0)
24
-
25
- # Plot the data
26
- plt.plot(time_stamps, latitudes)
27
- plt.xlabel('Time (s)')
28
- plt.ylabel('Latitude')
29
- plt.title('Aircraft Position Over Time')
30
- plt.show()
31
-
32
- plt.plot(time_stamps, longitudes)
33
- plt.xlabel('Time (s)')
34
- plt.ylabel('Longitude')
35
- plt.title('Aircraft Position Over Time')
36
- plt.show()
37
-
38
- plt.plot(time_stamps, altitudes)
39
- plt.xlabel('Time (s)')
40
- plt.ylabel('Altitude (km)')
41
- plt.title('Aircraft Position Over Time')
42
- plt.show()