Spaces:
Runtime error
Runtime error
import apache_beam as beam | |
import gradio as gr | |
import huggingface_hub | |
import pandas as pd | |
import plotly.graph_objects as go | |
import spaces | |
import textwrap | |
import torch | |
import us | |
from apache_beam.options.pipeline_options import PipelineOptions, SetupOptions | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import json | |
import logging | |
import os | |
import requests | |
MODEL_NAME = "google/gemma-2-2b-it" | |
PROMPT_TEMPLATE = """Write a succinct summary of the following weather alerts. Do not comment on missing information - just summarize the information provided/available. | |
```json | |
{} | |
``` | |
Summary (In the state...): | |
""" | |
# Initialize an empty list to store weather alerts | |
alerts = [] | |
# Define a transform for fetching weather alerts | |
class FetchWeatherAlerts(beam.DoFn): | |
def process(self, state): | |
logging.info(f"Fetching weather alerts for {state} from weather.gov") | |
url = f"https://api.weather.gov/alerts/active?area={state}" | |
response = requests.get( | |
url, | |
headers={ | |
"User-Agent": "(Neal DeBuhr, https://huggingface.co/spaces/ndebuhr/streaming-llm-weather-alerts)", | |
"Accept": "application/geo+json", | |
}, | |
) | |
if response.status_code == 200: | |
logging.info(f"Fetched weather alerts for {state} from weather.gov") | |
features = response.json()["features"] | |
alerts.append( | |
{ | |
"features": [ | |
{ | |
"event": feature["properties"]["event"], | |
"headline": feature["properties"]["headline"], | |
"instruction": feature["properties"]["instruction"], | |
} | |
for feature in features | |
if feature["properties"]["messageType"] == "Alert" | |
], | |
"state": state, | |
} | |
) | |
pipeline_options = PipelineOptions() | |
# Save the main session state so that pickled functions and classes | |
# defined in __main__ can be unpickled | |
pipeline_options.view_as(SetupOptions).save_main_session = True | |
# Create and run the Apache Beam pipeline to fetch weather alerts | |
with beam.Pipeline(options=pipeline_options) as p: | |
(p | |
| "Create States" >> beam.Create([state.abbr for state in us.states.STATES]) | |
| "Fetch Weather Alerts" >> beam.ParDo(FetchWeatherAlerts()) | |
) | |
# Define a function to generate alert summaries using transformers and ZeroGPU | |
def generate_summaries(alerts): | |
huggingface_hub.login(token=os.environ["HUGGINGFACE_TOKEN"]) | |
device = torch.device("cuda") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForCausalLM.from_pretrained(MODEL_NAME).to(device) | |
for alert in alerts: | |
prompt = PROMPT_TEMPLATE.format(json.dumps(alert, indent=2)) | |
inputs = tokenizer(prompt, return_tensors="pt").to(device) | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, max_new_tokens=256, pad_token_id=tokenizer.eos_token_id | |
) | |
alert["summary"] = ( | |
tokenizer.decode(outputs[0], skip_special_tokens=True) | |
.replace(prompt, "") | |
.strip() | |
) | |
return alerts | |
alerts = generate_summaries(alerts) | |
df = pd.DataFrame.from_dict( | |
[{"state": alert["state"], "summary": alert["summary"]} for alert in alerts] | |
) | |
def get_map(): | |
def wrap_text(text, width=50): | |
return "<br>".join(textwrap.wrap(text, width=width)) | |
df["wrapped_summary"] = df["summary"].apply(wrap_text) | |
fig = go.Figure( | |
go.Choropleth( | |
locations=df["state"], | |
z=[1 for _ in df["summary"]], | |
locationmode="USA-states", | |
colorscale=[ | |
[0, "lightgrey"], | |
[1, "lightgrey"], | |
], # Single color for all states | |
showscale=False, | |
text=df["wrapped_summary"], | |
hoverinfo="text", | |
hovertemplate="%{text}<extra></extra>", | |
) | |
) | |
fig.update_layout(title_text="Streaming LLM Weather Alerts", geo_scope="usa") | |
return fig | |
# Create Gradio interface | |
iface = gr.Interface(fn=get_map, inputs=None, outputs=gr.Plot()) | |
# Launch the Gradio interface | |
iface.launch() | |