YDluffy commited on
Commit
484cdf8
·
verified ·
1 Parent(s): 767eb12

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -40
app.py CHANGED
@@ -3,60 +3,43 @@ import gradio as gr
3
  import xgboost as xgb
4
  import numpy as np
5
  import pandas as pd
6
- import requests
7
  from huggingface_hub import hf_hub_download
8
 
9
- # **📌 LLM 服务器地址(修改为你的 LLM Space 地址)**
10
- LLM_API_URL = "https://your-llm-space.gradio.app" # 替换为 LLM 服务器地址
11
 
12
- # **📌 先运行 `preprocess.py` 处理数据**
13
- if not os.path.exists("processed_data.csv"):
14
  print("📌 运行 `preprocess.py` 进行数据处理...")
15
  os.system("python preprocess.py")
16
 
17
  # **📌 加载处理后的数据**
18
- processed_data_path = "processed_data.csv"
19
- if not os.path.exists(processed_data_path):
20
- raise FileNotFoundError("❌ `processed_data.csv` 未找到,请先运行 `preprocess.py` 处理数据!")
21
-
22
  df = pd.read_csv(processed_data_path)
23
 
 
 
 
 
 
24
  # **📌 加载 XGBoost 预测模型**
25
  model_path = hf_hub_download(repo_id="YDluffy/lottery_prediction", filename="lottery_xgboost_model.json")
26
  model = xgb.XGBRegressor()
27
  model.load_model(model_path)
28
 
29
- # **📌 调用 LLM API 解析用户输入**
30
- def get_prediction_from_llm(user_input):
 
 
31
  try:
32
- # 发送请求到 LLM API
33
- response = requests.post(LLM_API_URL + "/api/predict", json={"input": user_input})
34
-
35
- # **确保返回 JSON**
36
- if response.status_code == 200:
37
- try:
38
- llm_output = response.json().get("prediction", "")
39
- except requests.exceptions.JSONDecodeError:
40
- return "❌ LLM 服务器返回了无效的 JSON 数据,请检查 LLM Space 是否正常运行。"
41
- else:
42
- return f"❌ LLM 服务器错误,状态码: {response.status_code}"
43
-
44
- except requests.exceptions.ConnectionError:
45
- return "❌ 无法连接到 LLM 服务器,请确保 LLM Space 正在运行。"
46
-
47
- # **📌 解析 LLM 输出**
48
- year, period = 2025, 16
49
- nums = [5, 12, 23, 34, 45, 56]
50
- special = 7
51
-
52
- # **📌 进行 XGBoost 预测**
53
- prediction = predict_lottery(year, period, *nums, special)
54
-
55
- return f"📊 预测的号码是: {prediction}\n\n📢 LLM 解析的特征:{llm_output}"
56
 
57
  # **📌 预测函数**
58
  def predict_lottery(year, period, num1, num2, num3, num4, num5, num6, special):
59
- # 从历史数据查找对应的月份和日期
60
  history = df[(df['期号_年份'] == year) & (df['期数'] == period)]
61
  if not history.empty:
62
  month = history.iloc[0]['月份']
@@ -64,19 +47,33 @@ def predict_lottery(year, period, num1, num2, num3, num4, num5, num6, special):
64
  else:
65
  month, day = 1, 1 # 默认值
66
 
67
- # 计算中奖号码均值
68
  avg_number = np.mean([num1, num2, num3, num4, num5, num6])
69
 
70
- # 形成特征数组
71
  features = np.array([[year, period, month, day, num1, num2, num3, num4, num5, num6, special, avg_number]])
72
 
73
- # 进行预测
74
  prediction = model.predict(features)
75
  return prediction
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  # **📌 Gradio Web 界面**
78
  iface = gr.Interface(
79
- fn=get_prediction_from_llm,
80
  inputs=gr.Textbox(label="请输入问题或期号信息"),
81
  outputs="text",
82
  title="六合彩预测模型",
 
3
  import xgboost as xgb
4
  import numpy as np
5
  import pandas as pd
6
+ from transformers import pipeline
7
  from huggingface_hub import hf_hub_download
8
 
9
+ # **📌 先运行 `preprocess.py`,确保数据可用**
10
+ processed_data_path = "processed_data.csv"
11
 
12
+ if not os.path.exists(processed_data_path):
 
13
  print("📌 运行 `preprocess.py` 进行数据处理...")
14
  os.system("python preprocess.py")
15
 
16
  # **📌 加载处理后的数据**
 
 
 
 
17
  df = pd.read_csv(processed_data_path)
18
 
19
+ # **📌 加载 Hugging Face GPT-Neo 作为 LLM**
20
+ print("📌 正在加载 GPT-Neo 模型...")
21
+ generator = pipeline('text-generation', model='EleutherAI/gpt-neo-2.7B')
22
+ print("✅ GPT-Neo 模型加载成功!")
23
+
24
  # **📌 加载 XGBoost 预测模型**
25
  model_path = hf_hub_download(repo_id="YDluffy/lottery_prediction", filename="lottery_xgboost_model.json")
26
  model = xgb.XGBRegressor()
27
  model.load_model(model_path)
28
 
29
+ # **📌 LLM 解析用户输入**
30
+ def chat_with_llm(user_input):
31
+ prompt = f"请从以下问题中提取出预测所需的参数(年份、期号、中奖号码等):'{user_input}'"
32
+
33
  try:
34
+ response = generator(prompt, max_length=100, num_return_sequences=1)
35
+ extracted_text = response[0]['generated_text']
36
+ return extracted_text
37
+ except Exception as e:
38
+ return f"❌ GPT-Neo 处理错误: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  # **📌 预测函数**
41
  def predict_lottery(year, period, num1, num2, num3, num4, num5, num6, special):
42
+ # **从历史数据查找对应的月份和日期**
43
  history = df[(df['期号_年份'] == year) & (df['期数'] == period)]
44
  if not history.empty:
45
  month = history.iloc[0]['月份']
 
47
  else:
48
  month, day = 1, 1 # 默认值
49
 
50
+ # **计算中奖号码均值**
51
  avg_number = np.mean([num1, num2, num3, num4, num5, num6])
52
 
53
+ # **形成特征数组**
54
  features = np.array([[year, period, month, day, num1, num2, num3, num4, num5, num6, special, avg_number]])
55
 
56
+ # **进行预测**
57
  prediction = model.predict(features)
58
  return prediction
59
 
60
+ # **📌 结合 LLM 和 XGBoost**
61
+ def predict_and_interact(user_input):
62
+ llm_output = chat_with_llm(user_input)
63
+
64
+ # **默认解析出的预测参数**
65
+ year, period = 2025, 16
66
+ nums = [5, 12, 23, 34, 45, 56]
67
+ special = 7
68
+
69
+ # **进行预测**
70
+ prediction = predict_lottery(year, period, *nums, special)
71
+
72
+ return f"📊 预测的号码是: {prediction}\n\n📢 LLM 解析的特征:{llm_output}"
73
+
74
  # **📌 Gradio Web 界面**
75
  iface = gr.Interface(
76
+ fn=predict_and_interact,
77
  inputs=gr.Textbox(label="请输入问题或期号信息"),
78
  outputs="text",
79
  title="六合彩预测模型",