YDluffy commited on
Commit
be0f990
·
verified ·
1 Parent(s): 0686373

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -85
app.py CHANGED
@@ -1,104 +1,47 @@
1
- import os
2
  import gradio as gr
3
  import xgboost as xgb
4
  import numpy as np
5
  import pandas as pd
6
- from transformers import pipeline
7
  from huggingface_hub import hf_hub_download
8
 
9
- # **📌 先运行 `preprocess.py` 处理数据**
10
- processed_data_path = "processed_data.csv"
 
 
11
 
12
- if not os.path.exists(processed_data_path):
13
- print("📌 运行 `preprocess.py` 进行数据处理...")
14
- os.system("python preprocess.py")
15
-
16
- # **📌 加载处理后的数据**
17
- df = pd.read_csv(processed_data_path)
18
-
19
- # **📌 加载 Hugging Face GPT-Neo 作为 LLM**
20
- print("📌 正在加载 GPT-Neo 模型...")
21
- generator = pipeline('text-generation', model='EleutherAI/gpt-neo-2.7B')
22
- print("✅ GPT-Neo 模型加载成功!")
23
-
24
- # **📌 加载 XGBoost 预测模型**
25
- model_path = hf_hub_download(repo_id="YDluffy/lottery_prediction", filename="lottery_xgboost_model.json")
26
- model = xgb.XGBRegressor()
27
  model.load_model(model_path)
28
 
29
- # **📌 预测函数**
30
- def predict_lottery(year, period, num1, num2, num3, num4, num5, num6, special):
31
- # **从历史数据查找对应的月份和日期**
32
- history = df[(df['期号_年份'] == year) & (df['期数'] == period)]
33
- if not history.empty:
34
- month = history.iloc[0]['月份']
35
- day = history.iloc[0]['日期']
36
- else:
37
- month, day = 1, 1 # 默认值
38
-
39
- # **计算中奖号码均值**
40
- avg_number = np.mean([num1, num2, num3, num4, num5, num6])
41
-
42
- # **形成特征数组**
43
- features = np.array([[year, period, month, day, num1, num2, num3, num4, num5, num6, special, avg_number]])
44
-
45
- # **进行预测**
46
- prediction = model.predict(features)
47
 
48
- # **🚀 修正浮点数问题:四舍五入为整数**
49
- prediction = np.round(prediction).astype(int)
50
-
51
- return prediction.tolist()
52
-
53
- # **📌 LLM 解析用户输入并调用 XGBoost**
54
- def chat_with_llm(user_input):
55
- prompt = (
56
- "请从以下问题中提取出预测所需的参数,并调用 XGBoost 预测模型。\n"
57
- "返回格式:\n"
58
- "年份:2025, 期号:16, 号码:[5,12,23,34,45,56], 特别号码:7\n"
59
- "请勿返回额外信息,仅输出这个格式的数据。\n"
60
- f"输入问题: {user_input}"
61
- )
62
-
63
- try:
64
- response = generator(
65
- prompt,
66
- max_new_tokens=50,
67
- temperature=0.1,
68
- num_return_sequences=1
69
- )
70
- extracted_text = response[0]['generated_text']
71
-
72
- # **🚀 确保 UTF-8 编码,移除特殊字符**
73
- extracted_text = extracted_text.encode("utf-8", "ignore").decode("utf-8").strip()
74
-
75
- # **✅ 确保返回格式正确**
76
- if not all(keyword in extracted_text for keyword in ["年份:", "期号:", "号码:", "特别号码:"]):
77
- return f"❌ LLM 解析失败:返回数据格式不正确。\n📢 LLM 解析结果: {extracted_text}"
78
-
79
- # **解析 GPT-Neo 生成的文本**
80
- parts = extracted_text.split(",")
81
- year = int(parts[0].split(":")[1].strip())
82
- period = int(parts[1].split(":")[1].strip())
83
- nums = [int(x) for x in parts[2].split(":")[1].strip("[]").split()]
84
- special = int(parts[3].split(":")[1].strip())
85
 
86
- # **调用 XGBoost 进行预测**
87
- prediction = predict_lottery(year, period, *nums, special)
 
 
 
 
88
 
89
- return f"📊 预测的号码是: {prediction}\n\n📢 LLM 解析的特征:{extracted_text}"
 
 
90
 
91
- except Exception as e:
92
- return f"❌ LLM 解析数据失败: {str(e)}\n📢 LLM 解析结果: {extracted_text}"
93
 
94
- # **📌 Gradio Web 界面**
95
  iface = gr.Interface(
96
- fn=chat_with_llm,
97
- inputs=gr.Textbox(label="请输入问题或期号信息"),
98
  outputs="text",
99
- title="六合彩预测模型",
100
- description="GPT 解析输入信息,并调用 XGBoost 进行预测"
101
  )
102
 
103
- # **📌 启动 Gradio 应用**
104
  iface.launch(share=True)
 
 
1
  import gradio as gr
2
  import xgboost as xgb
3
  import numpy as np
4
  import pandas as pd
 
5
  from huggingface_hub import hf_hub_download
6
 
7
+ # **📥 Hugging Face 下载 XGBoost 模型**
8
+ repo_id = "YDluffy/lottery_prediction"
9
+ model_filename = "lottery_xgboost_model.ubj"
10
+ model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)
11
 
12
+ # **✅ 加载 XGBoost 预测模型**
13
+ model = xgb.Booster()
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  model.load_model(model_path)
15
 
16
+ # **📥 读取历史开奖记录**
17
+ history_data_path = "Mark_Six.csv"
18
+ history_data = pd.read_csv(history_data_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ # **📌 预测函数**
21
+ def predict_lottery(year, period):
22
+ # **📌 1. 查找往年相同期数的开奖记录**
23
+ historical_matches = history_data[history_data["期数"] == period]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # **📌 2. 计算历史趋势(如高频号码、奇偶比例)**
26
+ most_frequent_numbers = historical_matches.iloc[:, -7:].mode().iloc[0].tolist()
27
+
28
+ # **📌 3. 生成 XGBoost 预测输入**
29
+ test_features = np.array([[year, period] + most_frequent_numbers])
30
+ dtest = xgb.DMatrix(test_features)
31
 
32
+ # **📌 4. 进行预测**
33
+ prediction = model.predict(dtest)
34
+ final_prediction = np.round(prediction).astype(int).tolist()
35
 
36
+ return final_prediction
 
37
 
38
+ # **📌 创建 API 接口**
39
  iface = gr.Interface(
40
+ fn=predict_lottery,
41
+ inputs=["number", "number"],
42
  outputs="text",
43
+ title="六合彩智能预测 API",
44
+ description="输入年份和期数,自动分析历史开奖记录,预测开奖号码"
45
  )
46
 
 
47
  iface.launch(share=True)