yannjj commited on
Commit
5b5f612
·
verified ·
1 Parent(s): 81933e4

Create chronos-gradio.py

Browse files
Files changed (1) hide show
  1. chronos-gradio.py +59 -0
chronos-gradio.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from chronos import ChronosPipeline
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ # 从 Hugging Face 加载模型
8
+ # model_name = "amazon/chronos-t5-small" # 替换为你在 Hugging Face 上的模型名称
9
+ # model = AutoModelForConditionalGeneration.from_pretrained(model_name)
10
+ # model.eval()
11
+ model = ChronosPipeline.from_pretrained(
12
+ "amazon/chronos-t5-small",
13
+ device_map="cuda",
14
+ torch_dtype=torch.bfloat16,
15
+ )
16
+
17
+ def predict_with_chronos(input_data):
18
+ prediction = model.predict(
19
+ context=input_data,
20
+ prediction_length=24,
21
+ num_samples=1
22
+ )
23
+ return np.round(prediction.mean(axis=0).squeeze().cpu().numpy()).astype(int)
24
+
25
+ def predict_from_csv(csv_file):
26
+ df = pd.read_csv(csv_file.name)
27
+
28
+ raw_values = pd.to_numeric(df['value'], errors='coerce').dropna().values
29
+ print(raw_values)
30
+ print('输入数据长度为:',len(raw_values))
31
+ input_data = torch.tensor(
32
+ raw_values.astype(np.float32)
33
+ )
34
+
35
+ predictions = predict_with_chronos(input_data)
36
+ predictions = np.asarray(predictions).ravel()
37
+
38
+ forecast_index = range(1, len(predictions)+1)
39
+
40
+ assert len(forecast_index) == len(predictions), "数组长度不一致"
41
+
42
+ output_df = pd.DataFrame({
43
+ 'period': forecast_index,
44
+ 'value': predictions
45
+ })
46
+
47
+ output_path = "/tmp/predictions.csv"
48
+ output_df.to_csv(output_path, index=False)
49
+ return output_path
50
+
51
+ iface = gr.Interface(
52
+ fn=predict_from_csv,
53
+ inputs=gr.File(label="上传包含时序数据的 CSV 文件"),
54
+ outputs=gr.File(label="预测结果下载", file_count="single"),
55
+ title="Chronos时序预测",
56
+ description="上传包含时序数据的 CSV 文件,获取未来24步预测结果。"
57
+ )
58
+
59
+ iface.launch()