yannjj's picture
Rename chronos-gradio.py to app.py
2c4fd7d verified
import gradio as gr
import torch
from chronos import ChronosPipeline
import numpy as np
import pandas as pd
# 从 Hugging Face 加载模型
# model_name = "amazon/chronos-t5-small" # 替换为你在 Hugging Face 上的模型名称
# model = AutoModelForConditionalGeneration.from_pretrained(model_name)
# model.eval()
model = ChronosPipeline.from_pretrained(
"amazon/chronos-t5-small",
device_map="cuda",
torch_dtype=torch.bfloat16,
)
def predict_with_chronos(input_data):
prediction = model.predict(
context=input_data,
prediction_length=24,
num_samples=1
)
return np.round(prediction.mean(axis=0).squeeze().cpu().numpy()).astype(int)
def predict_from_csv(csv_file):
df = pd.read_csv(csv_file.name)
raw_values = pd.to_numeric(df['value'], errors='coerce').dropna().values
print(raw_values)
print('输入数据长度为:',len(raw_values))
input_data = torch.tensor(
raw_values.astype(np.float32)
)
predictions = predict_with_chronos(input_data)
predictions = np.asarray(predictions).ravel()
forecast_index = range(1, len(predictions)+1)
assert len(forecast_index) == len(predictions), "数组长度不一致"
output_df = pd.DataFrame({
'period': forecast_index,
'value': predictions
})
output_path = "/tmp/predictions.csv"
output_df.to_csv(output_path, index=False)
return output_path
iface = gr.Interface(
fn=predict_from_csv,
inputs=gr.File(label="上传包含时序数据的 CSV 文件"),
outputs=gr.File(label="预测结果下载", file_count="single"),
title="Chronos时序预测",
description="上传包含时序数据的 CSV 文件,获取未来24步预测结果。"
)
iface.launch()