Spaces:
Running
Running
Merge marimo-team/learn (add __marimo__ to gitignore)
Browse files- .github/workflows/deploy.yml +56 -0
- .github/workflows/hf_sync.yml +45 -0
- .github/workflows/typos.yaml +0 -1
- .gitignore +3 -0
- Dockerfile +26 -0
- README.md +18 -2
- _server/README.md +24 -0
- _server/main.py +90 -0
- duckdb/README.md +28 -0
- functional_programming/05_functors.py +1313 -0
- functional_programming/CHANGELOG.md +36 -0
- functional_programming/README.md +61 -0
- optimization/05_portfolio_optimization.py +1 -1
- polars/04_basic_operations.py +631 -0
- polars/10_strings.py +1004 -0
- polars/12_aggregations.py +355 -0
- polars/14_user_defined_functions.py +946 -0
- polars/README.md +1 -0
- probability/08_bayes_theorem.py +1 -1
- probability/10_probability_mass_function.py +711 -0
- probability/11_expectation.py +860 -0
- probability/12_variance.py +631 -0
- probability/13_bernoulli_distribution.py +427 -0
- probability/14_binomial_distribution.py +545 -0
- probability/15_poisson_distribution.py +805 -0
- python/006_dictionaries.py +2 -2
- scripts/build.py +1523 -0
- scripts/preview.py +76 -0
.github/workflows/deploy.yml
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Deploy to GitHub Pages
|
2 |
+
|
3 |
+
on:
|
4 |
+
push:
|
5 |
+
branches: ['main']
|
6 |
+
workflow_dispatch:
|
7 |
+
|
8 |
+
concurrency:
|
9 |
+
group: 'pages'
|
10 |
+
cancel-in-progress: false
|
11 |
+
|
12 |
+
env:
|
13 |
+
UV_SYSTEM_PYTHON: 1
|
14 |
+
|
15 |
+
jobs:
|
16 |
+
build:
|
17 |
+
runs-on: ubuntu-latest
|
18 |
+
steps:
|
19 |
+
- uses: actions/checkout@v4
|
20 |
+
|
21 |
+
- name: 🚀 Install uv
|
22 |
+
uses: astral-sh/setup-uv@v4
|
23 |
+
|
24 |
+
- name: 🐍 Set up Python
|
25 |
+
uses: actions/setup-python@v5
|
26 |
+
with:
|
27 |
+
python-version: 3.12
|
28 |
+
|
29 |
+
- name: 📦 Install dependencies
|
30 |
+
run: |
|
31 |
+
uv pip install marimo
|
32 |
+
|
33 |
+
- name: 🛠️ Export notebooks
|
34 |
+
run: |
|
35 |
+
python scripts/build.py
|
36 |
+
|
37 |
+
- name: 📤 Upload artifact
|
38 |
+
uses: actions/upload-pages-artifact@v3
|
39 |
+
with:
|
40 |
+
path: _site
|
41 |
+
|
42 |
+
deploy:
|
43 |
+
needs: build
|
44 |
+
|
45 |
+
permissions:
|
46 |
+
pages: write
|
47 |
+
id-token: write
|
48 |
+
|
49 |
+
environment:
|
50 |
+
name: github-pages
|
51 |
+
url: ${{ steps.deployment.outputs.page_url }}
|
52 |
+
runs-on: ubuntu-latest
|
53 |
+
steps:
|
54 |
+
- name: 🚀 Deploy to GitHub Pages
|
55 |
+
id: deployment
|
56 |
+
uses: actions/deploy-pages@v4
|
.github/workflows/hf_sync.yml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: Sync to Hugging Face hub
|
2 |
+
on:
|
3 |
+
push:
|
4 |
+
branches: [main]
|
5 |
+
|
6 |
+
# to run this workflow manually from the Actions tab
|
7 |
+
workflow_dispatch:
|
8 |
+
|
9 |
+
jobs:
|
10 |
+
sync-to-hub:
|
11 |
+
runs-on: ubuntu-latest
|
12 |
+
steps:
|
13 |
+
- uses: actions/checkout@v4
|
14 |
+
with:
|
15 |
+
fetch-depth: 0
|
16 |
+
|
17 |
+
- name: Configure Git
|
18 |
+
run: |
|
19 |
+
git config --global user.name "GitHub Action"
|
20 |
+
git config --global user.email "[email protected]"
|
21 |
+
|
22 |
+
- name: Prepend frontmatter to README
|
23 |
+
run: |
|
24 |
+
if [ -f README.md ] && ! grep -q "^---" README.md; then
|
25 |
+
FRONTMATTER="---
|
26 |
+
title: marimo learn
|
27 |
+
emoji: 🧠
|
28 |
+
colorFrom: blue
|
29 |
+
colorTo: indigo
|
30 |
+
sdk: docker
|
31 |
+
sdk_version: \"latest\"
|
32 |
+
app_file: app.py
|
33 |
+
pinned: false
|
34 |
+
---
|
35 |
+
|
36 |
+
"
|
37 |
+
echo "$FRONTMATTER$(cat README.md)" > README.md
|
38 |
+
git add README.md
|
39 |
+
git commit -m "Add HF frontmatter to README" || echo "No changes to commit"
|
40 |
+
fi
|
41 |
+
|
42 |
+
- name: Push to hub
|
43 |
+
env:
|
44 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
45 |
+
run: git push -f https://mylessss:[email protected]/spaces/marimo-team/marimo-learn main
|
.github/workflows/typos.yaml
CHANGED
@@ -13,4 +13,3 @@ jobs:
|
|
13 |
uses: styfle/[email protected]
|
14 |
- uses: actions/checkout@v4
|
15 |
- uses: crate-ci/[email protected]
|
16 |
-
name: Tests
|
|
|
13 |
uses: styfle/[email protected]
|
14 |
- uses: actions/checkout@v4
|
15 |
- uses: crate-ci/[email protected]
|
|
.gitignore
CHANGED
@@ -172,3 +172,6 @@ cython_debug/
|
|
172 |
|
173 |
# Marimo specific
|
174 |
__marimo__
|
|
|
|
|
|
|
|
172 |
|
173 |
# Marimo specific
|
174 |
__marimo__
|
175 |
+
|
176 |
+
# Generated site content
|
177 |
+
_site/
|
Dockerfile
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
# Create a non-root user
|
6 |
+
RUN useradd -m appuser
|
7 |
+
|
8 |
+
# Copy application files
|
9 |
+
COPY _server/main.py _server/main.py
|
10 |
+
COPY polars/ polars/
|
11 |
+
COPY duckdb/ duckdb/
|
12 |
+
|
13 |
+
# Set proper ownership
|
14 |
+
RUN chown -R appuser:appuser /app
|
15 |
+
|
16 |
+
# Switch to non-root user
|
17 |
+
USER appuser
|
18 |
+
|
19 |
+
# Create virtual environment and install dependencies
|
20 |
+
RUN uv venv
|
21 |
+
RUN uv export --script _server/main.py | uv pip install -r -
|
22 |
+
|
23 |
+
ENV PORT=7860
|
24 |
+
EXPOSE 7860
|
25 |
+
|
26 |
+
CMD ["uv", "run", "_server/main.py"]
|
README.md
CHANGED
@@ -3,7 +3,7 @@
|
|
3 |
</p>
|
4 |
|
5 |
<p align="center">
|
6 |
-
<em>A curated collection of educational <a href="https://github.com/marimo-team/marimo">marimo</a> notebooks</em
|
7 |
</p>
|
8 |
|
9 |
# 📚 Learn
|
@@ -29,6 +29,7 @@ notebooks for educators, students, and practitioners.
|
|
29 |
- 📏 Linear algebra
|
30 |
- ❄️ Polars
|
31 |
- 🔥 Pytorch
|
|
|
32 |
- 📈 Altair
|
33 |
- 📈 Plotly
|
34 |
- 📈 matplotlib
|
@@ -57,12 +58,27 @@ Here's a contribution checklist:
|
|
57 |
If you aren't comfortable adding a new notebook or course, you can also request
|
58 |
what you'd like to see by [filing an issue](https://github.com/marimo-team/learn/issues/new?template=example_request.yaml).
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
## Community
|
61 |
|
62 |
We're building a community. Come hang out with us!
|
63 |
|
64 |
- 🌟 [Star us on GitHub](https://github.com/marimo-team/examples)
|
65 |
-
- 💬 [Chat with us on Discord](https://
|
66 |
- 📧 [Subscribe to our Newsletter](https://marimo.io/newsletter)
|
67 |
- ☁️ [Join our Cloud Waitlist](https://marimo.io/cloud)
|
68 |
- ✏️ [Start a GitHub Discussion](https://github.com/marimo-team/marimo/discussions)
|
|
|
3 |
</p>
|
4 |
|
5 |
<p align="center">
|
6 |
+
<span><em>A curated collection of educational <a href="https://github.com/marimo-team/marimo">marimo</a> notebooks</em> || <a href="https://discord.gg/rT48v2Y9fe">💬 Discord</a></span>
|
7 |
</p>
|
8 |
|
9 |
# 📚 Learn
|
|
|
29 |
- 📏 Linear algebra
|
30 |
- ❄️ Polars
|
31 |
- 🔥 Pytorch
|
32 |
+
- 🗄️ Duckdb
|
33 |
- 📈 Altair
|
34 |
- 📈 Plotly
|
35 |
- 📈 matplotlib
|
|
|
58 |
If you aren't comfortable adding a new notebook or course, you can also request
|
59 |
what you'd like to see by [filing an issue](https://github.com/marimo-team/learn/issues/new?template=example_request.yaml).
|
60 |
|
61 |
+
## Building and Previewing
|
62 |
+
|
63 |
+
The site is built using a Python script that exports marimo notebooks to HTML and generates an index page.
|
64 |
+
|
65 |
+
```bash
|
66 |
+
# Build the site
|
67 |
+
python scripts/build.py --output-dir _site
|
68 |
+
|
69 |
+
# Preview the site (builds first)
|
70 |
+
python scripts/preview.py
|
71 |
+
|
72 |
+
# Preview without rebuilding
|
73 |
+
python scripts/preview.py --no-build
|
74 |
+
```
|
75 |
+
|
76 |
## Community
|
77 |
|
78 |
We're building a community. Come hang out with us!
|
79 |
|
80 |
- 🌟 [Star us on GitHub](https://github.com/marimo-team/examples)
|
81 |
+
- 💬 [Chat with us on Discord](https://discord.gg/rT48v2Y9fe)
|
82 |
- 📧 [Subscribe to our Newsletter](https://marimo.io/newsletter)
|
83 |
- ☁️ [Join our Cloud Waitlist](https://marimo.io/cloud)
|
84 |
- ✏️ [Start a GitHub Discussion](https://github.com/marimo-team/marimo/discussions)
|
_server/README.md
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# marimo learn server
|
2 |
+
|
3 |
+
This folder contains server code for hosting marimo apps.
|
4 |
+
|
5 |
+
## Running the server
|
6 |
+
|
7 |
+
```bash
|
8 |
+
cd _server
|
9 |
+
uv run --no-project main.py
|
10 |
+
```
|
11 |
+
|
12 |
+
## Building a Docker image
|
13 |
+
|
14 |
+
From the root directory, run:
|
15 |
+
|
16 |
+
```bash
|
17 |
+
docker build -t marimo-learn .
|
18 |
+
```
|
19 |
+
|
20 |
+
## Running the Docker container
|
21 |
+
|
22 |
+
```bash
|
23 |
+
docker run -p 7860:7860 marimo-learn
|
24 |
+
```
|
_server/main.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.12"
|
3 |
+
# dependencies = [
|
4 |
+
# "fastapi",
|
5 |
+
# "marimo",
|
6 |
+
# "starlette",
|
7 |
+
# "python-dotenv",
|
8 |
+
# "pydantic",
|
9 |
+
# "duckdb",
|
10 |
+
# "altair==5.5.0",
|
11 |
+
# "beautifulsoup4==4.13.3",
|
12 |
+
# "httpx==0.28.1",
|
13 |
+
# "marimo",
|
14 |
+
# "nest-asyncio==1.6.0",
|
15 |
+
# "numba==0.61.0",
|
16 |
+
# "numpy==2.1.3",
|
17 |
+
# "polars==1.24.0",
|
18 |
+
# ]
|
19 |
+
# ///
|
20 |
+
|
21 |
+
import logging
|
22 |
+
import os
|
23 |
+
from pathlib import Path
|
24 |
+
|
25 |
+
import marimo
|
26 |
+
from dotenv import load_dotenv
|
27 |
+
from fastapi import FastAPI, Request
|
28 |
+
from fastapi.responses import HTMLResponse
|
29 |
+
|
30 |
+
# Load environment variables
|
31 |
+
load_dotenv()
|
32 |
+
|
33 |
+
# Set up logging
|
34 |
+
logging.basicConfig(level=logging.INFO)
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
# Get port from environment variable or use default
|
38 |
+
PORT = int(os.environ.get("PORT", 7860))
|
39 |
+
|
40 |
+
root_dir = Path(__file__).parent.parent
|
41 |
+
|
42 |
+
ROOTS = [
|
43 |
+
root_dir / "polars",
|
44 |
+
root_dir / "duckdb",
|
45 |
+
]
|
46 |
+
|
47 |
+
|
48 |
+
server = marimo.create_asgi_app(include_code=True)
|
49 |
+
app_names: list[str] = []
|
50 |
+
|
51 |
+
for root in ROOTS:
|
52 |
+
for filename in root.iterdir():
|
53 |
+
if filename.is_file() and filename.suffix == ".py":
|
54 |
+
app_path = root.stem + "/" + filename.stem
|
55 |
+
server = server.with_app(path=f"/{app_path}", root=str(filename))
|
56 |
+
app_names.append(app_path)
|
57 |
+
|
58 |
+
# Create a FastAPI app
|
59 |
+
app = FastAPI()
|
60 |
+
|
61 |
+
logger.info(f"Mounting {len(app_names)} apps")
|
62 |
+
for app_name in app_names:
|
63 |
+
logger.info(f" /{app_name}")
|
64 |
+
|
65 |
+
|
66 |
+
@app.get("/")
|
67 |
+
async def home(request: Request):
|
68 |
+
html_content = """
|
69 |
+
<!DOCTYPE html>
|
70 |
+
<html>
|
71 |
+
<head>
|
72 |
+
<title>marimo learn</title>
|
73 |
+
</head>
|
74 |
+
<body>
|
75 |
+
<h1>Welcome to marimo learn!</h1>
|
76 |
+
<p>This is a collection of interactive tutorials for learning data science libraries with marimo.</p>
|
77 |
+
<p>Check out the <a href="https://github.com/marimo-team/learn">GitHub repository</a> for more information.</p>
|
78 |
+
</body>
|
79 |
+
</html>
|
80 |
+
"""
|
81 |
+
return HTMLResponse(content=html_content)
|
82 |
+
|
83 |
+
|
84 |
+
app.mount("/", server.build())
|
85 |
+
|
86 |
+
# Run the server
|
87 |
+
if __name__ == "__main__":
|
88 |
+
import uvicorn
|
89 |
+
|
90 |
+
uvicorn.run(app, host="0.0.0.0", port=PORT, log_level="info")
|
duckdb/README.md
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Learn DuckDB
|
2 |
+
|
3 |
+
_🚧 This collection is a work in progress. Please help us add notebooks!_
|
4 |
+
|
5 |
+
This collection of marimo notebooks is designed to teach you the basics of
|
6 |
+
DuckDB, a fast in-memory OLAP engine that can interoperate with Dataframes.
|
7 |
+
These notebooks also show how marimo gives DuckDB superpowers.
|
8 |
+
|
9 |
+
**Help us build this course! ⚒️**
|
10 |
+
|
11 |
+
We're seeking contributors to help us build these notebooks. Every contributor
|
12 |
+
will be acknowledged as an author in this README and in their contributed
|
13 |
+
notebooks. Head over to the [tracking
|
14 |
+
issue](https://github.com/marimo-team/learn/issues/48) to sign up for a planned
|
15 |
+
notebook or propose your own.
|
16 |
+
|
17 |
+
**Running notebooks.** To run a notebook locally, use
|
18 |
+
|
19 |
+
```bash
|
20 |
+
uvx marimo edit <file_url>
|
21 |
+
```
|
22 |
+
|
23 |
+
You can also open notebooks in our online playground by appending marimo.app/ to a notebook's URL.
|
24 |
+
|
25 |
+
|
26 |
+
**Authors.**
|
27 |
+
|
28 |
+
Thanks to all our notebook authors!
|
functional_programming/05_functors.py
ADDED
@@ -0,0 +1,1313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.9"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# ]
|
6 |
+
# ///
|
7 |
+
|
8 |
+
import marimo
|
9 |
+
|
10 |
+
__generated_with = "0.11.17"
|
11 |
+
app = marimo.App(app_title="Category Theory and Functors")
|
12 |
+
|
13 |
+
|
14 |
+
@app.cell(hide_code=True)
|
15 |
+
def _(mo):
|
16 |
+
mo.md(
|
17 |
+
"""
|
18 |
+
# Category Theory and Functors
|
19 |
+
|
20 |
+
In this notebook, you will learn:
|
21 |
+
|
22 |
+
* Why `length` is a *functor* from the category of `list concatenation` to the category of `integer addition`
|
23 |
+
* How to *lift* an ordinary function into a specific *computational context*
|
24 |
+
* How to write an *adapter* between two categories
|
25 |
+
|
26 |
+
In short, a mathematical functor is a **mapping** between two categories in category theory. In practice, a functor represents a type that can be mapped over.
|
27 |
+
|
28 |
+
/// admonition | Intuitions
|
29 |
+
|
30 |
+
- A simple intuition is that a `Functor` represents a **container** of values, along with the ability to apply a function uniformly to every element in the container.
|
31 |
+
- Another intuition is that a `Functor` represents some sort of **computational context**.
|
32 |
+
- Mathematically, `Functors` generalize the idea of a container or a computational context.
|
33 |
+
///
|
34 |
+
|
35 |
+
We will start with intuition, introduce the basics of category theory, and then examine functors from a categorical perspective.
|
36 |
+
|
37 |
+
/// details | Notebook metadata
|
38 |
+
type: info
|
39 |
+
|
40 |
+
version: 0.1.1 | last modified: 2025-03-16 | author: [métaboulie](https://github.com/metaboulie)<br/>
|
41 |
+
reviewer: [Haleshot](https://github.com/Haleshot)
|
42 |
+
|
43 |
+
///
|
44 |
+
"""
|
45 |
+
)
|
46 |
+
return
|
47 |
+
|
48 |
+
|
49 |
+
@app.cell(hide_code=True)
|
50 |
+
def _(mo):
|
51 |
+
mo.md(
|
52 |
+
"""
|
53 |
+
# Functor as a Computational Context
|
54 |
+
|
55 |
+
A [**Functor**](https://wiki.haskell.org/Functor) is an abstraction that represents a computational context with the ability to apply a function to every value inside it without altering the structure of the context itself. This enables transformations while preserving the shape of the data.
|
56 |
+
|
57 |
+
To understand this, let's look at a simple example.
|
58 |
+
|
59 |
+
## [The One-Way Wrapper Design Pattern](http://blog.sigfpe.com/2007/04/trivial-monad.html)
|
60 |
+
|
61 |
+
Often, we need to wrap data in some kind of context. However, when performing operations on wrapped data, we typically have to:
|
62 |
+
|
63 |
+
1. Unwrap the data.
|
64 |
+
2. Modify the unwrapped data.
|
65 |
+
3. Rewrap the modified data.
|
66 |
+
|
67 |
+
This process is tedious and inefficient. Instead, we want to wrap data **once** and apply functions directly to the wrapped data without unwrapping it.
|
68 |
+
|
69 |
+
/// admonition | Rules for a One-Way Wrapper
|
70 |
+
|
71 |
+
1. We can wrap values, but we cannot unwrap them.
|
72 |
+
2. We should still be able to apply transformations to the wrapped data.
|
73 |
+
3. Any operation that depends on wrapped data should itself return a wrapped result.
|
74 |
+
///
|
75 |
+
|
76 |
+
Let's define such a `Wrapper` class:
|
77 |
+
|
78 |
+
```python
|
79 |
+
from dataclasses import dataclass
|
80 |
+
from typing import Callable, Generic, TypeVar
|
81 |
+
|
82 |
+
A = TypeVar("A")
|
83 |
+
B = TypeVar("B")
|
84 |
+
|
85 |
+
@dataclass
|
86 |
+
class Wrapper(Generic[A]):
|
87 |
+
value: A
|
88 |
+
```
|
89 |
+
|
90 |
+
Now, we can create an instance of wrapped data:
|
91 |
+
|
92 |
+
```python
|
93 |
+
wrapped = Wrapper(1)
|
94 |
+
```
|
95 |
+
|
96 |
+
### Mapping Functions Over Wrapped Data
|
97 |
+
|
98 |
+
To modify wrapped data while keeping it wrapped, we define an `fmap` method:
|
99 |
+
|
100 |
+
```python
|
101 |
+
@dataclass
|
102 |
+
class Wrapper(Functor, Generic[A]):
|
103 |
+
value: A
|
104 |
+
|
105 |
+
@classmethod
|
106 |
+
def fmap(cls, f: Callable[[A], B], a: "Wrapper[A]") -> "Wrapper[B]":
|
107 |
+
return Wrapper(f(a.value))
|
108 |
+
```
|
109 |
+
|
110 |
+
Now, we can apply transformations without unwrapping:
|
111 |
+
|
112 |
+
```python
|
113 |
+
>>> Wrapper.fmap(lambda x: x + 1, wrapper)
|
114 |
+
Wrapper(value=2)
|
115 |
+
|
116 |
+
>>> Wrapper.fmap(lambda x: [x], wrapper)
|
117 |
+
Wrapper(value=[1])
|
118 |
+
```
|
119 |
+
|
120 |
+
> Try using the `Wrapper` in the cell below.
|
121 |
+
"""
|
122 |
+
)
|
123 |
+
return
|
124 |
+
|
125 |
+
|
126 |
+
@app.cell
|
127 |
+
def _(A, B, Callable, Functor, Generic, dataclass, pp):
|
128 |
+
@dataclass
|
129 |
+
class Wrapper(Functor, Generic[A]):
|
130 |
+
value: A
|
131 |
+
|
132 |
+
@classmethod
|
133 |
+
def fmap(cls, f: Callable[[A], B], a: "Wrapper[A]") -> "Wrapper[B]":
|
134 |
+
return Wrapper(f(a.value))
|
135 |
+
|
136 |
+
|
137 |
+
wrapper = Wrapper(1)
|
138 |
+
|
139 |
+
pp(Wrapper.fmap(lambda x: x + 1, wrapper))
|
140 |
+
pp(Wrapper.fmap(lambda x: [x], wrapper))
|
141 |
+
return Wrapper, wrapper
|
142 |
+
|
143 |
+
|
144 |
+
@app.cell(hide_code=True)
|
145 |
+
def _(mo):
|
146 |
+
mo.md(
|
147 |
+
"""
|
148 |
+
We can analyze the type signature of `fmap` for `Wrapper`:
|
149 |
+
|
150 |
+
* `f` is of type `Callable[[A], B]`
|
151 |
+
* `a` is of type `Wrapper[A]`
|
152 |
+
* The return value is of type `Wrapper[B]`
|
153 |
+
|
154 |
+
Thus, in Python's type system, we can express the type signature of `fmap` as:
|
155 |
+
|
156 |
+
```python
|
157 |
+
fmap(f: Callable[[A], B], a: Wrapper[A]) -> Wrapper[B]:
|
158 |
+
```
|
159 |
+
|
160 |
+
Essentially, `fmap`:
|
161 |
+
|
162 |
+
1. Takes a function `Callable[[A], B]` and a `Wrapper[A]` instance as input.
|
163 |
+
2. Applies the function to the value inside the wrapper.
|
164 |
+
3. Returns a new `Wrapper[B]` instance with the transformed value, leaving the original wrapper and its internal data unmodified.
|
165 |
+
|
166 |
+
Now, let's examine `list` as a similar kind of wrapper.
|
167 |
+
"""
|
168 |
+
)
|
169 |
+
return
|
170 |
+
|
171 |
+
|
172 |
+
@app.cell(hide_code=True)
|
173 |
+
def _(mo):
|
174 |
+
mo.md(
|
175 |
+
"""
|
176 |
+
## The List Wrapper
|
177 |
+
|
178 |
+
We can define a `List` class to represent a wrapped list that supports `fmap`:
|
179 |
+
|
180 |
+
```python
|
181 |
+
@dataclass
|
182 |
+
class List(Functor, Generic[A]):
|
183 |
+
value: list[A]
|
184 |
+
|
185 |
+
@classmethod
|
186 |
+
def fmap(cls, f: Callable[[A], B], a: "List[A]") -> "List[B]":
|
187 |
+
return List([f(x) for x in a.value])
|
188 |
+
```
|
189 |
+
|
190 |
+
Now, we can apply transformations:
|
191 |
+
|
192 |
+
```python
|
193 |
+
>>> flist = List([1, 2, 3, 4])
|
194 |
+
>>> List.fmap(lambda x: x + 1, flist)
|
195 |
+
List(value=[2, 3, 4, 5])
|
196 |
+
>>> List.fmap(lambda x: [x], flist)
|
197 |
+
List(value=[[1], [2], [3], [4]])
|
198 |
+
```
|
199 |
+
"""
|
200 |
+
)
|
201 |
+
return
|
202 |
+
|
203 |
+
|
204 |
+
@app.cell
|
205 |
+
def _(A, B, Callable, Functor, Generic, dataclass, pp):
|
206 |
+
@dataclass
|
207 |
+
class List(Functor, Generic[A]):
|
208 |
+
value: list[A]
|
209 |
+
|
210 |
+
@classmethod
|
211 |
+
def fmap(cls, f: Callable[[A], B], a: "List[A]") -> "List[B]":
|
212 |
+
return List([f(x) for x in a.value])
|
213 |
+
|
214 |
+
|
215 |
+
flist = List([1, 2, 3, 4])
|
216 |
+
pp(List.fmap(lambda x: x + 1, flist))
|
217 |
+
pp(List.fmap(lambda x: [x], flist))
|
218 |
+
return List, flist
|
219 |
+
|
220 |
+
|
221 |
+
@app.cell(hide_code=True)
|
222 |
+
def _(mo):
|
223 |
+
mo.md(
|
224 |
+
"""
|
225 |
+
### Extracting the Type of `fmap`
|
226 |
+
|
227 |
+
The type signature of `fmap` for `List` is:
|
228 |
+
|
229 |
+
```python
|
230 |
+
fmap(f: Callable[[A], B], a: List[A]) -> List[B]
|
231 |
+
```
|
232 |
+
|
233 |
+
Similarly, for `Wrapper`:
|
234 |
+
|
235 |
+
```python
|
236 |
+
fmap(f: Callable[[A], B], a: Wrapper[A]) -> Wrapper[B]
|
237 |
+
```
|
238 |
+
|
239 |
+
Both follow the same pattern, which we can generalize as:
|
240 |
+
|
241 |
+
```python
|
242 |
+
fmap(f: Callable[[A], B], a: Functor[A]) -> Functor[B]
|
243 |
+
```
|
244 |
+
|
245 |
+
where `Functor` can be `Wrapper`, `List`, or any other wrapper type that follows the same structure.
|
246 |
+
|
247 |
+
### Functors in Haskell (optional)
|
248 |
+
|
249 |
+
In Haskell, the type of `fmap` is:
|
250 |
+
|
251 |
+
```haskell
|
252 |
+
fmap :: Functor f => (a -> b) -> f a -> f b
|
253 |
+
```
|
254 |
+
|
255 |
+
or equivalently:
|
256 |
+
|
257 |
+
```haskell
|
258 |
+
fmap :: Functor f => (a -> b) -> (f a -> f b)
|
259 |
+
```
|
260 |
+
|
261 |
+
This means that `fmap` **lifts** an ordinary function into the **functor world**, allowing it to operate within a computational context.
|
262 |
+
|
263 |
+
Now, let's define an abstract class for `Functor`.
|
264 |
+
"""
|
265 |
+
)
|
266 |
+
return
|
267 |
+
|
268 |
+
|
269 |
+
@app.cell(hide_code=True)
|
270 |
+
def _(mo):
|
271 |
+
mo.md(
|
272 |
+
"""
|
273 |
+
## Defining Functor
|
274 |
+
|
275 |
+
Recall that, a **Functor** is an abstraction that allows us to apply a function to values inside a computational context while preserving its structure.
|
276 |
+
|
277 |
+
To define `Functor` in Python, we use an abstract base class:
|
278 |
+
|
279 |
+
```python
|
280 |
+
from dataclasses import dataclass
|
281 |
+
from typing import Callable, Generic, TypeVar
|
282 |
+
from abc import ABC, abstractmethod
|
283 |
+
|
284 |
+
A = TypeVar("A")
|
285 |
+
B = TypeVar("B")
|
286 |
+
|
287 |
+
@dataclass
|
288 |
+
class Functor(ABC, Generic[A]):
|
289 |
+
@classmethod
|
290 |
+
@abstractmethod
|
291 |
+
def fmap(f: Callable[[A], B], a: "Functor[A]") -> "Functor[B]":
|
292 |
+
raise NotImplementedError
|
293 |
+
```
|
294 |
+
|
295 |
+
We can now extend custom wrappers, containers, or computation contexts with this `Functor` base class, implement the `fmap` method, and apply any function.
|
296 |
+
|
297 |
+
Next, let's implement a more complex data structure: [RoseTree](https://en.wikipedia.org/wiki/Rose_tree).
|
298 |
+
"""
|
299 |
+
)
|
300 |
+
return
|
301 |
+
|
302 |
+
|
303 |
+
@app.cell(hide_code=True)
|
304 |
+
def _(mo):
|
305 |
+
mo.md(
|
306 |
+
"""
|
307 |
+
## Case Study: RoseTree
|
308 |
+
|
309 |
+
A **RoseTree** is a tree where:
|
310 |
+
|
311 |
+
- Each node holds a **value**.
|
312 |
+
- Each node has a **list of child nodes** (which are also RoseTrees).
|
313 |
+
|
314 |
+
This structure is useful for representing hierarchical data, such as:
|
315 |
+
|
316 |
+
- Abstract Syntax Trees (ASTs)
|
317 |
+
- File system directories
|
318 |
+
- Recursive computations
|
319 |
+
|
320 |
+
We can implement `RoseTree` by extending the `Functor` class:
|
321 |
+
|
322 |
+
```python
|
323 |
+
from dataclasses import dataclass
|
324 |
+
from typing import Callable, Generic, TypeVar
|
325 |
+
|
326 |
+
A = TypeVar("A")
|
327 |
+
B = TypeVar("B")
|
328 |
+
|
329 |
+
@dataclass
|
330 |
+
class RoseTree(Functor, Generic[a]):
|
331 |
+
|
332 |
+
value: A
|
333 |
+
children: list["RoseTree[A]"]
|
334 |
+
|
335 |
+
@classmethod
|
336 |
+
def fmap(cls, f: Callable[[A], B], a: "RoseTree[A]") -> "RoseTree[B]":
|
337 |
+
return RoseTree(
|
338 |
+
f(a.value), [cls.fmap(f, child) for child in a.children]
|
339 |
+
)
|
340 |
+
|
341 |
+
def __repr__(self) -> str:
|
342 |
+
return f"Node: {self.value}, Children: {self.children}"
|
343 |
+
```
|
344 |
+
|
345 |
+
- The function is applied **recursively** to each node's value.
|
346 |
+
- The tree structure **remains unchanged**.
|
347 |
+
- Only the values inside the tree are modified.
|
348 |
+
|
349 |
+
> Try using `RoseTree` in the cell below.
|
350 |
+
"""
|
351 |
+
)
|
352 |
+
return
|
353 |
+
|
354 |
+
|
355 |
+
@app.cell(hide_code=True)
|
356 |
+
def _(A, B, Callable, Functor, Generic, dataclass, mo):
|
357 |
+
@dataclass
|
358 |
+
class RoseTree(Functor, Generic[A]):
|
359 |
+
"""
|
360 |
+
### Doc: RoseTree
|
361 |
+
|
362 |
+
A Functor implementation of `RoseTree`, allowing transformation of values while preserving the tree structure.
|
363 |
+
|
364 |
+
**Attributes**
|
365 |
+
|
366 |
+
- `value (A)`: The value stored in the node.
|
367 |
+
- `children (list[RoseTree[A]])`: A list of child nodes forming the tree structure.
|
368 |
+
|
369 |
+
**Methods:**
|
370 |
+
|
371 |
+
- `fmap(f: Callable[[A], B], a: "RoseTree[A]") -> "RoseTree[B]"`
|
372 |
+
|
373 |
+
Applies a function to each value in the tree, producing a new `RoseTree[b]` with transformed values.
|
374 |
+
|
375 |
+
**Implementation logic:**
|
376 |
+
|
377 |
+
- The function `f` is applied to the root node's `value`.
|
378 |
+
- Each child in `children` recursively calls `fmap`, ensuring all values in the tree are mapped.
|
379 |
+
- The overall tree structure remains unchanged.
|
380 |
+
"""
|
381 |
+
|
382 |
+
value: A
|
383 |
+
children: list["RoseTree[A]"]
|
384 |
+
|
385 |
+
@classmethod
|
386 |
+
def fmap(cls, f: Callable[[A], B], a: "RoseTree[A]") -> "RoseTree[B]":
|
387 |
+
return RoseTree(
|
388 |
+
f(a.value), [cls.fmap(f, child) for child in a.children]
|
389 |
+
)
|
390 |
+
|
391 |
+
def __repr__(self) -> str:
|
392 |
+
return f"Node: {self.value}, Children: {self.children}"
|
393 |
+
|
394 |
+
|
395 |
+
mo.md(RoseTree.__doc__)
|
396 |
+
return (RoseTree,)
|
397 |
+
|
398 |
+
|
399 |
+
@app.cell
|
400 |
+
def _(RoseTree, pp):
|
401 |
+
rosetree = RoseTree(1, [RoseTree(2, []), RoseTree(3, [RoseTree(4, [])])])
|
402 |
+
|
403 |
+
pp(rosetree)
|
404 |
+
pp(RoseTree.fmap(lambda x: [x], rosetree))
|
405 |
+
pp(RoseTree.fmap(lambda x: RoseTree(x, []), rosetree))
|
406 |
+
return (rosetree,)
|
407 |
+
|
408 |
+
|
409 |
+
@app.cell(hide_code=True)
|
410 |
+
def _(mo):
|
411 |
+
mo.md(
|
412 |
+
"""
|
413 |
+
## Generic Functions that can be Used with Any Functor
|
414 |
+
|
415 |
+
One of the powerful features of functors is that we can write **generic functions** that can work with any functor.
|
416 |
+
|
417 |
+
Remember that in Haskell, the type of `fmap` can be written as:
|
418 |
+
|
419 |
+
```haskell
|
420 |
+
fmap :: Functor f => (a -> b) -> (f a -> f b)
|
421 |
+
```
|
422 |
+
|
423 |
+
Translating to Python, we get:
|
424 |
+
|
425 |
+
```python
|
426 |
+
def fmap(func: Callable[[A], B]) -> Callable[[Functor[A]], Functor[B]]
|
427 |
+
```
|
428 |
+
|
429 |
+
This means that `fmap`:
|
430 |
+
|
431 |
+
- Takes an **ordinary function** `Callable[[A], B]` as input.
|
432 |
+
- Outputs a function that:
|
433 |
+
- Takes a **functor** of type `Functor[A]` as input.
|
434 |
+
- Outputs a **functor** of type `Functor[B]`.
|
435 |
+
|
436 |
+
We can implement a similar idea in Python:
|
437 |
+
|
438 |
+
```python
|
439 |
+
fmap = lambda f, functor: functor.__class__.fmap(f, functor)
|
440 |
+
inc = lambda functor: fmap(lambda x: x + 1, functor)
|
441 |
+
```
|
442 |
+
|
443 |
+
- **`fmap`**: Lifts an ordinary function (`f`) to the functor world, allowing the function to operate on the wrapped value inside the functor.
|
444 |
+
- **`inc`**: A specific instance of `fmap` that operates on any functor. It takes a functor, applies the function `lambda x: x + 1` to every value inside it, and returns a new functor with the updated values.
|
445 |
+
|
446 |
+
Thus, **`fmap`** transforms an ordinary function into a **function that operates on functors**, and **`inc`** is a specific case where it increments the value inside the functor.
|
447 |
+
|
448 |
+
### Applying the `inc` Function to Various Functors
|
449 |
+
|
450 |
+
You can now apply `inc` to any functor like `Wrapper`, `List`, or `RoseTree`:
|
451 |
+
|
452 |
+
```python
|
453 |
+
# Applying `inc` to a Wrapper
|
454 |
+
wrapper = Wrapper(5)
|
455 |
+
inc(wrapper) # Wrapper(value=6)
|
456 |
+
|
457 |
+
# Applying `inc` to a List
|
458 |
+
list_wrapper = List([1, 2, 3])
|
459 |
+
inc(list_wrapper) # List(value=[2, 3, 4])
|
460 |
+
|
461 |
+
# Applying `inc` to a RoseTree
|
462 |
+
tree = RoseTree(1, [RoseTree(2, []), RoseTree(3, [])])
|
463 |
+
inc(tree) # RoseTree(value=2, children=[RoseTree(value=3, children=[]), RoseTree(value=4, children=[])])
|
464 |
+
```
|
465 |
+
|
466 |
+
> Try using `fmap` in the cell below.
|
467 |
+
"""
|
468 |
+
)
|
469 |
+
return
|
470 |
+
|
471 |
+
|
472 |
+
@app.cell
|
473 |
+
def _(flist, pp, rosetree, wrapper):
|
474 |
+
fmap = lambda f, functor: functor.__class__.fmap(f, functor)
|
475 |
+
inc = lambda functor: fmap(lambda x: x + 1, functor)
|
476 |
+
|
477 |
+
pp(inc(wrapper))
|
478 |
+
pp(inc(flist))
|
479 |
+
pp(inc(rosetree))
|
480 |
+
return fmap, inc
|
481 |
+
|
482 |
+
|
483 |
+
@app.cell(hide_code=True)
|
484 |
+
def _(mo):
|
485 |
+
mo.md(
|
486 |
+
"""
|
487 |
+
## Functor laws
|
488 |
+
|
489 |
+
In addition to providing a function `fmap` of the specified type, functors are also required to satisfy two equational laws:
|
490 |
+
|
491 |
+
```haskell
|
492 |
+
fmap id = id -- fmap preserves identity
|
493 |
+
fmap (g . h) = fmap g . fmap h -- fmap distributes over composition
|
494 |
+
```
|
495 |
+
|
496 |
+
1. `fmap` should preserve the **identity function**, in the sense that applying `fmap` to this function returns the same function as the result.
|
497 |
+
2. `fmap` should also preserve **function composition**. Applying two composed functions `g` and `h` to a functor via `fmap` should give the same result as first applying `fmap` to `g` and then applying `fmap` to `h`.
|
498 |
+
|
499 |
+
/// admonition |
|
500 |
+
- Any `Functor` instance satisfying the first law `(fmap id = id)` will automatically satisfy the [second law](https://github.com/quchen/articles/blob/master/second_functor_law.mo) as well.
|
501 |
+
///
|
502 |
+
|
503 |
+
### Functor Law Verification
|
504 |
+
|
505 |
+
We can define `id` and `compose` in `Python` as below:
|
506 |
+
|
507 |
+
```python
|
508 |
+
id = lambda x: x
|
509 |
+
compose = lambda f, g: lambda x: f(g(x))
|
510 |
+
```
|
511 |
+
|
512 |
+
We can add a helper function `check_functor_law` to verify that an instance satisfies the functor laws.
|
513 |
+
|
514 |
+
```Python
|
515 |
+
check_functor_law = lambda functor: repr(fmap(id, functor)) == repr(functor)
|
516 |
+
```
|
517 |
+
|
518 |
+
We can verify the functor we've defined.
|
519 |
+
"""
|
520 |
+
)
|
521 |
+
return
|
522 |
+
|
523 |
+
|
524 |
+
@app.cell
|
525 |
+
def _():
|
526 |
+
id = lambda x: x
|
527 |
+
compose = lambda f, g: lambda x: f(g(x))
|
528 |
+
return compose, id
|
529 |
+
|
530 |
+
|
531 |
+
@app.cell
|
532 |
+
def _(fmap, id):
|
533 |
+
check_functor_law = lambda functor: repr(fmap(id, functor)) == repr(functor)
|
534 |
+
return (check_functor_law,)
|
535 |
+
|
536 |
+
|
537 |
+
@app.cell
|
538 |
+
def _(check_functor_law, flist, pp, rosetree, wrapper):
|
539 |
+
for functor in (wrapper, flist, rosetree):
|
540 |
+
pp(check_functor_law(functor))
|
541 |
+
return (functor,)
|
542 |
+
|
543 |
+
|
544 |
+
@app.cell(hide_code=True)
|
545 |
+
def _(mo):
|
546 |
+
mo.md(
|
547 |
+
"""
|
548 |
+
And here is an `EvilFunctor`. We can verify it's not a valid `Functor`.
|
549 |
+
|
550 |
+
```python
|
551 |
+
@dataclass
|
552 |
+
class EvilFunctor(Functor, Generic[A]):
|
553 |
+
value: list[A]
|
554 |
+
|
555 |
+
@classmethod
|
556 |
+
def fmap(cls, f: Callable[[A], B], a: "EvilFunctor[A]") -> "EvilFunctor[B]":
|
557 |
+
return (
|
558 |
+
cls([a.value[0]] * 2 + list(map(f, a.value[1:])))
|
559 |
+
if a.value
|
560 |
+
else []
|
561 |
+
)
|
562 |
+
```
|
563 |
+
"""
|
564 |
+
)
|
565 |
+
return
|
566 |
+
|
567 |
+
|
568 |
+
@app.cell
|
569 |
+
def _(A, B, Callable, Functor, Generic, check_functor_law, dataclass, pp):
|
570 |
+
@dataclass
|
571 |
+
class EvilFunctor(Functor, Generic[A]):
|
572 |
+
value: list[A]
|
573 |
+
|
574 |
+
@classmethod
|
575 |
+
def fmap(
|
576 |
+
cls, f: Callable[[A], B], a: "EvilFunctor[A]"
|
577 |
+
) -> "EvilFunctor[B]":
|
578 |
+
return (
|
579 |
+
cls([a.value[0]] * 2 + [f(x) for x in a.value[1:]])
|
580 |
+
if a.value
|
581 |
+
else []
|
582 |
+
)
|
583 |
+
|
584 |
+
|
585 |
+
pp(check_functor_law(EvilFunctor([1, 2, 3, 4])))
|
586 |
+
return (EvilFunctor,)
|
587 |
+
|
588 |
+
|
589 |
+
@app.cell(hide_code=True)
|
590 |
+
def _(mo):
|
591 |
+
mo.md(
|
592 |
+
"""
|
593 |
+
## Final definition of Functor
|
594 |
+
|
595 |
+
We can now draft the final definition of `Functor` with some utility functions.
|
596 |
+
|
597 |
+
```Python
|
598 |
+
@classmethod
|
599 |
+
@abstractmethod
|
600 |
+
def fmap(cls, f: Callable[[A], B], a: "Functor[A]") -> "Functor[B]":
|
601 |
+
return NotImplementedError
|
602 |
+
|
603 |
+
@classmethod
|
604 |
+
def const_fmap(cls, a: "Functor[A]", b: B) -> "Functor[B]":
|
605 |
+
return cls.fmap(lambda _: b, a)
|
606 |
+
|
607 |
+
@classmethod
|
608 |
+
def void(cls, a: "Functor[A]") -> "Functor[None]":
|
609 |
+
return cls.const_fmap(a, None)
|
610 |
+
```
|
611 |
+
"""
|
612 |
+
)
|
613 |
+
return
|
614 |
+
|
615 |
+
|
616 |
+
@app.cell(hide_code=True)
|
617 |
+
def _(A, ABC, B, Callable, Generic, abstractmethod, dataclass, mo):
|
618 |
+
@dataclass
|
619 |
+
class Functor(ABC, Generic[A]):
|
620 |
+
"""
|
621 |
+
### Doc: Functor
|
622 |
+
|
623 |
+
A generic interface for types that support mapping over their values.
|
624 |
+
|
625 |
+
**Methods:**
|
626 |
+
|
627 |
+
- `fmap(f: Callable[[A], B], a: Functor[A]) -> Functor[B]`
|
628 |
+
Abstract method to apply a function to all values inside a functor.
|
629 |
+
|
630 |
+
- `const_fmap(a: "Functor[A]", b: B) -> Functor[B]`
|
631 |
+
Replaces all values inside a functor with a constant `b`, preserving the original structure.
|
632 |
+
|
633 |
+
- `void(a: "Functor[A]") -> Functor[None]`
|
634 |
+
Equivalent to `const_fmap(a, None)`, transforming all values in a functor into `None`.
|
635 |
+
"""
|
636 |
+
|
637 |
+
@classmethod
|
638 |
+
@abstractmethod
|
639 |
+
def fmap(cls, f: Callable[[A], B], a: "Functor[A]") -> "Functor[B]":
|
640 |
+
return NotImplementedError
|
641 |
+
|
642 |
+
@classmethod
|
643 |
+
def const_fmap(cls, a: "Functor[A]", b: B) -> "Functor[B]":
|
644 |
+
return cls.fmap(lambda _: b, a)
|
645 |
+
|
646 |
+
@classmethod
|
647 |
+
def void(cls, a: "Functor[A]") -> "Functor[None]":
|
648 |
+
return cls.const_fmap(a, None)
|
649 |
+
|
650 |
+
|
651 |
+
mo.md(Functor.__doc__)
|
652 |
+
return (Functor,)
|
653 |
+
|
654 |
+
|
655 |
+
@app.cell(hide_code=True)
|
656 |
+
def _(mo):
|
657 |
+
mo.md("""> Try with utility functions in the cell below""")
|
658 |
+
return
|
659 |
+
|
660 |
+
|
661 |
+
@app.cell
|
662 |
+
def _(List, RoseTree, flist, pp, rosetree):
|
663 |
+
pp(RoseTree.const_fmap(rosetree, "λ"))
|
664 |
+
pp(RoseTree.void(rosetree))
|
665 |
+
pp(List.const_fmap(flist, "λ"))
|
666 |
+
pp(List.void(flist))
|
667 |
+
return
|
668 |
+
|
669 |
+
|
670 |
+
@app.cell(hide_code=True)
|
671 |
+
def _(mo):
|
672 |
+
mo.md(
|
673 |
+
"""
|
674 |
+
## Functors for Non-Iterable Types
|
675 |
+
|
676 |
+
In the previous examples, we implemented functors for **iterables**, like `List` and `RoseTree`, which are inherently **iterable types**. This is a natural fit for functors, as iterables can be mapped over.
|
677 |
+
|
678 |
+
However, **functors are not limited to iterables**. There are cases where we want to apply the concept of functors to types that are not inherently iterable, such as types that represent optional values, computations, or other data structures.
|
679 |
+
|
680 |
+
### The Maybe Functor
|
681 |
+
|
682 |
+
One example is the **`Maybe`** type from Haskell, which is used to represent computations that can either result in a value or no value (`Nothing`).
|
683 |
+
|
684 |
+
We can define the `Maybe` functor as below:
|
685 |
+
|
686 |
+
```python
|
687 |
+
@dataclass
|
688 |
+
class Maybe(Functor, Generic[A]):
|
689 |
+
value: None | A
|
690 |
+
|
691 |
+
@classmethod
|
692 |
+
def fmap(cls, f: Callable[[A], B], a: "Maybe[A]") -> "Maybe[B]":
|
693 |
+
return (
|
694 |
+
cls(None) if a.value is None else cls(f(a.value))
|
695 |
+
)
|
696 |
+
|
697 |
+
def __repr__(self):
|
698 |
+
return "Nothing" if self.value is None else repr(self.value)
|
699 |
+
```
|
700 |
+
"""
|
701 |
+
)
|
702 |
+
return
|
703 |
+
|
704 |
+
|
705 |
+
@app.cell
|
706 |
+
def _(A, B, Callable, Functor, Generic, dataclass):
|
707 |
+
@dataclass
|
708 |
+
class Maybe(Functor, Generic[A]):
|
709 |
+
value: None | A
|
710 |
+
|
711 |
+
@classmethod
|
712 |
+
def fmap(cls, f: Callable[[A], B], a: "Maybe[A]") -> "Maybe[B]":
|
713 |
+
return cls(None) if a.value is None else cls(f(a.value))
|
714 |
+
|
715 |
+
def __repr__(self):
|
716 |
+
return "Nothing" if self.value is None else repr(self.value)
|
717 |
+
return (Maybe,)
|
718 |
+
|
719 |
+
|
720 |
+
@app.cell(hide_code=True)
|
721 |
+
def _(mo):
|
722 |
+
mo.md(
|
723 |
+
"""
|
724 |
+
**`Maybe`** is a functor that can either hold a value or be `Nothing` (equivalent to `None` in Python). The `fmap` method applies a function to the value inside the functor, if it exists. If the value is `None` (representing `Nothing`), `fmap` simply returns `None`.
|
725 |
+
|
726 |
+
By using `Maybe` as a functor, we gain the ability to apply transformations (`fmap`) to potentially absent values, without having to explicitly handle the `None` case every time.
|
727 |
+
|
728 |
+
> Try using `Maybe` in the cell below.
|
729 |
+
"""
|
730 |
+
)
|
731 |
+
return
|
732 |
+
|
733 |
+
|
734 |
+
@app.cell
|
735 |
+
def _(Maybe, pp):
|
736 |
+
mint = Maybe(1)
|
737 |
+
mnone = Maybe(None)
|
738 |
+
|
739 |
+
pp(Maybe.fmap(lambda x: x + 1, mint))
|
740 |
+
pp(Maybe.fmap(lambda x: x + 1, mnone))
|
741 |
+
return mint, mnone
|
742 |
+
|
743 |
+
|
744 |
+
@app.cell(hide_code=True)
|
745 |
+
def _(mo):
|
746 |
+
mo.md(
|
747 |
+
"""
|
748 |
+
## Limitations of Functor
|
749 |
+
|
750 |
+
Functors abstract the idea of mapping a function over each element of a structure. Suppose now that we wish to generalise this idea to allow functions with any number of arguments to be mapped, rather than being restricted to functions with a single argument. More precisely, suppose that we wish to define a hierarchy of `fmap` functions with the following types:
|
751 |
+
|
752 |
+
```haskell
|
753 |
+
fmap0 :: a -> f a
|
754 |
+
|
755 |
+
fmap1 :: (a -> b) -> f a -> f b
|
756 |
+
|
757 |
+
fmap2 :: (a -> b -> c) -> f a -> f b -> f c
|
758 |
+
|
759 |
+
fmap3 :: (a -> b -> c -> d) -> f a -> f b -> f c -> f d
|
760 |
+
```
|
761 |
+
|
762 |
+
And we have to declare a special version of the functor class for each case.
|
763 |
+
|
764 |
+
We will learn how to resolve this problem in the next notebook on `Applicatives`.
|
765 |
+
"""
|
766 |
+
)
|
767 |
+
return
|
768 |
+
|
769 |
+
|
770 |
+
@app.cell(hide_code=True)
|
771 |
+
def _(mo):
|
772 |
+
mo.md(
|
773 |
+
"""
|
774 |
+
# Introduction to Categories
|
775 |
+
|
776 |
+
A [category](https://en.wikibooks.org/wiki/Haskell/Category_theory#Introduction_to_categories) is, in essence, a simple collection. It has three components:
|
777 |
+
|
778 |
+
- A collection of **objects**.
|
779 |
+
- A collection of **morphisms**, each of which ties two objects (a _source object_ and a _target object_) together. If $f$ is a morphism with source object $C$ and target object $B$, we write $f : C → B$.
|
780 |
+
- A notion of **composition** of these morphisms. If $g : A → B$ and $f : B → C$ are two morphisms, they can be composed, resulting in a morphism $f ∘ g : A → C$.
|
781 |
+
|
782 |
+
## Category laws
|
783 |
+
|
784 |
+
There are three laws that categories need to follow.
|
785 |
+
|
786 |
+
1. The composition of morphisms needs to be **associative**. Symbolically, $f ∘ (g ∘ h) = (f ∘ g) ∘ h$
|
787 |
+
|
788 |
+
- Morphisms are applied right to left, so with $f ∘ g$ first $g$ is applied, then $f$.
|
789 |
+
|
790 |
+
2. The category needs to be **closed** under the composition operation. So if $f : B → C$ and $g : A → B$, then there must be some morphism $h : A → C$ in the category such that $h = f ∘ g$.
|
791 |
+
|
792 |
+
3. Given a category $C$ there needs to be for every object $A$ an **identity** morphism, $id_A : A → A$ that is an identity of composition with other morphisms. Put precisely, for every morphism $g : A → B$: $g ∘ id_A = id_B ∘ g = g$
|
793 |
+
|
794 |
+
/// attention | The definition of a category does not define:
|
795 |
+
|
796 |
+
- what `∘` is,
|
797 |
+
- what `id` is, or
|
798 |
+
- what `f`, `g`, and `h` might be.
|
799 |
+
|
800 |
+
Instead, category theory leaves it up to us to discover what they might be.
|
801 |
+
///
|
802 |
+
"""
|
803 |
+
)
|
804 |
+
return
|
805 |
+
|
806 |
+
|
807 |
+
@app.cell(hide_code=True)
|
808 |
+
def _(mo):
|
809 |
+
mo.md(
|
810 |
+
"""
|
811 |
+
## The Python category
|
812 |
+
|
813 |
+
The main category we'll be concerning ourselves with in this part is the Python category, or we can give it a shorter name: `Py`. `Py` treats Python types as objects and Python functions as morphisms. A function `def f(a: A) -> B` for types A and B is a morphism in Python.
|
814 |
+
|
815 |
+
Remember that we defined the `id` and `compose` function above as:
|
816 |
+
|
817 |
+
```Python
|
818 |
+
def id(x: Generic[A]) -> Generic[A]:
|
819 |
+
return x
|
820 |
+
|
821 |
+
def compose(f: Callable[[B], C], g: Callable[[A], B]) -> Callable[[A], C]:
|
822 |
+
return lambda x: f(g(x))
|
823 |
+
```
|
824 |
+
|
825 |
+
We can check second law easily.
|
826 |
+
|
827 |
+
For the first law, we have:
|
828 |
+
|
829 |
+
```python
|
830 |
+
# compose(f, g) = lambda x: f(g(x))
|
831 |
+
f ∘ (g ∘ h)
|
832 |
+
= compose(f, compose(g, h))
|
833 |
+
= lambda x: f(compose(g, h)(x))
|
834 |
+
= lambda x: f(lambda y: g(h(y))(x))
|
835 |
+
= lambda x: f(g(h(x)))
|
836 |
+
|
837 |
+
(f ∘ g) ∘ h
|
838 |
+
= compose(compose(f, g), h)
|
839 |
+
= lambda x: compose(f, g)(h(x))
|
840 |
+
= lambda x: lambda y: f(g(y))(h(x))
|
841 |
+
= lambda x: f(g(h(x)))
|
842 |
+
```
|
843 |
+
|
844 |
+
For the third law, we have:
|
845 |
+
|
846 |
+
```python
|
847 |
+
g ∘ id_A
|
848 |
+
= compose(g: Callable[[a], b], id: Callable[[a], a]) -> Callable[[a], b]
|
849 |
+
= lambda x: g(id(x))
|
850 |
+
= lambda x: g(x) # id(x) = x
|
851 |
+
= g
|
852 |
+
```
|
853 |
+
the similar proof can be applied to $id_B ∘ g =g$.
|
854 |
+
|
855 |
+
Thus `Py` is a valid category.
|
856 |
+
"""
|
857 |
+
)
|
858 |
+
return
|
859 |
+
|
860 |
+
|
861 |
+
@app.cell(hide_code=True)
|
862 |
+
def _(mo):
|
863 |
+
mo.md(
|
864 |
+
"""
|
865 |
+
# Functors, again
|
866 |
+
|
867 |
+
A functor is essentially a transformation between categories, so given categories $C$ and $D$, a functor $F : C → D$:
|
868 |
+
|
869 |
+
- Maps any object $A$ in $C$ to $F ( A )$, in $D$.
|
870 |
+
- Maps morphisms $f : A → B$ in $C$ to $F ( f ) : F ( A ) → F ( B )$ in $D$.
|
871 |
+
|
872 |
+
/// admonition |
|
873 |
+
|
874 |
+
Endofunctors are functors from a category to itself.
|
875 |
+
|
876 |
+
///
|
877 |
+
"""
|
878 |
+
)
|
879 |
+
return
|
880 |
+
|
881 |
+
|
882 |
+
@app.cell(hide_code=True)
|
883 |
+
def _(mo):
|
884 |
+
mo.md(
|
885 |
+
"""
|
886 |
+
## Functors on the category of Python
|
887 |
+
|
888 |
+
Remember that a functor has two parts: it maps objects in one category to objects in another and morphisms in the first category to morphisms in the second.
|
889 |
+
|
890 |
+
Functors in Python are from `Py` to `func`, where `func` is the subcategory of `Py` defined on just that functor's types. E.g. the RoseTree functor goes from `Py` to `RoseTree`, where `RoseTree` is the category containing only RoseTree types, that is, `RoseTree[T]` for any type `T`. The morphisms in `RoseTree` are functions defined on RoseTree types, that is, functions `Callable[[RoseTree[T]], RoseTree[U]]` for types `T`, `U`.
|
891 |
+
|
892 |
+
Recall the definition of `Functor`:
|
893 |
+
|
894 |
+
```Python
|
895 |
+
@dataclass
|
896 |
+
class Functor(ABC, Generic[A])
|
897 |
+
```
|
898 |
+
|
899 |
+
And RoseTree:
|
900 |
+
|
901 |
+
```Python
|
902 |
+
@dataclass
|
903 |
+
class RoseTree(Functor, Generic[A])
|
904 |
+
```
|
905 |
+
|
906 |
+
**Here's the key part:** the _type constructor_ `RoseTree` takes any type `T` to a new type, `RoseTree[T]`. Also, `fmap` restricted to `RoseTree` types takes a function `Callable[[A], B]` to a function `Callable[[RoseTree[A]], RoseTree[B]]`.
|
907 |
+
|
908 |
+
But that's it. We've defined two parts, something that takes objects in `Py` to objects in another category (that of `RoseTree` types and functions defined on `RoseTree` types), and something that takes morphisms in `Py` to morphisms in this category. So `RoseTree` is a functor.
|
909 |
+
|
910 |
+
To sum up:
|
911 |
+
|
912 |
+
- We work in the category **Py** and its subcategories.
|
913 |
+
- **Objects** are types (e.g., `int`, `str`, `list`).
|
914 |
+
- **Morphisms** are functions (`Callable[[A], B]`).
|
915 |
+
- **Things that take a type and return another type** are type constructors (`RoseTree[T]`).
|
916 |
+
- **Things that take a function and return another function** are higher-order functions (`Callable[[Callable[[A], B]], Callable[[C], D]]`).
|
917 |
+
- **Abstract base classes (ABC)** and duck typing provide a way to express polymorphism, capturing the idea that in category theory, structures are often defined over multiple objects at once.
|
918 |
+
"""
|
919 |
+
)
|
920 |
+
return
|
921 |
+
|
922 |
+
|
923 |
+
@app.cell(hide_code=True)
|
924 |
+
def _(mo):
|
925 |
+
mo.md(
|
926 |
+
"""
|
927 |
+
## Functor laws, again
|
928 |
+
|
929 |
+
Once again there are a few axioms that functors have to obey.
|
930 |
+
|
931 |
+
1. Given an identity morphism $id_A$ on an object $A$, $F ( id_A )$ must be the identity morphism on $F ( A )$, i.e.: ${\displaystyle F(\operatorname {id} _{A})=\operatorname {id} _{F(A)}}$
|
932 |
+
2. Functors must distribute over morphism composition, i.e. ${\displaystyle F(f\circ g)=F(f)\circ F(g)}$
|
933 |
+
"""
|
934 |
+
)
|
935 |
+
return
|
936 |
+
|
937 |
+
|
938 |
+
@app.cell(hide_code=True)
|
939 |
+
def _(mo):
|
940 |
+
mo.md(
|
941 |
+
"""
|
942 |
+
Remember that we defined the `fmap`, `id` and `compose` as
|
943 |
+
```python
|
944 |
+
fmap = lambda f, functor: functor.__class__.fmap(f, functor)
|
945 |
+
id = lambda x: x
|
946 |
+
compose = lambda f, g: lambda x: f(g(x))
|
947 |
+
```
|
948 |
+
|
949 |
+
Let's prove that `fmap` is a functor.
|
950 |
+
|
951 |
+
First, let's define a `Category` for a specific `Functor`. We choose to define the `Category` for the `Wrapper` as `WrapperCategory` here for simplicity, but remember that `Wrapper` can be any `Functor`(i.e. `List`, `RoseTree`, `Maybe` and more):
|
952 |
+
|
953 |
+
**Notice that** in this case, we can actually view `fmap` as:
|
954 |
+
```python
|
955 |
+
fmap = lambda f, functor: functor.fmap(f, functor)
|
956 |
+
```
|
957 |
+
|
958 |
+
We define `WrapperCategory` as:
|
959 |
+
|
960 |
+
```python
|
961 |
+
@dataclass
|
962 |
+
class WrapperCategory:
|
963 |
+
@staticmethod
|
964 |
+
def id(wrapper: Wrapper[A]) -> Wrapper[A]:
|
965 |
+
return Wrapper(wrapper.value)
|
966 |
+
|
967 |
+
@staticmethod
|
968 |
+
def compose(
|
969 |
+
f: Callable[[Wrapper[B]], Wrapper[C]],
|
970 |
+
g: Callable[[Wrapper[A]], Wrapper[B]],
|
971 |
+
wrapper: Wrapper[A]
|
972 |
+
) -> Callable[[Wrapper[A]], Wrapper[C]]:
|
973 |
+
return f(g(Wrapper(wrapper.value)))
|
974 |
+
```
|
975 |
+
|
976 |
+
And `Wrapper` is:
|
977 |
+
|
978 |
+
```Python
|
979 |
+
@dataclass
|
980 |
+
class Wrapper(Functor, Generic[A]):
|
981 |
+
value: A
|
982 |
+
|
983 |
+
@classmethod
|
984 |
+
def fmap(cls, f: Callable[[A], B], a: "Wrapper[A]") -> "Wrapper[B]":
|
985 |
+
return Wrapper(f(a.value))
|
986 |
+
```
|
987 |
+
"""
|
988 |
+
)
|
989 |
+
return
|
990 |
+
|
991 |
+
|
992 |
+
@app.cell(hide_code=True)
|
993 |
+
def _(mo):
|
994 |
+
mo.md(
|
995 |
+
"""
|
996 |
+
We can prove that:
|
997 |
+
|
998 |
+
```python
|
999 |
+
fmap(id, wrapper)
|
1000 |
+
= Wrapper.fmap(id, wrapper)
|
1001 |
+
= Wrapper(id(wrapper.value))
|
1002 |
+
= Wrapper(wrapper.value)
|
1003 |
+
= WrapperCategory.id(wrapper)
|
1004 |
+
```
|
1005 |
+
and:
|
1006 |
+
```python
|
1007 |
+
fmap(compose(f, g), wrapper)
|
1008 |
+
= Wrapper.fmap(compose(f, g), wrapper)
|
1009 |
+
= Wrapper(compose(f, g)(wrapper.value))
|
1010 |
+
= Wrapper(f(g(wrapper.value)))
|
1011 |
+
|
1012 |
+
WrapperCategory.compose(fmap(f, wrapper), fmap(g, wrapper), wrapper)
|
1013 |
+
= fmap(f, wrapper)(fmap(g, wrapper)(wrapper))
|
1014 |
+
= fmap(f, wrapper)(Wrapper.fmap(g, wrapper))
|
1015 |
+
= fmap(f, wrapper)(Wrapper(g(wrapper.value)))
|
1016 |
+
= Wrapper.fmap(f, Wrapper(g(wrapper.value)))
|
1017 |
+
= Wrapper(f(Wrapper(g(wrapper.value)).value))
|
1018 |
+
= Wrapper(f(g(wrapper.value))) # Wrapper(g(wrapper.value)).value = g(wrapper.value)
|
1019 |
+
```
|
1020 |
+
|
1021 |
+
So our `Wrapper` is a valid `Functor`.
|
1022 |
+
|
1023 |
+
> Try validating functor laws for `Wrapper` below.
|
1024 |
+
"""
|
1025 |
+
)
|
1026 |
+
return
|
1027 |
+
|
1028 |
+
|
1029 |
+
@app.cell
|
1030 |
+
def _(A, B, C, Callable, Wrapper, dataclass):
|
1031 |
+
@dataclass
|
1032 |
+
class WrapperCategory:
|
1033 |
+
@staticmethod
|
1034 |
+
def id(wrapper: Wrapper[A]) -> Wrapper[A]:
|
1035 |
+
return Wrapper(wrapper.value)
|
1036 |
+
|
1037 |
+
@staticmethod
|
1038 |
+
def compose(
|
1039 |
+
f: Callable[[Wrapper[B]], Wrapper[C]],
|
1040 |
+
g: Callable[[Wrapper[A]], Wrapper[B]],
|
1041 |
+
wrapper: Wrapper[A],
|
1042 |
+
) -> Callable[[Wrapper[A]], Wrapper[C]]:
|
1043 |
+
return f(g(Wrapper(wrapper.value)))
|
1044 |
+
return (WrapperCategory,)
|
1045 |
+
|
1046 |
+
|
1047 |
+
@app.cell
|
1048 |
+
def _(WrapperCategory, fmap, id, pp, wrapper):
|
1049 |
+
pp(fmap(id, wrapper) == WrapperCategory.id(wrapper))
|
1050 |
+
return
|
1051 |
+
|
1052 |
+
|
1053 |
+
@app.cell(hide_code=True)
|
1054 |
+
def _(mo):
|
1055 |
+
mo.md(
|
1056 |
+
"""
|
1057 |
+
## Length as a Functor
|
1058 |
+
|
1059 |
+
Remember that a functor is a transformation between two categories. It is not only limited to a functor from `Py` to `func`, but also includes transformations between other mathematical structures.
|
1060 |
+
|
1061 |
+
Let’s prove that **`length`** can be viewed as a functor. Specifically, we will demonstrate that `length` is a functor from the **category of list concatenation** to the **category of integer addition**.
|
1062 |
+
|
1063 |
+
### Category of List Concatenation
|
1064 |
+
|
1065 |
+
First, let’s define the category of list concatenation:
|
1066 |
+
|
1067 |
+
```python
|
1068 |
+
@dataclass
|
1069 |
+
class ListConcatenation(Generic[A]):
|
1070 |
+
value: list[A]
|
1071 |
+
|
1072 |
+
@staticmethod
|
1073 |
+
def id() -> "ListConcatenation[A]":
|
1074 |
+
return ListConcatenation([])
|
1075 |
+
|
1076 |
+
@staticmethod
|
1077 |
+
def compose(
|
1078 |
+
this: "ListConcatenation[A]", other: "ListConcatenation[A]"
|
1079 |
+
) -> "ListConcatenation[a]":
|
1080 |
+
return ListConcatenation(this.value + other.value)
|
1081 |
+
```
|
1082 |
+
"""
|
1083 |
+
)
|
1084 |
+
return
|
1085 |
+
|
1086 |
+
|
1087 |
+
@app.cell
|
1088 |
+
def _(A, Generic, dataclass):
|
1089 |
+
@dataclass
|
1090 |
+
class ListConcatenation(Generic[A]):
|
1091 |
+
value: list[A]
|
1092 |
+
|
1093 |
+
@staticmethod
|
1094 |
+
def id() -> "ListConcatenation[A]":
|
1095 |
+
return ListConcatenation([])
|
1096 |
+
|
1097 |
+
@staticmethod
|
1098 |
+
def compose(
|
1099 |
+
this: "ListConcatenation[A]", other: "ListConcatenation[A]"
|
1100 |
+
) -> "ListConcatenation[a]":
|
1101 |
+
return ListConcatenation(this.value + other.value)
|
1102 |
+
return (ListConcatenation,)
|
1103 |
+
|
1104 |
+
|
1105 |
+
@app.cell(hide_code=True)
|
1106 |
+
def _(mo):
|
1107 |
+
mo.md(
|
1108 |
+
"""
|
1109 |
+
- **Identity**: The identity element is an empty list (`ListConcatenation([])`).
|
1110 |
+
- **Composition**: The composition of two lists is their concatenation (`this.value + other.value`).
|
1111 |
+
"""
|
1112 |
+
)
|
1113 |
+
return
|
1114 |
+
|
1115 |
+
|
1116 |
+
@app.cell(hide_code=True)
|
1117 |
+
def _(mo):
|
1118 |
+
mo.md(
|
1119 |
+
"""
|
1120 |
+
### Category of Integer Addition
|
1121 |
+
|
1122 |
+
Now, let's define the category of integer addition:
|
1123 |
+
|
1124 |
+
```python
|
1125 |
+
@dataclass
|
1126 |
+
class IntAddition:
|
1127 |
+
value: int
|
1128 |
+
|
1129 |
+
@staticmethod
|
1130 |
+
def id() -> "IntAddition":
|
1131 |
+
return IntAddition(0)
|
1132 |
+
|
1133 |
+
@staticmethod
|
1134 |
+
def compose(this: "IntAddition", other: "IntAddition") -> "IntAddition":
|
1135 |
+
return IntAddition(this.value + other.value)
|
1136 |
+
```
|
1137 |
+
"""
|
1138 |
+
)
|
1139 |
+
return
|
1140 |
+
|
1141 |
+
|
1142 |
+
@app.cell
|
1143 |
+
def _(dataclass):
|
1144 |
+
@dataclass
|
1145 |
+
class IntAddition:
|
1146 |
+
value: int
|
1147 |
+
|
1148 |
+
@staticmethod
|
1149 |
+
def id() -> "IntAddition":
|
1150 |
+
return IntAddition(0)
|
1151 |
+
|
1152 |
+
@staticmethod
|
1153 |
+
def compose(this: "IntAddition", other: "IntAddition") -> "IntAddition":
|
1154 |
+
return IntAddition(this.value + other.value)
|
1155 |
+
return (IntAddition,)
|
1156 |
+
|
1157 |
+
|
1158 |
+
@app.cell(hide_code=True)
|
1159 |
+
def _(mo):
|
1160 |
+
mo.md(
|
1161 |
+
"""
|
1162 |
+
- **Identity**: The identity element is `IntAddition(0)` (the additive identity).
|
1163 |
+
- **Composition**: The composition of two integers is their sum (`this.value + other.value`).
|
1164 |
+
"""
|
1165 |
+
)
|
1166 |
+
return
|
1167 |
+
|
1168 |
+
|
1169 |
+
@app.cell(hide_code=True)
|
1170 |
+
def _(mo):
|
1171 |
+
mo.md(
|
1172 |
+
"""
|
1173 |
+
### Defining the Length Functor
|
1174 |
+
|
1175 |
+
We now define the `length` function as a functor, mapping from the category of list concatenation to the category of integer addition:
|
1176 |
+
|
1177 |
+
```python
|
1178 |
+
length = lambda l: IntAddition(len(l.value))
|
1179 |
+
```
|
1180 |
+
"""
|
1181 |
+
)
|
1182 |
+
return
|
1183 |
+
|
1184 |
+
|
1185 |
+
@app.cell(hide_code=True)
|
1186 |
+
def _(IntAddition):
|
1187 |
+
length = lambda l: IntAddition(len(l.value))
|
1188 |
+
return (length,)
|
1189 |
+
|
1190 |
+
|
1191 |
+
@app.cell(hide_code=True)
|
1192 |
+
def _(mo):
|
1193 |
+
mo.md("""This function takes an instance of `ListConcatenation`, computes its length, and returns an `IntAddition` instance with the computed length.""")
|
1194 |
+
return
|
1195 |
+
|
1196 |
+
|
1197 |
+
@app.cell(hide_code=True)
|
1198 |
+
def _(mo):
|
1199 |
+
mo.md(
|
1200 |
+
"""
|
1201 |
+
### Verifying Functor Laws
|
1202 |
+
|
1203 |
+
Now, let’s verify that `length` satisfies the two functor laws.
|
1204 |
+
|
1205 |
+
#### 1. **Identity Law**:
|
1206 |
+
The identity law states that applying the functor to the identity element of one category should give the identity element of the other category.
|
1207 |
+
|
1208 |
+
```python
|
1209 |
+
> length(ListConcatenation.id()) == IntAddition.id()
|
1210 |
+
True
|
1211 |
+
```
|
1212 |
+
"""
|
1213 |
+
)
|
1214 |
+
return
|
1215 |
+
|
1216 |
+
|
1217 |
+
@app.cell(hide_code=True)
|
1218 |
+
def _(mo):
|
1219 |
+
mo.md("""This ensures that the length of an empty list (identity in the `ListConcatenation` category) is `0` (identity in the `IntAddition` category).""")
|
1220 |
+
return
|
1221 |
+
|
1222 |
+
|
1223 |
+
@app.cell(hide_code=True)
|
1224 |
+
def _(mo):
|
1225 |
+
mo.md(
|
1226 |
+
"""
|
1227 |
+
#### 2. **Composition Law**:
|
1228 |
+
The composition law states that the functor should preserve composition. Applying the functor to a composed element should be the same as composing the functor applied to the individual elements.
|
1229 |
+
|
1230 |
+
```python
|
1231 |
+
> lista = ListConcatenation([1, 2])
|
1232 |
+
> listb = ListConcatenation([3, 4])
|
1233 |
+
> length(ListConcatenation.compose(lista, listb)) == IntAddition.compose(
|
1234 |
+
> length(lista), length(listb)
|
1235 |
+
> )
|
1236 |
+
True
|
1237 |
+
```
|
1238 |
+
"""
|
1239 |
+
)
|
1240 |
+
return
|
1241 |
+
|
1242 |
+
|
1243 |
+
@app.cell(hide_code=True)
|
1244 |
+
def _(mo):
|
1245 |
+
mo.md("""This ensures that the length of the concatenation of two lists is the same as the sum of the lengths of the individual lists.""")
|
1246 |
+
return
|
1247 |
+
|
1248 |
+
|
1249 |
+
@app.cell
|
1250 |
+
def _(IntAddition, ListConcatenation, length, pp):
|
1251 |
+
pp(length(ListConcatenation.id()) == IntAddition.id())
|
1252 |
+
lista = ListConcatenation([1, 2])
|
1253 |
+
listb = ListConcatenation([3, 4])
|
1254 |
+
pp(
|
1255 |
+
length(ListConcatenation.compose(lista, listb))
|
1256 |
+
== IntAddition.compose(length(lista), length(listb))
|
1257 |
+
)
|
1258 |
+
return lista, listb
|
1259 |
+
|
1260 |
+
|
1261 |
+
@app.cell(hide_code=True)
|
1262 |
+
def _(mo):
|
1263 |
+
mo.md(
|
1264 |
+
"""
|
1265 |
+
# Further reading
|
1266 |
+
|
1267 |
+
- [The Trivial Monad](http://blog.sigfpe.com/2007/04/trivial-monad.html)
|
1268 |
+
- [Haskellwiki. Category Theory](https://en.wikibooks.org/wiki/Haskell/Category_theory)
|
1269 |
+
- [Haskellforall. The Category Design Pattern](https://www.haskellforall.com/2012/08/the-category-design-pattern.html)
|
1270 |
+
- [Haskellforall. The Functor Design Pattern](https://www.haskellforall.com/2012/09/the-functor-design-pattern.html)
|
1271 |
+
|
1272 |
+
/// attention | ATTENTION
|
1273 |
+
The functor design pattern doesn't work at all if you aren't using categories in the first place. This is why you should structure your tools using the compositional category design pattern so that you can take advantage of functors to easily mix your tools together.
|
1274 |
+
///
|
1275 |
+
|
1276 |
+
- [Haskellwiki. Functor](https://wiki.haskell.org/index.php?title=Functor)
|
1277 |
+
- [Haskellwiki. Typeclassopedia#Functor](https://wiki.haskell.org/index.php?title=Typeclassopedia#Functor)
|
1278 |
+
- [Haskellwiki. Typeclassopedia#Category](https://wiki.haskell.org/index.php?title=Typeclassopedia#Category)
|
1279 |
+
"""
|
1280 |
+
)
|
1281 |
+
return
|
1282 |
+
|
1283 |
+
|
1284 |
+
@app.cell(hide_code=True)
|
1285 |
+
def _():
|
1286 |
+
import marimo as mo
|
1287 |
+
return (mo,)
|
1288 |
+
|
1289 |
+
|
1290 |
+
@app.cell(hide_code=True)
|
1291 |
+
def _():
|
1292 |
+
from abc import abstractmethod, ABC
|
1293 |
+
return ABC, abstractmethod
|
1294 |
+
|
1295 |
+
|
1296 |
+
@app.cell(hide_code=True)
|
1297 |
+
def _():
|
1298 |
+
from dataclasses import dataclass
|
1299 |
+
from typing import Callable, Generic, TypeVar
|
1300 |
+
from pprint import pp
|
1301 |
+
return Callable, Generic, TypeVar, dataclass, pp
|
1302 |
+
|
1303 |
+
|
1304 |
+
@app.cell(hide_code=True)
|
1305 |
+
def _(TypeVar):
|
1306 |
+
A = TypeVar("A")
|
1307 |
+
B = TypeVar("B")
|
1308 |
+
C = TypeVar("C")
|
1309 |
+
return A, B, C
|
1310 |
+
|
1311 |
+
|
1312 |
+
if __name__ == "__main__":
|
1313 |
+
app.run()
|
functional_programming/CHANGELOG.md
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Changelog of the functional-programming course
|
2 |
+
|
3 |
+
## 2025-03-11
|
4 |
+
|
5 |
+
* Demo version of notebook `05_functors.py`
|
6 |
+
|
7 |
+
## 2025-03-13
|
8 |
+
|
9 |
+
* `0.1.0` version of notebook `05_functors`
|
10 |
+
|
11 |
+
Thank [Akshay](https://github.com/akshayka) and [Haleshot](https://github.com/Haleshot) for reviewing
|
12 |
+
|
13 |
+
## 2025-03-16
|
14 |
+
|
15 |
+
+ Use uppercased letters for `Generic` types, e.g. `A = TypeVar("A")`
|
16 |
+
+ Refactor the `Functor` class, changing `fmap` and utility methods to `classmethod`
|
17 |
+
|
18 |
+
For example:
|
19 |
+
|
20 |
+
```python
|
21 |
+
@dataclass
|
22 |
+
class Wrapper(Functor, Generic[A]):
|
23 |
+
value: A
|
24 |
+
|
25 |
+
@classmethod
|
26 |
+
def fmap(cls, f: Callable[[A], B], a: "Wrapper[A]") -> "Wrapper[B]":
|
27 |
+
return Wrapper(f(a.value))
|
28 |
+
|
29 |
+
>>> Wrapper.fmap(lambda x: x + 1, wrapper)
|
30 |
+
Wrapper(value=2)
|
31 |
+
```
|
32 |
+
|
33 |
+
+ Move the `check_functor_law` method from `Functor` class to a standard function
|
34 |
+
- Rename `ListWrapper` to `List` for simplicity
|
35 |
+
- Remove the `Just` class
|
36 |
+
+ Rewrite proofs
|
functional_programming/README.md
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Learn Functional Programming
|
2 |
+
|
3 |
+
_🚧 This collection is a
|
4 |
+
[work in progress](https://github.com/marimo-team/learn/issues/51)._
|
5 |
+
|
6 |
+
This series of marimo notebooks introduces the powerful paradigm of functional
|
7 |
+
programming through Python. Taking inspiration from Haskell and Category Theory,
|
8 |
+
we'll build a strong foundation in FP concepts that can transform how you
|
9 |
+
approach software development.
|
10 |
+
|
11 |
+
## What You'll Learn
|
12 |
+
|
13 |
+
**Using only Python's standard library**, we'll construct functional programming
|
14 |
+
concepts from first principles.
|
15 |
+
|
16 |
+
Topics include:
|
17 |
+
|
18 |
+
- Recursion and higher-order functions
|
19 |
+
- Category theory fundamentals
|
20 |
+
- Functors, applicatives, and monads
|
21 |
+
- Composable abstractions for robust code
|
22 |
+
|
23 |
+
## Timeline & Collaboration
|
24 |
+
|
25 |
+
I'm currently studying functional programming and Haskell, estimating about 2
|
26 |
+
months or even longer to complete this series. The structure may evolve as the
|
27 |
+
project develops.
|
28 |
+
|
29 |
+
If you're interested in collaborating or have questions, please reach out to me
|
30 |
+
on Discord (@eugene.hs).
|
31 |
+
|
32 |
+
**Running notebooks.** To run a notebook locally, use
|
33 |
+
|
34 |
+
```bash
|
35 |
+
uvx marimo edit <URL>
|
36 |
+
```
|
37 |
+
|
38 |
+
For example, run the `Functor` tutorial with
|
39 |
+
|
40 |
+
```bash
|
41 |
+
uvx marimo edit https://github.com/marimo-team/learn/blob/main/Functional_programming/05_functors.py
|
42 |
+
```
|
43 |
+
|
44 |
+
You can also open notebooks in our online playground by appending `marimo.app/`
|
45 |
+
to a notebook's URL:
|
46 |
+
[marimo.app/github.com/marimo-team/learn/blob/main/functional_programming/05_functors.py](https://marimo.app/https://github.com/marimo-team/learn/blob/main/functional_programming/05_functors.py).
|
47 |
+
|
48 |
+
# Description of notebooks
|
49 |
+
|
50 |
+
Check [here](https://github.com/marimo-team/learn/issues/51) for current series
|
51 |
+
structure.
|
52 |
+
|
53 |
+
| Notebook | Description | References |
|
54 |
+
| ----------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
55 |
+
| [05. Category and Functors](https://github.com/marimo-team/learn/blob/main/Functional_programming/05_functors.py) | Learn why `len` is a _Functor_ from `list concatenation` to `integer addition`, how to _lift_ an ordinary function into a _computation context_, and how to write an _adapter_ between two categories. | - [The Trivial Monad](http://blog.sigfpe.com/2007/04/trivial-monad.html) <br> - [Haskellwiki. Category Theory](https://en.wikibooks.org/wiki/Haskell/Category_theory) <br> - [Haskellforall. The Category Design Pattern](https://www.haskellforall.com/2012/08/the-category-design-pattern.html) <br> - [Haskellforall. The Functor Design Pattern](https://www.haskellforall.com/2012/09/the-functor-design-pattern.html) <br> - [Haskellwiki. Functor](https://wiki.haskell.org/index.php?title=Functor) <br> - [Haskellwiki. Typeclassopedia#Functor](https://wiki.haskell.org/index.php?title=Typeclassopedia#Functor) <br> - [Haskellwiki. Typeclassopedia#Category](https://wiki.haskell.org/index.php?title=Typeclassopedia#Category) |
|
56 |
+
|
57 |
+
**Authors.**
|
58 |
+
|
59 |
+
Thanks to all our notebook authors!
|
60 |
+
|
61 |
+
- [métaboulie](https://github.com/metaboulie)
|
optimization/05_portfolio_optimization.py
CHANGED
@@ -47,7 +47,7 @@ def _(mo):
|
|
47 |
r"""
|
48 |
## Asset returns and risk
|
49 |
|
50 |
-
We will only model investments held for one period. The initial prices are $p_i > 0$. The end of period prices are $p_i^+ >0$. The asset (fractional) returns are $r_i = (p_i^+-p_i)/p_i$. The
|
51 |
|
52 |
A common model is that $r$ is a random variable with mean ${\bf E}r = \mu$ and covariance ${\bf E{(r-\mu)(r-\mu)^T}} = \Sigma$.
|
53 |
It follows that $R$ is a random variable with ${\bf E}R = \mu^T w$ and ${\bf var}(R) = w^T\Sigma w$. In real-world applications, $\mu$ and $\Sigma$ are estimated from data and models, and $w$ is chosen using a library like CVXPY.
|
|
|
47 |
r"""
|
48 |
## Asset returns and risk
|
49 |
|
50 |
+
We will only model investments held for one period. The initial prices are $p_i > 0$. The end of period prices are $p_i^+ >0$. The asset (fractional) returns are $r_i = (p_i^+-p_i)/p_i$. The portfolio (fractional) return is $R = r^Tw$.
|
51 |
|
52 |
A common model is that $r$ is a random variable with mean ${\bf E}r = \mu$ and covariance ${\bf E{(r-\mu)(r-\mu)^T}} = \Sigma$.
|
53 |
It follows that $R$ is a random variable with ${\bf E}R = \mu^T w$ and ${\bf var}(R) = w^T\Sigma w$. In real-world applications, $\mu$ and $\Sigma$ are estimated from data and models, and $w$ is chosen using a library like CVXPY.
|
polars/04_basic_operations.py
ADDED
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.13"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "polars==1.23.0",
|
6 |
+
# ]
|
7 |
+
# ///
|
8 |
+
|
9 |
+
import marimo
|
10 |
+
|
11 |
+
__generated_with = "0.11.13"
|
12 |
+
app = marimo.App(width="medium")
|
13 |
+
|
14 |
+
|
15 |
+
@app.cell
|
16 |
+
def _():
|
17 |
+
import marimo as mo
|
18 |
+
return (mo,)
|
19 |
+
|
20 |
+
|
21 |
+
@app.cell(hide_code=True)
|
22 |
+
def _(mo):
|
23 |
+
mo.md(
|
24 |
+
r"""
|
25 |
+
# Basic operations on data
|
26 |
+
_By [Joram Mutenge](https://www.udemy.com/user/joram-mutenge/)._
|
27 |
+
|
28 |
+
In this notebook, you'll learn how to perform arithmetic operations, comparisons, and conditionals on a Polars dataframe. We'll work with a DataFrame that tracks software usage by year, categorized as either Vintage (old) or Modern (new).
|
29 |
+
"""
|
30 |
+
)
|
31 |
+
return
|
32 |
+
|
33 |
+
|
34 |
+
@app.cell
|
35 |
+
def _():
|
36 |
+
import polars as pl
|
37 |
+
|
38 |
+
df = pl.DataFrame(
|
39 |
+
{
|
40 |
+
"software": [
|
41 |
+
"Lotus-123",
|
42 |
+
"WordStar",
|
43 |
+
"dBase III",
|
44 |
+
"VisiCalc",
|
45 |
+
"WinZip",
|
46 |
+
"MS-DOS",
|
47 |
+
"HyperCard",
|
48 |
+
"WordPerfect",
|
49 |
+
"Excel",
|
50 |
+
"Photoshop",
|
51 |
+
"Visual Studio",
|
52 |
+
"Slack",
|
53 |
+
"Zoom",
|
54 |
+
"Notion",
|
55 |
+
"Figma",
|
56 |
+
"Spotify",
|
57 |
+
"VSCode",
|
58 |
+
"Docker",
|
59 |
+
],
|
60 |
+
"users": [
|
61 |
+
10000,
|
62 |
+
4500,
|
63 |
+
2500,
|
64 |
+
3000,
|
65 |
+
1800,
|
66 |
+
17000,
|
67 |
+
2200,
|
68 |
+
1900,
|
69 |
+
500000,
|
70 |
+
12000000,
|
71 |
+
1500000,
|
72 |
+
3000000,
|
73 |
+
4000000,
|
74 |
+
2000000,
|
75 |
+
2500000,
|
76 |
+
4500000,
|
77 |
+
6000000,
|
78 |
+
3500000,
|
79 |
+
],
|
80 |
+
"category": ["Vintage"] * 8 + ["Modern"] * 10,
|
81 |
+
"year": [
|
82 |
+
1985,
|
83 |
+
1980,
|
84 |
+
1984,
|
85 |
+
1979,
|
86 |
+
1991,
|
87 |
+
1981,
|
88 |
+
1987,
|
89 |
+
1982,
|
90 |
+
1987,
|
91 |
+
1990,
|
92 |
+
1997,
|
93 |
+
2013,
|
94 |
+
2011,
|
95 |
+
2016,
|
96 |
+
2016,
|
97 |
+
2008,
|
98 |
+
2015,
|
99 |
+
2013,
|
100 |
+
],
|
101 |
+
}
|
102 |
+
)
|
103 |
+
|
104 |
+
df
|
105 |
+
return df, pl
|
106 |
+
|
107 |
+
|
108 |
+
@app.cell(hide_code=True)
|
109 |
+
def _(mo):
|
110 |
+
mo.md(
|
111 |
+
r"""
|
112 |
+
## Arithmetic
|
113 |
+
### Addition
|
114 |
+
Let's add 42 users to each piece of software. This means adding 42 to each value under **users**.
|
115 |
+
"""
|
116 |
+
)
|
117 |
+
return
|
118 |
+
|
119 |
+
|
120 |
+
@app.cell
|
121 |
+
def _(df, pl):
|
122 |
+
df.with_columns(pl.col("users") + 42)
|
123 |
+
return
|
124 |
+
|
125 |
+
|
126 |
+
@app.cell(hide_code=True)
|
127 |
+
def _(mo):
|
128 |
+
mo.md(r"""Another way to perform the above operation is using the built-in function.""")
|
129 |
+
return
|
130 |
+
|
131 |
+
|
132 |
+
@app.cell
|
133 |
+
def _(df, pl):
|
134 |
+
df.with_columns(pl.col("users").add(42))
|
135 |
+
return
|
136 |
+
|
137 |
+
|
138 |
+
@app.cell(hide_code=True)
|
139 |
+
def _(mo):
|
140 |
+
mo.md(
|
141 |
+
r"""
|
142 |
+
### Subtraction
|
143 |
+
Let's subtract 42 users to each piece of software.
|
144 |
+
"""
|
145 |
+
)
|
146 |
+
return
|
147 |
+
|
148 |
+
|
149 |
+
@app.cell
|
150 |
+
def _(df, pl):
|
151 |
+
df.with_columns(pl.col("users") - 42)
|
152 |
+
return
|
153 |
+
|
154 |
+
|
155 |
+
@app.cell(hide_code=True)
|
156 |
+
def _(mo):
|
157 |
+
mo.md(r"""Alternatively, you could subtract like this:""")
|
158 |
+
return
|
159 |
+
|
160 |
+
|
161 |
+
@app.cell
|
162 |
+
def _(df, pl):
|
163 |
+
df.with_columns(pl.col("users").sub(42))
|
164 |
+
return
|
165 |
+
|
166 |
+
|
167 |
+
@app.cell(hide_code=True)
|
168 |
+
def _(mo):
|
169 |
+
mo.md(
|
170 |
+
r"""
|
171 |
+
### Division
|
172 |
+
Suppose the **users** values are inflated, we can reduce them by dividing by 1000. Here's how to do it.
|
173 |
+
"""
|
174 |
+
)
|
175 |
+
return
|
176 |
+
|
177 |
+
|
178 |
+
@app.cell
|
179 |
+
def _(df, pl):
|
180 |
+
df.with_columns(pl.col("users") / 1000)
|
181 |
+
return
|
182 |
+
|
183 |
+
|
184 |
+
@app.cell(hide_code=True)
|
185 |
+
def _(mo):
|
186 |
+
mo.md(r"""Or we could do it with a built-in expression.""")
|
187 |
+
return
|
188 |
+
|
189 |
+
|
190 |
+
@app.cell
|
191 |
+
def _(df, pl):
|
192 |
+
df.with_columns(pl.col("users").truediv(1000))
|
193 |
+
return
|
194 |
+
|
195 |
+
|
196 |
+
@app.cell(hide_code=True)
|
197 |
+
def _(mo):
|
198 |
+
mo.md(r"""If we didn't care about the remainder after division (i.e remove numbers after decimal point) we could do it like this.""")
|
199 |
+
return
|
200 |
+
|
201 |
+
|
202 |
+
@app.cell
|
203 |
+
def _(df, pl):
|
204 |
+
df.with_columns(pl.col("users").floordiv(1000))
|
205 |
+
return
|
206 |
+
|
207 |
+
|
208 |
+
@app.cell(hide_code=True)
|
209 |
+
def _(mo):
|
210 |
+
mo.md(
|
211 |
+
r"""
|
212 |
+
### Multiplication
|
213 |
+
Let's pretend the *user* values are deflated and increase them by multiplying by 100.
|
214 |
+
"""
|
215 |
+
)
|
216 |
+
return
|
217 |
+
|
218 |
+
|
219 |
+
@app.cell
|
220 |
+
def _(df, pl):
|
221 |
+
(df.with_columns(pl.col("users") * 100))
|
222 |
+
return
|
223 |
+
|
224 |
+
|
225 |
+
@app.cell(hide_code=True)
|
226 |
+
def _(mo):
|
227 |
+
mo.md(r"""Polars also has a built-in function for multiplication.""")
|
228 |
+
return
|
229 |
+
|
230 |
+
|
231 |
+
@app.cell
|
232 |
+
def _(df, pl):
|
233 |
+
df.with_columns(pl.col("users").mul(100))
|
234 |
+
return
|
235 |
+
|
236 |
+
|
237 |
+
@app.cell(hide_code=True)
|
238 |
+
def _(mo):
|
239 |
+
mo.md(r"""So far, we've only modified the values in an existing column. Let's create a column **decade** that will represent the years as decades. Thus 1985 will be 1980 and 2008 will be 2000.""")
|
240 |
+
return
|
241 |
+
|
242 |
+
|
243 |
+
@app.cell
|
244 |
+
def _(df, pl):
|
245 |
+
(df.with_columns(decade=pl.col("year").floordiv(10).mul(10)))
|
246 |
+
return
|
247 |
+
|
248 |
+
|
249 |
+
@app.cell(hide_code=True)
|
250 |
+
def _(mo):
|
251 |
+
mo.md(r"""We could create a new column another way as follows:""")
|
252 |
+
return
|
253 |
+
|
254 |
+
|
255 |
+
@app.cell
|
256 |
+
def _(df, pl):
|
257 |
+
df.with_columns((pl.col("year").floordiv(10).mul(10)).alias("decade"))
|
258 |
+
return
|
259 |
+
|
260 |
+
|
261 |
+
@app.cell(hide_code=True)
|
262 |
+
def _(mo):
|
263 |
+
mo.md(
|
264 |
+
r"""
|
265 |
+
**Tip**
|
266 |
+
Polars encounrages you to perform your operations as a chain. This enables you to take advantage of the query optimizer. We'll build upon the above code as a chain.
|
267 |
+
|
268 |
+
## Comparison
|
269 |
+
### Equal
|
270 |
+
Let's get all the software categorized as Vintage.
|
271 |
+
"""
|
272 |
+
)
|
273 |
+
return
|
274 |
+
|
275 |
+
|
276 |
+
@app.cell
|
277 |
+
def _(df, pl):
|
278 |
+
(
|
279 |
+
df.with_columns(decade=pl.col("year").floordiv(10).mul(10))
|
280 |
+
.filter(pl.col("category") == "Vintage")
|
281 |
+
)
|
282 |
+
return
|
283 |
+
|
284 |
+
|
285 |
+
@app.cell(hide_code=True)
|
286 |
+
def _(mo):
|
287 |
+
mo.md(r"""We could also do a double comparison. VisiCal is the only software that's vintage and in the decade 1970s. Let's perform this comparison operation.""")
|
288 |
+
return
|
289 |
+
|
290 |
+
|
291 |
+
@app.cell
|
292 |
+
def _(df, pl):
|
293 |
+
(
|
294 |
+
df.with_columns(decade=pl.col("year").floordiv(10).mul(10))
|
295 |
+
.filter(pl.col("category") == "Vintage")
|
296 |
+
.filter(pl.col("decade") == 1970)
|
297 |
+
)
|
298 |
+
return
|
299 |
+
|
300 |
+
|
301 |
+
@app.cell(hide_code=True)
|
302 |
+
def _(mo):
|
303 |
+
mo.md(
|
304 |
+
r"""
|
305 |
+
We could also do this comparison in one line, if readability is not a concern
|
306 |
+
|
307 |
+
**Notice** that we must enclose the two expressions between the `&` with parenthesis.
|
308 |
+
"""
|
309 |
+
)
|
310 |
+
return
|
311 |
+
|
312 |
+
|
313 |
+
@app.cell
|
314 |
+
def _(df, pl):
|
315 |
+
(
|
316 |
+
df.with_columns(decade=pl.col("year").floordiv(10).mul(10))
|
317 |
+
.filter((pl.col("category") == "Vintage") & (pl.col("decade") == 1970))
|
318 |
+
)
|
319 |
+
return
|
320 |
+
|
321 |
+
|
322 |
+
@app.cell(hide_code=True)
|
323 |
+
def _(mo):
|
324 |
+
mo.md(r"""We can also use the built-in function for equal to comparisons.""")
|
325 |
+
return
|
326 |
+
|
327 |
+
|
328 |
+
@app.cell
|
329 |
+
def _(df, pl):
|
330 |
+
(df
|
331 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
332 |
+
.filter(pl.col('category').eq('Vintage'))
|
333 |
+
)
|
334 |
+
return
|
335 |
+
|
336 |
+
|
337 |
+
@app.cell(hide_code=True)
|
338 |
+
def _(mo):
|
339 |
+
mo.md(
|
340 |
+
r"""
|
341 |
+
### Not equal
|
342 |
+
We can also compare if something is `not` equal to something. In this case, category is not vintage.
|
343 |
+
"""
|
344 |
+
)
|
345 |
+
return
|
346 |
+
|
347 |
+
|
348 |
+
@app.cell
|
349 |
+
def _(df, pl):
|
350 |
+
(df
|
351 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
352 |
+
.filter(pl.col('category') != 'Vintage')
|
353 |
+
)
|
354 |
+
return
|
355 |
+
|
356 |
+
|
357 |
+
@app.cell(hide_code=True)
|
358 |
+
def _(mo):
|
359 |
+
mo.md(r"""Or with the built-in function.""")
|
360 |
+
return
|
361 |
+
|
362 |
+
|
363 |
+
@app.cell
|
364 |
+
def _(df, pl):
|
365 |
+
(df
|
366 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
367 |
+
.filter(pl.col('category').ne('Vintage'))
|
368 |
+
)
|
369 |
+
return
|
370 |
+
|
371 |
+
|
372 |
+
@app.cell(hide_code=True)
|
373 |
+
def _(mo):
|
374 |
+
mo.md(r"""Or if you want to be extra clever, you can use the negation symbol `~` used in logic.""")
|
375 |
+
return
|
376 |
+
|
377 |
+
|
378 |
+
@app.cell
|
379 |
+
def _(df, pl):
|
380 |
+
(df
|
381 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
382 |
+
.filter(~pl.col('category').eq('Vintage'))
|
383 |
+
)
|
384 |
+
return
|
385 |
+
|
386 |
+
|
387 |
+
@app.cell(hide_code=True)
|
388 |
+
def _(mo):
|
389 |
+
mo.md(
|
390 |
+
r"""
|
391 |
+
### Greater than
|
392 |
+
Let's get the software where the year is greater than 2008 from the above dataframe.
|
393 |
+
"""
|
394 |
+
)
|
395 |
+
return
|
396 |
+
|
397 |
+
|
398 |
+
@app.cell
|
399 |
+
def _(df, pl):
|
400 |
+
(df
|
401 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
402 |
+
.filter(~pl.col('category').eq('Vintage'))
|
403 |
+
.filter(pl.col('year') > 2008)
|
404 |
+
)
|
405 |
+
return
|
406 |
+
|
407 |
+
|
408 |
+
@app.cell(hide_code=True)
|
409 |
+
def _(mo):
|
410 |
+
mo.md(r"""Or if we wanted the year 2008 to be included, we could use great or equal to.""")
|
411 |
+
return
|
412 |
+
|
413 |
+
|
414 |
+
@app.cell
|
415 |
+
def _(df, pl):
|
416 |
+
(df
|
417 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
418 |
+
.filter(~pl.col('category').eq('Vintage'))
|
419 |
+
.filter(pl.col('year') >= 2008)
|
420 |
+
)
|
421 |
+
return
|
422 |
+
|
423 |
+
|
424 |
+
@app.cell(hide_code=True)
|
425 |
+
def _(mo):
|
426 |
+
mo.md(r"""We could do the previous two operations with built-in functions. Here's with greater than.""")
|
427 |
+
return
|
428 |
+
|
429 |
+
|
430 |
+
@app.cell
|
431 |
+
def _(df, pl):
|
432 |
+
(df
|
433 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
434 |
+
.filter(~pl.col('category').eq('Vintage'))
|
435 |
+
.filter(pl.col('year').gt(2008))
|
436 |
+
)
|
437 |
+
return
|
438 |
+
|
439 |
+
|
440 |
+
@app.cell(hide_code=True)
|
441 |
+
def _(mo):
|
442 |
+
mo.md(r"""And here's with greater or equal to""")
|
443 |
+
return
|
444 |
+
|
445 |
+
|
446 |
+
@app.cell
|
447 |
+
def _(df, pl):
|
448 |
+
(df
|
449 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
450 |
+
.filter(~pl.col('category').eq('Vintage'))
|
451 |
+
.filter(pl.col('year').ge(2008))
|
452 |
+
)
|
453 |
+
return
|
454 |
+
|
455 |
+
|
456 |
+
@app.cell(hide_code=True)
|
457 |
+
def _(mo):
|
458 |
+
mo.md(
|
459 |
+
r"""
|
460 |
+
**Note**: For "less than", and "less or equal to" you can use the operators `<` or `<=`. Alternatively, you can use built-in functions `lt` or `le` respectively.
|
461 |
+
|
462 |
+
### Is between
|
463 |
+
Polars also allows us to filter between a range of values. Let's get the modern software were the year is between 2013 and 2016. This is inclusive on both ends (i.e. both years are part of the result).
|
464 |
+
"""
|
465 |
+
)
|
466 |
+
return
|
467 |
+
|
468 |
+
|
469 |
+
@app.cell
|
470 |
+
def _(df, pl):
|
471 |
+
(df
|
472 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
473 |
+
.filter(pl.col('category').eq('Modern'))
|
474 |
+
.filter(pl.col('year').is_between(2013, 2016))
|
475 |
+
)
|
476 |
+
return
|
477 |
+
|
478 |
+
|
479 |
+
@app.cell(hide_code=True)
|
480 |
+
def _(mo):
|
481 |
+
mo.md(
|
482 |
+
r"""
|
483 |
+
### Or operator
|
484 |
+
If we only want either one of the conditions in the comparison to be met, we could use `|`, which is the `or` operator.
|
485 |
+
|
486 |
+
Let's get software that is either modern or used in the decade 1980s.
|
487 |
+
"""
|
488 |
+
)
|
489 |
+
return
|
490 |
+
|
491 |
+
|
492 |
+
@app.cell
|
493 |
+
def _(df, pl):
|
494 |
+
(df
|
495 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
496 |
+
.filter((pl.col('category') == 'Modern') | (pl.col('decade') == 1980))
|
497 |
+
)
|
498 |
+
return
|
499 |
+
|
500 |
+
|
501 |
+
@app.cell(hide_code=True)
|
502 |
+
def _(mo):
|
503 |
+
mo.md(
|
504 |
+
r"""
|
505 |
+
## Conditionals
|
506 |
+
Polars also allows you create new columns based on a condition. Let's create a column *status* that will indicate if the software is "discontinued" or "in use".
|
507 |
+
|
508 |
+
Here's a list of products that are no longer in use.
|
509 |
+
"""
|
510 |
+
)
|
511 |
+
return
|
512 |
+
|
513 |
+
|
514 |
+
@app.cell
|
515 |
+
def _():
|
516 |
+
discontinued_list = ['Lotus-123', 'WordStar', 'dBase III', 'VisiCalc', 'MS-DOS', 'HyperCard']
|
517 |
+
return (discontinued_list,)
|
518 |
+
|
519 |
+
|
520 |
+
@app.cell(hide_code=True)
|
521 |
+
def _(mo):
|
522 |
+
mo.md(r"""Here's how we can get a dataframe of the products that are discontinued.""")
|
523 |
+
return
|
524 |
+
|
525 |
+
|
526 |
+
@app.cell
|
527 |
+
def _(df, discontinued_list, pl):
|
528 |
+
(df
|
529 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
530 |
+
.filter(pl.col('software').is_in(discontinued_list))
|
531 |
+
)
|
532 |
+
return
|
533 |
+
|
534 |
+
|
535 |
+
@app.cell(hide_code=True)
|
536 |
+
def _(mo):
|
537 |
+
mo.md(r"""Now, let's create the **status** column.""")
|
538 |
+
return
|
539 |
+
|
540 |
+
|
541 |
+
@app.cell
|
542 |
+
def _(df, discontinued_list, pl):
|
543 |
+
(df
|
544 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
545 |
+
.with_columns(pl.when(pl.col('software').is_in(discontinued_list))
|
546 |
+
.then(pl.lit('Discontinued'))
|
547 |
+
.otherwise(pl.lit('In use'))
|
548 |
+
.alias('status')
|
549 |
+
)
|
550 |
+
)
|
551 |
+
return
|
552 |
+
|
553 |
+
|
554 |
+
@app.cell(hide_code=True)
|
555 |
+
def _(mo):
|
556 |
+
mo.md(
|
557 |
+
r"""
|
558 |
+
## Unique counts
|
559 |
+
Sometimes you may want to see only the unique values in a column. Let's check the unique decades we have in our DataFrame.
|
560 |
+
"""
|
561 |
+
)
|
562 |
+
return
|
563 |
+
|
564 |
+
|
565 |
+
@app.cell
|
566 |
+
def _(df, discontinued_list, pl):
|
567 |
+
(df
|
568 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
569 |
+
.with_columns(pl.when(pl.col('software').is_in(discontinued_list))
|
570 |
+
.then(pl.lit('Discontinued'))
|
571 |
+
.otherwise(pl.lit('In use'))
|
572 |
+
.alias('status')
|
573 |
+
)
|
574 |
+
.select('decade').unique()
|
575 |
+
)
|
576 |
+
return
|
577 |
+
|
578 |
+
|
579 |
+
@app.cell(hide_code=True)
|
580 |
+
def _(mo):
|
581 |
+
mo.md(r"""Finally, let's find out the number of software used in each decade.""")
|
582 |
+
return
|
583 |
+
|
584 |
+
|
585 |
+
@app.cell
|
586 |
+
def _(df, discontinued_list, pl):
|
587 |
+
(df
|
588 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
589 |
+
.with_columns(pl.when(pl.col('software').is_in(discontinued_list))
|
590 |
+
.then(pl.lit('Discontinued'))
|
591 |
+
.otherwise(pl.lit('In use'))
|
592 |
+
.alias('status')
|
593 |
+
)
|
594 |
+
['decade'].value_counts()
|
595 |
+
)
|
596 |
+
return
|
597 |
+
|
598 |
+
|
599 |
+
@app.cell(hide_code=True)
|
600 |
+
def _(mo):
|
601 |
+
mo.md(r"""We could also rewrite the above code as follows:""")
|
602 |
+
return
|
603 |
+
|
604 |
+
|
605 |
+
@app.cell
|
606 |
+
def _(df, discontinued_list, pl):
|
607 |
+
(df
|
608 |
+
.with_columns(decade=pl.col('year').floordiv(10).mul(10))
|
609 |
+
.with_columns(pl.when(pl.col('software').is_in(discontinued_list))
|
610 |
+
.then(pl.lit('Discontinued'))
|
611 |
+
.otherwise(pl.lit('In use'))
|
612 |
+
.alias('status')
|
613 |
+
)
|
614 |
+
.select('decade').to_series().value_counts()
|
615 |
+
)
|
616 |
+
return
|
617 |
+
|
618 |
+
|
619 |
+
@app.cell(hide_code=True)
|
620 |
+
def _(mo):
|
621 |
+
mo.md(r"""Hopefully, we've picked your interest to try out Polars the next time you analyze your data.""")
|
622 |
+
return
|
623 |
+
|
624 |
+
|
625 |
+
@app.cell
|
626 |
+
def _():
|
627 |
+
return
|
628 |
+
|
629 |
+
|
630 |
+
if __name__ == "__main__":
|
631 |
+
app.run()
|
polars/10_strings.py
ADDED
@@ -0,0 +1,1004 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.12"
|
3 |
+
# dependencies = [
|
4 |
+
# "altair==5.5.0",
|
5 |
+
# "marimo",
|
6 |
+
# "numpy==2.2.3",
|
7 |
+
# "polars==1.24.0",
|
8 |
+
# ]
|
9 |
+
# ///
|
10 |
+
|
11 |
+
import marimo
|
12 |
+
|
13 |
+
__generated_with = "0.11.17"
|
14 |
+
app = marimo.App(width="medium")
|
15 |
+
|
16 |
+
|
17 |
+
@app.cell(hide_code=True)
|
18 |
+
def _(mo):
|
19 |
+
mo.md(
|
20 |
+
r"""
|
21 |
+
# Strings
|
22 |
+
|
23 |
+
_By [Péter Ferenc Gyarmati](http://github.com/peter-gy)_.
|
24 |
+
|
25 |
+
In this chapter we're going to dig into string manipulation. For a fun twist, we'll be mostly playing around with a dataset that every Polars user has bumped into without really thinking about it—the source code of the `polars` module itself. More precisely, we'll use a dataframe that pulls together all the Polars expressions and their docstrings, giving us a cool, hands-on way to explore the expression API in a truly data-driven manner.
|
26 |
+
|
27 |
+
We'll cover parsing, length calculation, case conversion, and much more, with practical examples and visualizations. Finally, we will combine various techniques you learned in prior chapters to build a fully interactive playground in which you can execute the official code examples of Polars expressions.
|
28 |
+
"""
|
29 |
+
)
|
30 |
+
return
|
31 |
+
|
32 |
+
|
33 |
+
@app.cell(hide_code=True)
|
34 |
+
def _(mo):
|
35 |
+
mo.md(
|
36 |
+
r"""
|
37 |
+
## 🛠️ Parsing & Conversion
|
38 |
+
|
39 |
+
Let's warm up with one of the most frequent use cases: parsing raw strings into various formats.
|
40 |
+
We'll take a tiny dataframe with metadata about Python packages represented as raw JSON strings and we'll use Polars string expressions to parse the attributes into their true data types.
|
41 |
+
"""
|
42 |
+
)
|
43 |
+
return
|
44 |
+
|
45 |
+
|
46 |
+
@app.cell
|
47 |
+
def _(pl):
|
48 |
+
pip_metadata_raw_df = pl.DataFrame(
|
49 |
+
[
|
50 |
+
'{"package": "polars", "version": "1.24.0", "released_at": "2025-03-02T20:31:12+0000", "size_mb": "30.9"}',
|
51 |
+
'{"package": "marimo", "version": "0.11.14", "released_at": "2025-03-04T00:28:57+0000", "size_mb": "10.7"}',
|
52 |
+
],
|
53 |
+
schema={"raw_json": pl.String},
|
54 |
+
)
|
55 |
+
pip_metadata_raw_df
|
56 |
+
return (pip_metadata_raw_df,)
|
57 |
+
|
58 |
+
|
59 |
+
@app.cell(hide_code=True)
|
60 |
+
def _(mo):
|
61 |
+
mo.md(r"""We can use the [`json_decode`](https://docs.pola.rs/api/python/stable/reference/series/api/polars.Series.str.json_decode.html) expression to parse the raw JSON strings into Polars-native structs and we can use the [unnest](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.unnest.html) dataframe operation to have a dedicated column per parsed attribute.""")
|
62 |
+
return
|
63 |
+
|
64 |
+
|
65 |
+
@app.cell
|
66 |
+
def _(pip_metadata_raw_df, pl):
|
67 |
+
pip_metadata_df = pip_metadata_raw_df.select(json=pl.col('raw_json').str.json_decode()).unnest('json')
|
68 |
+
pip_metadata_df
|
69 |
+
return (pip_metadata_df,)
|
70 |
+
|
71 |
+
|
72 |
+
@app.cell(hide_code=True)
|
73 |
+
def _(mo):
|
74 |
+
mo.md(r"""This is already a much friendlier representation of the data we started out with, but note that since the JSON entries had only string attributes, all values are strings, even the temporal `released_at` and numerical `size_mb` columns.""")
|
75 |
+
return
|
76 |
+
|
77 |
+
|
78 |
+
@app.cell(hide_code=True)
|
79 |
+
def _(mo):
|
80 |
+
mo.md(r"""As we know that the `size_mb` column should have a decimal representation, we go ahead and use [`to_decimal`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_decimal.html#polars.Expr.str.to_decimal) to perform the conversion.""")
|
81 |
+
return
|
82 |
+
|
83 |
+
|
84 |
+
@app.cell
|
85 |
+
def _(pip_metadata_df, pl):
|
86 |
+
pip_metadata_df.select(
|
87 |
+
'package',
|
88 |
+
'version',
|
89 |
+
pl.col('size_mb').str.to_decimal(),
|
90 |
+
)
|
91 |
+
return
|
92 |
+
|
93 |
+
|
94 |
+
@app.cell(hide_code=True)
|
95 |
+
def _(mo):
|
96 |
+
mo.md(
|
97 |
+
r"""
|
98 |
+
Moving on to the `released_at` attribute which indicates the exact time when a given Python package got released, we have a bit more options to consider. We can convert to `Date`, `DateTime`, and `Time` types based on the desired temporal granularity. The [`to_date`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_date.html), [`to_datetime`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_datetime.html), and [`to_time`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_time.html) expressions are here to help us with the conversion, all we need is to provide the desired format string.
|
99 |
+
|
100 |
+
Since Polars uses Rust under the hood to implement all its expressions, we need to consult the [`chrono::format`](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) reference to come up with appropriate format strings.
|
101 |
+
|
102 |
+
Here's a quick reference:
|
103 |
+
|
104 |
+
| Specifier | Meaning |
|
105 |
+
|-----------|--------------------|
|
106 |
+
| `%Y` | Year (e.g., 2025) |
|
107 |
+
| `%m` | Month (01-12) |
|
108 |
+
| `%d` | Day (01-31) |
|
109 |
+
| `%H` | Hour (00-23) |
|
110 |
+
| `%z` | UTC offset |
|
111 |
+
|
112 |
+
The raw strings we are working with look like `"2025-03-02T20:31:12+0000"`. We can match this using the `"%Y-%m-%dT%H:%M:%S%z"` format string.
|
113 |
+
"""
|
114 |
+
)
|
115 |
+
return
|
116 |
+
|
117 |
+
|
118 |
+
@app.cell
|
119 |
+
def _(pip_metadata_df, pl):
|
120 |
+
pip_metadata_df.select(
|
121 |
+
'package',
|
122 |
+
'version',
|
123 |
+
release_date=pl.col('released_at').str.to_date('%Y-%m-%dT%H:%M:%S%z'),
|
124 |
+
release_datetime=pl.col('released_at').str.to_datetime('%Y-%m-%dT%H:%M:%S%z'),
|
125 |
+
release_time=pl.col('released_at').str.to_time('%Y-%m-%dT%H:%M:%S%z'),
|
126 |
+
)
|
127 |
+
return
|
128 |
+
|
129 |
+
|
130 |
+
@app.cell(hide_code=True)
|
131 |
+
def _(mo):
|
132 |
+
mo.md(r"""Alternatively, instead of using three different functions to perform the conversion to date, we can use a single one, [`strptime`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strptime.html) which takes the desired temporal data type as its first parameter.""")
|
133 |
+
return
|
134 |
+
|
135 |
+
|
136 |
+
@app.cell
|
137 |
+
def _(pip_metadata_df, pl):
|
138 |
+
pip_metadata_df.select(
|
139 |
+
'package',
|
140 |
+
'version',
|
141 |
+
release_date=pl.col('released_at').str.strptime(pl.Date, '%Y-%m-%dT%H:%M:%S%z'),
|
142 |
+
release_datetime=pl.col('released_at').str.strptime(pl.Datetime, '%Y-%m-%dT%H:%M:%S%z'),
|
143 |
+
release_time=pl.col('released_at').str.strptime(pl.Time, '%Y-%m-%dT%H:%M:%S%z'),
|
144 |
+
)
|
145 |
+
return
|
146 |
+
|
147 |
+
|
148 |
+
@app.cell(hide_code=True)
|
149 |
+
def _(mo):
|
150 |
+
mo.md(r"""And to wrap up this section on parsing and conversion, let's consider a final scenario. What if we don't want to parse the entire raw JSON string, because we only need a subset of its attributes? Well, in this case we can leverage the [`json_path_match`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.json_path_match.html) expression to extract only the desired attributes using standard [JSONPath](https://goessner.net/articles/JsonPath/) syntax.""")
|
151 |
+
return
|
152 |
+
|
153 |
+
|
154 |
+
@app.cell
|
155 |
+
def _(pip_metadata_raw_df, pl):
|
156 |
+
pip_metadata_raw_df.select(
|
157 |
+
package=pl.col("raw_json").str.json_path_match("$.package"),
|
158 |
+
version=pl.col("raw_json").str.json_path_match("$.version"),
|
159 |
+
release_date=pl.col("raw_json")
|
160 |
+
.str.json_path_match("$.size_mb")
|
161 |
+
.str.to_decimal(),
|
162 |
+
)
|
163 |
+
return
|
164 |
+
|
165 |
+
|
166 |
+
@app.cell(hide_code=True)
|
167 |
+
def _(mo):
|
168 |
+
mo.md(
|
169 |
+
r"""
|
170 |
+
## 📊 Dataset Overview
|
171 |
+
|
172 |
+
Now that we got our hands dirty, let's consider a somewhat wilder dataset for the subsequent sections: a dataframe of metadata about every single expression in your current Polars module.
|
173 |
+
|
174 |
+
At the risk of stating the obvious, in the previous section, when we typed `pl.col('raw_json').str.json_decode()`, we accessed the `json_decode` member of the `str` expression namespace through the `pl.col('raw_json')` expression *instance*. Under the hood, deep inside the Polars source code, there is a corresponding `def json_decode(...)` method with a carefully authored docstring explaining the purpose and signature of the member.
|
175 |
+
|
176 |
+
Since Python makes module introspection simple, we can easily enumerate all Polars expressions and organize their metadata in `expressions_df`, to be used for all the upcoming string manipulation examples.
|
177 |
+
"""
|
178 |
+
)
|
179 |
+
return
|
180 |
+
|
181 |
+
|
182 |
+
@app.cell(hide_code=True)
|
183 |
+
def _(pl):
|
184 |
+
def list_members(expr, namespace) -> list[dict]:
|
185 |
+
"""Iterates through the attributes of `expr` and returns their metadata"""
|
186 |
+
members = []
|
187 |
+
for attrname in expr.__dir__():
|
188 |
+
is_namespace = attrname in pl.Expr._accessors
|
189 |
+
is_private = attrname.startswith("_")
|
190 |
+
if is_namespace or is_private:
|
191 |
+
continue
|
192 |
+
|
193 |
+
attr = getattr(expr, attrname)
|
194 |
+
members.append(
|
195 |
+
{
|
196 |
+
"namespace": namespace,
|
197 |
+
"member": attrname,
|
198 |
+
"docstring": attr.__doc__,
|
199 |
+
}
|
200 |
+
)
|
201 |
+
return members
|
202 |
+
|
203 |
+
|
204 |
+
def list_expr_meta() -> list[dict]:
|
205 |
+
# Dummy expression instance to 'crawl'
|
206 |
+
expr = pl.lit("")
|
207 |
+
root_members = list_members(expr, "root")
|
208 |
+
namespaced_members: list[list[dict]] = [
|
209 |
+
list_members(getattr(expr, namespace), namespace)
|
210 |
+
for namespace in pl.Expr._accessors
|
211 |
+
]
|
212 |
+
return sum(namespaced_members, root_members)
|
213 |
+
|
214 |
+
|
215 |
+
expressions_df = pl.from_dicts(list_expr_meta(), infer_schema_length=None).sort('namespace', 'member')
|
216 |
+
expressions_df
|
217 |
+
return expressions_df, list_expr_meta, list_members
|
218 |
+
|
219 |
+
|
220 |
+
@app.cell(hide_code=True)
|
221 |
+
def _(mo):
|
222 |
+
mo.md(r"""As the following visualization shows, `str` is one of the richest Polars expression namespaces with multiple dozens of functions in it.""")
|
223 |
+
return
|
224 |
+
|
225 |
+
|
226 |
+
@app.cell(hide_code=True)
|
227 |
+
def _(alt, expressions_df):
|
228 |
+
expressions_df.plot.bar(
|
229 |
+
x=alt.X("count(member):Q", title='Count of Expressions'),
|
230 |
+
y=alt.Y("namespace:N", title='Namespace').sort("-x"),
|
231 |
+
)
|
232 |
+
return
|
233 |
+
|
234 |
+
|
235 |
+
@app.cell(hide_code=True)
|
236 |
+
def _(mo):
|
237 |
+
mo.md(
|
238 |
+
r"""
|
239 |
+
## 📏 Length Calculation
|
240 |
+
|
241 |
+
A common use case is to compute the length of a string. Most people associate string length exclusively with the number of characters the said string consists of; however, in certain scenarios it is useful to also know how much memory is required for storing, so how many bytes are required to represent the textual data.
|
242 |
+
|
243 |
+
The expressions [`len_chars`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.len_chars.html) and [`len_bytes`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.len_bytes.html) are here to help us with these calculations.
|
244 |
+
|
245 |
+
Below, we compute `docstring_len_chars` and `docstring_len_bytes` columns to see how many characters and bytes the documentation of each expression is made up of.
|
246 |
+
"""
|
247 |
+
)
|
248 |
+
return
|
249 |
+
|
250 |
+
|
251 |
+
@app.cell
|
252 |
+
def _(expressions_df, pl):
|
253 |
+
docstring_length_df = expressions_df.select(
|
254 |
+
'namespace',
|
255 |
+
'member',
|
256 |
+
docstring_len_chars=pl.col("docstring").str.len_chars(),
|
257 |
+
docstring_len_bytes=pl.col("docstring").str.len_bytes(),
|
258 |
+
)
|
259 |
+
docstring_length_df
|
260 |
+
return (docstring_length_df,)
|
261 |
+
|
262 |
+
|
263 |
+
@app.cell(hide_code=True)
|
264 |
+
def _(mo):
|
265 |
+
mo.md(r"""As the dataframe preview above and the scatterplot below show, the docstring length measured in bytes is almost always bigger than the length expressed in characters. This is due to the fact that the docstrings include characters which require more than a single byte to represent, such as "╞" for displaying dataframe header and body separators.""")
|
266 |
+
return
|
267 |
+
|
268 |
+
|
269 |
+
@app.cell
|
270 |
+
def _(alt, docstring_length_df):
|
271 |
+
docstring_length_df.plot.point(
|
272 |
+
x=alt.X('docstring_len_chars', title='Docstring Length (Chars)'),
|
273 |
+
y=alt.Y('docstring_len_bytes', title='Docstring Length (Bytes)'),
|
274 |
+
tooltip=['namespace', 'member', 'docstring_len_chars', 'docstring_len_bytes'],
|
275 |
+
)
|
276 |
+
return
|
277 |
+
|
278 |
+
|
279 |
+
@app.cell(hide_code=True)
|
280 |
+
def _(mo):
|
281 |
+
mo.md(
|
282 |
+
r"""
|
283 |
+
## 🔠 Case Conversion
|
284 |
+
|
285 |
+
Another frequent string transformation is lowercasing, uppercasing, and titlecasing. We can use [`to_lowercase`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_lowercase.html), [`to_uppercase`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_lowercase.html) and [`to_titlecase`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_titlecase.html) for doing so.
|
286 |
+
"""
|
287 |
+
)
|
288 |
+
return
|
289 |
+
|
290 |
+
|
291 |
+
@app.cell
|
292 |
+
def _(expressions_df, pl):
|
293 |
+
expressions_df.select(
|
294 |
+
member_lower=pl.col('member').str.to_lowercase(),
|
295 |
+
member_upper=pl.col('member').str.to_uppercase(),
|
296 |
+
member_title=pl.col('member').str.to_titlecase(),
|
297 |
+
)
|
298 |
+
return
|
299 |
+
|
300 |
+
|
301 |
+
@app.cell(hide_code=True)
|
302 |
+
def _(mo):
|
303 |
+
mo.md(
|
304 |
+
r"""
|
305 |
+
## ➕ Padding
|
306 |
+
|
307 |
+
Sometimes we need to ensure that strings have a fixed-size character length. [`pad_start`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.pad_start.html) and [`pad_end`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.pad_end.html) can be used to fill the "front" or "back" of a string with a supplied character, while [`zfill`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.zfill.html) is a utility for padding the start of a string with `"0"` until it reaches a particular length. In other words, `zfill` is a more specific version of `pad_start`, where the `fill_char` parameter is explicitly set to `"0"`.
|
308 |
+
|
309 |
+
In the example below we take the unique Polars expression namespaces and pad them so that they have a uniform length which you can control via a slider.
|
310 |
+
"""
|
311 |
+
)
|
312 |
+
return
|
313 |
+
|
314 |
+
|
315 |
+
@app.cell(hide_code=True)
|
316 |
+
def _(mo):
|
317 |
+
padding = mo.ui.slider(0, 16, step=1, value=8, label="Padding Size")
|
318 |
+
return (padding,)
|
319 |
+
|
320 |
+
|
321 |
+
@app.cell
|
322 |
+
def _(expressions_df, padding, pl):
|
323 |
+
padded_df = expressions_df.select("namespace").unique().select(
|
324 |
+
"namespace",
|
325 |
+
namespace_front_padded=pl.col("namespace").str.pad_start(padding.value, "_"),
|
326 |
+
namespace_back_padded=pl.col("namespace").str.pad_end(padding.value, "_"),
|
327 |
+
namespace_zfilled=pl.col("namespace").str.zfill(padding.value),
|
328 |
+
)
|
329 |
+
return (padded_df,)
|
330 |
+
|
331 |
+
|
332 |
+
@app.cell(hide_code=True)
|
333 |
+
def _(mo, padded_df, padding):
|
334 |
+
mo.vstack([
|
335 |
+
padding,
|
336 |
+
padded_df,
|
337 |
+
])
|
338 |
+
return
|
339 |
+
|
340 |
+
|
341 |
+
@app.cell(hide_code=True)
|
342 |
+
def _(mo):
|
343 |
+
mo.md(
|
344 |
+
r"""
|
345 |
+
## 🔄 Replacing
|
346 |
+
|
347 |
+
Let's say we want to convert from `snake_case` API member names to `kebab-case`, that is, we need to replace the underscore character with a hyphen. For operations like that, we can use [`replace`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.replace.html) and [`replace_all`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.replace_all.html).
|
348 |
+
|
349 |
+
As the example below demonstrates, `replace` stops after the first occurrence of the to-be-replaced pattern, while `replace_all` goes all the way through and changes all underscores to hyphens resulting in the `kebab-case` representation we were looking for.
|
350 |
+
"""
|
351 |
+
)
|
352 |
+
return
|
353 |
+
|
354 |
+
|
355 |
+
@app.cell
|
356 |
+
def _(expressions_df, pl):
|
357 |
+
expressions_df.select(
|
358 |
+
"member",
|
359 |
+
member_kebab_case_partial=pl.col("member").str.replace("_", "-"),
|
360 |
+
member_kebab_case=pl.col("member").str.replace_all("_", "-"),
|
361 |
+
).sort(pl.col("member").str.len_chars(), descending=True)
|
362 |
+
return
|
363 |
+
|
364 |
+
|
365 |
+
@app.cell(hide_code=True)
|
366 |
+
def _(mo):
|
367 |
+
mo.md(
|
368 |
+
r"""
|
369 |
+
A related expression is [`replace_many`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.replace_many.html), which accepts *many* pairs of to-be-matched patterns and corresponding replacements and uses the [Aho–Corasick algorithm](https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm) to carry out the operation with great performance.
|
370 |
+
|
371 |
+
In the example below we replace all instances of `"min"` with `"minimum"` and `"max"` with `"maximum"` using a single expression.
|
372 |
+
"""
|
373 |
+
)
|
374 |
+
return
|
375 |
+
|
376 |
+
|
377 |
+
@app.cell
|
378 |
+
def _(expressions_df, pl):
|
379 |
+
expressions_df.select(
|
380 |
+
"member",
|
381 |
+
member_modified=pl.col("member").str.replace_many(
|
382 |
+
{
|
383 |
+
"min": "minimum",
|
384 |
+
"max": "maximum",
|
385 |
+
}
|
386 |
+
),
|
387 |
+
)
|
388 |
+
return
|
389 |
+
|
390 |
+
|
391 |
+
@app.cell(hide_code=True)
|
392 |
+
def _(mo):
|
393 |
+
mo.md(
|
394 |
+
r"""
|
395 |
+
## 🔍 Searching & Matching
|
396 |
+
|
397 |
+
A common need when working with strings is to determine whether their content satisfies some condition: whether it starts or ends with a particular substring or contains a certain pattern.
|
398 |
+
|
399 |
+
Let's suppose we want to determine whether a member of the Polars expression API is a "converter", such as `to_decimal`, identified by its `"to_"` prefix. We can use [`starts_with`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.starts_with.html) to perform this check.
|
400 |
+
"""
|
401 |
+
)
|
402 |
+
return
|
403 |
+
|
404 |
+
|
405 |
+
@app.cell
|
406 |
+
def _(expressions_df, pl):
|
407 |
+
expressions_df.select(
|
408 |
+
"namespace",
|
409 |
+
"member",
|
410 |
+
is_converter=pl.col("member").str.starts_with("to_"),
|
411 |
+
).sort(-pl.col("is_converter").cast(pl.Int8))
|
412 |
+
return
|
413 |
+
|
414 |
+
|
415 |
+
@app.cell(hide_code=True)
|
416 |
+
def _(mo):
|
417 |
+
mo.md(
|
418 |
+
r"""
|
419 |
+
Throughout this course as you have gained familiarity with the expression API you might have noticed that some members end with an underscore such as `or_`, since their "body" is a reserved Python keyword.
|
420 |
+
|
421 |
+
Let's use [`ends_with`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.ends_with.html) to find all the members which are named after such keywords.
|
422 |
+
"""
|
423 |
+
)
|
424 |
+
return
|
425 |
+
|
426 |
+
|
427 |
+
@app.cell
|
428 |
+
def _(expressions_df, pl):
|
429 |
+
expressions_df.select(
|
430 |
+
"namespace",
|
431 |
+
"member",
|
432 |
+
is_escaped_keyword=pl.col("member").str.ends_with("_"),
|
433 |
+
).sort(-pl.col("is_escaped_keyword").cast(pl.Int8))
|
434 |
+
return
|
435 |
+
|
436 |
+
|
437 |
+
@app.cell(hide_code=True)
|
438 |
+
def _(mo):
|
439 |
+
mo.md(
|
440 |
+
r"""
|
441 |
+
Now let's move on to analyzing the docstrings in a bit more detail. Based on their content we can determine whether a member is deprecated, accepts parameters, comes with examples, or references external URL(s) & related members.
|
442 |
+
|
443 |
+
As demonstrated below, we can compute all these boolean attributes using [`contains`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.contains.html) to check whether the docstring includes a particular substring.
|
444 |
+
"""
|
445 |
+
)
|
446 |
+
return
|
447 |
+
|
448 |
+
|
449 |
+
@app.cell
|
450 |
+
def _(expressions_df, pl):
|
451 |
+
expressions_df.select(
|
452 |
+
'namespace',
|
453 |
+
'member',
|
454 |
+
is_deprecated=pl.col('docstring').str.contains('.. deprecated', literal=True),
|
455 |
+
has_parameters=pl.col('docstring').str.contains('Parameters'),
|
456 |
+
has_examples=pl.col('docstring').str.contains('Examples'),
|
457 |
+
has_related_members=pl.col('docstring').str.contains('See Also'),
|
458 |
+
has_url=pl.col('docstring').str.contains('https?://'),
|
459 |
+
)
|
460 |
+
return
|
461 |
+
|
462 |
+
|
463 |
+
@app.cell(hide_code=True)
|
464 |
+
def _(mo):
|
465 |
+
mo.md(r"""For scenarios where we want to combine multiple substrings to check for, we can use the [`contains`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.contains.html) expression to check for the presence of various patterns.""")
|
466 |
+
return
|
467 |
+
|
468 |
+
|
469 |
+
@app.cell
|
470 |
+
def _(expressions_df, pl):
|
471 |
+
expressions_df.select(
|
472 |
+
'namespace',
|
473 |
+
'member',
|
474 |
+
has_reference=pl.col('docstring').str.contains_any(['See Also', 'https://'])
|
475 |
+
)
|
476 |
+
return
|
477 |
+
|
478 |
+
|
479 |
+
@app.cell(hide_code=True)
|
480 |
+
def _(mo):
|
481 |
+
mo.md(
|
482 |
+
r"""
|
483 |
+
From the above analysis we could see that almost all the members come with code examples. It would be interesting to know how many variable assignments are going on within each of these examples, right? That's not as simple as checking for a pre-defined literal string containment though, because variables can have arbitrary names - any valid Python identifier is allowed. While the `contains` function supports checking for regular expressions instead of literal strings too, it would not suffice for this exercise because it only tells us whether there is at least a single occurrence of the sought pattern rather than telling us the exact number of matches.
|
484 |
+
|
485 |
+
Fortunately, we can take advantage of [`count_matches`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.count_matches.html) to achieve exactly what we want. We specify the regular expression `r'[a-zA-Z_][a-zA-Z0-9_]* = '` according to the [`regex` Rust crate](https://docs.rs/regex/latest/regex/) to match Python identifiers and we leave the rest to Polars.
|
486 |
+
|
487 |
+
In `count_matches(r'[a-zA-Z_][a-zA-Z0-9_]* = ')`:
|
488 |
+
|
489 |
+
- `[a-zA-Z_]` matches a letter or underscore (start of a Python identifier).
|
490 |
+
- `[a-zA-Z0-9_]*` matches zero or more letters, digits, or underscores.
|
491 |
+
- ` = ` matches a space, equals sign, and space (indicating assignment).
|
492 |
+
|
493 |
+
This finds variable assignments like `x = ` or `df_result = ` in docstrings.
|
494 |
+
"""
|
495 |
+
)
|
496 |
+
return
|
497 |
+
|
498 |
+
|
499 |
+
@app.cell
|
500 |
+
def _(expressions_df, pl):
|
501 |
+
expressions_df.select(
|
502 |
+
'namespace',
|
503 |
+
'member',
|
504 |
+
variable_assignment_count=pl.col('docstring').str.count_matches(r'[a-zA-Z_][a-zA-Z0-9_]* = '),
|
505 |
+
)
|
506 |
+
return
|
507 |
+
|
508 |
+
|
509 |
+
@app.cell(hide_code=True)
|
510 |
+
def _(mo):
|
511 |
+
mo.md(r"""A related application example is to *find* the first index where a particular pattern is present, so that it can be used for downstream processing such as slicing. Below we use the [`find`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.find.html) expression to determine the index at which a code example starts in the docstring - identified by the Python shell substring `">>>"`.""")
|
512 |
+
return
|
513 |
+
|
514 |
+
|
515 |
+
@app.cell
|
516 |
+
def _(expressions_df, pl):
|
517 |
+
expressions_df.select(
|
518 |
+
'namespace',
|
519 |
+
'member',
|
520 |
+
code_example_start=pl.col('docstring').str.find('>>>'),
|
521 |
+
)
|
522 |
+
return
|
523 |
+
|
524 |
+
|
525 |
+
@app.cell(hide_code=True)
|
526 |
+
def _(mo):
|
527 |
+
mo.md(
|
528 |
+
r"""
|
529 |
+
## ✂️ Slicing and Substrings
|
530 |
+
|
531 |
+
Sometimes we are only interested in a particular substring. We can use [`head`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.head.html), [`tail`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.tail.html) and [`slice`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.slice.html) to extract a substring from the start, end, or between arbitrary indices.
|
532 |
+
"""
|
533 |
+
)
|
534 |
+
return
|
535 |
+
|
536 |
+
|
537 |
+
@app.cell
|
538 |
+
def _(mo):
|
539 |
+
slice = mo.ui.slider(1, 50, step=1, value=25, label="Slice Size")
|
540 |
+
return (slice,)
|
541 |
+
|
542 |
+
|
543 |
+
@app.cell
|
544 |
+
def _(expressions_df, pl, slice):
|
545 |
+
sliced_df = expressions_df.select(
|
546 |
+
# First 25 chars
|
547 |
+
docstring_head=pl.col("docstring").str.head(slice.value),
|
548 |
+
# 50 chars after the first 25 chars
|
549 |
+
docstring_slice=pl.col("docstring").str.slice(slice.value, 2*slice.value),
|
550 |
+
# Last 25 chars
|
551 |
+
docstring_tail=pl.col("docstring").str.tail(slice.value),
|
552 |
+
)
|
553 |
+
return (sliced_df,)
|
554 |
+
|
555 |
+
|
556 |
+
@app.cell
|
557 |
+
def _(mo, slice, sliced_df):
|
558 |
+
mo.vstack([
|
559 |
+
slice,
|
560 |
+
sliced_df,
|
561 |
+
])
|
562 |
+
return
|
563 |
+
|
564 |
+
|
565 |
+
@app.cell(hide_code=True)
|
566 |
+
def _(mo):
|
567 |
+
mo.md(
|
568 |
+
r"""
|
569 |
+
## ➗ Splitting
|
570 |
+
|
571 |
+
Certain strings follow a well-defined structure and we might be only interested in some parts of them. For example, when dealing with `snake_cased_expression` member names we might be curious to get only the first, second, or $n^{\text{th}}$ word before an underscore. We would need to *split* the string at a particular pattern for downstream processing.
|
572 |
+
|
573 |
+
The [`split`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.split.html), [`split_exact`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.split_exact.html) and [`splitn`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.splitn.html) expressions enable us to achieve this.
|
574 |
+
|
575 |
+
The primary difference between these string splitting utilities is that `split` produces a list of variadic length based on the number of resulting segments, `splitn` returns a struct with at least `0` and at most `n` fields while `split_exact` returns a struct of exactly `n` fields.
|
576 |
+
"""
|
577 |
+
)
|
578 |
+
return
|
579 |
+
|
580 |
+
|
581 |
+
@app.cell
|
582 |
+
def _(expressions_df, pl):
|
583 |
+
expressions_df.select(
|
584 |
+
'member',
|
585 |
+
member_name_parts=pl.col('member').str.split('_'),
|
586 |
+
member_name_parts_n=pl.col('member').str.splitn('_', n=2),
|
587 |
+
member_name_parts_exact=pl.col('member').str.split_exact('_', n=2),
|
588 |
+
)
|
589 |
+
return
|
590 |
+
|
591 |
+
|
592 |
+
@app.cell(hide_code=True)
|
593 |
+
def _(mo):
|
594 |
+
mo.md(r"""As a more practical example, we can use the `split` expression with some aggregation to count the number of times a particular word occurs in member names across all namespaces. This enables us to create a word cloud of the API members' constituents!""")
|
595 |
+
return
|
596 |
+
|
597 |
+
|
598 |
+
@app.cell(hide_code=True)
|
599 |
+
def _(mo, wordcloud, wordcloud_height, wordcloud_width):
|
600 |
+
mo.vstack([
|
601 |
+
wordcloud_width,
|
602 |
+
wordcloud_height,
|
603 |
+
wordcloud,
|
604 |
+
])
|
605 |
+
return
|
606 |
+
|
607 |
+
|
608 |
+
@app.cell(hide_code=True)
|
609 |
+
def _(mo):
|
610 |
+
wordcloud_width = mo.ui.slider(0, 64, step=1, value=32, label="Word Cloud Width")
|
611 |
+
wordcloud_height = mo.ui.slider(0, 32, step=1, value=16, label="Word Cloud Height")
|
612 |
+
return wordcloud_height, wordcloud_width
|
613 |
+
|
614 |
+
|
615 |
+
@app.cell(hide_code=True)
|
616 |
+
def _(alt, expressions_df, pl, random, wordcloud_height, wordcloud_width):
|
617 |
+
wordcloud_df = (
|
618 |
+
expressions_df.select(pl.col("member").str.split("_"))
|
619 |
+
.explode("member")
|
620 |
+
.group_by("member")
|
621 |
+
.agg(pl.len())
|
622 |
+
# Generating random x and y coordinates to distribute the words in the 2D space
|
623 |
+
.with_columns(
|
624 |
+
x=pl.col("member").map_elements(
|
625 |
+
lambda e: random.randint(0, wordcloud_width.value),
|
626 |
+
return_dtype=pl.UInt8,
|
627 |
+
),
|
628 |
+
y=pl.col("member").map_elements(
|
629 |
+
lambda e: random.randint(0, wordcloud_height.value),
|
630 |
+
return_dtype=pl.UInt8,
|
631 |
+
),
|
632 |
+
)
|
633 |
+
)
|
634 |
+
|
635 |
+
wordcloud = alt.Chart(wordcloud_df).mark_text(baseline="middle").encode(
|
636 |
+
x=alt.X("x:O", axis=None),
|
637 |
+
y=alt.Y("y:O", axis=None),
|
638 |
+
text="member:N",
|
639 |
+
color=alt.Color("len:Q", scale=alt.Scale(scheme="bluepurple")),
|
640 |
+
size=alt.Size("len:Q", legend=None),
|
641 |
+
tooltip=["member", "len"],
|
642 |
+
).configure_view(strokeWidth=0)
|
643 |
+
return wordcloud, wordcloud_df
|
644 |
+
|
645 |
+
|
646 |
+
@app.cell(hide_code=True)
|
647 |
+
def _(mo):
|
648 |
+
mo.md(
|
649 |
+
r"""
|
650 |
+
## 🔗 Concatenation & Joining
|
651 |
+
|
652 |
+
Often we would like to create longer strings from strings we already have. We might want to create a formatted, sentence-like string or join multiple existing strings in our dataframe into a single one.
|
653 |
+
|
654 |
+
The top-level [`concat_str`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.concat_str.html) expression enables us to combine strings *horizontally* in a dataframe. As the example below shows, we can take the `member` and `namespace` column of each row and construct a `description` column in which each row will correspond to the value ``f"- Expression `{member}` belongs to namespace `{namespace}`"``.
|
655 |
+
"""
|
656 |
+
)
|
657 |
+
return
|
658 |
+
|
659 |
+
|
660 |
+
@app.cell
|
661 |
+
def _(expressions_df, pl):
|
662 |
+
descriptions_df = expressions_df.sample(5).select(
|
663 |
+
description=pl.concat_str(
|
664 |
+
[
|
665 |
+
pl.lit("- Expression "),
|
666 |
+
pl.lit("`"),
|
667 |
+
"member",
|
668 |
+
pl.lit("`"),
|
669 |
+
pl.lit(" belongs to namespace "),
|
670 |
+
pl.lit("`"),
|
671 |
+
"namespace",
|
672 |
+
pl.lit("`"),
|
673 |
+
],
|
674 |
+
)
|
675 |
+
)
|
676 |
+
descriptions_df
|
677 |
+
return (descriptions_df,)
|
678 |
+
|
679 |
+
|
680 |
+
@app.cell(hide_code=True)
|
681 |
+
def _(mo):
|
682 |
+
mo.md(
|
683 |
+
r"""
|
684 |
+
Now that we have constructed these bullet points through *horizontal* concatenation of strings, we can perform a *vertical* one so that we end up with a single string in which we have a bullet point on each line.
|
685 |
+
|
686 |
+
We will use the [`join`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.join.html) expression to do so.
|
687 |
+
"""
|
688 |
+
)
|
689 |
+
return
|
690 |
+
|
691 |
+
|
692 |
+
@app.cell
|
693 |
+
def _(descriptions_df, pl):
|
694 |
+
descriptions_df.select(pl.col('description').str.join('\n'))
|
695 |
+
return
|
696 |
+
|
697 |
+
|
698 |
+
@app.cell(hide_code=True)
|
699 |
+
def _(descriptions_df, mo, pl):
|
700 |
+
mo.md(f"""In fact, since the string we constructed dynamically is valid markdown, we can display it dynamically using Marimo's `mo.md` utility!
|
701 |
+
|
702 |
+
---
|
703 |
+
|
704 |
+
{descriptions_df.select(pl.col('description').str.join('\n')).to_numpy().squeeze().tolist()}
|
705 |
+
""")
|
706 |
+
return
|
707 |
+
|
708 |
+
|
709 |
+
@app.cell(hide_code=True)
|
710 |
+
def _(mo):
|
711 |
+
mo.md(
|
712 |
+
r"""
|
713 |
+
## 🔍 Pattern-based Extraction
|
714 |
+
|
715 |
+
In the vast majority of the cases, when dealing with unstructured text data, all we really want is to extract something structured from it. A common use case is to extract URLs from text to get a better understanding of related content.
|
716 |
+
|
717 |
+
In the example below that's exactly what we do. We scan the `docstring` of each API member and extract URLs from them using [`extract`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.extract.html) and [`extract_all`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.extract_all.html) using a simple regular expression to match http and https URLs.
|
718 |
+
|
719 |
+
Note that `extract` stops after a first match and returns a scalar result (or `null` if there was no match) while `extract_all` returns a - potentially empty - list of matches.
|
720 |
+
"""
|
721 |
+
)
|
722 |
+
return
|
723 |
+
|
724 |
+
|
725 |
+
@app.cell
|
726 |
+
def _(expressions_df, pl):
|
727 |
+
url_pattern = r'(https?://[^\s>]+)'
|
728 |
+
expressions_df.select(
|
729 |
+
'namespace',
|
730 |
+
'member',
|
731 |
+
url_match=pl.col('docstring').str.extract(url_pattern),
|
732 |
+
url_matches=pl.col('docstring').str.extract_all(url_pattern),
|
733 |
+
).filter(pl.col('url_match').is_not_null())
|
734 |
+
return (url_pattern,)
|
735 |
+
|
736 |
+
|
737 |
+
@app.cell(hide_code=True)
|
738 |
+
def _(mo):
|
739 |
+
mo.md(
|
740 |
+
r"""
|
741 |
+
Note that in each `docstring` where a code example involving dataframes is present, we will see an output such as "shape: (5, 2)" indicating the number of rows and columns of the dataframe produced by the sample code. Let's say we would like to *capture* this information in a structured way.
|
742 |
+
|
743 |
+
[`extract_groups`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.extract_groups.html) is a really powerful expression allowing us to achieve exactly that.
|
744 |
+
|
745 |
+
Below we define the regular expression `r"shape:\s*\((?<height>\S+),\s*(?<width>\S+)\)"` with two capture groups, named `height` and `width` and pass it as the parameter of `extract_groups`. After execution, for each `docstring`, we end up with fully structured data we can further process downstream!
|
746 |
+
"""
|
747 |
+
)
|
748 |
+
return
|
749 |
+
|
750 |
+
|
751 |
+
@app.cell
|
752 |
+
def _(expressions_df, pl):
|
753 |
+
expressions_df.select(
|
754 |
+
'namespace',
|
755 |
+
'member',
|
756 |
+
example_df_shape=pl.col('docstring').str.extract_groups(r"shape:\s*\((?<height>\S+),\s*(?<width>\S+)\)"),
|
757 |
+
)
|
758 |
+
return
|
759 |
+
|
760 |
+
|
761 |
+
@app.cell(hide_code=True)
|
762 |
+
def _(mo):
|
763 |
+
mo.md(
|
764 |
+
r"""
|
765 |
+
## 🧹 Stripping
|
766 |
+
|
767 |
+
Strings might require some cleaning before further processing, such as the removal of some characters from the beginning or end of the text. [`strip_chars_start`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_chars_start.html), [`strip_chars_end`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_chars_end.html) and [`strip_chars`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_chars.html) are here to facilitate this.
|
768 |
+
|
769 |
+
All we need to do is to specify a set of characters we would like to get rid of and Polars handles the rest for us.
|
770 |
+
"""
|
771 |
+
)
|
772 |
+
return
|
773 |
+
|
774 |
+
|
775 |
+
@app.cell
|
776 |
+
def _(expressions_df, pl):
|
777 |
+
expressions_df.select(
|
778 |
+
"member",
|
779 |
+
member_front_stripped=pl.col("member").str.strip_chars_start("a"),
|
780 |
+
member_back_stripped=pl.col("member").str.strip_chars_end("n"),
|
781 |
+
member_fully_stripped=pl.col("member").str.strip_chars("na"),
|
782 |
+
)
|
783 |
+
return
|
784 |
+
|
785 |
+
|
786 |
+
@app.cell(hide_code=True)
|
787 |
+
def _(mo):
|
788 |
+
mo.md(
|
789 |
+
r"""
|
790 |
+
Note that when using the above expressions, the specified characters do not need to form a sequence; they are handled as a set. However, in certain use cases we only want to strip complete substrings, so we would need our input to be strictly treated as a sequence rather than as a set.
|
791 |
+
|
792 |
+
That's exactly the rationale behind [`strip_prefix`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_prefix.html) and [`strip_suffix`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_suffix.html).
|
793 |
+
|
794 |
+
Below we use these to remove the `"to_"` prefixes and `"_with"` suffixes from each member name.
|
795 |
+
"""
|
796 |
+
)
|
797 |
+
return
|
798 |
+
|
799 |
+
|
800 |
+
@app.cell
|
801 |
+
def _(expressions_df, pl):
|
802 |
+
expressions_df.select(
|
803 |
+
"member",
|
804 |
+
member_prefix_stripped=pl.col("member").str.strip_prefix("to_"),
|
805 |
+
member_suffix_stripped=pl.col("member").str.strip_suffix("_with"),
|
806 |
+
).slice(20)
|
807 |
+
return
|
808 |
+
|
809 |
+
|
810 |
+
@app.cell(hide_code=True)
|
811 |
+
def _(mo):
|
812 |
+
mo.md(
|
813 |
+
r"""
|
814 |
+
## 🔑 Encoding & Decoding
|
815 |
+
|
816 |
+
Should you find yourself in the need of encoding your strings into [base64](https://en.wikipedia.org/wiki/Base64) or [hexadecimal](https://en.wikipedia.org/wiki/Hexadecimal) format, then Polars has your back with its [`encode`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.encode.html) expression.
|
817 |
+
"""
|
818 |
+
)
|
819 |
+
return
|
820 |
+
|
821 |
+
|
822 |
+
@app.cell
|
823 |
+
def _(expressions_df, pl):
|
824 |
+
encoded_df = expressions_df.select(
|
825 |
+
"member",
|
826 |
+
member_base64=pl.col('member').str.encode('base64'),
|
827 |
+
member_hex=pl.col('member').str.encode('hex'),
|
828 |
+
)
|
829 |
+
encoded_df
|
830 |
+
return (encoded_df,)
|
831 |
+
|
832 |
+
|
833 |
+
@app.cell(hide_code=True)
|
834 |
+
def _(mo):
|
835 |
+
mo.md(r"""And of course, you can convert back into a human-readable representation using the [`decode`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.decode.html) expression.""")
|
836 |
+
return
|
837 |
+
|
838 |
+
|
839 |
+
@app.cell
|
840 |
+
def _(encoded_df, pl):
|
841 |
+
encoded_df.with_columns(
|
842 |
+
member_base64_decoded=pl.col('member_base64').str.decode('base64').cast(pl.String),
|
843 |
+
member_hex_decoded=pl.col('member_hex').str.decode('hex').cast(pl.String),
|
844 |
+
)
|
845 |
+
return
|
846 |
+
|
847 |
+
|
848 |
+
@app.cell(hide_code=True)
|
849 |
+
def _(mo):
|
850 |
+
mo.md(
|
851 |
+
r"""
|
852 |
+
## 🚀 Application: Dynamic Execution of Polars Examples
|
853 |
+
|
854 |
+
Now that we are familiar with string expressions, we can combine them with other Polars operations to build a fully interactive playground where code examples of Polars expressions can be explored.
|
855 |
+
|
856 |
+
We make use of string expressions to extract the raw Python source code of examples from the docstrings and we leverage the interactive Marimo environment to enable the selection of expressions via a searchable dropdown and a fully functional code editor whose output is rendered with Marimo's rich display utilities.
|
857 |
+
|
858 |
+
In other words, we will use Polars to execute Polars. ❄️ How cool is that?
|
859 |
+
|
860 |
+
---
|
861 |
+
"""
|
862 |
+
)
|
863 |
+
return
|
864 |
+
|
865 |
+
|
866 |
+
@app.cell(hide_code=True)
|
867 |
+
def _(
|
868 |
+
example_editor,
|
869 |
+
execution_result,
|
870 |
+
expression,
|
871 |
+
expression_description,
|
872 |
+
expression_docs_link,
|
873 |
+
mo,
|
874 |
+
):
|
875 |
+
mo.vstack(
|
876 |
+
[
|
877 |
+
mo.md(f'### {expression.value}'),
|
878 |
+
expression,
|
879 |
+
mo.hstack([expression_description, expression_docs_link]),
|
880 |
+
example_editor,
|
881 |
+
execution_result,
|
882 |
+
]
|
883 |
+
)
|
884 |
+
return
|
885 |
+
|
886 |
+
|
887 |
+
@app.cell(hide_code=True)
|
888 |
+
def _(mo, selected_expression_record):
|
889 |
+
expression_description = mo.md(selected_expression_record["description"])
|
890 |
+
expression_docs_link = mo.md(
|
891 |
+
f"🐻❄️ [Official Docs](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.{selected_expression_record['expr']}.html)"
|
892 |
+
)
|
893 |
+
return expression_description, expression_docs_link
|
894 |
+
|
895 |
+
|
896 |
+
@app.cell(hide_code=True)
|
897 |
+
def _(example_editor, execute_code):
|
898 |
+
execution_result = execute_code(example_editor.value)
|
899 |
+
return (execution_result,)
|
900 |
+
|
901 |
+
|
902 |
+
@app.cell(hide_code=True)
|
903 |
+
def _(code_df, mo):
|
904 |
+
expression = mo.ui.dropdown(code_df.get_column('expr'), value='arr.all', searchable=True)
|
905 |
+
return (expression,)
|
906 |
+
|
907 |
+
|
908 |
+
@app.cell(hide_code=True)
|
909 |
+
def _(code_df, expression):
|
910 |
+
selected_expression_record = code_df.filter(expr=expression.value).to_dicts()[0]
|
911 |
+
return (selected_expression_record,)
|
912 |
+
|
913 |
+
|
914 |
+
@app.cell(hide_code=True)
|
915 |
+
def _(mo, selected_expression_record):
|
916 |
+
example_editor = mo.ui.code_editor(value=selected_expression_record["code"])
|
917 |
+
return (example_editor,)
|
918 |
+
|
919 |
+
|
920 |
+
@app.cell(hide_code=True)
|
921 |
+
def _(expressions_df, pl):
|
922 |
+
code_df = (
|
923 |
+
expressions_df.select(
|
924 |
+
expr=pl.when(pl.col("namespace") == "root")
|
925 |
+
.then("member")
|
926 |
+
.otherwise(pl.concat_str(["namespace", "member"], separator=".")),
|
927 |
+
description=pl.col("docstring")
|
928 |
+
.str.split("\n\n")
|
929 |
+
.list.get(0)
|
930 |
+
.str.slice(9),
|
931 |
+
docstring_lines=pl.col("docstring").str.split("\n"),
|
932 |
+
)
|
933 |
+
.with_row_index()
|
934 |
+
.explode("docstring_lines")
|
935 |
+
.rename({"docstring_lines": "docstring_line"})
|
936 |
+
.with_columns(pl.col("docstring_line").str.strip_chars(" "))
|
937 |
+
.filter(pl.col("docstring_line").str.contains_any([">>> ", "... "]))
|
938 |
+
.with_columns(pl.col("docstring_line").str.slice(4))
|
939 |
+
.group_by(pl.exclude("docstring_line"), maintain_order=True)
|
940 |
+
.agg(code=pl.col("docstring_line").str.join("\n"))
|
941 |
+
.drop("index")
|
942 |
+
)
|
943 |
+
return (code_df,)
|
944 |
+
|
945 |
+
|
946 |
+
@app.cell(hide_code=True)
|
947 |
+
def _():
|
948 |
+
def execute_code(code: str):
|
949 |
+
import ast
|
950 |
+
|
951 |
+
# Create a new local namespace for execution
|
952 |
+
local_namespace = {}
|
953 |
+
|
954 |
+
# Parse the code into an AST to identify the last expression
|
955 |
+
parsed_code = ast.parse(code)
|
956 |
+
|
957 |
+
# Check if there's at least one statement
|
958 |
+
if not parsed_code.body:
|
959 |
+
return None
|
960 |
+
|
961 |
+
# If the last statement is an expression, we'll need to get its value
|
962 |
+
last_is_expr = isinstance(parsed_code.body[-1], ast.Expr)
|
963 |
+
|
964 |
+
if last_is_expr:
|
965 |
+
# Split the code: everything except the last statement, and the last statement
|
966 |
+
last_expr = ast.Expression(parsed_code.body[-1].value)
|
967 |
+
|
968 |
+
# Remove the last statement from the parsed code
|
969 |
+
parsed_code.body = parsed_code.body[:-1]
|
970 |
+
|
971 |
+
# Execute everything except the last statement
|
972 |
+
if parsed_code.body:
|
973 |
+
exec(
|
974 |
+
compile(parsed_code, "<string>", "exec"),
|
975 |
+
globals(),
|
976 |
+
local_namespace,
|
977 |
+
)
|
978 |
+
|
979 |
+
# Execute the last statement and get its value
|
980 |
+
result = eval(
|
981 |
+
compile(last_expr, "<string>", "eval"), globals(), local_namespace
|
982 |
+
)
|
983 |
+
return result
|
984 |
+
else:
|
985 |
+
# If the last statement is not an expression (e.g., an assignment),
|
986 |
+
# execute the entire code and return None
|
987 |
+
exec(code, globals(), local_namespace)
|
988 |
+
return None
|
989 |
+
return (execute_code,)
|
990 |
+
|
991 |
+
|
992 |
+
@app.cell(hide_code=True)
|
993 |
+
def _():
|
994 |
+
import polars as pl
|
995 |
+
import marimo as mo
|
996 |
+
import altair as alt
|
997 |
+
import random
|
998 |
+
|
999 |
+
random.seed(42)
|
1000 |
+
return alt, mo, pl, random
|
1001 |
+
|
1002 |
+
|
1003 |
+
if __name__ == "__main__":
|
1004 |
+
app.run()
|
polars/12_aggregations.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.13"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "polars==1.23.0",
|
6 |
+
# ]
|
7 |
+
# ///
|
8 |
+
|
9 |
+
import marimo
|
10 |
+
|
11 |
+
__generated_with = "0.11.14"
|
12 |
+
app = marimo.App(width="medium")
|
13 |
+
|
14 |
+
|
15 |
+
@app.cell
|
16 |
+
def _():
|
17 |
+
import marimo as mo
|
18 |
+
return (mo,)
|
19 |
+
|
20 |
+
|
21 |
+
@app.cell(hide_code=True)
|
22 |
+
def _(mo):
|
23 |
+
mo.md(
|
24 |
+
r"""
|
25 |
+
# Aggregations
|
26 |
+
_By [Joram Mutenge](https://www.udemy.com/user/joram-mutenge/)._
|
27 |
+
|
28 |
+
In this notebook, you'll learn how to perform different types of aggregations in Polars, including grouping by categories and time. We'll analyze sales data from a clothing store, focusing on three product categories: hats, socks, and sweaters.
|
29 |
+
"""
|
30 |
+
)
|
31 |
+
return
|
32 |
+
|
33 |
+
|
34 |
+
@app.cell
|
35 |
+
def _():
|
36 |
+
import polars as pl
|
37 |
+
|
38 |
+
df = (pl.read_csv('https://raw.githubusercontent.com/jorammutenge/learn-rust/refs/heads/main/sample_sales.csv', try_parse_dates=True)
|
39 |
+
.rename(lambda col: col.replace(' ','_').lower())
|
40 |
+
)
|
41 |
+
df
|
42 |
+
return df, pl
|
43 |
+
|
44 |
+
|
45 |
+
@app.cell(hide_code=True)
|
46 |
+
def _(mo):
|
47 |
+
mo.md(
|
48 |
+
r"""
|
49 |
+
## Grouping by category
|
50 |
+
### With single category
|
51 |
+
Let's find out how many of each product category we sold.
|
52 |
+
"""
|
53 |
+
)
|
54 |
+
return
|
55 |
+
|
56 |
+
|
57 |
+
@app.cell
|
58 |
+
def _(df, pl):
|
59 |
+
(df
|
60 |
+
.group_by('category')
|
61 |
+
.agg(pl.sum('quantity'))
|
62 |
+
)
|
63 |
+
return
|
64 |
+
|
65 |
+
|
66 |
+
@app.cell(hide_code=True)
|
67 |
+
def _(mo):
|
68 |
+
mo.md(
|
69 |
+
r"""
|
70 |
+
It looks like we sold more sweaters. Maybe this was a winter season.
|
71 |
+
|
72 |
+
Let's add another aggregate to see how much was spent on the total units for each product.
|
73 |
+
"""
|
74 |
+
)
|
75 |
+
return
|
76 |
+
|
77 |
+
|
78 |
+
@app.cell
|
79 |
+
def _(df, pl):
|
80 |
+
(df
|
81 |
+
.group_by('category')
|
82 |
+
.agg(pl.sum('quantity'),
|
83 |
+
pl.sum('ext_price'))
|
84 |
+
)
|
85 |
+
return
|
86 |
+
|
87 |
+
|
88 |
+
@app.cell(hide_code=True)
|
89 |
+
def _(mo):
|
90 |
+
mo.md(r"""We could also write aggregate code for the two columns as a single line.""")
|
91 |
+
return
|
92 |
+
|
93 |
+
|
94 |
+
@app.cell
|
95 |
+
def _(df, pl):
|
96 |
+
(df
|
97 |
+
.group_by('category')
|
98 |
+
.agg(pl.sum('quantity','ext_price'))
|
99 |
+
)
|
100 |
+
return
|
101 |
+
|
102 |
+
|
103 |
+
@app.cell(hide_code=True)
|
104 |
+
def _(mo):
|
105 |
+
mo.md(r"""Actually, the way we've been writing the aggregate lines is syntactic sugar. Here's a longer way of doing it as shown in the [Polars documentation](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.dataframe.group_by.GroupBy.agg.html).""")
|
106 |
+
return
|
107 |
+
|
108 |
+
|
109 |
+
@app.cell
|
110 |
+
def _(df, pl):
|
111 |
+
(df
|
112 |
+
.group_by('category')
|
113 |
+
.agg(pl.col('quantity').sum(),
|
114 |
+
pl.col('ext_price').sum())
|
115 |
+
)
|
116 |
+
return
|
117 |
+
|
118 |
+
|
119 |
+
@app.cell(hide_code=True)
|
120 |
+
def _(mo):
|
121 |
+
mo.md(
|
122 |
+
r"""
|
123 |
+
### With multiple categories
|
124 |
+
We can also group by multiple categories. Let's find out how many items we sold in each product category for each SKU. This more detailed aggregation will produce more rows than the previous DataFrame.
|
125 |
+
"""
|
126 |
+
)
|
127 |
+
return
|
128 |
+
|
129 |
+
|
130 |
+
@app.cell
|
131 |
+
def _(df, pl):
|
132 |
+
(df
|
133 |
+
.group_by('category','sku')
|
134 |
+
.agg(pl.sum('quantity'))
|
135 |
+
)
|
136 |
+
return
|
137 |
+
|
138 |
+
|
139 |
+
@app.cell(hide_code=True)
|
140 |
+
def _(mo):
|
141 |
+
mo.md(
|
142 |
+
r"""
|
143 |
+
Aggregations when grouping data are not limited to sums. You can also use functions like [`max`, `min`, `median`, `first`, and `last`](https://docs.pola.rs/user-guide/expressions/aggregation/#basic-aggregations).
|
144 |
+
|
145 |
+
Let's find the largest sale quantity for each product category.
|
146 |
+
"""
|
147 |
+
)
|
148 |
+
return
|
149 |
+
|
150 |
+
|
151 |
+
@app.cell
|
152 |
+
def _(df, pl):
|
153 |
+
(df
|
154 |
+
.group_by('category')
|
155 |
+
.agg(pl.max('quantity'))
|
156 |
+
)
|
157 |
+
return
|
158 |
+
|
159 |
+
|
160 |
+
@app.cell(hide_code=True)
|
161 |
+
def _(mo):
|
162 |
+
mo.md(
|
163 |
+
r"""
|
164 |
+
Let's make the aggregation more interesting. We'll identify the first customer to purchase each item, along with the quantity they bought and the amount they spent.
|
165 |
+
|
166 |
+
**Note:** To make this work, we'll have to sort the date from earliest to latest.
|
167 |
+
"""
|
168 |
+
)
|
169 |
+
return
|
170 |
+
|
171 |
+
|
172 |
+
@app.cell
|
173 |
+
def _(df, pl):
|
174 |
+
(df
|
175 |
+
.sort('date')
|
176 |
+
.group_by('category')
|
177 |
+
.agg(pl.first('account_name','quantity','ext_price'))
|
178 |
+
)
|
179 |
+
return
|
180 |
+
|
181 |
+
|
182 |
+
@app.cell(hide_code=True)
|
183 |
+
def _(mo):
|
184 |
+
mo.md(
|
185 |
+
r"""
|
186 |
+
## Grouping by time
|
187 |
+
Since `datetime` is a special data type in Polars, we can perform various group-by aggregations on it.
|
188 |
+
|
189 |
+
Our dataset spans a two-year period. Let's calculate the total dollar sales for each year. We'll do it the naive way first so you can appreciate grouping with time.
|
190 |
+
"""
|
191 |
+
)
|
192 |
+
return
|
193 |
+
|
194 |
+
|
195 |
+
@app.cell
|
196 |
+
def _(df, pl):
|
197 |
+
(df
|
198 |
+
.with_columns(year=pl.col('date').dt.year())
|
199 |
+
.group_by('year')
|
200 |
+
.agg(pl.sum('ext_price').round(2))
|
201 |
+
)
|
202 |
+
return
|
203 |
+
|
204 |
+
|
205 |
+
@app.cell(hide_code=True)
|
206 |
+
def _(mo):
|
207 |
+
mo.md(
|
208 |
+
r"""
|
209 |
+
We had more sales in 2014.
|
210 |
+
|
211 |
+
Now let's perform the above operation by groupin with time. This requires sorting the dataframe first.
|
212 |
+
"""
|
213 |
+
)
|
214 |
+
return
|
215 |
+
|
216 |
+
|
217 |
+
@app.cell
|
218 |
+
def _(df, pl):
|
219 |
+
(df
|
220 |
+
.sort('date')
|
221 |
+
.group_by_dynamic('date', every='1y')
|
222 |
+
.agg(pl.sum('ext_price'))
|
223 |
+
)
|
224 |
+
return
|
225 |
+
|
226 |
+
|
227 |
+
@app.cell(hide_code=True)
|
228 |
+
def _(mo):
|
229 |
+
mo.md(
|
230 |
+
r"""
|
231 |
+
The beauty of grouping with time is that it allows us to resample the data by selecting whatever time interval we want.
|
232 |
+
|
233 |
+
Let's find out what the quarterly sales were for 2014
|
234 |
+
"""
|
235 |
+
)
|
236 |
+
return
|
237 |
+
|
238 |
+
|
239 |
+
@app.cell
|
240 |
+
def _(df, pl):
|
241 |
+
(df
|
242 |
+
.filter(pl.col('date').dt.year() == 2014)
|
243 |
+
.sort('date')
|
244 |
+
.group_by_dynamic('date', every='1q')
|
245 |
+
.agg(pl.sum('ext_price'))
|
246 |
+
)
|
247 |
+
return
|
248 |
+
|
249 |
+
|
250 |
+
@app.cell(hide_code=True)
|
251 |
+
def _(mo):
|
252 |
+
mo.md(
|
253 |
+
r"""
|
254 |
+
Here's an interesting question we can answer that takes advantage of grouping by time.
|
255 |
+
|
256 |
+
Let's find the hour of the day where we had the most sales in dollars.
|
257 |
+
"""
|
258 |
+
)
|
259 |
+
return
|
260 |
+
|
261 |
+
|
262 |
+
@app.cell
|
263 |
+
def _(df, pl):
|
264 |
+
(df
|
265 |
+
.sort('date')
|
266 |
+
.group_by_dynamic('date', every='1h')
|
267 |
+
.agg(pl.max('ext_price'))
|
268 |
+
.filter(pl.col('ext_price') == pl.col('ext_price').max())
|
269 |
+
)
|
270 |
+
return
|
271 |
+
|
272 |
+
|
273 |
+
@app.cell(hide_code=True)
|
274 |
+
def _(mo):
|
275 |
+
mo.md(r"""Just for fun, let's find the median number of items sold in each SKU and the total dollar amount in each SKU every six days.""")
|
276 |
+
return
|
277 |
+
|
278 |
+
|
279 |
+
@app.cell
|
280 |
+
def _(df, pl):
|
281 |
+
(df
|
282 |
+
.sort('date')
|
283 |
+
.group_by_dynamic('date', every='6d')
|
284 |
+
.agg(pl.first('sku'),
|
285 |
+
pl.median('quantity'),
|
286 |
+
pl.sum('ext_price'))
|
287 |
+
)
|
288 |
+
return
|
289 |
+
|
290 |
+
|
291 |
+
@app.cell(hide_code=True)
|
292 |
+
def _(mo):
|
293 |
+
mo.md(r"""Let's rename the columns to clearly indicate the type of aggregation performed. This will help us identify the aggregation method used on a column without needing to check the code.""")
|
294 |
+
return
|
295 |
+
|
296 |
+
|
297 |
+
@app.cell
|
298 |
+
def _(df, pl):
|
299 |
+
(df
|
300 |
+
.sort('date')
|
301 |
+
.group_by_dynamic('date', every='6d')
|
302 |
+
.agg(pl.first('sku'),
|
303 |
+
pl.median('quantity').alias('median_qty'),
|
304 |
+
pl.sum('ext_price').alias('total_dollars'))
|
305 |
+
)
|
306 |
+
return
|
307 |
+
|
308 |
+
|
309 |
+
@app.cell(hide_code=True)
|
310 |
+
def _(mo):
|
311 |
+
mo.md(
|
312 |
+
r"""
|
313 |
+
## Grouping with over
|
314 |
+
|
315 |
+
Sometimes, we may want to perform an aggregation but also keep all the columns and rows of the dataframe.
|
316 |
+
|
317 |
+
Let's assign a value to indicate the number of times each customer visited and bought something.
|
318 |
+
"""
|
319 |
+
)
|
320 |
+
return
|
321 |
+
|
322 |
+
|
323 |
+
@app.cell
|
324 |
+
def _(df, pl):
|
325 |
+
(df
|
326 |
+
.with_columns(buy_freq=pl.col('account_name').len().over('account_name'))
|
327 |
+
)
|
328 |
+
return
|
329 |
+
|
330 |
+
|
331 |
+
@app.cell(hide_code=True)
|
332 |
+
def _(mo):
|
333 |
+
mo.md(r"""Finally, let's determine which customers visited the store the most and bought something.""")
|
334 |
+
return
|
335 |
+
|
336 |
+
|
337 |
+
@app.cell
|
338 |
+
def _(df, pl):
|
339 |
+
(df
|
340 |
+
.with_columns(buy_freq=pl.col('account_name').len().over('account_name'))
|
341 |
+
.filter(pl.col('buy_freq') == pl.col('buy_freq').max())
|
342 |
+
.select('account_name','buy_freq')
|
343 |
+
.unique()
|
344 |
+
)
|
345 |
+
return
|
346 |
+
|
347 |
+
|
348 |
+
@app.cell(hide_code=True)
|
349 |
+
def _(mo):
|
350 |
+
mo.md(r"""There's more you can do with aggregations in Polars such as [sorting with aggregations](https://docs.pola.rs/user-guide/expressions/aggregation/#sorting). We hope that in this notebook, we've armed you with the tools to get started.""")
|
351 |
+
return
|
352 |
+
|
353 |
+
|
354 |
+
if __name__ == "__main__":
|
355 |
+
app.run()
|
polars/14_user_defined_functions.py
ADDED
@@ -0,0 +1,946 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.12"
|
3 |
+
# dependencies = [
|
4 |
+
# "altair==5.5.0",
|
5 |
+
# "beautifulsoup4==4.13.3",
|
6 |
+
# "httpx==0.28.1",
|
7 |
+
# "marimo",
|
8 |
+
# "nest-asyncio==1.6.0",
|
9 |
+
# "numba==0.61.0",
|
10 |
+
# "numpy==2.1.3",
|
11 |
+
# "polars==1.24.0",
|
12 |
+
# ]
|
13 |
+
# ///
|
14 |
+
|
15 |
+
import marimo
|
16 |
+
|
17 |
+
__generated_with = "0.11.17"
|
18 |
+
app = marimo.App(width="medium")
|
19 |
+
|
20 |
+
|
21 |
+
@app.cell(hide_code=True)
|
22 |
+
def _(mo):
|
23 |
+
mo.md(
|
24 |
+
r"""
|
25 |
+
# User-Defined Functions
|
26 |
+
|
27 |
+
_By [Péter Ferenc Gyarmati](http://github.com/peter-gy)_.
|
28 |
+
|
29 |
+
Throughout the previous chapters, you've seen how Polars provides a comprehensive set of built-in expressions for flexible data transformation. But what happens when you need something *more*? Perhaps your project has unique requirements, or you need to integrate functionality from an external Python library. This is where User-Defined Functions (UDFs) come into play, allowing you to extend Polars with your own custom logic.
|
30 |
+
|
31 |
+
In this chapter, we'll weigh the performance trade-offs of UDFs, pinpoint situations where they're truly beneficial, and explore different ways to effectively incorporate them into your Polars workflows. We'll walk through a complete, practical example.
|
32 |
+
"""
|
33 |
+
)
|
34 |
+
return
|
35 |
+
|
36 |
+
|
37 |
+
@app.cell(hide_code=True)
|
38 |
+
def _(mo):
|
39 |
+
mo.md(
|
40 |
+
r"""
|
41 |
+
## ⚖️ The Cost of UDFs
|
42 |
+
|
43 |
+
> Performance vs. Flexibility
|
44 |
+
|
45 |
+
Polars' built-in expressions are highly optimized for speed and parallel processing. User-defined functions (UDFs), however, introduce a significant performance overhead because they rely on standard Python code, which often runs in a single thread and bypasses Polars' logical optimizations. Therefore, always prioritize native Polars operations *whenever possible*.
|
46 |
+
|
47 |
+
However, UDFs become inevitable when you need to:
|
48 |
+
|
49 |
+
- **Integrate external libraries:** Use functionality not directly available in Polars.
|
50 |
+
- **Implement custom logic:** Handle complex transformations that can't be easily expressed with Polars' built-in functions.
|
51 |
+
|
52 |
+
Let's dive into a real-world project where UDFs were the only way to get the job done, demonstrating a scenario where native Polars expressions simply weren't sufficient.
|
53 |
+
"""
|
54 |
+
)
|
55 |
+
return
|
56 |
+
|
57 |
+
|
58 |
+
@app.cell(hide_code=True)
|
59 |
+
def _(mo):
|
60 |
+
mo.md(
|
61 |
+
r"""
|
62 |
+
## 📊 Project Overview
|
63 |
+
|
64 |
+
> Scraping and Analyzing Observable Notebook Statistics
|
65 |
+
|
66 |
+
If you're into data visualization, you've probably seen [D3.js](https://d3js.org/) and [Observable Plot](https://observablehq.com/plot/). Both have extensive galleries showcasing amazing visualizations. Each gallery item is a standalone [Observable notebook](https://observablehq.com/documentation/notebooks/), with metrics like stars, comments, and forks – indicators of popularity. But getting and analyzing these statistics directly isn't straightforward. We'll need to scrape the web.
|
67 |
+
"""
|
68 |
+
)
|
69 |
+
return
|
70 |
+
|
71 |
+
|
72 |
+
@app.cell(hide_code=True)
|
73 |
+
def _(mo):
|
74 |
+
mo.hstack(
|
75 |
+
[
|
76 |
+
mo.image(
|
77 |
+
"https://minio.peter.gy/static/assets/marimo/learn/polars/14_d3-gallery.png?0",
|
78 |
+
width=600,
|
79 |
+
caption="Screenshot of https://observablehq.com/@d3/gallery",
|
80 |
+
),
|
81 |
+
mo.image(
|
82 |
+
"https://minio.peter.gy/static/assets/marimo/learn/polars/14_plot-gallery.png?0",
|
83 |
+
width=600,
|
84 |
+
caption="Screenshot of https://observablehq.com/@observablehq/plot-gallery",
|
85 |
+
),
|
86 |
+
]
|
87 |
+
)
|
88 |
+
return
|
89 |
+
|
90 |
+
|
91 |
+
@app.cell(hide_code=True)
|
92 |
+
def _(mo):
|
93 |
+
mo.md(r"""Our goal is to use Polars UDFs to fetch the HTML content of these gallery pages. Then, we'll use the `BeautifulSoup` Python library to parse the HTML and extract the relevant metadata. After some data wrangling with native Polars expressions, we'll have a DataFrame listing each visualization notebook. Then, we'll use another UDF to retrieve the number of likes, forks, and comments for each notebook. Finally, we will create our own high-performance UDF to implement a custom notebook ranking scheme. This will involve multiple steps, showcasing different UDF approaches.""")
|
94 |
+
return
|
95 |
+
|
96 |
+
|
97 |
+
@app.cell(hide_code=True)
|
98 |
+
def _(mo):
|
99 |
+
mo.mermaid('''
|
100 |
+
graph LR;
|
101 |
+
url_df --> |"UDF: Fetch HTML"| html_df
|
102 |
+
html_df --> |"UDF: Parse with BeautifulSoup"| parsed_html_df
|
103 |
+
parsed_html_df --> |"Native Polars: Extract Data"| notebooks_df
|
104 |
+
notebooks_df --> |"UDF: Get Notebook Stats"| notebook_stats_df
|
105 |
+
notebook_stats_df --> |"Numba UDF: Compute Popularity"| notebook_popularity_df
|
106 |
+
''')
|
107 |
+
return
|
108 |
+
|
109 |
+
|
110 |
+
@app.cell(hide_code=True)
|
111 |
+
def _(mo):
|
112 |
+
mo.md(r"""Our starting point, `url_df`, is a simple DataFrame with a single `url` column containing the URLs of the D3 and Observable Plot gallery notebooks.""")
|
113 |
+
return
|
114 |
+
|
115 |
+
|
116 |
+
@app.cell(hide_code=True)
|
117 |
+
def _(pl):
|
118 |
+
url_df = pl.from_dict(
|
119 |
+
{
|
120 |
+
"url": [
|
121 |
+
"https://observablehq.com/@d3/gallery",
|
122 |
+
"https://observablehq.com/@observablehq/plot-gallery",
|
123 |
+
]
|
124 |
+
}
|
125 |
+
)
|
126 |
+
url_df
|
127 |
+
return (url_df,)
|
128 |
+
|
129 |
+
|
130 |
+
@app.cell(hide_code=True)
|
131 |
+
def _(mo):
|
132 |
+
mo.md(
|
133 |
+
r"""
|
134 |
+
## 🔂 Element-Wise UDFs
|
135 |
+
|
136 |
+
> Processing Value by Value
|
137 |
+
|
138 |
+
The most common way to use UDFs is to apply them element-wise. This means our custom function will execute for *each individual row* in a specified column. Our first task is to fetch the HTML content for each URL in `url_df`.
|
139 |
+
|
140 |
+
We'll define a Python function that takes a `url` (a string) as input, uses the `httpx` library (an HTTP client) to fetch the content, and returns the HTML as a string. We then integrate this function into Polars using the [`map_elements`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.map_elements.html) expression.
|
141 |
+
|
142 |
+
You'll notice we have to explicitly specify the `return_dtype`. This is *crucial*. Polars doesn't automatically know what our custom function will return. We're responsible for defining the function's logic and, therefore, its output type. By providing the `return_dtype`, we help Polars maintain its internal representation of the DataFrame's schema, enabling query optimization. Think of it as giving Polars a "heads-up" about the data type it should expect.
|
143 |
+
"""
|
144 |
+
)
|
145 |
+
return
|
146 |
+
|
147 |
+
|
148 |
+
@app.cell(hide_code=True)
|
149 |
+
def _(httpx, pl, url_df):
|
150 |
+
html_df = url_df.with_columns(
|
151 |
+
html=pl.col("url").map_elements(
|
152 |
+
lambda url: httpx.get(url).text,
|
153 |
+
return_dtype=pl.String,
|
154 |
+
)
|
155 |
+
)
|
156 |
+
html_df
|
157 |
+
return (html_df,)
|
158 |
+
|
159 |
+
|
160 |
+
@app.cell(hide_code=True)
|
161 |
+
def _(mo):
|
162 |
+
mo.md(
|
163 |
+
r"""
|
164 |
+
Now, `html_df` holds the HTML for each URL. We need to parse it. Again, a UDF is the way to go. Parsing HTML with native Polars expressions would be a nightmare! Instead, we'll use the [`beautifulsoup4`](https://pypi.org/project/beautifulsoup4/) library, a standard tool for this.
|
165 |
+
|
166 |
+
These Observable pages are built with [Next.js](https://nextjs.org/), which helpfully serializes page properties as JSON within the HTML. This simplifies our UDF: we'll extract the raw JSON from the `<script id="__NEXT_DATA__" type="application/json">` tag. We'll use [`map_elements`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.map_elements.html) again. For clarity, we'll define this UDF as a named function, `extract_nextjs_data`, since it's a bit more complex than a simple HTTP request.
|
167 |
+
"""
|
168 |
+
)
|
169 |
+
return
|
170 |
+
|
171 |
+
|
172 |
+
@app.cell(hide_code=True)
|
173 |
+
def _(BeautifulSoup):
|
174 |
+
def extract_nextjs_data(html: str) -> str:
|
175 |
+
soup = BeautifulSoup(html, "html.parser")
|
176 |
+
script_tag = soup.find("script", id="__NEXT_DATA__")
|
177 |
+
return script_tag.text
|
178 |
+
return (extract_nextjs_data,)
|
179 |
+
|
180 |
+
|
181 |
+
@app.cell(hide_code=True)
|
182 |
+
def _(extract_nextjs_data, html_df, pl):
|
183 |
+
parsed_html_df = html_df.select(
|
184 |
+
"url",
|
185 |
+
next_data=pl.col("html").map_elements(
|
186 |
+
extract_nextjs_data,
|
187 |
+
return_dtype=pl.String,
|
188 |
+
),
|
189 |
+
)
|
190 |
+
parsed_html_df
|
191 |
+
return (parsed_html_df,)
|
192 |
+
|
193 |
+
|
194 |
+
@app.cell(hide_code=True)
|
195 |
+
def _(mo):
|
196 |
+
mo.md(r"""With some data wrangling of the raw JSON (using *native* Polars expressions!), we get `notebooks_df`, containing the metadata for each notebook.""")
|
197 |
+
return
|
198 |
+
|
199 |
+
|
200 |
+
@app.cell(hide_code=True)
|
201 |
+
def _(parsed_html_df, pl):
|
202 |
+
notebooks_df = (
|
203 |
+
parsed_html_df.select(
|
204 |
+
"url",
|
205 |
+
# We extract the content of every cell present in the gallery notebooks
|
206 |
+
cell=pl.col("next_data")
|
207 |
+
.str.json_path_match("$.props.pageProps.initialNotebook.nodes")
|
208 |
+
.str.json_decode()
|
209 |
+
.list.eval(pl.element().struct.field("value")),
|
210 |
+
)
|
211 |
+
# We want one row per cell
|
212 |
+
.explode("cell")
|
213 |
+
# Only keep categorized notebook listing cells starting with H3
|
214 |
+
.filter(pl.col("cell").str.starts_with("### "))
|
215 |
+
# Split up the cells into [heading, description, config] sections
|
216 |
+
.with_columns(pl.col("cell").str.split("\n\n"))
|
217 |
+
.select(
|
218 |
+
gallery_url="url",
|
219 |
+
# Text after the '### ' heading, ignore '<!--' comments'
|
220 |
+
category=pl.col("cell").list.get(0).str.extract(r"###\s+(.*?)(?:\s+<!--.*?-->|$)"),
|
221 |
+
# Paragraph after heading
|
222 |
+
description=pl.col("cell")
|
223 |
+
.list.get(1)
|
224 |
+
.str.strip_chars(" ")
|
225 |
+
.str.replace_all("](/", "](https://observablehq.com/", literal=True),
|
226 |
+
# Parsed notebook config from ${preview([{...}])}
|
227 |
+
notebooks=pl.col("cell")
|
228 |
+
.list.get(2)
|
229 |
+
.str.strip_prefix("${previews([")
|
230 |
+
.str.strip_suffix("]})}")
|
231 |
+
.str.strip_chars(" \n")
|
232 |
+
.str.split("},")
|
233 |
+
# Simple regex-based attribute extraction from JS/JSON objects like
|
234 |
+
# ```js
|
235 |
+
# {
|
236 |
+
# path: "@d3/spilhaus-shoreline-map",
|
237 |
+
# "thumbnail": "66a87355e205d820...",
|
238 |
+
# title: "Spilhaus shoreline map",
|
239 |
+
# "author": "D3"
|
240 |
+
# }
|
241 |
+
# ```
|
242 |
+
.list.eval(
|
243 |
+
pl.struct(
|
244 |
+
*(
|
245 |
+
pl.element()
|
246 |
+
.str.extract(f'(?:"{key}"|{key})\s*:\s*"([^"]*)"')
|
247 |
+
.alias(key)
|
248 |
+
for key in ["path", "thumbnail", "title"]
|
249 |
+
)
|
250 |
+
)
|
251 |
+
),
|
252 |
+
)
|
253 |
+
.explode("notebooks")
|
254 |
+
.unnest("notebooks")
|
255 |
+
.filter(pl.col("path").is_not_null())
|
256 |
+
# Final projection to end up with directly usable values
|
257 |
+
.select(
|
258 |
+
pl.concat_str(
|
259 |
+
[
|
260 |
+
pl.lit("https://static.observableusercontent.com/thumbnail/"),
|
261 |
+
"thumbnail",
|
262 |
+
pl.lit(".jpg"),
|
263 |
+
],
|
264 |
+
).alias("notebook_thumbnail_src"),
|
265 |
+
"category",
|
266 |
+
"title",
|
267 |
+
"description",
|
268 |
+
pl.concat_str(
|
269 |
+
[pl.lit("https://observablehq.com"), "path"], separator="/"
|
270 |
+
).alias("notebook_url"),
|
271 |
+
)
|
272 |
+
)
|
273 |
+
notebooks_df
|
274 |
+
return (notebooks_df,)
|
275 |
+
|
276 |
+
|
277 |
+
@app.cell(hide_code=True)
|
278 |
+
def _(mo):
|
279 |
+
mo.md(
|
280 |
+
r"""
|
281 |
+
## 📦 Batch-Wise UDFs
|
282 |
+
|
283 |
+
> Processing Entire Series
|
284 |
+
|
285 |
+
`map_elements` calls the UDF for *each row*. Fine for our tiny, two-rows-tall `url_df`. But `notebooks_df` has almost 400 rows! Individual HTTP requests for each would be painfully slow.
|
286 |
+
|
287 |
+
We want stats for each notebook in `notebooks_df`. To avoid sequential requests, we'll use Polars' [`map_batches`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.map_batches.html). This lets us process an *entire Series* (a column) at once.
|
288 |
+
|
289 |
+
Our UDF, `fetch_html_batch`, will take a *Series* of URLs and use `asyncio` to make concurrent requests – a huge performance boost.
|
290 |
+
"""
|
291 |
+
)
|
292 |
+
return
|
293 |
+
|
294 |
+
|
295 |
+
@app.cell(hide_code=True)
|
296 |
+
def _(Iterable, asyncio, httpx, mo):
|
297 |
+
async def _fetch_html_batch(urls: Iterable[str]) -> tuple[str, ...]:
|
298 |
+
async with httpx.AsyncClient(timeout=15) as client:
|
299 |
+
res = await asyncio.gather(*(client.get(url) for url in urls))
|
300 |
+
return tuple((r.text for r in res))
|
301 |
+
|
302 |
+
|
303 |
+
@mo.cache
|
304 |
+
def fetch_html_batch(urls: Iterable[str]) -> tuple[str, ...]:
|
305 |
+
return asyncio.run(_fetch_html_batch(urls))
|
306 |
+
return (fetch_html_batch,)
|
307 |
+
|
308 |
+
|
309 |
+
@app.cell(hide_code=True)
|
310 |
+
def _(mo):
|
311 |
+
mo.callout(
|
312 |
+
mo.md("""
|
313 |
+
Since `fetch_html_batch` is a pure Python function and performs multiple network requests, it's a good candidate for caching. We use [`mo.cache`](https://docs.marimo.io/api/caching/#marimo.cache) to avoid redundant requests to the same URL. This is a simple way to improve performance without modifying the core logic.
|
314 |
+
"""
|
315 |
+
),
|
316 |
+
kind="info",
|
317 |
+
)
|
318 |
+
return
|
319 |
+
|
320 |
+
|
321 |
+
@app.cell(hide_code=True)
|
322 |
+
def _(mo, notebooks_df):
|
323 |
+
category = mo.ui.dropdown(
|
324 |
+
notebooks_df.sort("category").get_column("category"),
|
325 |
+
value="Maps",
|
326 |
+
)
|
327 |
+
return (category,)
|
328 |
+
|
329 |
+
|
330 |
+
@app.cell(hide_code=True)
|
331 |
+
def _(category, extract_nextjs_data, fetch_html_batch, notebooks_df, pl):
|
332 |
+
notebook_stats_df = (
|
333 |
+
# Setting filter upstream to limit number of concurrent HTTP requests
|
334 |
+
notebooks_df.filter(category=category.value)
|
335 |
+
.with_columns(
|
336 |
+
notebook_html=pl.col("notebook_url")
|
337 |
+
.map_batches(fetch_html_batch, return_dtype=pl.List(pl.String))
|
338 |
+
.explode()
|
339 |
+
)
|
340 |
+
.with_columns(
|
341 |
+
notebook_data=pl.col("notebook_html")
|
342 |
+
.map_elements(
|
343 |
+
extract_nextjs_data,
|
344 |
+
return_dtype=pl.String,
|
345 |
+
)
|
346 |
+
.str.json_path_match("$.props.pageProps.initialNotebook")
|
347 |
+
.str.json_decode()
|
348 |
+
)
|
349 |
+
.drop("notebook_html")
|
350 |
+
.with_columns(
|
351 |
+
*[
|
352 |
+
pl.col("notebook_data").struct.field(key).alias(key)
|
353 |
+
for key in ["likes", "forks", "comments", "license"]
|
354 |
+
]
|
355 |
+
)
|
356 |
+
.drop("notebook_data")
|
357 |
+
.with_columns(pl.col("comments").list.len())
|
358 |
+
.select(
|
359 |
+
pl.exclude("description", "notebook_url"),
|
360 |
+
"description",
|
361 |
+
"notebook_url",
|
362 |
+
)
|
363 |
+
.sort("likes", descending=True)
|
364 |
+
)
|
365 |
+
return (notebook_stats_df,)
|
366 |
+
|
367 |
+
|
368 |
+
@app.cell(hide_code=True)
|
369 |
+
def _(mo, notebook_stats_df):
|
370 |
+
notebooks = mo.ui.table(notebook_stats_df, selection='single', initial_selection=[2], page_size=5)
|
371 |
+
notebook_height = mo.ui.slider(start=400, stop=2000, value=825, step=25, show_value=True, label='Notebook Height')
|
372 |
+
return notebook_height, notebooks
|
373 |
+
|
374 |
+
|
375 |
+
@app.cell(hide_code=True)
|
376 |
+
def _():
|
377 |
+
def nb_iframe(notebook_url: str, height=825) -> str:
|
378 |
+
embed_url = notebook_url.replace(
|
379 |
+
"https://observablehq.com", "https://observablehq.com/embed"
|
380 |
+
)
|
381 |
+
return f'<iframe width="100%" height="{height}" frameborder="0" src="{embed_url}?cell=*"></iframe>'
|
382 |
+
return (nb_iframe,)
|
383 |
+
|
384 |
+
|
385 |
+
@app.cell(hide_code=True)
|
386 |
+
def _(mo):
|
387 |
+
mo.md(r"""Now that we have access to notebook-level statistics, we can rank the visualizations by the number of likes they received & display them interactively.""")
|
388 |
+
return
|
389 |
+
|
390 |
+
|
391 |
+
@app.cell(hide_code=True)
|
392 |
+
def _(mo):
|
393 |
+
mo.callout("💡 Explore the visualizations by paging through the table below and selecting any of its rows.")
|
394 |
+
return
|
395 |
+
|
396 |
+
|
397 |
+
@app.cell(hide_code=True)
|
398 |
+
def _(category, mo, nb_iframe, notebook_height, notebooks):
|
399 |
+
notebook = notebooks.value.to_dicts()[0]
|
400 |
+
mo.vstack(
|
401 |
+
[
|
402 |
+
mo.hstack([category, notebook_height]),
|
403 |
+
notebooks,
|
404 |
+
mo.md(f"{notebook['description']}"),
|
405 |
+
mo.md('---'),
|
406 |
+
mo.md(nb_iframe(notebook["notebook_url"], notebook_height.value)),
|
407 |
+
]
|
408 |
+
)
|
409 |
+
return (notebook,)
|
410 |
+
|
411 |
+
|
412 |
+
@app.cell(hide_code=True)
|
413 |
+
def _(mo):
|
414 |
+
mo.md(
|
415 |
+
r"""
|
416 |
+
## ⚙️ Row-Wise UDFs
|
417 |
+
|
418 |
+
> Accessing All Columns at Once
|
419 |
+
|
420 |
+
Sometimes, you need to work with *all* columns of a row at once. This is where [`map_rows`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.map_rows.html) comes in. It operates directly on the DataFrame, passing each row to your UDF *as a tuple*.
|
421 |
+
|
422 |
+
Below, `create_notebook_summary` takes a row from `notebook_stats_df` (as a tuple) and returns a formatted Markdown string summarizing the notebook's key stats. We're essentially reducing the DataFrame to a single column. While this *could* be done with native Polars expressions, it would be much more cumbersome. This example demonstrates a case where a row-wise UDF simplifies the code, even if the underlying operation isn't inherently complex.
|
423 |
+
"""
|
424 |
+
)
|
425 |
+
return
|
426 |
+
|
427 |
+
|
428 |
+
@app.cell(hide_code=True)
|
429 |
+
def _():
|
430 |
+
def create_notebook_summary(row: tuple) -> str:
|
431 |
+
(
|
432 |
+
thumbnail_src,
|
433 |
+
category,
|
434 |
+
title,
|
435 |
+
likes,
|
436 |
+
forks,
|
437 |
+
comments,
|
438 |
+
license,
|
439 |
+
description,
|
440 |
+
notebook_url,
|
441 |
+
) = row
|
442 |
+
return (
|
443 |
+
f"""
|
444 |
+
### [{title}]({notebook_url})
|
445 |
+
|
446 |
+
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 12px; margin: 12px 0;">
|
447 |
+
<div>⭐ <strong>Likes:</strong> {likes}</div>
|
448 |
+
<div>↗️ <strong>Forks:</strong> {forks}</div>
|
449 |
+
<div>💬 <strong>Comments:</strong> {comments}</div>
|
450 |
+
<div>⚖️ <strong>License:</strong> {license}</div>
|
451 |
+
</div>
|
452 |
+
|
453 |
+
<a href="{notebook_url}" target="_blank">
|
454 |
+
<img src="{thumbnail_src}" style="height: 300px;" />
|
455 |
+
<a/>
|
456 |
+
""".strip('\n')
|
457 |
+
)
|
458 |
+
return (create_notebook_summary,)
|
459 |
+
|
460 |
+
|
461 |
+
@app.cell(hide_code=True)
|
462 |
+
def _(create_notebook_summary, notebook_stats_df, pl):
|
463 |
+
notebook_summary_df = notebook_stats_df.map_rows(
|
464 |
+
create_notebook_summary,
|
465 |
+
return_dtype=pl.String,
|
466 |
+
).rename({"map": "summary"})
|
467 |
+
notebook_summary_df.head(1)
|
468 |
+
return (notebook_summary_df,)
|
469 |
+
|
470 |
+
|
471 |
+
@app.cell(hide_code=True)
|
472 |
+
def _(mo):
|
473 |
+
mo.callout("💡 You can explore individual notebook statistics through the carousel. Discover the visualization's source code by clicking the notebook title or the thumbnail.")
|
474 |
+
return
|
475 |
+
|
476 |
+
|
477 |
+
@app.cell(hide_code=True)
|
478 |
+
def _(mo, notebook_summary_df):
|
479 |
+
mo.carousel(
|
480 |
+
[
|
481 |
+
mo.lazy(mo.md(summary))
|
482 |
+
for summary in notebook_summary_df.get_column("summary")
|
483 |
+
]
|
484 |
+
)
|
485 |
+
return
|
486 |
+
|
487 |
+
|
488 |
+
@app.cell(hide_code=True)
|
489 |
+
def _(mo):
|
490 |
+
mo.md(
|
491 |
+
r"""
|
492 |
+
## 🚀 Higher-performance UDFs
|
493 |
+
|
494 |
+
> Leveraging Numba to Make Python Fast
|
495 |
+
|
496 |
+
Python code doesn't *always* mean slow code. While UDFs *often* introduce performance overhead, there are exceptions. NumPy's universal functions ([`ufuncs`](https://numpy.org/doc/stable/reference/ufuncs.html)) and generalized universal functions ([`gufuncs`](https://numpy.org/neps/nep-0005-generalized-ufuncs.html)) provide high-performance operations on NumPy arrays, thanks to low-level implementations.
|
497 |
+
|
498 |
+
But NumPy's built-in functions are predefined. We can't easily use them for *custom* logic. Enter [`numba`](https://numba.pydata.org/). Numba is a just-in-time (JIT) compiler that translates Python functions into optimized machine code *at runtime*. It provides decorators like [`numba.guvectorize`](https://numba.readthedocs.io/en/stable/user/vectorize.html#the-guvectorize-decorator) that let us create our *own* high-performance `gufuncs` – *without* writing low-level code!
|
499 |
+
"""
|
500 |
+
)
|
501 |
+
return
|
502 |
+
|
503 |
+
|
504 |
+
@app.cell(hide_code=True)
|
505 |
+
def _(mo):
|
506 |
+
mo.md(
|
507 |
+
r"""
|
508 |
+
Let's create a custom popularity metric to rank notebooks, considering likes, forks, *and* comments (not just likes). We'll define `weighted_popularity_numba`, decorated with `@numba.guvectorize`. The decorator arguments specify that we're taking three integer vectors of length `n` and returning a float vector of length `n`.
|
509 |
+
|
510 |
+
The weighted popularity score for each notebook is calculated using the following formula:
|
511 |
+
|
512 |
+
$$
|
513 |
+
\begin{equation}
|
514 |
+
\text{score}_i = w_l \cdot l_i^{f} + w_f \cdot f_i^{f} + w_c \cdot c_i^{f}
|
515 |
+
\end{equation}
|
516 |
+
$$
|
517 |
+
|
518 |
+
with:
|
519 |
+
"""
|
520 |
+
)
|
521 |
+
return
|
522 |
+
|
523 |
+
|
524 |
+
@app.cell(hide_code=True)
|
525 |
+
def _(mo, non_linear_factor, weight_comments, weight_forks, weight_likes):
|
526 |
+
mo.md(rf"""
|
527 |
+
| Symbol | Description |
|
528 |
+
|--------|-------------|
|
529 |
+
| $\text{{score}}_i$ | Popularity score for the *i*-th notebook |
|
530 |
+
| $w_l = {weight_likes.value}$ | Weight for likes |
|
531 |
+
| $l_i$ | Number of likes for the *i*-th notebook |
|
532 |
+
| $w_f = {weight_forks.value}$ | Weight for forks |
|
533 |
+
| $f_i$ | Number of forks for the *i*-th notebook |
|
534 |
+
| $w_c = {weight_comments.value}$ | Weight for comments |
|
535 |
+
| $c_i$ | Number of comments for the *i*-th notebook |
|
536 |
+
| $f = {non_linear_factor.value}$ | Non-linear factor (exponent) |
|
537 |
+
""")
|
538 |
+
return
|
539 |
+
|
540 |
+
|
541 |
+
@app.cell(hide_code=True)
|
542 |
+
def _(mo):
|
543 |
+
weight_likes = mo.ui.slider(
|
544 |
+
start=0.1,
|
545 |
+
stop=1,
|
546 |
+
value=0.5,
|
547 |
+
step=0.1,
|
548 |
+
show_value=True,
|
549 |
+
label="⭐ Weight for Likes",
|
550 |
+
)
|
551 |
+
weight_forks = mo.ui.slider(
|
552 |
+
start=0.1,
|
553 |
+
stop=1,
|
554 |
+
value=0.3,
|
555 |
+
step=0.1,
|
556 |
+
show_value=True,
|
557 |
+
label="↗️ Weight for Forks",
|
558 |
+
)
|
559 |
+
weight_comments = mo.ui.slider(
|
560 |
+
start=0.1,
|
561 |
+
stop=1,
|
562 |
+
value=0.5,
|
563 |
+
step=0.1,
|
564 |
+
show_value=True,
|
565 |
+
label="💬 Weight for Comments",
|
566 |
+
)
|
567 |
+
non_linear_factor = mo.ui.slider(
|
568 |
+
start=1,
|
569 |
+
stop=2,
|
570 |
+
value=1.2,
|
571 |
+
step=0.1,
|
572 |
+
show_value=True,
|
573 |
+
label="🎢 Non-Linear Factor",
|
574 |
+
)
|
575 |
+
return non_linear_factor, weight_comments, weight_forks, weight_likes
|
576 |
+
|
577 |
+
|
578 |
+
@app.cell(hide_code=True)
|
579 |
+
def _(
|
580 |
+
non_linear_factor,
|
581 |
+
np,
|
582 |
+
numba,
|
583 |
+
weight_comments,
|
584 |
+
weight_forks,
|
585 |
+
weight_likes,
|
586 |
+
):
|
587 |
+
w_l = weight_likes.value
|
588 |
+
w_f = weight_forks.value
|
589 |
+
w_c = weight_comments.value
|
590 |
+
nlf = non_linear_factor.value
|
591 |
+
|
592 |
+
|
593 |
+
@numba.guvectorize(
|
594 |
+
[(numba.int64[:], numba.int64[:], numba.int64[:], numba.float64[:])],
|
595 |
+
"(n), (n), (n) -> (n)",
|
596 |
+
)
|
597 |
+
def weighted_popularity_numba(
|
598 |
+
likes: np.ndarray,
|
599 |
+
forks: np.ndarray,
|
600 |
+
comments: np.ndarray,
|
601 |
+
out: np.ndarray,
|
602 |
+
):
|
603 |
+
for i in range(likes.shape[0]):
|
604 |
+
out[i] = (
|
605 |
+
w_l * (likes[i] ** nlf)
|
606 |
+
+ w_f * (forks[i] ** nlf)
|
607 |
+
+ w_c * (comments[i] ** nlf)
|
608 |
+
)
|
609 |
+
return nlf, w_c, w_f, w_l, weighted_popularity_numba
|
610 |
+
|
611 |
+
|
612 |
+
@app.cell(hide_code=True)
|
613 |
+
def _(mo):
|
614 |
+
mo.md(r"""We apply our JIT-compiled UDF using `map_batches`, as before. The key is that we're passing entire columns directly to `weighted_popularity_numba`. Polars and Numba handle the conversion to NumPy arrays behind the scenes. This direct integration is a major benefit of using `guvectorize`.""")
|
615 |
+
return
|
616 |
+
|
617 |
+
|
618 |
+
@app.cell(hide_code=True)
|
619 |
+
def _(notebook_stats_df, pl, weighted_popularity_numba):
|
620 |
+
notebook_popularity_df = (
|
621 |
+
notebook_stats_df.select(
|
622 |
+
pl.col("notebook_thumbnail_src").alias("thumbnail"),
|
623 |
+
"title",
|
624 |
+
"likes",
|
625 |
+
"forks",
|
626 |
+
"comments",
|
627 |
+
popularity=pl.struct(["likes", "forks", "comments"]).map_batches(
|
628 |
+
lambda obj: weighted_popularity_numba(
|
629 |
+
obj.struct.field("likes"),
|
630 |
+
obj.struct.field("forks"),
|
631 |
+
obj.struct.field("comments"),
|
632 |
+
),
|
633 |
+
return_dtype=pl.Float64,
|
634 |
+
),
|
635 |
+
url="notebook_url",
|
636 |
+
)
|
637 |
+
)
|
638 |
+
return (notebook_popularity_df,)
|
639 |
+
|
640 |
+
|
641 |
+
@app.cell(hide_code=True)
|
642 |
+
def _(mo):
|
643 |
+
mo.callout("💡 Adjust the hyperparameters of the popularity ranking UDF. How do the weights and non-linear factor affect the notebook rankings?")
|
644 |
+
return
|
645 |
+
|
646 |
+
|
647 |
+
@app.cell(hide_code=True)
|
648 |
+
def _(
|
649 |
+
mo,
|
650 |
+
non_linear_factor,
|
651 |
+
notebook_popularity_df,
|
652 |
+
weight_comments,
|
653 |
+
weight_forks,
|
654 |
+
weight_likes,
|
655 |
+
):
|
656 |
+
mo.vstack(
|
657 |
+
[
|
658 |
+
mo.hstack([weight_likes, weight_forks]),
|
659 |
+
mo.hstack([weight_comments, non_linear_factor]),
|
660 |
+
notebook_popularity_df,
|
661 |
+
]
|
662 |
+
)
|
663 |
+
return
|
664 |
+
|
665 |
+
|
666 |
+
@app.cell(hide_code=True)
|
667 |
+
def _(mo):
|
668 |
+
mo.md(r"""As the slope chart below demonstrates, this new ranking strategy significantly changes the notebook order, as it considers forks and comments, not just likes.""")
|
669 |
+
return
|
670 |
+
|
671 |
+
|
672 |
+
@app.cell(hide_code=True)
|
673 |
+
def _(alt, notebook_popularity_df, pl):
|
674 |
+
notebook_ranks_df = (
|
675 |
+
notebook_popularity_df.sort("likes", descending=True)
|
676 |
+
.with_row_index("rank_by_likes")
|
677 |
+
.with_columns(pl.col("rank_by_likes") + 1)
|
678 |
+
.sort("popularity", descending=True)
|
679 |
+
.with_row_index("rank_by_popularity")
|
680 |
+
.with_columns(pl.col("rank_by_popularity") + 1)
|
681 |
+
.select("thumbnail", "title", "rank_by_popularity", "rank_by_likes")
|
682 |
+
.unpivot(
|
683 |
+
["rank_by_popularity", "rank_by_likes"],
|
684 |
+
index="title",
|
685 |
+
variable_name="strategy",
|
686 |
+
value_name="rank",
|
687 |
+
)
|
688 |
+
)
|
689 |
+
|
690 |
+
# Slope chart to visualize rank differences by strategy
|
691 |
+
lines = notebook_ranks_df.plot.line(
|
692 |
+
x="strategy:O",
|
693 |
+
y="rank:Q",
|
694 |
+
color="title:N",
|
695 |
+
)
|
696 |
+
points = notebook_ranks_df.plot.point(
|
697 |
+
x="strategy:O",
|
698 |
+
y="rank:Q",
|
699 |
+
color=alt.Color("title:N", legend=None),
|
700 |
+
fill="title:N",
|
701 |
+
)
|
702 |
+
(points + lines).properties(width=400)
|
703 |
+
return lines, notebook_ranks_df, points
|
704 |
+
|
705 |
+
|
706 |
+
@app.cell(hide_code=True)
|
707 |
+
def _(mo):
|
708 |
+
mo.md(
|
709 |
+
r"""
|
710 |
+
## ⏱️ Quantifying the Overhead
|
711 |
+
|
712 |
+
> UDF Performance Comparison
|
713 |
+
|
714 |
+
To truly understand the performance implications of using UDFs, let's conduct a benchmark. We'll create a DataFrame with random numbers and perform the same numerical operation using four different methods:
|
715 |
+
|
716 |
+
1. **Native Polars:** Using Polars' built-in expressions.
|
717 |
+
2. **`map_elements`:** Applying a Python function element-wise.
|
718 |
+
3. **`map_batches`:** **Applying** a Python function to the entire Series.
|
719 |
+
4. **`map_batches` with Numba:** Applying a JIT-compiled function to batches, similar to a generalized universal function.
|
720 |
+
|
721 |
+
We'll use a simple, but non-trivial, calculation: `result = (x * 2.5 + 5) / (x + 1)`. This involves multiplication, addition, and division, giving us a realistic representation of a common numerical operation. We'll use the `timeit` module, to accurately measure execution times over multiple trials.
|
722 |
+
"""
|
723 |
+
)
|
724 |
+
return
|
725 |
+
|
726 |
+
|
727 |
+
@app.cell(hide_code=True)
|
728 |
+
def _(mo):
|
729 |
+
mo.callout("💡 Tweak the benchmark parameters to explore how execution times change with different sample sizes and trial counts. Do you notice anything surprising as you decrease the number of samples?")
|
730 |
+
return
|
731 |
+
|
732 |
+
|
733 |
+
@app.cell(hide_code=True)
|
734 |
+
def _(benchmark_plot, mo, num_samples, num_trials):
|
735 |
+
mo.vstack(
|
736 |
+
[
|
737 |
+
mo.hstack([num_samples, num_trials]),
|
738 |
+
mo.md(
|
739 |
+
f"""---
|
740 |
+
Performance comparison over **{num_trials.value:,} trials** with **{num_samples.value:,} samples**.
|
741 |
+
|
742 |
+
> Lower execution times are better.
|
743 |
+
"""
|
744 |
+
),
|
745 |
+
benchmark_plot,
|
746 |
+
]
|
747 |
+
)
|
748 |
+
return
|
749 |
+
|
750 |
+
|
751 |
+
@app.cell(hide_code=True)
|
752 |
+
def _(mo):
|
753 |
+
mo.md(
|
754 |
+
r"""
|
755 |
+
As anticipated, the `Batch-Wise UDF (Python)` and `Element-Wise UDF` exhibit significantly worse performance, essentially acting as pure-Python for-each loops.
|
756 |
+
|
757 |
+
However, when Python serves as an interface to lower-level, high-performance libraries, we observe substantial improvements. The `Batch-Wise UDF (NumPy)` lags behind both `Batch-Wise UDF (Numba)` and `Native Polars`, but it still represents a considerable improvement over pure-Python UDFs due to its vectorized computations.
|
758 |
+
|
759 |
+
Numba's Just-In-Time (JIT) compilation delivers a dramatic performance boost, achieving speeds comparable to native Polars expressions. This demonstrates that UDFs, particularly when combined with tools like Numba, don't inevitably lead to bottlenecks in numerical computations.
|
760 |
+
"""
|
761 |
+
)
|
762 |
+
return
|
763 |
+
|
764 |
+
|
765 |
+
@app.cell(hide_code=True)
|
766 |
+
def _(mo):
|
767 |
+
num_samples = mo.ui.slider(
|
768 |
+
start=1_000,
|
769 |
+
stop=1_000_000,
|
770 |
+
value=250_000,
|
771 |
+
step=1000,
|
772 |
+
show_value=True,
|
773 |
+
debounce=True,
|
774 |
+
label="Number of Samples",
|
775 |
+
)
|
776 |
+
num_trials = mo.ui.slider(
|
777 |
+
start=50,
|
778 |
+
stop=1_000,
|
779 |
+
value=100,
|
780 |
+
step=50,
|
781 |
+
show_value=True,
|
782 |
+
debounce=True,
|
783 |
+
label="Number of Trials",
|
784 |
+
)
|
785 |
+
return num_samples, num_trials
|
786 |
+
|
787 |
+
|
788 |
+
@app.cell(hide_code=True)
|
789 |
+
def _(np, num_samples, pl):
|
790 |
+
rng = np.random.default_rng(42)
|
791 |
+
sample_df = pl.from_dict({"x": rng.random(num_samples.value)})
|
792 |
+
return rng, sample_df
|
793 |
+
|
794 |
+
|
795 |
+
@app.cell(hide_code=True)
|
796 |
+
def _(np, num_trials, numba, pl, sample_df, timeit):
|
797 |
+
def run_native():
|
798 |
+
sample_df.with_columns(
|
799 |
+
result_native=(pl.col("x") * 2.5 + 5) / (pl.col("x") + 1)
|
800 |
+
)
|
801 |
+
|
802 |
+
|
803 |
+
def _calculate_elementwise(x: float) -> float:
|
804 |
+
return (x * 2.5 + 5) / (x + 1)
|
805 |
+
|
806 |
+
|
807 |
+
def run_map_elements():
|
808 |
+
sample_df.with_columns(
|
809 |
+
result_map_elements=pl.col("x").map_elements(
|
810 |
+
_calculate_elementwise,
|
811 |
+
return_dtype=pl.Float64,
|
812 |
+
)
|
813 |
+
)
|
814 |
+
|
815 |
+
|
816 |
+
def _calculate_batchwise_numpy(x_series: pl.Series) -> pl.Series:
|
817 |
+
x_array = x_series.to_numpy()
|
818 |
+
result_array = (x_array * 2.5 + 5) / (x_array + 1)
|
819 |
+
return pl.Series(result_array)
|
820 |
+
|
821 |
+
|
822 |
+
def run_map_batches_numpy():
|
823 |
+
sample_df.with_columns(
|
824 |
+
result_map_batches_numpy=pl.col("x").map_batches(
|
825 |
+
_calculate_batchwise_numpy,
|
826 |
+
return_dtype=pl.Float64,
|
827 |
+
)
|
828 |
+
)
|
829 |
+
|
830 |
+
|
831 |
+
def _calculate_batchwise_python(x_series: pl.Series) -> pl.Series:
|
832 |
+
x_array = x_series.to_list()
|
833 |
+
result_array = [_calculate_elementwise(x) for x in x_array]
|
834 |
+
return pl.Series(result_array)
|
835 |
+
|
836 |
+
|
837 |
+
def run_map_batches_python():
|
838 |
+
sample_df.with_columns(
|
839 |
+
result_map_batches_python=pl.col("x").map_batches(
|
840 |
+
_calculate_batchwise_python,
|
841 |
+
return_dtype=pl.Float64,
|
842 |
+
)
|
843 |
+
)
|
844 |
+
|
845 |
+
|
846 |
+
@numba.guvectorize([(numba.float64[:], numba.float64[:])], "(n) -> (n)")
|
847 |
+
def _calculate_batchwise_numba(x: np.ndarray, out: np.ndarray):
|
848 |
+
for i in range(x.shape[0]):
|
849 |
+
out[i] = (x[i] * 2.5 + 5) / (x[i] + 1)
|
850 |
+
|
851 |
+
|
852 |
+
def run_map_batches_numba():
|
853 |
+
sample_df.with_columns(
|
854 |
+
result_map_batches_numba=pl.col("x").map_batches(
|
855 |
+
_calculate_batchwise_numba,
|
856 |
+
return_dtype=pl.Float64,
|
857 |
+
)
|
858 |
+
)
|
859 |
+
|
860 |
+
|
861 |
+
def time_method(callable_name: str, number=num_trials.value) -> float:
|
862 |
+
fn = globals()[callable_name]
|
863 |
+
return timeit.timeit(fn, number=number)
|
864 |
+
return (
|
865 |
+
run_map_batches_numba,
|
866 |
+
run_map_batches_numpy,
|
867 |
+
run_map_batches_python,
|
868 |
+
run_map_elements,
|
869 |
+
run_native,
|
870 |
+
time_method,
|
871 |
+
)
|
872 |
+
|
873 |
+
|
874 |
+
@app.cell(hide_code=True)
|
875 |
+
def _(alt, pl, time_method):
|
876 |
+
benchmark_df = pl.from_dicts(
|
877 |
+
[
|
878 |
+
{
|
879 |
+
"title": "Native Polars",
|
880 |
+
"callable_name": "run_native",
|
881 |
+
},
|
882 |
+
{
|
883 |
+
"title": "Element-Wise UDF",
|
884 |
+
"callable_name": "run_map_elements",
|
885 |
+
},
|
886 |
+
{
|
887 |
+
"title": "Batch-Wise UDF (NumPy)",
|
888 |
+
"callable_name": "run_map_batches_numpy",
|
889 |
+
},
|
890 |
+
{
|
891 |
+
"title": "Batch-Wise UDF (Python)",
|
892 |
+
"callable_name": "run_map_batches_python",
|
893 |
+
},
|
894 |
+
{
|
895 |
+
"title": "Batch-Wise UDF (Numba)",
|
896 |
+
"callable_name": "run_map_batches_numba",
|
897 |
+
},
|
898 |
+
]
|
899 |
+
).with_columns(
|
900 |
+
time=pl.col("callable_name").map_elements(
|
901 |
+
time_method, return_dtype=pl.Float64
|
902 |
+
)
|
903 |
+
)
|
904 |
+
|
905 |
+
benchmark_plot = benchmark_df.plot.bar(
|
906 |
+
x=alt.X("title:N", title="Method", sort="-y"),
|
907 |
+
y=alt.Y("time:Q", title="Execution Time (s)", axis=alt.Axis(format=".3f")),
|
908 |
+
).properties(width=400)
|
909 |
+
return benchmark_df, benchmark_plot
|
910 |
+
|
911 |
+
|
912 |
+
@app.cell(hide_code=True)
|
913 |
+
def _():
|
914 |
+
import asyncio
|
915 |
+
import timeit
|
916 |
+
from typing import Iterable
|
917 |
+
|
918 |
+
import altair as alt
|
919 |
+
import httpx
|
920 |
+
import marimo as mo
|
921 |
+
import nest_asyncio
|
922 |
+
import numba
|
923 |
+
import numpy as np
|
924 |
+
from bs4 import BeautifulSoup
|
925 |
+
|
926 |
+
import polars as pl
|
927 |
+
|
928 |
+
# Fixes RuntimeError: asyncio.run() cannot be called from a running event loop
|
929 |
+
nest_asyncio.apply()
|
930 |
+
return (
|
931 |
+
BeautifulSoup,
|
932 |
+
Iterable,
|
933 |
+
alt,
|
934 |
+
asyncio,
|
935 |
+
httpx,
|
936 |
+
mo,
|
937 |
+
nest_asyncio,
|
938 |
+
np,
|
939 |
+
numba,
|
940 |
+
pl,
|
941 |
+
timeit,
|
942 |
+
)
|
943 |
+
|
944 |
+
|
945 |
+
if __name__ == "__main__":
|
946 |
+
app.run()
|
polars/README.md
CHANGED
@@ -23,3 +23,4 @@ You can also open notebooks in our online playground by appending marimo.app/ to
|
|
23 |
Thanks to all our notebook authors!
|
24 |
|
25 |
* [Koushik Khan](https://github.com/koushikkhan)
|
|
|
|
23 |
Thanks to all our notebook authors!
|
24 |
|
25 |
* [Koushik Khan](https://github.com/koushikkhan)
|
26 |
+
* [Péter Gyarmati](https://github.com/peter-gy)
|
probability/08_bayes_theorem.py
CHANGED
@@ -307,7 +307,7 @@ def _(mo):
|
|
307 |
mo.md(
|
308 |
r"""
|
309 |
|
310 |
-
_This interactive
|
311 |
|
312 |
Bayes theorem provides a convenient way to calculate the probability
|
313 |
of a hypothesis event $H$ given evidence $E$:
|
|
|
307 |
mo.md(
|
308 |
r"""
|
309 |
|
310 |
+
_This interactive example was made with [marimo](https://github.com/marimo-team/marimo/blob/main/examples/misc/bayes_theorem.py), and is [based on an explanation of Bayes' Theorem by Grant Sanderson](https://www.youtube.com/watch?v=HZGCoVF3YvM&list=PLzq7odmtfKQw2KIbQq0rzWrqgifHKkPG1&index=1&t=3s)_.
|
311 |
|
312 |
Bayes theorem provides a convenient way to calculate the probability
|
313 |
of a hypothesis event $H$ given evidence $E$:
|
probability/10_probability_mass_function.py
ADDED
@@ -0,0 +1,711 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.0",
|
6 |
+
# "numpy==2.2.3",
|
7 |
+
# "scipy==1.15.2",
|
8 |
+
# ]
|
9 |
+
# ///
|
10 |
+
|
11 |
+
import marimo
|
12 |
+
|
13 |
+
__generated_with = "0.11.17"
|
14 |
+
app = marimo.App(width="medium", app_title="Probability Mass Functions")
|
15 |
+
|
16 |
+
|
17 |
+
@app.cell(hide_code=True)
|
18 |
+
def _(mo):
|
19 |
+
mo.md(
|
20 |
+
r"""
|
21 |
+
# Probability Mass Functions
|
22 |
+
|
23 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/pmf/), by Stanford professor Chris Piech._
|
24 |
+
|
25 |
+
For a random variable, the most important thing to know is: how likely is each outcome? For a discrete random variable, this information is called the "**Probability Mass Function**". The probability mass function (PMF) provides the "mass" (i.e. amount) of "probability" for each possible assignment of the random variable.
|
26 |
+
|
27 |
+
Formally, the Probability Mass Function is a mapping between the values that the random variable could take on and the probability of the random variable taking on said value. In mathematics, we call these associations functions. There are many different ways of representing functions: you can write an equation, you can make a graph, you can even store many samples in a list.
|
28 |
+
"""
|
29 |
+
)
|
30 |
+
return
|
31 |
+
|
32 |
+
|
33 |
+
@app.cell(hide_code=True)
|
34 |
+
def _(mo):
|
35 |
+
mo.md(
|
36 |
+
r"""
|
37 |
+
## Properties of a PMF
|
38 |
+
|
39 |
+
For a function $p_X(x)$ to be a valid PMF, it must satisfy:
|
40 |
+
|
41 |
+
1. **Non-negativity**: $p_X(x) \geq 0$ for all $x$
|
42 |
+
2. **Unit total probability**: $\sum_x p_X(x) = 1$
|
43 |
+
|
44 |
+
### Probabilities Must Sum to 1
|
45 |
+
|
46 |
+
For a variable (call it $X$) to be a proper random variable, it must be the case that if you summed up the values of $P(X=x)$ for all possible values $x$ that $X$ can take on, the result must be 1:
|
47 |
+
|
48 |
+
$$\sum_x P(X=x) = 1$$
|
49 |
+
|
50 |
+
This is because a random variable taking on a value is an event (for example $X=3$). Each of those events is mutually exclusive because a random variable will take on exactly one value. Those mutually exclusive cases define an entire sample space. Why? Because $X$ must take on some value.
|
51 |
+
"""
|
52 |
+
)
|
53 |
+
return
|
54 |
+
|
55 |
+
|
56 |
+
@app.cell(hide_code=True)
|
57 |
+
def _(mo):
|
58 |
+
mo.md(
|
59 |
+
r"""
|
60 |
+
## PMFs as Graphs
|
61 |
+
|
62 |
+
Let's start by looking at PMFs as graphs where the $x$-axis is the values that the random variable could take on and the $y$-axis is the probability of the random variable taking on said value.
|
63 |
+
|
64 |
+
In the following example, we show two PMFs:
|
65 |
+
|
66 |
+
- On the left: PMF for the random variable $X$ = the value of a single six-sided die roll
|
67 |
+
- On the right: PMF for the random variable $Y$ = value of the sum of two dice rolls
|
68 |
+
"""
|
69 |
+
)
|
70 |
+
return
|
71 |
+
|
72 |
+
|
73 |
+
@app.cell(hide_code=True)
|
74 |
+
def _(np, plt):
|
75 |
+
# Single die PMF
|
76 |
+
single_die_values = np.arange(1, 7)
|
77 |
+
single_die_probs = np.ones(6) / 6
|
78 |
+
|
79 |
+
# Two dice sum PMF
|
80 |
+
two_dice_values = np.arange(2, 13)
|
81 |
+
two_dice_probs = []
|
82 |
+
|
83 |
+
for dice_sum in two_dice_values:
|
84 |
+
if dice_sum <= 7:
|
85 |
+
dice_prob = (dice_sum-1) / 36
|
86 |
+
else:
|
87 |
+
dice_prob = (13-dice_sum) / 36
|
88 |
+
two_dice_probs.append(dice_prob)
|
89 |
+
|
90 |
+
# Create side-by-side plots
|
91 |
+
dice_fig, (dice_ax1, dice_ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
92 |
+
|
93 |
+
# Single die plot
|
94 |
+
dice_ax1.bar(single_die_values, single_die_probs, width=0.4)
|
95 |
+
dice_ax1.set_xticks(single_die_values)
|
96 |
+
dice_ax1.set_xlabel('Value of die roll (x)')
|
97 |
+
dice_ax1.set_ylabel('Probability: P(X = x)')
|
98 |
+
dice_ax1.set_title('PMF of a Single Die Roll')
|
99 |
+
dice_ax1.grid(alpha=0.3)
|
100 |
+
|
101 |
+
# Two dice sum plot
|
102 |
+
dice_ax2.bar(two_dice_values, two_dice_probs, width=0.4)
|
103 |
+
dice_ax2.set_xticks(two_dice_values)
|
104 |
+
dice_ax2.set_xlabel('Sum of two dice (y)')
|
105 |
+
dice_ax2.set_ylabel('Probability: P(Y = y)')
|
106 |
+
dice_ax2.set_title('PMF of Sum of Two Dice')
|
107 |
+
dice_ax2.grid(alpha=0.3)
|
108 |
+
|
109 |
+
plt.tight_layout()
|
110 |
+
plt.gca()
|
111 |
+
return (
|
112 |
+
dice_ax1,
|
113 |
+
dice_ax2,
|
114 |
+
dice_fig,
|
115 |
+
dice_prob,
|
116 |
+
dice_sum,
|
117 |
+
single_die_probs,
|
118 |
+
single_die_values,
|
119 |
+
two_dice_probs,
|
120 |
+
two_dice_values,
|
121 |
+
)
|
122 |
+
|
123 |
+
|
124 |
+
@app.cell(hide_code=True)
|
125 |
+
def _(mo):
|
126 |
+
mo.md(
|
127 |
+
r"""
|
128 |
+
The information provided in these graphs shows the likelihood of a random variable taking on different values.
|
129 |
+
|
130 |
+
In the graph on the right, the value "6" on the $x$-axis is associated with the probability $\frac{5}{36}$ on the $y$-axis. This $x$-axis refers to the event "the sum of two dice is 6" or $Y = 6$. The $y$-axis tells us that the probability of that event is $\frac{5}{36}$. In full: $P(Y = 6) = \frac{5}{36}$.
|
131 |
+
|
132 |
+
The value "2" is associated with "$\frac{1}{36}$" which tells us that, $P(Y = 2) = \frac{1}{36}$, the probability that two dice sum to 2 is $\frac{1}{36}$. There is no value associated with "1" because the sum of two dice cannot be 1.
|
133 |
+
"""
|
134 |
+
)
|
135 |
+
return
|
136 |
+
|
137 |
+
|
138 |
+
@app.cell(hide_code=True)
|
139 |
+
def _(mo):
|
140 |
+
mo.md(
|
141 |
+
r"""
|
142 |
+
## PMFs as Equations
|
143 |
+
|
144 |
+
Here is the exact same information in equation form:
|
145 |
+
|
146 |
+
For a single die roll $X$:
|
147 |
+
$$P(X=x) = \frac{1}{6} \quad \text{ if } 1 \leq x \leq 6$$
|
148 |
+
|
149 |
+
For the sum of two dice $Y$:
|
150 |
+
$$P(Y=y) = \begin{cases}
|
151 |
+
\frac{(y-1)}{36} & \text{ if } 2 \leq y \leq 7\\
|
152 |
+
\frac{(13-y)}{36} & \text{ if } 8 \leq y \leq 12
|
153 |
+
\end{cases}$$
|
154 |
+
|
155 |
+
Let's implement the PMF for $Y$, the sum of two dice, in Python code:
|
156 |
+
"""
|
157 |
+
)
|
158 |
+
return
|
159 |
+
|
160 |
+
|
161 |
+
@app.cell
|
162 |
+
def _():
|
163 |
+
def pmf_sum_two_dice(y_val):
|
164 |
+
"""Returns the probability that the sum of two dice is y"""
|
165 |
+
if y_val < 2 or y_val > 12:
|
166 |
+
return 0
|
167 |
+
if y_val <= 7:
|
168 |
+
return (y_val-1) / 36
|
169 |
+
else:
|
170 |
+
return (13-y_val) / 36
|
171 |
+
|
172 |
+
# Test the function for a few values
|
173 |
+
test_values = [1, 2, 7, 12, 13]
|
174 |
+
for test_y in test_values:
|
175 |
+
print(f"P(Y = {test_y}) = {pmf_sum_two_dice(test_y)}")
|
176 |
+
return pmf_sum_two_dice, test_values, test_y
|
177 |
+
|
178 |
+
|
179 |
+
@app.cell(hide_code=True)
|
180 |
+
def _(mo):
|
181 |
+
mo.md(r"""Now, let's verify that our PMF satisfies the property that the sum of all probabilities equals 1:""")
|
182 |
+
return
|
183 |
+
|
184 |
+
|
185 |
+
@app.cell
|
186 |
+
def _(pmf_sum_two_dice):
|
187 |
+
# Verify that probabilities sum to 1
|
188 |
+
verify_total_prob = sum(pmf_sum_two_dice(y_val) for y_val in range(2, 13))
|
189 |
+
# Round to 10 decimal places to handle floating-point precision
|
190 |
+
verify_total_prob_rounded = round(verify_total_prob, 10)
|
191 |
+
print(f"Sum of all probabilities: {verify_total_prob_rounded}")
|
192 |
+
return verify_total_prob, verify_total_prob_rounded
|
193 |
+
|
194 |
+
|
195 |
+
@app.cell(hide_code=True)
|
196 |
+
def _(plt, pmf_sum_two_dice):
|
197 |
+
# Create a visual verification
|
198 |
+
verify_y_values = list(range(2, 13))
|
199 |
+
verify_probabilities = [pmf_sum_two_dice(y_val) for y_val in verify_y_values]
|
200 |
+
|
201 |
+
plt.figure(figsize=(10, 4))
|
202 |
+
plt.bar(verify_y_values, verify_probabilities, width=0.4)
|
203 |
+
plt.xticks(verify_y_values)
|
204 |
+
plt.xlabel('Sum of two dice (y)')
|
205 |
+
plt.ylabel('Probability: P(Y = y)')
|
206 |
+
plt.title('PMF of Sum of Two Dice (Total Probability = 1)')
|
207 |
+
plt.grid(alpha=0.3)
|
208 |
+
|
209 |
+
# Add probability values on top of bars
|
210 |
+
for verify_i, verify_prob in enumerate(verify_probabilities):
|
211 |
+
plt.text(verify_y_values[verify_i], verify_prob + 0.001, f'{verify_prob:.3f}', ha='center')
|
212 |
+
|
213 |
+
plt.gca() # Return the current axes to ensure proper display
|
214 |
+
return verify_i, verify_prob, verify_probabilities, verify_y_values
|
215 |
+
|
216 |
+
|
217 |
+
@app.cell(hide_code=True)
|
218 |
+
def _(mo):
|
219 |
+
mo.md(
|
220 |
+
r"""
|
221 |
+
## Data to Histograms to Probability Mass Functions
|
222 |
+
|
223 |
+
One surprising way to store a likelihood function (recall that a PMF is the name of the likelihood function for discrete random variables) is simply a list of data. Let's simulate summing two dice many times to create an empirical PMF:
|
224 |
+
"""
|
225 |
+
)
|
226 |
+
return
|
227 |
+
|
228 |
+
|
229 |
+
@app.cell
|
230 |
+
def _(np):
|
231 |
+
# Simulate rolling two dice many times
|
232 |
+
sim_num_trials = 10000
|
233 |
+
np.random.seed(42) # For reproducibility
|
234 |
+
|
235 |
+
# Generate random dice rolls
|
236 |
+
sim_die1 = np.random.randint(1, 7, size=sim_num_trials)
|
237 |
+
sim_die2 = np.random.randint(1, 7, size=sim_num_trials)
|
238 |
+
|
239 |
+
# Calculate the sum
|
240 |
+
sim_dice_sums = sim_die1 + sim_die2
|
241 |
+
|
242 |
+
# Display a small sample of the data
|
243 |
+
print(f"First 20 dice sums: {sim_dice_sums[:20]}")
|
244 |
+
print(f"Total number of trials: {sim_num_trials}")
|
245 |
+
return sim_dice_sums, sim_die1, sim_die2, sim_num_trials
|
246 |
+
|
247 |
+
|
248 |
+
@app.cell(hide_code=True)
|
249 |
+
def _(collections, np, plt, sim_dice_sums):
|
250 |
+
# Count the frequency of each sum
|
251 |
+
sim_counter = collections.Counter(sim_dice_sums)
|
252 |
+
|
253 |
+
# Sort the values
|
254 |
+
sim_sorted_values = sorted(sim_counter.keys())
|
255 |
+
|
256 |
+
# Calculate the empirical PMF
|
257 |
+
sim_empirical_pmf = [sim_counter[x] / len(sim_dice_sums) for x in sim_sorted_values]
|
258 |
+
|
259 |
+
# Calculate the theoretical PMF
|
260 |
+
sim_theoretical_values = np.arange(2, 13)
|
261 |
+
sim_theoretical_pmf = []
|
262 |
+
for sim_y in sim_theoretical_values:
|
263 |
+
if sim_y <= 7:
|
264 |
+
sim_prob = (sim_y-1) / 36
|
265 |
+
else:
|
266 |
+
sim_prob = (13-sim_y) / 36
|
267 |
+
sim_theoretical_pmf.append(sim_prob)
|
268 |
+
|
269 |
+
# Create a comparison plot
|
270 |
+
sim_fig, (sim_ax1, sim_ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
271 |
+
|
272 |
+
# Empirical PMF (normalized histogram)
|
273 |
+
sim_ax1.bar(sim_sorted_values, sim_empirical_pmf, width=0.4)
|
274 |
+
sim_ax1.set_xticks(sim_sorted_values)
|
275 |
+
sim_ax1.set_xlabel('Sum of two dice')
|
276 |
+
sim_ax1.set_ylabel('Empirical Probability')
|
277 |
+
sim_ax1.set_title(f'Empirical PMF from {len(sim_dice_sums)} Trials')
|
278 |
+
sim_ax1.grid(alpha=0.3)
|
279 |
+
|
280 |
+
# Theoretical PMF
|
281 |
+
sim_ax2.bar(sim_theoretical_values, sim_theoretical_pmf, width=0.4)
|
282 |
+
sim_ax2.set_xticks(sim_theoretical_values)
|
283 |
+
sim_ax2.set_xlabel('Sum of two dice')
|
284 |
+
sim_ax2.set_ylabel('Theoretical Probability')
|
285 |
+
sim_ax2.set_title('Theoretical PMF')
|
286 |
+
sim_ax2.grid(alpha=0.3)
|
287 |
+
|
288 |
+
plt.tight_layout()
|
289 |
+
|
290 |
+
# Let's also look at the raw counts (histogram)
|
291 |
+
plt.figure(figsize=(10, 4))
|
292 |
+
sim_counts = [sim_counter[x] for x in sim_sorted_values]
|
293 |
+
plt.bar(sim_sorted_values, sim_counts, width=0.4)
|
294 |
+
plt.xticks(sim_sorted_values)
|
295 |
+
plt.xlabel('Sum of two dice')
|
296 |
+
plt.ylabel('Frequency')
|
297 |
+
plt.title('Histogram of Dice Sum Frequencies')
|
298 |
+
plt.grid(alpha=0.3)
|
299 |
+
|
300 |
+
# Add count values on top of bars
|
301 |
+
for sim_i, sim_count in enumerate(sim_counts):
|
302 |
+
plt.text(sim_sorted_values[sim_i], sim_count + 19, str(sim_count), ha='center')
|
303 |
+
|
304 |
+
plt.gca() # Return the current axes to ensure proper display
|
305 |
+
return (
|
306 |
+
sim_ax1,
|
307 |
+
sim_ax2,
|
308 |
+
sim_count,
|
309 |
+
sim_counter,
|
310 |
+
sim_counts,
|
311 |
+
sim_empirical_pmf,
|
312 |
+
sim_fig,
|
313 |
+
sim_i,
|
314 |
+
sim_prob,
|
315 |
+
sim_sorted_values,
|
316 |
+
sim_theoretical_pmf,
|
317 |
+
sim_theoretical_values,
|
318 |
+
sim_y,
|
319 |
+
)
|
320 |
+
|
321 |
+
|
322 |
+
@app.cell(hide_code=True)
|
323 |
+
def _(mo):
|
324 |
+
mo.md(
|
325 |
+
r"""
|
326 |
+
A normalized histogram (where each value is divided by the length of your data list) is an approximation of the PMF. For a dataset of discrete numbers, a histogram shows the count of each value. By the definition of probability, if you divide this count by the number of experiments run, you arrive at an approximation of the probability of the event $P(Y=y)$.
|
327 |
+
|
328 |
+
Let's look at a specific example. If we want to approximate $P(Y=3)$ (the probability that the sum of two dice is 3), we can count the number of times that "3" occurs in our data and divide by the total number of trials:
|
329 |
+
"""
|
330 |
+
)
|
331 |
+
return
|
332 |
+
|
333 |
+
|
334 |
+
@app.cell
|
335 |
+
def _(sim_counter, sim_dice_sums):
|
336 |
+
# Calculate P(Y=3) empirically
|
337 |
+
sim_count_of_3 = sim_counter[3]
|
338 |
+
sim_empirical_prob = sim_count_of_3 / len(sim_dice_sums)
|
339 |
+
|
340 |
+
# Calculate P(Y=3) theoretically
|
341 |
+
sim_theoretical_prob = 2/36 # There are 2 ways to get a sum of 3 out of 36 possible outcomes
|
342 |
+
|
343 |
+
print(f"Count of sum=3: {sim_count_of_3}")
|
344 |
+
print(f"Empirical P(Y=3): {sim_count_of_3}/{len(sim_dice_sums)} = {sim_empirical_prob:.4f}")
|
345 |
+
print(f"Theoretical P(Y=3): 2/36 = {sim_theoretical_prob:.4f}")
|
346 |
+
print(f"Difference: {abs(sim_empirical_prob - sim_theoretical_prob):.4f}")
|
347 |
+
return sim_count_of_3, sim_empirical_prob, sim_theoretical_prob
|
348 |
+
|
349 |
+
|
350 |
+
@app.cell(hide_code=True)
|
351 |
+
def _(mo):
|
352 |
+
mo.md(
|
353 |
+
r"""
|
354 |
+
As we can see, with a large number of trials, the empirical PMF becomes a very good approximation of the theoretical PMF. This is an example of the [Law of Large Numbers](https://en.wikipedia.org/wiki/Law_of_large_numbers) in action.
|
355 |
+
|
356 |
+
## Interactive Example: Exploring PMFs
|
357 |
+
|
358 |
+
Let's create an interactive tool to explore different PMFs:
|
359 |
+
"""
|
360 |
+
)
|
361 |
+
return
|
362 |
+
|
363 |
+
|
364 |
+
@app.cell
|
365 |
+
def _(dist_param1, dist_param2, dist_selection, mo):
|
366 |
+
mo.hstack([dist_selection, dist_param1, dist_param2], justify="space-around")
|
367 |
+
return
|
368 |
+
|
369 |
+
|
370 |
+
@app.cell(hide_code=True)
|
371 |
+
def _(mo):
|
372 |
+
dist_selection = mo.ui.dropdown(
|
373 |
+
options=[
|
374 |
+
"bernoulli",
|
375 |
+
"binomial",
|
376 |
+
"geometric",
|
377 |
+
"poisson"
|
378 |
+
],
|
379 |
+
value="bernoulli",
|
380 |
+
label="Select a distribution"
|
381 |
+
)
|
382 |
+
|
383 |
+
# Parameters for different distributions
|
384 |
+
dist_param1 = mo.ui.slider(
|
385 |
+
start=0.05,
|
386 |
+
stop=0.95,
|
387 |
+
step=0.05,
|
388 |
+
value=0.5,
|
389 |
+
label="p (success probability)"
|
390 |
+
)
|
391 |
+
|
392 |
+
dist_param2 = mo.ui.slider(
|
393 |
+
start=1,
|
394 |
+
stop=20,
|
395 |
+
step=1,
|
396 |
+
value=10,
|
397 |
+
label="n (trials) or λ (rate)"
|
398 |
+
)
|
399 |
+
return dist_param1, dist_param2, dist_selection
|
400 |
+
|
401 |
+
|
402 |
+
@app.cell(hide_code=True)
|
403 |
+
def _(dist_param1, dist_param2, dist_selection, np, plt, stats):
|
404 |
+
# Set up the plot based on the selected distribution
|
405 |
+
if dist_selection.value == "bernoulli":
|
406 |
+
# Bernoulli distribution
|
407 |
+
dist_p = dist_param1.value
|
408 |
+
dist_x_values = np.array([0, 1])
|
409 |
+
dist_pmf_values = [1-dist_p, dist_p]
|
410 |
+
dist_title = f"Bernoulli PMF (p = {dist_p:.2f})"
|
411 |
+
dist_x_label = "Outcome (0 = Failure, 1 = Success)"
|
412 |
+
dist_max_x = 1
|
413 |
+
|
414 |
+
elif dist_selection.value == "binomial":
|
415 |
+
# Binomial distribution
|
416 |
+
dist_n = int(dist_param2.value)
|
417 |
+
dist_p = dist_param1.value
|
418 |
+
dist_x_values = np.arange(0, dist_n+1)
|
419 |
+
dist_pmf_values = stats.binom.pmf(dist_x_values, dist_n, dist_p)
|
420 |
+
dist_title = f"Binomial PMF (n = {dist_n}, p = {dist_p:.2f})"
|
421 |
+
dist_x_label = "Number of Successes"
|
422 |
+
dist_max_x = dist_n
|
423 |
+
|
424 |
+
elif dist_selection.value == "geometric":
|
425 |
+
# Geometric distribution
|
426 |
+
dist_p = dist_param1.value
|
427 |
+
dist_max_x = min(int(5/dist_p), 50) # Limit the range for visualization
|
428 |
+
dist_x_values = np.arange(1, dist_max_x+1)
|
429 |
+
dist_pmf_values = stats.geom.pmf(dist_x_values, dist_p)
|
430 |
+
dist_title = f"Geometric PMF (p = {dist_p:.2f})"
|
431 |
+
dist_x_label = "Number of Trials Until First Success"
|
432 |
+
|
433 |
+
else: # Poisson
|
434 |
+
# Poisson distribution
|
435 |
+
dist_lam = dist_param2.value
|
436 |
+
dist_max_x = int(dist_lam*3) + 1 # Reasonable range for visualization
|
437 |
+
dist_x_values = np.arange(0, dist_max_x)
|
438 |
+
dist_pmf_values = stats.poisson.pmf(dist_x_values, dist_lam)
|
439 |
+
dist_title = f"Poisson PMF (λ = {dist_lam})"
|
440 |
+
dist_x_label = "Number of Events"
|
441 |
+
|
442 |
+
# Create the plot
|
443 |
+
plt.figure(figsize=(10, 5))
|
444 |
+
|
445 |
+
# For discrete distributions, use stem plot for clarity
|
446 |
+
dist_markerline, dist_stemlines, dist_baseline = plt.stem(
|
447 |
+
dist_x_values, dist_pmf_values, markerfmt='o', basefmt=' '
|
448 |
+
)
|
449 |
+
plt.setp(dist_markerline, markersize=6)
|
450 |
+
plt.setp(dist_stemlines, linewidth=1.5)
|
451 |
+
|
452 |
+
# Add a bar plot for better visibility
|
453 |
+
plt.bar(dist_x_values, dist_pmf_values, alpha=0.3, width=0.4)
|
454 |
+
|
455 |
+
plt.xlabel(dist_x_label)
|
456 |
+
plt.ylabel("Probability: P(X = x)")
|
457 |
+
plt.title(dist_title)
|
458 |
+
plt.grid(alpha=0.3)
|
459 |
+
|
460 |
+
# Calculate and display expected value and variance
|
461 |
+
if dist_selection.value == "bernoulli":
|
462 |
+
dist_mean = dist_p
|
463 |
+
dist_variance = dist_p * (1-dist_p)
|
464 |
+
elif dist_selection.value == "binomial":
|
465 |
+
dist_mean = dist_n * dist_p
|
466 |
+
dist_variance = dist_n * dist_p * (1-dist_p)
|
467 |
+
elif dist_selection.value == "geometric":
|
468 |
+
dist_mean = 1/dist_p
|
469 |
+
dist_variance = (1-dist_p)/(dist_p**2)
|
470 |
+
else: # Poisson
|
471 |
+
dist_mean = dist_lam
|
472 |
+
dist_variance = dist_lam
|
473 |
+
|
474 |
+
dist_std_dev = np.sqrt(dist_variance)
|
475 |
+
|
476 |
+
# Add text with distribution properties
|
477 |
+
dist_props_text = (
|
478 |
+
f"Mean: {dist_mean:.3f}\n"
|
479 |
+
f"Variance: {dist_variance:.3f}\n"
|
480 |
+
f"Std Dev: {dist_std_dev:.3f}\n"
|
481 |
+
f"Sum of probabilities: {sum(dist_pmf_values):.6f}"
|
482 |
+
)
|
483 |
+
|
484 |
+
plt.text(0.95, 0.95, dist_props_text,
|
485 |
+
transform=plt.gca().transAxes,
|
486 |
+
verticalalignment='top',
|
487 |
+
horizontalalignment='right',
|
488 |
+
bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
|
489 |
+
|
490 |
+
plt.gca() # Return the current axes to ensure proper display
|
491 |
+
return (
|
492 |
+
dist_baseline,
|
493 |
+
dist_lam,
|
494 |
+
dist_markerline,
|
495 |
+
dist_max_x,
|
496 |
+
dist_mean,
|
497 |
+
dist_n,
|
498 |
+
dist_p,
|
499 |
+
dist_pmf_values,
|
500 |
+
dist_props_text,
|
501 |
+
dist_std_dev,
|
502 |
+
dist_stemlines,
|
503 |
+
dist_title,
|
504 |
+
dist_variance,
|
505 |
+
dist_x_label,
|
506 |
+
dist_x_values,
|
507 |
+
)
|
508 |
+
|
509 |
+
|
510 |
+
@app.cell(hide_code=True)
|
511 |
+
def _(mo):
|
512 |
+
mo.md(
|
513 |
+
r"""
|
514 |
+
## Expected Value from a PMF
|
515 |
+
|
516 |
+
The expected value (or mean) of a discrete random variable is calculated using its PMF:
|
517 |
+
|
518 |
+
$$E[X] = \sum_x x \cdot p_X(x)$$
|
519 |
+
|
520 |
+
This represents the long-run average value of the random variable.
|
521 |
+
"""
|
522 |
+
)
|
523 |
+
return
|
524 |
+
|
525 |
+
|
526 |
+
@app.cell
|
527 |
+
def _(dist_pmf_values, dist_x_values):
|
528 |
+
def calc_expected_value(x_values, pmf_values):
|
529 |
+
"""Calculate the expected value of a discrete random variable."""
|
530 |
+
return sum(x * p for x, p in zip(x_values, pmf_values))
|
531 |
+
|
532 |
+
# Calculate expected value for the current distribution
|
533 |
+
ev_dist_mean = calc_expected_value(dist_x_values, dist_pmf_values)
|
534 |
+
|
535 |
+
print(f"Expected value: {ev_dist_mean:.4f}")
|
536 |
+
return calc_expected_value, ev_dist_mean
|
537 |
+
|
538 |
+
|
539 |
+
@app.cell(hide_code=True)
|
540 |
+
def _(mo):
|
541 |
+
mo.md(
|
542 |
+
r"""
|
543 |
+
## Variance from a PMF
|
544 |
+
|
545 |
+
The variance measures the spread or dispersion of a random variable around its mean:
|
546 |
+
|
547 |
+
$$\text{Var}(X) = E[(X - E[X])^2] = \sum_x (x - E[X])^2 \cdot p_X(x)$$
|
548 |
+
|
549 |
+
An alternative formula is:
|
550 |
+
|
551 |
+
$$\text{Var}(X) = E[X^2] - (E[X])^2 = \sum_x x^2 \cdot p_X(x) - \left(\sum_x x \cdot p_X(x)\right)^2$$
|
552 |
+
"""
|
553 |
+
)
|
554 |
+
return
|
555 |
+
|
556 |
+
|
557 |
+
@app.cell
|
558 |
+
def _(dist_pmf_values, dist_x_values, ev_dist_mean, np):
|
559 |
+
def calc_variance(x_values, pmf_values, mean_value):
|
560 |
+
"""Calculate the variance of a discrete random variable."""
|
561 |
+
return sum((x - mean_value)**2 * p for x, p in zip(x_values, pmf_values))
|
562 |
+
|
563 |
+
# Calculate variance for the current distribution
|
564 |
+
var_dist_var = calc_variance(dist_x_values, dist_pmf_values, ev_dist_mean)
|
565 |
+
var_dist_std_dev = np.sqrt(var_dist_var)
|
566 |
+
|
567 |
+
print(f"Variance: {var_dist_var:.4f}")
|
568 |
+
print(f"Standard deviation: {var_dist_std_dev:.4f}")
|
569 |
+
return calc_variance, var_dist_std_dev, var_dist_var
|
570 |
+
|
571 |
+
|
572 |
+
@app.cell(hide_code=True)
|
573 |
+
def _(mo):
|
574 |
+
mo.md(
|
575 |
+
r"""
|
576 |
+
## PMF vs. CDF
|
577 |
+
|
578 |
+
The **Cumulative Distribution Function (CDF)** is related to the PMF but gives the probability that the random variable $X$ is less than or equal to a value $x$:
|
579 |
+
|
580 |
+
$$F_X(x) = P(X \leq x) = \sum_{k \leq x} p_X(k)$$
|
581 |
+
|
582 |
+
While the PMF gives the probability mass at each point, the CDF accumulates these probabilities.
|
583 |
+
"""
|
584 |
+
)
|
585 |
+
return
|
586 |
+
|
587 |
+
|
588 |
+
@app.cell(hide_code=True)
|
589 |
+
def _(dist_pmf_values, dist_x_values, np, plt):
|
590 |
+
# Calculate the CDF from the PMF
|
591 |
+
cdf_dist_values = np.cumsum(dist_pmf_values)
|
592 |
+
|
593 |
+
# Create a plot comparing PMF and CDF
|
594 |
+
cdf_fig, (cdf_ax1, cdf_ax2) = plt.subplots(1, 2, figsize=(12, 4))
|
595 |
+
|
596 |
+
# PMF plot
|
597 |
+
cdf_ax1.bar(dist_x_values, dist_pmf_values, width=0.4, alpha=0.7)
|
598 |
+
cdf_ax1.set_xlabel('x')
|
599 |
+
cdf_ax1.set_ylabel('P(X = x)')
|
600 |
+
cdf_ax1.set_title('Probability Mass Function (PMF)')
|
601 |
+
cdf_ax1.grid(alpha=0.3)
|
602 |
+
|
603 |
+
# CDF plot - using step function with 'post' style for proper discrete representation
|
604 |
+
cdf_ax2.step(dist_x_values, cdf_dist_values, where='post', linewidth=2, color='blue')
|
605 |
+
cdf_ax2.scatter(dist_x_values, cdf_dist_values, s=50, color='blue')
|
606 |
+
|
607 |
+
# Set appropriate limits for better visualization
|
608 |
+
if len(dist_x_values) > 0:
|
609 |
+
x_min = min(dist_x_values) - 0.5
|
610 |
+
x_max = max(dist_x_values) + 0.5
|
611 |
+
cdf_ax2.set_xlim(x_min, x_max)
|
612 |
+
cdf_ax2.set_ylim(0, 1.05) # CDF goes from 0 to 1
|
613 |
+
|
614 |
+
cdf_ax2.set_xlabel('x')
|
615 |
+
cdf_ax2.set_ylabel('P(X ≤ x)')
|
616 |
+
cdf_ax2.set_title('Cumulative Distribution Function (CDF)')
|
617 |
+
cdf_ax2.grid(alpha=0.3)
|
618 |
+
|
619 |
+
plt.tight_layout()
|
620 |
+
plt.gca() # Return the current axes to ensure proper display
|
621 |
+
return cdf_ax1, cdf_ax2, cdf_dist_values, cdf_fig, x_max, x_min
|
622 |
+
|
623 |
+
|
624 |
+
@app.cell(hide_code=True)
|
625 |
+
def _(mo):
|
626 |
+
mo.md(
|
627 |
+
r"""
|
628 |
+
The graphs above illustrate the key difference between PMF and CDF:
|
629 |
+
|
630 |
+
- **PMF (left)**: Shows the probability of the random variable taking each specific value: P(X = x)
|
631 |
+
- **CDF (right)**: Shows the probability of the random variable being less than or equal to each value: P(X ≤ x)
|
632 |
+
|
633 |
+
The CDF at any point is the sum of all PMF values up to and including that point. This is why the CDF is always non-decreasing and eventually reaches 1. For discrete distributions like this one, the CDF forms a step function that jumps at each value in the support of the random variable.
|
634 |
+
"""
|
635 |
+
)
|
636 |
+
return
|
637 |
+
|
638 |
+
|
639 |
+
@app.cell(hide_code=True)
|
640 |
+
def _(mo):
|
641 |
+
mo.md(
|
642 |
+
r"""
|
643 |
+
## Test Your Understanding
|
644 |
+
|
645 |
+
Choose what you believe are the correct options in the questions below:
|
646 |
+
|
647 |
+
<details>
|
648 |
+
<summary>If X is a discrete random variable with PMF p(x), then p(x) must always be less than 1</summary>
|
649 |
+
❌ False! While most values in a PMF are typically less than 1, a PMF can have p(x) = 1 for a specific value if the random variable always takes that value (with 100% probability).
|
650 |
+
</details>
|
651 |
+
|
652 |
+
<details>
|
653 |
+
<summary>The sum of all probabilities in a PMF must equal exactly 1</summary>
|
654 |
+
✅ True! This is a fundamental property of any valid PMF. The total probability across all possible values must be 1, as the random variable must take some value.
|
655 |
+
</details>
|
656 |
+
|
657 |
+
<details>
|
658 |
+
<summary>A PMF can be estimated from data by creating a normalized histogram</summary>
|
659 |
+
✅ True! Counting the frequency of each value and dividing by the total number of observations gives an empirical PMF.
|
660 |
+
</details>
|
661 |
+
|
662 |
+
<details>
|
663 |
+
<summary>The expected value of a discrete random variable is always one of the possible values of the variable</summary>
|
664 |
+
❌ False! The expected value is a weighted average and may not be a value the random variable can actually take. For example, the expected value of a fair die roll is 3.5, which is not a possible outcome.
|
665 |
+
</details>
|
666 |
+
"""
|
667 |
+
)
|
668 |
+
return
|
669 |
+
|
670 |
+
|
671 |
+
@app.cell(hide_code=True)
|
672 |
+
def _(mo):
|
673 |
+
mo.md(
|
674 |
+
r"""
|
675 |
+
## Practical Applications of PMFs
|
676 |
+
|
677 |
+
PMFs pop up everywhere - network engineers use them to model traffic patterns, reliability teams predict equipment failures, and marketers analyze purchase behavior. In finance, they help price options; in gaming, they're behind every dice roll. Machine learning algorithms like Naive Bayes rely on them, and they're essential for modeling rare events like genetic mutations or system failures.
|
678 |
+
"""
|
679 |
+
)
|
680 |
+
return
|
681 |
+
|
682 |
+
|
683 |
+
@app.cell(hide_code=True)
|
684 |
+
def _(mo):
|
685 |
+
mo.md(
|
686 |
+
r"""
|
687 |
+
## Key Takeaways
|
688 |
+
|
689 |
+
PMFs give us the probability picture for discrete random variables - they tell us how likely each value is, must be non-negative, and always sum to 1. We can write them as equations, draw them as graphs, or estimate them from data. They're the foundation for calculating expected values and variances, which we'll explore in our next notebook on Expectation, where we'll learn how to summarize random variables with a single, most "expected" value.
|
690 |
+
"""
|
691 |
+
)
|
692 |
+
return
|
693 |
+
|
694 |
+
|
695 |
+
@app.cell
|
696 |
+
def _():
|
697 |
+
import marimo as mo
|
698 |
+
return (mo,)
|
699 |
+
|
700 |
+
|
701 |
+
@app.cell
|
702 |
+
def _():
|
703 |
+
import matplotlib.pyplot as plt
|
704 |
+
import numpy as np
|
705 |
+
from scipy import stats
|
706 |
+
import collections
|
707 |
+
return collections, np, plt, stats
|
708 |
+
|
709 |
+
|
710 |
+
if __name__ == "__main__":
|
711 |
+
app.run()
|
probability/11_expectation.py
ADDED
@@ -0,0 +1,860 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.0",
|
6 |
+
# "numpy==2.2.3",
|
7 |
+
# "scipy==1.15.2",
|
8 |
+
# ]
|
9 |
+
# ///
|
10 |
+
|
11 |
+
import marimo
|
12 |
+
|
13 |
+
__generated_with = "0.11.19"
|
14 |
+
app = marimo.App(width="medium", app_title="Expectation")
|
15 |
+
|
16 |
+
|
17 |
+
@app.cell(hide_code=True)
|
18 |
+
def _(mo):
|
19 |
+
mo.md(
|
20 |
+
r"""
|
21 |
+
# Expectation
|
22 |
+
|
23 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/expectation/), by Stanford professor Chris Piech._
|
24 |
+
|
25 |
+
A random variable is fully represented by its Probability Mass Function (PMF), which describes each value the random variable can take on and the corresponding probabilities. However, a PMF can contain a lot of information. Sometimes it's useful to summarize a random variable with a single value!
|
26 |
+
|
27 |
+
The most common, and arguably the most useful, summary of a random variable is its **Expectation** (also called the expected value or mean).
|
28 |
+
"""
|
29 |
+
)
|
30 |
+
return
|
31 |
+
|
32 |
+
|
33 |
+
@app.cell(hide_code=True)
|
34 |
+
def _(mo):
|
35 |
+
mo.md(
|
36 |
+
r"""
|
37 |
+
## Definition of Expectation
|
38 |
+
|
39 |
+
The expectation of a random variable $X$, written $E[X]$, is the average of all the values the random variable can take on, each weighted by the probability that the random variable will take on that value.
|
40 |
+
|
41 |
+
$$E[X] = \sum_x x \cdot P(X=x)$$
|
42 |
+
|
43 |
+
Expectation goes by many other names: Mean, Weighted Average, Center of Mass, 1st Moment. All of these are calculated using the same formula.
|
44 |
+
"""
|
45 |
+
)
|
46 |
+
return
|
47 |
+
|
48 |
+
|
49 |
+
@app.cell(hide_code=True)
|
50 |
+
def _(mo):
|
51 |
+
mo.md(
|
52 |
+
r"""
|
53 |
+
## Intuition Behind Expectation
|
54 |
+
|
55 |
+
The expected value represents the long-run average value of a random variable over many independent repetitions of an experiment.
|
56 |
+
|
57 |
+
For example, if you roll a fair six-sided die many times and calculate the average of all rolls, that average will approach the expected value of 3.5 as the number of rolls increases.
|
58 |
+
|
59 |
+
Let's visualize this concept:
|
60 |
+
"""
|
61 |
+
)
|
62 |
+
return
|
63 |
+
|
64 |
+
|
65 |
+
@app.cell(hide_code=True)
|
66 |
+
def _(np, plt):
|
67 |
+
# Set random seed for reproducibility
|
68 |
+
np.random.seed(42)
|
69 |
+
|
70 |
+
# Simulate rolling a die many times
|
71 |
+
exp_num_rolls = 1000
|
72 |
+
exp_die_rolls = np.random.randint(1, 7, size=exp_num_rolls)
|
73 |
+
|
74 |
+
# Calculate the running average
|
75 |
+
exp_running_avg = np.cumsum(exp_die_rolls) / np.arange(1, exp_num_rolls + 1)
|
76 |
+
|
77 |
+
# Create the plot
|
78 |
+
plt.figure(figsize=(10, 5))
|
79 |
+
plt.plot(range(1, exp_num_rolls + 1), exp_running_avg, label='Running Average')
|
80 |
+
plt.axhline(y=3.5, color='r', linestyle='--', label='Expected Value (3.5)')
|
81 |
+
plt.xlabel('Number of Rolls')
|
82 |
+
plt.ylabel('Average Value')
|
83 |
+
plt.title('Running Average of Die Rolls Approaching Expected Value')
|
84 |
+
plt.legend()
|
85 |
+
plt.grid(alpha=0.3)
|
86 |
+
plt.xscale('log') # Log scale to better see convergence
|
87 |
+
|
88 |
+
# Add annotations
|
89 |
+
plt.annotate('As the number of rolls increases,\nthe average approaches the expected value',
|
90 |
+
xy=(exp_num_rolls, exp_running_avg[-1]), xytext=(exp_num_rolls/3, 4),
|
91 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1.5))
|
92 |
+
|
93 |
+
plt.gca()
|
94 |
+
return exp_die_rolls, exp_num_rolls, exp_running_avg
|
95 |
+
|
96 |
+
|
97 |
+
@app.cell(hide_code=True)
|
98 |
+
def _(mo):
|
99 |
+
mo.md(r"""## Properties of Expectation""")
|
100 |
+
return
|
101 |
+
|
102 |
+
|
103 |
+
@app.cell(hide_code=True)
|
104 |
+
def _(mo):
|
105 |
+
mo.accordion(
|
106 |
+
{
|
107 |
+
"1. Linearity of Expectation": mo.md(
|
108 |
+
r"""
|
109 |
+
$$E[aX + b] = a \cdot E[X] + b$$
|
110 |
+
|
111 |
+
Where $a$ and $b$ are constants (not random variables).
|
112 |
+
|
113 |
+
This means that if you multiply a random variable by a constant, the expectation is multiplied by that constant. And if you add a constant to a random variable, the expectation increases by that constant.
|
114 |
+
"""
|
115 |
+
),
|
116 |
+
"2. Expectation of the Sum of Random Variables": mo.md(
|
117 |
+
r"""
|
118 |
+
$$E[X + Y] = E[X] + E[Y]$$
|
119 |
+
|
120 |
+
This is true regardless of the relationship between $X$ and $Y$. They can be dependent, and they can have different distributions. This also applies with more than two random variables:
|
121 |
+
|
122 |
+
$$E\left[\sum_{i=1}^n X_i\right] = \sum_{i=1}^n E[X_i]$$
|
123 |
+
"""
|
124 |
+
),
|
125 |
+
"3. Law of the Unconscious Statistician (LOTUS)": mo.md(
|
126 |
+
r"""
|
127 |
+
$$E[g(X)] = \sum_x g(x) \cdot P(X=x)$$
|
128 |
+
|
129 |
+
This allows us to calculate the expected value of a function $g(X)$ of a random variable $X$ when we know the probability distribution of $X$ but don't explicitly know the distribution of $g(X)$.
|
130 |
+
|
131 |
+
This theorem has the humorous name "Law of the Unconscious Statistician" (LOTUS) because it's so useful that you should be able to employ it unconsciously.
|
132 |
+
"""
|
133 |
+
),
|
134 |
+
"4. Expectation of a Constant": mo.md(
|
135 |
+
r"""
|
136 |
+
$$E[a] = a$$
|
137 |
+
|
138 |
+
Sometimes in proofs, you'll end up with the expectation of a constant (rather than a random variable). Since a constant doesn't change, its expected value is just the constant itself.
|
139 |
+
"""
|
140 |
+
),
|
141 |
+
}
|
142 |
+
)
|
143 |
+
return
|
144 |
+
|
145 |
+
|
146 |
+
@app.cell(hide_code=True)
|
147 |
+
def _(mo):
|
148 |
+
mo.md(
|
149 |
+
r"""
|
150 |
+
## Calculating Expectation
|
151 |
+
|
152 |
+
Let's calculate the expected value for some common examples:
|
153 |
+
|
154 |
+
### Example 1: Fair Die Roll
|
155 |
+
|
156 |
+
For a fair six-sided die, the PMF is:
|
157 |
+
|
158 |
+
$$P(X=x) = \frac{1}{6} \text{ for } x \in \{1, 2, 3, 4, 5, 6\}$$
|
159 |
+
|
160 |
+
The expected value is:
|
161 |
+
|
162 |
+
$$E[X] = 1 \cdot \frac{1}{6} + 2 \cdot \frac{1}{6} + 3 \cdot \frac{1}{6} + 4 \cdot \frac{1}{6} + 5 \cdot \frac{1}{6} + 6 \cdot \frac{1}{6} = \frac{21}{6} = 3.5$$
|
163 |
+
|
164 |
+
Let's implement this calculation in Python:
|
165 |
+
"""
|
166 |
+
)
|
167 |
+
return
|
168 |
+
|
169 |
+
|
170 |
+
@app.cell
|
171 |
+
def _():
|
172 |
+
def calc_expectation_die():
|
173 |
+
"""Calculate the expected value of a fair six-sided die roll."""
|
174 |
+
exp_die_values = range(1, 7)
|
175 |
+
exp_die_probs = [1/6] * 6
|
176 |
+
|
177 |
+
exp_die_expected = sum(x * p for x, p in zip(exp_die_values, exp_die_probs))
|
178 |
+
return exp_die_expected
|
179 |
+
|
180 |
+
exp_die_result = calc_expectation_die()
|
181 |
+
print(f"Expected value of a fair die roll: {exp_die_result}")
|
182 |
+
return calc_expectation_die, exp_die_result
|
183 |
+
|
184 |
+
|
185 |
+
@app.cell(hide_code=True)
|
186 |
+
def _(mo):
|
187 |
+
mo.md(
|
188 |
+
r"""
|
189 |
+
### Example 2: Sum of Two Dice
|
190 |
+
|
191 |
+
Now let's calculate the expected value for the sum of two fair dice. First, we need the PMF:
|
192 |
+
"""
|
193 |
+
)
|
194 |
+
return
|
195 |
+
|
196 |
+
|
197 |
+
@app.cell
|
198 |
+
def _():
|
199 |
+
def pmf_sum_two_dice(y_val):
|
200 |
+
"""Returns the probability that the sum of two dice is y."""
|
201 |
+
# Count the number of ways to get sum y
|
202 |
+
exp_count = 0
|
203 |
+
for dice1 in range(1, 7):
|
204 |
+
for dice2 in range(1, 7):
|
205 |
+
if dice1 + dice2 == y_val:
|
206 |
+
exp_count += 1
|
207 |
+
return exp_count / 36 # There are 36 possible outcomes (6×6)
|
208 |
+
|
209 |
+
# Test the function for a few values
|
210 |
+
exp_test_values = [2, 7, 12]
|
211 |
+
for exp_test_y in exp_test_values:
|
212 |
+
print(f"P(Y = {exp_test_y}) = {pmf_sum_two_dice(exp_test_y)}")
|
213 |
+
return exp_test_values, exp_test_y, pmf_sum_two_dice
|
214 |
+
|
215 |
+
|
216 |
+
@app.cell
|
217 |
+
def _(pmf_sum_two_dice):
|
218 |
+
def calc_expectation_sum_two_dice():
|
219 |
+
"""Calculate the expected value of the sum of two dice."""
|
220 |
+
exp_sum_two_dice = 0
|
221 |
+
# Sum of dice can take on the values 2 through 12
|
222 |
+
for exp_x in range(2, 13):
|
223 |
+
exp_pr_x = pmf_sum_two_dice(exp_x) # PMF gives P(sum is x)
|
224 |
+
exp_sum_two_dice += exp_x * exp_pr_x
|
225 |
+
return exp_sum_two_dice
|
226 |
+
|
227 |
+
exp_sum_result = calc_expectation_sum_two_dice()
|
228 |
+
|
229 |
+
# Round to 2 decimal places for display
|
230 |
+
exp_sum_result_rounded = round(exp_sum_result, 2)
|
231 |
+
|
232 |
+
print(f"Expected value of the sum of two dice: {exp_sum_result_rounded}")
|
233 |
+
|
234 |
+
# Let's also verify this with a direct calculation
|
235 |
+
exp_direct_calc = sum(x * pmf_sum_two_dice(x) for x in range(2, 13))
|
236 |
+
exp_direct_calc_rounded = round(exp_direct_calc, 2)
|
237 |
+
|
238 |
+
print(f"Direct calculation: {exp_direct_calc_rounded}")
|
239 |
+
|
240 |
+
# Verify that this equals 7
|
241 |
+
print(f"Is the expected value exactly 7? {abs(exp_sum_result - 7) < 1e-10}")
|
242 |
+
return (
|
243 |
+
calc_expectation_sum_two_dice,
|
244 |
+
exp_direct_calc,
|
245 |
+
exp_direct_calc_rounded,
|
246 |
+
exp_sum_result,
|
247 |
+
exp_sum_result_rounded,
|
248 |
+
)
|
249 |
+
|
250 |
+
|
251 |
+
@app.cell(hide_code=True)
|
252 |
+
def _(mo):
|
253 |
+
mo.md(
|
254 |
+
r"""
|
255 |
+
### Visualizing Expectation
|
256 |
+
|
257 |
+
Let's visualize the expectation for the sum of two dice. The expected value is the "center of mass" of the PMF:
|
258 |
+
"""
|
259 |
+
)
|
260 |
+
return
|
261 |
+
|
262 |
+
|
263 |
+
@app.cell(hide_code=True)
|
264 |
+
def _(plt, pmf_sum_two_dice):
|
265 |
+
# Create the visualization
|
266 |
+
exp_y_values = list(range(2, 13))
|
267 |
+
exp_probabilities = [pmf_sum_two_dice(y) for y in exp_y_values]
|
268 |
+
|
269 |
+
dice_fig, dice_ax = plt.subplots(figsize=(10, 5))
|
270 |
+
dice_ax.bar(exp_y_values, exp_probabilities, width=0.4)
|
271 |
+
dice_ax.axvline(x=7, color='r', linestyle='--', linewidth=2, label='Expected Value (7)')
|
272 |
+
|
273 |
+
dice_ax.set_xticks(exp_y_values)
|
274 |
+
dice_ax.set_xlabel('Sum of two dice (y)')
|
275 |
+
dice_ax.set_ylabel('Probability: P(Y = y)')
|
276 |
+
dice_ax.set_title('PMF of Sum of Two Dice with Expected Value')
|
277 |
+
dice_ax.grid(alpha=0.3)
|
278 |
+
dice_ax.legend()
|
279 |
+
|
280 |
+
# Add probability values on top of bars
|
281 |
+
for exp_i, exp_prob in enumerate(exp_probabilities):
|
282 |
+
dice_ax.text(exp_y_values[exp_i], exp_prob + 0.001, f'{exp_prob:.3f}', ha='center')
|
283 |
+
|
284 |
+
plt.tight_layout()
|
285 |
+
plt.gca()
|
286 |
+
return dice_ax, dice_fig, exp_i, exp_prob, exp_probabilities, exp_y_values
|
287 |
+
|
288 |
+
|
289 |
+
@app.cell(hide_code=True)
|
290 |
+
def _(mo):
|
291 |
+
mo.md(
|
292 |
+
r"""
|
293 |
+
## Demonstrating the Properties of Expectation
|
294 |
+
|
295 |
+
Let's demonstrate some of these properties with examples:
|
296 |
+
"""
|
297 |
+
)
|
298 |
+
return
|
299 |
+
|
300 |
+
|
301 |
+
@app.cell
|
302 |
+
def _(exp_die_result):
|
303 |
+
# Demonstrate linearity of expectation (1)
|
304 |
+
# E[aX + b] = a*E[X] + b
|
305 |
+
|
306 |
+
# For a die roll X with E[X] = 3.5
|
307 |
+
prop_a = 2
|
308 |
+
prop_b = 10
|
309 |
+
|
310 |
+
# Calculate E[2X + 10] using the property
|
311 |
+
prop_expected_using_property = prop_a * exp_die_result + prop_b
|
312 |
+
prop_expected_using_property_rounded = round(prop_expected_using_property, 2)
|
313 |
+
|
314 |
+
print(f"Using linearity property: E[{prop_a}X + {prop_b}] = {prop_a} * E[X] + {prop_b} = {prop_expected_using_property_rounded}")
|
315 |
+
|
316 |
+
# Calculate E[2X + 10] directly
|
317 |
+
prop_expected_direct = sum((prop_a * x + prop_b) * (1/6) for x in range(1, 7))
|
318 |
+
prop_expected_direct_rounded = round(prop_expected_direct, 2)
|
319 |
+
|
320 |
+
print(f"Direct calculation: E[{prop_a}X + {prop_b}] = {prop_expected_direct_rounded}")
|
321 |
+
|
322 |
+
# Verify they match
|
323 |
+
print(f"Do they match? {abs(prop_expected_using_property - prop_expected_direct) < 1e-10}")
|
324 |
+
return (
|
325 |
+
prop_a,
|
326 |
+
prop_b,
|
327 |
+
prop_expected_direct,
|
328 |
+
prop_expected_direct_rounded,
|
329 |
+
prop_expected_using_property,
|
330 |
+
prop_expected_using_property_rounded,
|
331 |
+
)
|
332 |
+
|
333 |
+
|
334 |
+
@app.cell(hide_code=True)
|
335 |
+
def _(mo):
|
336 |
+
mo.md(
|
337 |
+
r"""
|
338 |
+
### Law of the Unconscious Statistician (LOTUS)
|
339 |
+
|
340 |
+
Let's use LOTUS to calculate $E[X^2]$ for a die roll, which will be useful when we study variance:
|
341 |
+
"""
|
342 |
+
)
|
343 |
+
return
|
344 |
+
|
345 |
+
|
346 |
+
@app.cell
|
347 |
+
def _():
|
348 |
+
# Calculate E[X^2] for a die roll using LOTUS (3)
|
349 |
+
lotus_die_values = range(1, 7)
|
350 |
+
lotus_die_probs = [1/6] * 6
|
351 |
+
|
352 |
+
# Using LOTUS: E[X^2] = sum(x^2 * P(X=x))
|
353 |
+
lotus_expected_x_squared = sum(x**2 * p for x, p in zip(lotus_die_values, lotus_die_probs))
|
354 |
+
lotus_expected_x_squared_rounded = round(lotus_expected_x_squared, 2)
|
355 |
+
|
356 |
+
expected_x_squared = 3.5**2
|
357 |
+
expected_x_squared_rounded = round(expected_x_squared, 2)
|
358 |
+
|
359 |
+
print(f"E[X^2] for a die roll = {lotus_expected_x_squared_rounded}")
|
360 |
+
print(f"(E[X])^2 for a die roll = {expected_x_squared_rounded}")
|
361 |
+
return (
|
362 |
+
expected_x_squared,
|
363 |
+
expected_x_squared_rounded,
|
364 |
+
lotus_die_probs,
|
365 |
+
lotus_die_values,
|
366 |
+
lotus_expected_x_squared,
|
367 |
+
lotus_expected_x_squared_rounded,
|
368 |
+
)
|
369 |
+
|
370 |
+
|
371 |
+
@app.cell(hide_code=True)
|
372 |
+
def _(mo):
|
373 |
+
mo.md(
|
374 |
+
r"""
|
375 |
+
/// Note
|
376 |
+
Note that E[X^2] != (E[X])^2
|
377 |
+
"""
|
378 |
+
)
|
379 |
+
return
|
380 |
+
|
381 |
+
|
382 |
+
@app.cell(hide_code=True)
|
383 |
+
def _(mo):
|
384 |
+
mo.md(
|
385 |
+
r"""
|
386 |
+
## Interactive Example
|
387 |
+
|
388 |
+
Let's explore how the expected value changes as we adjust the parameters of common probability distributions. This interactive visualization focuses specifically on the relationship between distribution parameters and expected values.
|
389 |
+
|
390 |
+
Use the controls below to select a distribution and adjust its parameters. The graph will show how the expected value changes across a range of parameter values.
|
391 |
+
"""
|
392 |
+
)
|
393 |
+
return
|
394 |
+
|
395 |
+
|
396 |
+
@app.cell(hide_code=True)
|
397 |
+
def _(mo):
|
398 |
+
# Create UI elements for distribution selection
|
399 |
+
dist_selection = mo.ui.dropdown(
|
400 |
+
options=[
|
401 |
+
"bernoulli",
|
402 |
+
"binomial",
|
403 |
+
"geometric",
|
404 |
+
"poisson"
|
405 |
+
],
|
406 |
+
value="bernoulli",
|
407 |
+
label="Select a distribution"
|
408 |
+
)
|
409 |
+
return (dist_selection,)
|
410 |
+
|
411 |
+
|
412 |
+
@app.cell(hide_code=True)
|
413 |
+
def _(dist_selection):
|
414 |
+
dist_selection.center()
|
415 |
+
return
|
416 |
+
|
417 |
+
|
418 |
+
@app.cell(hide_code=True)
|
419 |
+
def _(dist_description):
|
420 |
+
dist_description
|
421 |
+
return
|
422 |
+
|
423 |
+
|
424 |
+
@app.cell(hide_code=True)
|
425 |
+
def _(mo):
|
426 |
+
mo.md("""### Adjust Parameters""")
|
427 |
+
return
|
428 |
+
|
429 |
+
|
430 |
+
@app.cell(hide_code=True)
|
431 |
+
def _(controls):
|
432 |
+
controls
|
433 |
+
return
|
434 |
+
|
435 |
+
|
436 |
+
@app.cell(hide_code=True)
|
437 |
+
def _(
|
438 |
+
dist_selection,
|
439 |
+
lambda_range,
|
440 |
+
np,
|
441 |
+
param_lambda,
|
442 |
+
param_n,
|
443 |
+
param_p,
|
444 |
+
param_range,
|
445 |
+
plt,
|
446 |
+
):
|
447 |
+
# Calculate expected values based on the selected distribution
|
448 |
+
if dist_selection.value == "bernoulli":
|
449 |
+
# Get parameter range for visualization
|
450 |
+
p_min, p_max = param_range.value
|
451 |
+
param_values = np.linspace(p_min, p_max, 100)
|
452 |
+
|
453 |
+
# E[X] = p for Bernoulli
|
454 |
+
expected_values = param_values
|
455 |
+
current_param = param_p.value
|
456 |
+
current_expected = round(current_param, 2)
|
457 |
+
x_label = "p (probability of success)"
|
458 |
+
title = "Expected Value of Bernoulli Distribution"
|
459 |
+
formula = "E[X] = p"
|
460 |
+
|
461 |
+
elif dist_selection.value == "binomial":
|
462 |
+
p_min, p_max = param_range.value
|
463 |
+
param_values = np.linspace(p_min, p_max, 100)
|
464 |
+
|
465 |
+
# E[X] = np for Binomial
|
466 |
+
n = int(param_n.value)
|
467 |
+
expected_values = [n * p for p in param_values]
|
468 |
+
current_param = param_p.value
|
469 |
+
current_expected = round(n * current_param, 2)
|
470 |
+
x_label = "p (probability of success)"
|
471 |
+
title = f"Expected Value of Binomial Distribution (n={n})"
|
472 |
+
formula = f"E[X] = n × p = {n} × p"
|
473 |
+
|
474 |
+
elif dist_selection.value == "geometric":
|
475 |
+
p_min, p_max = param_range.value
|
476 |
+
# Ensure p is not 0 for geometric distribution
|
477 |
+
p_min = max(0.01, p_min)
|
478 |
+
param_values = np.linspace(p_min, p_max, 100)
|
479 |
+
|
480 |
+
# E[X] = 1/p for Geometric
|
481 |
+
expected_values = [1/p for p in param_values]
|
482 |
+
current_param = param_p.value
|
483 |
+
current_expected = round(1 / current_param, 2)
|
484 |
+
x_label = "p (probability of success)"
|
485 |
+
title = "Expected Value of Geometric Distribution"
|
486 |
+
formula = "E[X] = 1/p"
|
487 |
+
|
488 |
+
else: # Poisson
|
489 |
+
lambda_min, lambda_max = lambda_range.value
|
490 |
+
param_values = np.linspace(lambda_min, lambda_max, 100)
|
491 |
+
|
492 |
+
# E[X] = lambda for Poisson
|
493 |
+
expected_values = param_values
|
494 |
+
current_param = param_lambda.value
|
495 |
+
current_expected = round(current_param, 2)
|
496 |
+
x_label = "λ (rate parameter)"
|
497 |
+
title = "Expected Value of Poisson Distribution"
|
498 |
+
formula = "E[X] = λ"
|
499 |
+
|
500 |
+
# Create the plot
|
501 |
+
dist_fig, dist_ax = plt.subplots(figsize=(10, 6))
|
502 |
+
|
503 |
+
# Plot the expected value function
|
504 |
+
dist_ax.plot(param_values, expected_values, 'b-', linewidth=2, label="Expected Value Function")
|
505 |
+
|
506 |
+
dist_ax.plot(current_param, current_expected, 'ro', markersize=10, label=f"Current Value: E[X] = {current_expected}")
|
507 |
+
|
508 |
+
dist_ax.hlines(current_expected, param_values[0], current_param, colors='r', linestyles='dashed')
|
509 |
+
|
510 |
+
dist_ax.vlines(current_param, 0, current_expected, colors='r', linestyles='dashed')
|
511 |
+
|
512 |
+
dist_ax.fill_between(param_values, 0, expected_values, alpha=0.2, color='blue')
|
513 |
+
|
514 |
+
dist_ax.set_xlabel(x_label, fontsize=12)
|
515 |
+
dist_ax.set_ylabel("Expected Value: E[X]", fontsize=12)
|
516 |
+
dist_ax.set_title(title, fontsize=14, fontweight='bold')
|
517 |
+
dist_ax.grid(True, alpha=0.3)
|
518 |
+
|
519 |
+
# Move legend to lower right to avoid overlap with formula
|
520 |
+
dist_ax.legend(loc='lower right', fontsize=10)
|
521 |
+
|
522 |
+
# Add formula text box in upper left
|
523 |
+
dist_props = dict(boxstyle='round', facecolor='white', alpha=0.8)
|
524 |
+
dist_ax.text(0.02, 0.95, formula, transform=dist_ax.transAxes, fontsize=12,
|
525 |
+
verticalalignment='top', bbox=dist_props)
|
526 |
+
|
527 |
+
if dist_selection.value == "geometric":
|
528 |
+
max_y = min(50, 2/max(0.01, param_values[0]))
|
529 |
+
dist_ax.set_ylim(0, max_y)
|
530 |
+
elif dist_selection.value == "binomial":
|
531 |
+
dist_ax.set_ylim(0, int(param_n.value) + 1)
|
532 |
+
else:
|
533 |
+
dist_ax.set_ylim(0, max(expected_values) * 1.1)
|
534 |
+
|
535 |
+
annotation_x = current_param + (param_values[-1] - param_values[0]) * 0.05
|
536 |
+
annotation_y = current_expected
|
537 |
+
|
538 |
+
# Adjust annotation position if it would go off the chart
|
539 |
+
if annotation_x > param_values[-1] * 0.9:
|
540 |
+
annotation_x = current_param - (param_values[-1] - param_values[0]) * 0.2
|
541 |
+
|
542 |
+
dist_ax.annotate(
|
543 |
+
f"Parameter: {current_param:.2f}\nE[X] = {current_expected}",
|
544 |
+
xy=(current_param, current_expected),
|
545 |
+
xytext=(annotation_x, annotation_y),
|
546 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, alpha=0.7),
|
547 |
+
bbox=dist_props
|
548 |
+
)
|
549 |
+
|
550 |
+
plt.tight_layout()
|
551 |
+
plt.gca()
|
552 |
+
return (
|
553 |
+
annotation_x,
|
554 |
+
annotation_y,
|
555 |
+
current_expected,
|
556 |
+
current_param,
|
557 |
+
dist_ax,
|
558 |
+
dist_fig,
|
559 |
+
dist_props,
|
560 |
+
expected_values,
|
561 |
+
formula,
|
562 |
+
lambda_max,
|
563 |
+
lambda_min,
|
564 |
+
max_y,
|
565 |
+
n,
|
566 |
+
p_max,
|
567 |
+
p_min,
|
568 |
+
param_values,
|
569 |
+
title,
|
570 |
+
x_label,
|
571 |
+
)
|
572 |
+
|
573 |
+
|
574 |
+
@app.cell(hide_code=True)
|
575 |
+
def _(mo):
|
576 |
+
mo.md(
|
577 |
+
r"""
|
578 |
+
## Expectation vs. Mode
|
579 |
+
|
580 |
+
The expected value (mean) of a random variable is not always the same as its most likely value (mode). Let's explore this with an example:
|
581 |
+
"""
|
582 |
+
)
|
583 |
+
return
|
584 |
+
|
585 |
+
|
586 |
+
@app.cell(hide_code=True)
|
587 |
+
def _(np, plt, stats):
|
588 |
+
# Create a skewed distribution
|
589 |
+
skew_n = 10
|
590 |
+
skew_p = 0.25
|
591 |
+
|
592 |
+
# Binomial PMF
|
593 |
+
skew_x_values = np.arange(0, skew_n+1)
|
594 |
+
skew_pmf_values = stats.binom.pmf(skew_x_values, skew_n, skew_p)
|
595 |
+
|
596 |
+
# Find the mode (most likely value)
|
597 |
+
skew_mode = skew_x_values[np.argmax(skew_pmf_values)]
|
598 |
+
|
599 |
+
# Calculate the expected value
|
600 |
+
skew_expected = skew_n * skew_p
|
601 |
+
skew_expected_rounded = round(skew_expected, 2)
|
602 |
+
|
603 |
+
skew_fig, skew_ax = plt.subplots(figsize=(10, 5))
|
604 |
+
skew_ax.bar(skew_x_values, skew_pmf_values, alpha=0.7, width=0.4)
|
605 |
+
|
606 |
+
# Add vertical lines for mode and expected value
|
607 |
+
skew_ax.axvline(x=skew_mode, color='g', linestyle='--', linewidth=2,
|
608 |
+
label=f'Mode = {skew_mode} (Most likely value)')
|
609 |
+
skew_ax.axvline(x=skew_expected, color='r', linestyle='--', linewidth=2,
|
610 |
+
label=f'Expected Value = {skew_expected_rounded} (Mean)')
|
611 |
+
|
612 |
+
skew_ax.annotate('Mode', xy=(skew_mode, 0.05), xytext=(skew_mode-2.0, 0.1),
|
613 |
+
arrowprops=dict(facecolor='green', shrink=0.05, width=1.5), color='green')
|
614 |
+
skew_ax.annotate('Expected Value', xy=(skew_expected, 0.05), xytext=(skew_expected+1, 0.15),
|
615 |
+
arrowprops=dict(facecolor='red', shrink=0.05, width=1.5), color='red')
|
616 |
+
|
617 |
+
if skew_mode != int(skew_expected):
|
618 |
+
min_x = min(skew_mode, skew_expected)
|
619 |
+
max_x = max(skew_mode, skew_expected)
|
620 |
+
skew_ax.axvspan(min_x, max_x, alpha=0.2, color='purple')
|
621 |
+
|
622 |
+
# Add text explaining the difference
|
623 |
+
mid_x = (skew_mode + skew_expected) / 2
|
624 |
+
skew_ax.text(mid_x, max(skew_pmf_values) * 0.5,
|
625 |
+
f"Difference: {abs(skew_mode - skew_expected_rounded):.2f}",
|
626 |
+
ha='center', va='center', bbox=dict(facecolor='white', alpha=0.7))
|
627 |
+
|
628 |
+
skew_ax.set_xlabel('Number of Successes')
|
629 |
+
skew_ax.set_ylabel('Probability')
|
630 |
+
skew_ax.set_title(f'Binomial Distribution (n={skew_n}, p={skew_p})')
|
631 |
+
skew_ax.grid(alpha=0.3)
|
632 |
+
skew_ax.legend()
|
633 |
+
|
634 |
+
plt.tight_layout()
|
635 |
+
plt.gca()
|
636 |
+
return (
|
637 |
+
max_x,
|
638 |
+
mid_x,
|
639 |
+
min_x,
|
640 |
+
skew_ax,
|
641 |
+
skew_expected,
|
642 |
+
skew_expected_rounded,
|
643 |
+
skew_fig,
|
644 |
+
skew_mode,
|
645 |
+
skew_n,
|
646 |
+
skew_p,
|
647 |
+
skew_pmf_values,
|
648 |
+
skew_x_values,
|
649 |
+
)
|
650 |
+
|
651 |
+
|
652 |
+
@app.cell(hide_code=True)
|
653 |
+
def _(mo):
|
654 |
+
mo.md(
|
655 |
+
r"""
|
656 |
+
/// NOTE
|
657 |
+
For the sum of two dice we calculated earlier, we found the expected value to be exactly 7. In that case, 7 also happens to be the mode (most likely outcome) of the distribution. However, this is just a coincidence for this particular example!
|
658 |
+
|
659 |
+
As we can see from the binomial distribution above, the expected value (2.50) and the mode (2) are often different values (this is common in skewed distributions). The expected value represents the "center of mass" of the distribution, while the mode represents the most likely single outcome.
|
660 |
+
"""
|
661 |
+
)
|
662 |
+
return
|
663 |
+
|
664 |
+
|
665 |
+
@app.cell(hide_code=True)
|
666 |
+
def _(mo):
|
667 |
+
mo.md(
|
668 |
+
r"""
|
669 |
+
## 🤔 Test Your Understanding
|
670 |
+
|
671 |
+
Choose what you believe are the correct options in the questions below:
|
672 |
+
|
673 |
+
<details>
|
674 |
+
<summary>The expected value of a random variable is always one of the possible values the random variable can take.</summary>
|
675 |
+
❌ False! The expected value is a weighted average and may not be a value the random variable can actually take. For example, the expected value of a fair die roll is 3.5, which is not a possible outcome.
|
676 |
+
</details>
|
677 |
+
|
678 |
+
<details>
|
679 |
+
<summary>If X and Y are independent random variables, then E[X·Y] = E[X]·E[Y].</summary>
|
680 |
+
✅ True! For independent random variables, the expectation of their product equals the product of their expectations.
|
681 |
+
</details>
|
682 |
+
|
683 |
+
<details>
|
684 |
+
<summary>The expected value of a constant random variable (one that always takes the same value) is that constant.</summary>
|
685 |
+
✅ True! If X = c with probability 1, then E[X] = c.
|
686 |
+
</details>
|
687 |
+
|
688 |
+
<details>
|
689 |
+
<summary>The expected value of the sum of two random variables is always the sum of their expected values, regardless of whether they are independent.</summary>
|
690 |
+
✅ True! This is the linearity of expectation property: E[X + Y] = E[X] + E[Y], which holds regardless of dependence.
|
691 |
+
</details>
|
692 |
+
"""
|
693 |
+
)
|
694 |
+
return
|
695 |
+
|
696 |
+
|
697 |
+
@app.cell(hide_code=True)
|
698 |
+
def _(mo):
|
699 |
+
mo.md(
|
700 |
+
r"""
|
701 |
+
## Practical Applications of Expectation
|
702 |
+
|
703 |
+
Expected values show up everywhere - from investment decisions and insurance pricing to machine learning algorithms and game design. Engineers use them to predict system reliability, data scientists to understand customer behavior, and economists to model market outcomes. They're essential for risk assessment in project management and for optimizing resource allocation in operations research.
|
704 |
+
"""
|
705 |
+
)
|
706 |
+
return
|
707 |
+
|
708 |
+
|
709 |
+
@app.cell(hide_code=True)
|
710 |
+
def _(mo):
|
711 |
+
mo.md(
|
712 |
+
r"""
|
713 |
+
## Key Takeaways
|
714 |
+
|
715 |
+
Expectation gives us a single value that summarizes a random variable's central tendency - it's the weighted average of all possible outcomes, where the weights are probabilities. The linearity property makes expectations easy to work with, even for complex combinations of random variables. While a PMF gives the complete probability picture, expectation provides an essential summary that helps us make decisions under uncertainty. In our next notebook, we'll explore variance, which measures how spread out a random variable's values are around its expectation.
|
716 |
+
"""
|
717 |
+
)
|
718 |
+
return
|
719 |
+
|
720 |
+
|
721 |
+
@app.cell(hide_code=True)
|
722 |
+
def _(mo):
|
723 |
+
mo.md(r"""#### Appendix (containing helper code)""")
|
724 |
+
return
|
725 |
+
|
726 |
+
|
727 |
+
@app.cell(hide_code=True)
|
728 |
+
def _():
|
729 |
+
import marimo as mo
|
730 |
+
return (mo,)
|
731 |
+
|
732 |
+
|
733 |
+
@app.cell(hide_code=True)
|
734 |
+
def _():
|
735 |
+
import matplotlib.pyplot as plt
|
736 |
+
import numpy as np
|
737 |
+
from scipy import stats
|
738 |
+
import collections
|
739 |
+
return collections, np, plt, stats
|
740 |
+
|
741 |
+
|
742 |
+
@app.cell(hide_code=True)
|
743 |
+
def _(dist_selection, mo):
|
744 |
+
# Parameter controls for probability-based distributions
|
745 |
+
param_p = mo.ui.slider(
|
746 |
+
start=0.01,
|
747 |
+
stop=0.99,
|
748 |
+
step=0.01,
|
749 |
+
value=0.5,
|
750 |
+
label="p (probability of success)",
|
751 |
+
full_width=True
|
752 |
+
)
|
753 |
+
|
754 |
+
# Parameter control for binomial distribution
|
755 |
+
param_n = mo.ui.slider(
|
756 |
+
start=1,
|
757 |
+
stop=50,
|
758 |
+
step=1,
|
759 |
+
value=10,
|
760 |
+
label="n (number of trials)",
|
761 |
+
full_width=True
|
762 |
+
)
|
763 |
+
|
764 |
+
# Parameter control for Poisson distribution
|
765 |
+
param_lambda = mo.ui.slider(
|
766 |
+
start=0.1,
|
767 |
+
stop=20,
|
768 |
+
step=0.1,
|
769 |
+
value=5,
|
770 |
+
label="λ (rate parameter)",
|
771 |
+
full_width=True
|
772 |
+
)
|
773 |
+
|
774 |
+
# Parameter range sliders for visualization
|
775 |
+
param_range = mo.ui.range_slider(
|
776 |
+
start=0,
|
777 |
+
stop=1,
|
778 |
+
step=0.01,
|
779 |
+
value=[0, 1],
|
780 |
+
label="Parameter range to visualize",
|
781 |
+
full_width=True
|
782 |
+
)
|
783 |
+
|
784 |
+
lambda_range = mo.ui.range_slider(
|
785 |
+
start=0,
|
786 |
+
stop=20,
|
787 |
+
step=0.1,
|
788 |
+
value=[0, 20],
|
789 |
+
label="λ range to visualize",
|
790 |
+
full_width=True
|
791 |
+
)
|
792 |
+
|
793 |
+
# Display appropriate controls based on the selected distribution
|
794 |
+
if dist_selection.value == "bernoulli":
|
795 |
+
controls = mo.hstack([param_p, param_range], justify="space-around")
|
796 |
+
elif dist_selection.value == "binomial":
|
797 |
+
controls = mo.hstack([param_p, param_n, param_range], justify="space-around")
|
798 |
+
elif dist_selection.value == "geometric":
|
799 |
+
controls = mo.hstack([param_p, param_range], justify="space-around")
|
800 |
+
else: # poisson
|
801 |
+
controls = mo.hstack([param_lambda, lambda_range], justify="space-around")
|
802 |
+
return controls, lambda_range, param_lambda, param_n, param_p, param_range
|
803 |
+
|
804 |
+
|
805 |
+
@app.cell(hide_code=True)
|
806 |
+
def _(dist_selection, mo):
|
807 |
+
# Create distribution descriptions based on selection
|
808 |
+
if dist_selection.value == "bernoulli":
|
809 |
+
dist_description = mo.md(
|
810 |
+
r"""
|
811 |
+
**Bernoulli Distribution**
|
812 |
+
|
813 |
+
A Bernoulli distribution models a single trial with two possible outcomes: success (1) or failure (0).
|
814 |
+
|
815 |
+
- Parameter: $p$ = probability of success
|
816 |
+
- Expected Value: $E[X] = p$
|
817 |
+
- Example: Flipping a coin once (p = 0.5 for a fair coin)
|
818 |
+
"""
|
819 |
+
)
|
820 |
+
elif dist_selection.value == "binomial":
|
821 |
+
dist_description = mo.md(
|
822 |
+
r"""
|
823 |
+
**Binomial Distribution**
|
824 |
+
|
825 |
+
A Binomial distribution models the number of successes in $n$ independent trials.
|
826 |
+
|
827 |
+
- Parameters: $n$ = number of trials, $p$ = probability of success
|
828 |
+
- Expected Value: $E[X] = np$
|
829 |
+
- Example: Number of heads in 10 coin flips
|
830 |
+
"""
|
831 |
+
)
|
832 |
+
elif dist_selection.value == "geometric":
|
833 |
+
dist_description = mo.md(
|
834 |
+
r"""
|
835 |
+
**Geometric Distribution**
|
836 |
+
|
837 |
+
A Geometric distribution models the number of trials until the first success.
|
838 |
+
|
839 |
+
- Parameter: $p$ = probability of success
|
840 |
+
- Expected Value: $E[X] = \frac{1}{p}$
|
841 |
+
- Example: Number of coin flips until first heads
|
842 |
+
"""
|
843 |
+
)
|
844 |
+
else: # poisson
|
845 |
+
dist_description = mo.md(
|
846 |
+
r"""
|
847 |
+
**Poisson Distribution**
|
848 |
+
|
849 |
+
A Poisson distribution models the number of events occurring in a fixed interval.
|
850 |
+
|
851 |
+
- Parameter: $\lambda$ = average rate of events
|
852 |
+
- Expected Value: $E[X] = \lambda$
|
853 |
+
- Example: Number of emails received per hour
|
854 |
+
"""
|
855 |
+
)
|
856 |
+
return (dist_description,)
|
857 |
+
|
858 |
+
|
859 |
+
if __name__ == "__main__":
|
860 |
+
app.run()
|
probability/12_variance.py
ADDED
@@ -0,0 +1,631 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.0",
|
6 |
+
# "numpy==2.2.3",
|
7 |
+
# "scipy==1.15.2",
|
8 |
+
# "wigglystuff==0.1.10",
|
9 |
+
# ]
|
10 |
+
# ///
|
11 |
+
|
12 |
+
import marimo
|
13 |
+
|
14 |
+
__generated_with = "0.11.20"
|
15 |
+
app = marimo.App(width="medium", app_title="Variance")
|
16 |
+
|
17 |
+
|
18 |
+
@app.cell(hide_code=True)
|
19 |
+
def _(mo):
|
20 |
+
mo.md(
|
21 |
+
r"""
|
22 |
+
# Variance
|
23 |
+
|
24 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/variance/), by Stanford professor Chris Piech._
|
25 |
+
|
26 |
+
In our previous exploration of random variables, we learned about expectation - a measure of central tendency. However, knowing the average value alone doesn't tell us everything about a distribution. Consider these questions:
|
27 |
+
|
28 |
+
- How spread out are the values around the mean?
|
29 |
+
- How reliable is the expectation as a predictor of individual outcomes?
|
30 |
+
- How much do individual samples typically deviate from the average?
|
31 |
+
|
32 |
+
This is where **variance** comes in - it measures the spread or dispersion of a random variable around its expected value.
|
33 |
+
"""
|
34 |
+
)
|
35 |
+
return
|
36 |
+
|
37 |
+
|
38 |
+
@app.cell(hide_code=True)
|
39 |
+
def _(mo):
|
40 |
+
mo.md(
|
41 |
+
r"""
|
42 |
+
## Definition of Variance
|
43 |
+
|
44 |
+
The variance of a random variable $X$ with expected value $\mu = E[X]$ is defined as:
|
45 |
+
|
46 |
+
$$\text{Var}(X) = E[(X-\mu)^2]$$
|
47 |
+
|
48 |
+
This definition captures the average squared deviation from the mean. There's also an equivalent, often more convenient formula:
|
49 |
+
|
50 |
+
$$\text{Var}(X) = E[X^2] - (E[X])^2$$
|
51 |
+
|
52 |
+
/// tip
|
53 |
+
The second formula is usually easier to compute, as it only requires calculating $E[X^2]$ and $E[X]$, rather than working with deviations from the mean.
|
54 |
+
"""
|
55 |
+
)
|
56 |
+
return
|
57 |
+
|
58 |
+
|
59 |
+
@app.cell(hide_code=True)
|
60 |
+
def _(mo):
|
61 |
+
mo.md(
|
62 |
+
r"""
|
63 |
+
## Intuition Through Example
|
64 |
+
|
65 |
+
Let's look at a real-world example that illustrates why variance is important. Consider three different groups of graders evaluating assignments in a massive online course. Each grader has their own "grading distribution" - their pattern of assigning scores to work that deserves a 70/100.
|
66 |
+
|
67 |
+
The visualization below shows the probability distributions for three types of graders. Try clicking and dragging the blue numbers to adjust the parameters and see how they affect the variance.
|
68 |
+
"""
|
69 |
+
)
|
70 |
+
return
|
71 |
+
|
72 |
+
|
73 |
+
@app.cell(hide_code=True)
|
74 |
+
def _(mo):
|
75 |
+
mo.md(
|
76 |
+
r"""
|
77 |
+
/// TIP
|
78 |
+
Try adjusting the blue numbers above to see how:
|
79 |
+
|
80 |
+
- Increasing spread increases variance
|
81 |
+
- The mixture ratio affects how many outliers appear in Grader C's distribution
|
82 |
+
- Changing the true grade shifts all distributions but maintains their relative variances
|
83 |
+
"""
|
84 |
+
)
|
85 |
+
return
|
86 |
+
|
87 |
+
|
88 |
+
@app.cell(hide_code=True)
|
89 |
+
def _(controls):
|
90 |
+
controls
|
91 |
+
return
|
92 |
+
|
93 |
+
|
94 |
+
@app.cell(hide_code=True)
|
95 |
+
def _(
|
96 |
+
grader_a_spread,
|
97 |
+
grader_b_spread,
|
98 |
+
grader_c_mix,
|
99 |
+
np,
|
100 |
+
plt,
|
101 |
+
stats,
|
102 |
+
true_grade,
|
103 |
+
):
|
104 |
+
# Create data for three grader distributions
|
105 |
+
_grader_x = np.linspace(40, 100, 200)
|
106 |
+
|
107 |
+
# Calculate actual variances
|
108 |
+
var_a = grader_a_spread.amount**2
|
109 |
+
var_b = grader_b_spread.amount**2
|
110 |
+
var_c = (1-grader_c_mix.amount) * 3**2 + grader_c_mix.amount * 8**2 + \
|
111 |
+
grader_c_mix.amount * (1-grader_c_mix.amount) * (8-3)**2 # Mixture variance formula
|
112 |
+
|
113 |
+
# Grader A: Wide spread around true grade
|
114 |
+
grader_a = stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=grader_a_spread.amount)
|
115 |
+
|
116 |
+
# Grader B: Narrow spread around true grade
|
117 |
+
grader_b = stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=grader_b_spread.amount)
|
118 |
+
|
119 |
+
# Grader C: Mixture of distributions
|
120 |
+
grader_c = (1-grader_c_mix.amount) * stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=3) + \
|
121 |
+
grader_c_mix.amount * stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=8)
|
122 |
+
|
123 |
+
grader_fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
|
124 |
+
|
125 |
+
# Plot each distribution
|
126 |
+
ax1.fill_between(_grader_x, grader_a, alpha=0.3, color='green', label=f'Var ≈ {var_a:.2f}')
|
127 |
+
ax1.axvline(x=true_grade.amount, color='black', linestyle='--', label='True Grade')
|
128 |
+
ax1.set_title('Grader A: High Variance')
|
129 |
+
ax1.set_xlabel('Grade')
|
130 |
+
ax1.set_ylabel('Pr(G = g)')
|
131 |
+
ax1.set_ylim(0, max(grader_a)*1.1)
|
132 |
+
|
133 |
+
ax2.fill_between(_grader_x, grader_b, alpha=0.3, color='blue', label=f'Var ≈ {var_b:.2f}')
|
134 |
+
ax2.axvline(x=true_grade.amount, color='black', linestyle='--')
|
135 |
+
ax2.set_title('Grader B: Low Variance')
|
136 |
+
ax2.set_xlabel('Grade')
|
137 |
+
ax2.set_ylim(0, max(grader_b)*1.1)
|
138 |
+
|
139 |
+
ax3.fill_between(_grader_x, grader_c, alpha=0.3, color='purple', label=f'Var ≈ {var_c:.2f}')
|
140 |
+
ax3.axvline(x=true_grade.amount, color='black', linestyle='--')
|
141 |
+
ax3.set_title('Grader C: Mixed Distribution')
|
142 |
+
ax3.set_xlabel('Grade')
|
143 |
+
ax3.set_ylim(0, max(grader_c)*1.1)
|
144 |
+
|
145 |
+
# Add annotations to explain what's happening
|
146 |
+
ax1.annotate('Wide spread = high variance',
|
147 |
+
xy=(true_grade.amount, max(grader_a)*0.5),
|
148 |
+
xytext=(true_grade.amount-15, max(grader_a)*0.7),
|
149 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
150 |
+
|
151 |
+
ax2.annotate('Narrow spread = low variance',
|
152 |
+
xy=(true_grade.amount, max(grader_b)*0.5),
|
153 |
+
xytext=(true_grade.amount+8, max(grader_b)*0.7),
|
154 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
155 |
+
|
156 |
+
ax3.annotate('Mixture creates outliers',
|
157 |
+
xy=(true_grade.amount+15, grader_c[np.where(_grader_x >= true_grade.amount+15)[0][0]]),
|
158 |
+
xytext=(true_grade.amount+5, max(grader_c)*0.7),
|
159 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
160 |
+
|
161 |
+
# Add legends and adjust layout
|
162 |
+
for _ax in [ax1, ax2, ax3]:
|
163 |
+
_ax.legend()
|
164 |
+
_ax.grid(alpha=0.2)
|
165 |
+
|
166 |
+
plt.tight_layout()
|
167 |
+
plt.gca()
|
168 |
+
return (
|
169 |
+
ax1,
|
170 |
+
ax2,
|
171 |
+
ax3,
|
172 |
+
grader_a,
|
173 |
+
grader_b,
|
174 |
+
grader_c,
|
175 |
+
grader_fig,
|
176 |
+
var_a,
|
177 |
+
var_b,
|
178 |
+
var_c,
|
179 |
+
)
|
180 |
+
|
181 |
+
|
182 |
+
@app.cell(hide_code=True)
|
183 |
+
def _(mo):
|
184 |
+
mo.md(
|
185 |
+
r"""
|
186 |
+
/// note
|
187 |
+
All three distributions have the same expected value (the true grade), but they differ significantly in their spread:
|
188 |
+
|
189 |
+
- **Grader A** has high variance - grades vary widely from the true value
|
190 |
+
- **Grader B** has low variance - grades consistently stay close to the true value
|
191 |
+
- **Grader C** has a mixture distribution - mostly consistent but with occasional extreme values
|
192 |
+
|
193 |
+
This illustrates why variance is crucial: two distributions can have the same mean but behave very differently in practice.
|
194 |
+
"""
|
195 |
+
)
|
196 |
+
return
|
197 |
+
|
198 |
+
|
199 |
+
@app.cell(hide_code=True)
|
200 |
+
def _(mo):
|
201 |
+
mo.md(
|
202 |
+
r"""
|
203 |
+
## Computing Variance
|
204 |
+
|
205 |
+
Let's work through some concrete examples to understand how to calculate variance.
|
206 |
+
|
207 |
+
### Example 1: Fair Die Roll
|
208 |
+
|
209 |
+
Consider rolling a fair six-sided die. We'll calculate its variance step by step:
|
210 |
+
"""
|
211 |
+
)
|
212 |
+
return
|
213 |
+
|
214 |
+
|
215 |
+
@app.cell
|
216 |
+
def _(np):
|
217 |
+
# Define the die values and probabilities
|
218 |
+
die_values = np.array([1, 2, 3, 4, 5, 6])
|
219 |
+
die_probs = np.array([1/6] * 6)
|
220 |
+
|
221 |
+
# Calculate E[X]
|
222 |
+
expected_value = np.sum(die_values * die_probs)
|
223 |
+
|
224 |
+
# Calculate E[X^2]
|
225 |
+
expected_square = np.sum(die_values**2 * die_probs)
|
226 |
+
|
227 |
+
# Calculate Var(X) = E[X^2] - (E[X])^2
|
228 |
+
variance = expected_square - expected_value**2
|
229 |
+
|
230 |
+
# Calculate standard deviation
|
231 |
+
std_dev = np.sqrt(variance)
|
232 |
+
|
233 |
+
print(f"E[X] = {expected_value:.2f}")
|
234 |
+
print(f"E[X^2] = {expected_square:.2f}")
|
235 |
+
print(f"Var(X) = {variance:.2f}")
|
236 |
+
print(f"Standard Deviation = {std_dev:.2f}")
|
237 |
+
return (
|
238 |
+
die_probs,
|
239 |
+
die_values,
|
240 |
+
expected_square,
|
241 |
+
expected_value,
|
242 |
+
std_dev,
|
243 |
+
variance,
|
244 |
+
)
|
245 |
+
|
246 |
+
|
247 |
+
@app.cell(hide_code=True)
|
248 |
+
def _(mo):
|
249 |
+
mo.md(
|
250 |
+
r"""
|
251 |
+
/// NOTE
|
252 |
+
For a fair die:
|
253 |
+
|
254 |
+
- The expected value (3.50) tells us the average roll
|
255 |
+
- The variance (2.92) tells us how much typical rolls deviate from this average
|
256 |
+
- The standard deviation (1.71) gives us this spread in the original units
|
257 |
+
"""
|
258 |
+
)
|
259 |
+
return
|
260 |
+
|
261 |
+
|
262 |
+
@app.cell(hide_code=True)
|
263 |
+
def _(mo):
|
264 |
+
mo.md(
|
265 |
+
r"""
|
266 |
+
## Properties of Variance
|
267 |
+
|
268 |
+
Variance has several important properties that make it useful for analyzing random variables:
|
269 |
+
|
270 |
+
1. **Non-negativity**: $\text{Var}(X) \geq 0$ for any random variable $X$
|
271 |
+
2. **Variance of a constant**: $\text{Var}(c) = 0$ for any constant $c$
|
272 |
+
3. **Scaling**: $\text{Var}(aX) = a^2\text{Var}(X)$ for any constant $a$
|
273 |
+
4. **Translation**: $\text{Var}(X + b) = \text{Var}(X)$ for any constant $b$
|
274 |
+
5. **Independence**: If $X$ and $Y$ are independent, then $\text{Var}(X + Y) = \text{Var}(X) + \text{Var}(Y)$
|
275 |
+
|
276 |
+
Let's verify a property with an example.
|
277 |
+
"""
|
278 |
+
)
|
279 |
+
return
|
280 |
+
|
281 |
+
|
282 |
+
@app.cell(hide_code=True)
|
283 |
+
def _(mo):
|
284 |
+
mo.md(
|
285 |
+
r"""
|
286 |
+
## Proof of Variance Formula
|
287 |
+
|
288 |
+
The equivalence of the two variance formulas is a fundamental result in probability theory. Here's the proof:
|
289 |
+
|
290 |
+
Starting with the definition $\text{Var}(X) = E[(X-\mu)^2]$ where $\mu = E[X]$:
|
291 |
+
|
292 |
+
\begin{align}
|
293 |
+
\text{Var}(X) &= E[(X-\mu)^2] \\
|
294 |
+
&= \sum_x(x-\mu)^2P(x) && \text{Definition of Expectation}\\
|
295 |
+
&= \sum_x (x^2 -2\mu x + \mu^2)P(x) && \text{Expanding the square}\\
|
296 |
+
&= \sum_x x^2P(x)- 2\mu \sum_x xP(x) + \mu^2 \sum_x P(x) && \text{Distributing the sum}\\
|
297 |
+
&= E[X^2]- 2\mu E[X] + \mu^2 && \text{Definition of expectation}\\
|
298 |
+
&= E[X^2]- 2(E[X])^2 + (E[X])^2 && \text{Since }\mu = E[X]\\
|
299 |
+
&= E[X^2]- (E[X])^2 && \text{Simplifying}
|
300 |
+
\end{align}
|
301 |
+
|
302 |
+
/// tip
|
303 |
+
This proof shows why the formula $\text{Var}(X) = E[X^2] - (E[X])^2$ is so useful - it's much easier to compute $E[X^2]$ and $E[X]$ separately than to work with deviations directly.
|
304 |
+
"""
|
305 |
+
)
|
306 |
+
return
|
307 |
+
|
308 |
+
|
309 |
+
@app.cell
|
310 |
+
def _(die_probs, die_values, np):
|
311 |
+
# Demonstrate scaling property
|
312 |
+
a = 2 # Scale factor
|
313 |
+
|
314 |
+
# Original variance
|
315 |
+
original_var = np.sum(die_values**2 * die_probs) - (np.sum(die_values * die_probs))**2
|
316 |
+
|
317 |
+
# Scaled random variable variance
|
318 |
+
scaled_values = a * die_values
|
319 |
+
scaled_var = np.sum(scaled_values**2 * die_probs) - (np.sum(scaled_values * die_probs))**2
|
320 |
+
|
321 |
+
print(f"Original Variance: {original_var:.2f}")
|
322 |
+
print(f"Scaled Variance (a={a}): {scaled_var:.2f}")
|
323 |
+
print(f"a^2 * Original Variance: {a**2 * original_var:.2f}")
|
324 |
+
print(f"Property holds: {abs(scaled_var - a**2 * original_var) < 1e-10}")
|
325 |
+
return a, original_var, scaled_values, scaled_var
|
326 |
+
|
327 |
+
|
328 |
+
@app.cell
|
329 |
+
def _():
|
330 |
+
# DIY : Prove more properties as shown above
|
331 |
+
return
|
332 |
+
|
333 |
+
|
334 |
+
@app.cell(hide_code=True)
|
335 |
+
def _(mo):
|
336 |
+
mo.md(
|
337 |
+
r"""
|
338 |
+
## Standard Deviation
|
339 |
+
|
340 |
+
While variance is mathematically convenient, it has one practical drawback: its units are squared. For example, if we're measuring grades (0-100), the variance is in "grade points squared." This makes it hard to interpret intuitively.
|
341 |
+
|
342 |
+
The **standard deviation**, denoted by $\sigma$ or $\text{SD}(X)$, is the square root of variance:
|
343 |
+
|
344 |
+
$$\sigma = \sqrt{\text{Var}(X)}$$
|
345 |
+
|
346 |
+
/// tip
|
347 |
+
Standard deviation is often more intuitive because it's in the same units as the original data. For a normal distribution, approximately:
|
348 |
+
- 68% of values fall within 1 standard deviation of the mean
|
349 |
+
- 95% of values fall within 2 standard deviations
|
350 |
+
- 99.7% of values fall within 3 standard deviations
|
351 |
+
"""
|
352 |
+
)
|
353 |
+
return
|
354 |
+
|
355 |
+
|
356 |
+
@app.cell(hide_code=True)
|
357 |
+
def _(controls1):
|
358 |
+
controls1
|
359 |
+
return
|
360 |
+
|
361 |
+
|
362 |
+
@app.cell(hide_code=True)
|
363 |
+
def _(TangleSlider, mo):
|
364 |
+
normal_mean = mo.ui.anywidget(TangleSlider(
|
365 |
+
amount=0,
|
366 |
+
min_value=-5,
|
367 |
+
max_value=5,
|
368 |
+
step=0.5,
|
369 |
+
digits=1,
|
370 |
+
suffix=" units"
|
371 |
+
))
|
372 |
+
|
373 |
+
normal_std = mo.ui.anywidget(TangleSlider(
|
374 |
+
amount=1,
|
375 |
+
min_value=0.1,
|
376 |
+
max_value=3,
|
377 |
+
step=0.1,
|
378 |
+
digits=1,
|
379 |
+
suffix=" units"
|
380 |
+
))
|
381 |
+
|
382 |
+
# Create a grid layout for the controls
|
383 |
+
controls1 = mo.vstack([
|
384 |
+
mo.md("### Interactive Normal Distribution"),
|
385 |
+
mo.hstack([
|
386 |
+
mo.md("Adjust the parameters to see how standard deviation affects the shape of the distribution:"),
|
387 |
+
]),
|
388 |
+
mo.hstack([
|
389 |
+
mo.md("Mean (μ): "),
|
390 |
+
normal_mean,
|
391 |
+
mo.md(" Standard deviation (σ): "),
|
392 |
+
normal_std
|
393 |
+
], justify="start"),
|
394 |
+
])
|
395 |
+
return controls1, normal_mean, normal_std
|
396 |
+
|
397 |
+
|
398 |
+
@app.cell(hide_code=True)
|
399 |
+
def _(normal_mean, normal_std, np, plt, stats):
|
400 |
+
# data for normal distribution
|
401 |
+
_normal_x = np.linspace(-10, 10, 1000)
|
402 |
+
_normal_y = stats.norm.pdf(_normal_x, loc=normal_mean.amount, scale=normal_std.amount)
|
403 |
+
|
404 |
+
# ranges for standard deviation intervals
|
405 |
+
one_sigma_left = normal_mean.amount - normal_std.amount
|
406 |
+
one_sigma_right = normal_mean.amount + normal_std.amount
|
407 |
+
two_sigma_left = normal_mean.amount - 2 * normal_std.amount
|
408 |
+
two_sigma_right = normal_mean.amount + 2 * normal_std.amount
|
409 |
+
three_sigma_left = normal_mean.amount - 3 * normal_std.amount
|
410 |
+
three_sigma_right = normal_mean.amount + 3 * normal_std.amount
|
411 |
+
|
412 |
+
# Create the plot
|
413 |
+
normal_fig, normal_ax = plt.subplots(figsize=(10, 6))
|
414 |
+
|
415 |
+
# Plot the distribution
|
416 |
+
normal_ax.plot(_normal_x, _normal_y, 'b-', linewidth=2)
|
417 |
+
|
418 |
+
# stdev intervals
|
419 |
+
normal_ax.fill_between(_normal_x, 0, _normal_y, where=(_normal_x >= one_sigma_left) & (_normal_x <= one_sigma_right),
|
420 |
+
alpha=0.3, color='red', label='68% (±1σ)')
|
421 |
+
normal_ax.fill_between(_normal_x, 0, _normal_y, where=(_normal_x >= two_sigma_left) & (_normal_x <= two_sigma_right),
|
422 |
+
alpha=0.2, color='green', label='95% (±2σ)')
|
423 |
+
normal_ax.fill_between(_normal_x, 0, _normal_y, where=(_normal_x >= three_sigma_left) & (_normal_x <= three_sigma_right),
|
424 |
+
alpha=0.1, color='blue', label='99.7% (±3σ)')
|
425 |
+
|
426 |
+
# vertical lines for the mean and standard deviations
|
427 |
+
normal_ax.axvline(x=normal_mean.amount, color='black', linestyle='-', linewidth=1.5, label='Mean (μ)')
|
428 |
+
normal_ax.axvline(x=one_sigma_left, color='red', linestyle='--', linewidth=1)
|
429 |
+
normal_ax.axvline(x=one_sigma_right, color='red', linestyle='--', linewidth=1)
|
430 |
+
normal_ax.axvline(x=two_sigma_left, color='green', linestyle='--', linewidth=1)
|
431 |
+
normal_ax.axvline(x=two_sigma_right, color='green', linestyle='--', linewidth=1)
|
432 |
+
|
433 |
+
# annotations
|
434 |
+
normal_ax.annotate(f'μ = {normal_mean.amount:.2f}',
|
435 |
+
xy=(normal_mean.amount, max(_normal_y)*0.5),
|
436 |
+
xytext=(normal_mean.amount + 0.5, max(_normal_y)*0.8),
|
437 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
438 |
+
|
439 |
+
normal_ax.annotate(f'σ = {normal_std.amount:.2f}',
|
440 |
+
xy=(one_sigma_right, stats.norm.pdf(one_sigma_right, loc=normal_mean.amount, scale=normal_std.amount)),
|
441 |
+
xytext=(one_sigma_right + 0.5, max(_normal_y)*0.6),
|
442 |
+
arrowprops=dict(facecolor='red', shrink=0.05, width=1))
|
443 |
+
|
444 |
+
# labels and title
|
445 |
+
normal_ax.set_xlabel('Value')
|
446 |
+
normal_ax.set_ylabel('Probability Density')
|
447 |
+
normal_ax.set_title(f'Normal Distribution with μ = {normal_mean.amount:.2f} and σ = {normal_std.amount:.2f}')
|
448 |
+
|
449 |
+
# legend and grid
|
450 |
+
normal_ax.legend()
|
451 |
+
normal_ax.grid(alpha=0.3)
|
452 |
+
|
453 |
+
plt.tight_layout()
|
454 |
+
plt.gca()
|
455 |
+
return (
|
456 |
+
normal_ax,
|
457 |
+
normal_fig,
|
458 |
+
one_sigma_left,
|
459 |
+
one_sigma_right,
|
460 |
+
three_sigma_left,
|
461 |
+
three_sigma_right,
|
462 |
+
two_sigma_left,
|
463 |
+
two_sigma_right,
|
464 |
+
)
|
465 |
+
|
466 |
+
|
467 |
+
@app.cell(hide_code=True)
|
468 |
+
def _(mo):
|
469 |
+
mo.md(
|
470 |
+
r"""
|
471 |
+
/// tip
|
472 |
+
The interactive visualization above demonstrates how standard deviation (σ) affects the shape of a normal distribution:
|
473 |
+
|
474 |
+
- The **red region** covers μ ± 1σ, containing approximately 68% of the probability
|
475 |
+
- The **green region** covers μ ± 2σ, containing approximately 95% of the probability
|
476 |
+
- The **blue region** covers μ ± 3σ, containing approximately 99.7% of the probability
|
477 |
+
|
478 |
+
This is known as the "68-95-99.7 rule" or the "empirical rule" and is a useful heuristic for understanding the spread of data.
|
479 |
+
"""
|
480 |
+
)
|
481 |
+
return
|
482 |
+
|
483 |
+
|
484 |
+
@app.cell(hide_code=True)
|
485 |
+
def _(mo):
|
486 |
+
mo.md(
|
487 |
+
r"""
|
488 |
+
## 🤔 Test Your Understanding
|
489 |
+
|
490 |
+
Choose what you believe are the correct options in the questions below:
|
491 |
+
|
492 |
+
<details>
|
493 |
+
<summary>The variance of a random variable can be negative.</summary>
|
494 |
+
❌ False! Variance is defined as an expected value of squared deviations, and squares are always non-negative.
|
495 |
+
</details>
|
496 |
+
|
497 |
+
<details>
|
498 |
+
<summary>If X and Y are independent random variables, then Var(X + Y) = Var(X) + Var(Y).</summary>
|
499 |
+
✅ True! This is one of the key properties of variance for independent random variables.
|
500 |
+
</details>
|
501 |
+
|
502 |
+
<details>
|
503 |
+
<summary>Multiplying a random variable by 2 multiplies its variance by 2.</summary>
|
504 |
+
❌ False! Multiplying a random variable by a constant a multiplies its variance by a². So multiplying by 2 multiplies variance by 4.
|
505 |
+
</details>
|
506 |
+
|
507 |
+
<details>
|
508 |
+
<summary>Standard deviation is always equal to the square root of variance.</summary>
|
509 |
+
✅ True! By definition, standard deviation σ = √Var(X).
|
510 |
+
</details>
|
511 |
+
|
512 |
+
<details>
|
513 |
+
<summary>If Var(X) = 0, then X must be a constant.</summary>
|
514 |
+
✅ True! Zero variance means there is no spread around the mean, so X can only take one value.
|
515 |
+
</details>
|
516 |
+
"""
|
517 |
+
)
|
518 |
+
return
|
519 |
+
|
520 |
+
|
521 |
+
@app.cell(hide_code=True)
|
522 |
+
def _(mo):
|
523 |
+
mo.md(
|
524 |
+
r"""
|
525 |
+
## Key Takeaways
|
526 |
+
|
527 |
+
Variance gives us a way to measure how spread out a random variable is around its mean. It's like the "uncertainty" in our expectation - a high variance means individual outcomes can differ widely from what we expect on average.
|
528 |
+
|
529 |
+
Standard deviation brings this measure back to the original units, making it easier to interpret. For grades, a standard deviation of 10 points means typical grades fall within about 10 points of the average.
|
530 |
+
|
531 |
+
Variance pops up everywhere - from weather forecasts (how reliable is the predicted temperature?) to financial investments (how risky is this stock?) to quality control (how consistent is our manufacturing process?).
|
532 |
+
|
533 |
+
In our next notebook, we'll explore more properties of random variables and see how they combine to form more complex distributions.
|
534 |
+
"""
|
535 |
+
)
|
536 |
+
return
|
537 |
+
|
538 |
+
|
539 |
+
@app.cell(hide_code=True)
|
540 |
+
def _(mo):
|
541 |
+
mo.md(r"""Appendix (containing helper code):""")
|
542 |
+
return
|
543 |
+
|
544 |
+
|
545 |
+
@app.cell(hide_code=True)
|
546 |
+
def _():
|
547 |
+
import marimo as mo
|
548 |
+
return (mo,)
|
549 |
+
|
550 |
+
|
551 |
+
@app.cell(hide_code=True)
|
552 |
+
def _():
|
553 |
+
import numpy as np
|
554 |
+
import scipy.stats as stats
|
555 |
+
import matplotlib.pyplot as plt
|
556 |
+
from wigglystuff import TangleSlider
|
557 |
+
return TangleSlider, np, plt, stats
|
558 |
+
|
559 |
+
|
560 |
+
@app.cell(hide_code=True)
|
561 |
+
def _(TangleSlider, mo):
|
562 |
+
# Create interactive elements using TangleSlider for a more inline experience
|
563 |
+
true_grade = mo.ui.anywidget(TangleSlider(
|
564 |
+
amount=70,
|
565 |
+
min_value=50,
|
566 |
+
max_value=90,
|
567 |
+
step=5,
|
568 |
+
digits=0,
|
569 |
+
suffix=" points"
|
570 |
+
))
|
571 |
+
|
572 |
+
grader_a_spread = mo.ui.anywidget(TangleSlider(
|
573 |
+
amount=10,
|
574 |
+
min_value=5,
|
575 |
+
max_value=20,
|
576 |
+
step=1,
|
577 |
+
digits=0,
|
578 |
+
suffix=" points"
|
579 |
+
))
|
580 |
+
|
581 |
+
grader_b_spread = mo.ui.anywidget(TangleSlider(
|
582 |
+
amount=2,
|
583 |
+
min_value=1,
|
584 |
+
max_value=5,
|
585 |
+
step=0.5,
|
586 |
+
digits=1,
|
587 |
+
suffix=" points"
|
588 |
+
))
|
589 |
+
|
590 |
+
grader_c_mix = mo.ui.anywidget(TangleSlider(
|
591 |
+
amount=0.2,
|
592 |
+
min_value=0,
|
593 |
+
max_value=1,
|
594 |
+
step=0.05,
|
595 |
+
digits=2,
|
596 |
+
suffix=" proportion"
|
597 |
+
))
|
598 |
+
return grader_a_spread, grader_b_spread, grader_c_mix, true_grade
|
599 |
+
|
600 |
+
|
601 |
+
@app.cell(hide_code=True)
|
602 |
+
def _(grader_a_spread, grader_b_spread, grader_c_mix, mo, true_grade):
|
603 |
+
# Create a grid layout for the interactive controls
|
604 |
+
controls = mo.vstack([
|
605 |
+
mo.md("### Adjust Parameters to See How Variance Changes"),
|
606 |
+
mo.hstack([
|
607 |
+
mo.md("**True grade:** The correct score that should be assigned is "),
|
608 |
+
true_grade,
|
609 |
+
mo.md(" out of 100.")
|
610 |
+
], justify="start"),
|
611 |
+
mo.hstack([
|
612 |
+
mo.md("**Grader A:** Has a wide spread with standard deviation of "),
|
613 |
+
grader_a_spread,
|
614 |
+
mo.md(" points.")
|
615 |
+
], justify="start"),
|
616 |
+
mo.hstack([
|
617 |
+
mo.md("**Grader B:** Has a narrow spread with standard deviation of "),
|
618 |
+
grader_b_spread,
|
619 |
+
mo.md(" points.")
|
620 |
+
], justify="start"),
|
621 |
+
mo.hstack([
|
622 |
+
mo.md("**Grader C:** Has a mixture distribution with "),
|
623 |
+
grader_c_mix,
|
624 |
+
mo.md(" proportion of outliers.")
|
625 |
+
], justify="start"),
|
626 |
+
])
|
627 |
+
return (controls,)
|
628 |
+
|
629 |
+
|
630 |
+
if __name__ == "__main__":
|
631 |
+
app.run()
|
probability/13_bernoulli_distribution.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.0",
|
6 |
+
# "numpy==2.2.3",
|
7 |
+
# "scipy==1.15.2",
|
8 |
+
# ]
|
9 |
+
# ///
|
10 |
+
|
11 |
+
import marimo
|
12 |
+
|
13 |
+
__generated_with = "0.11.22"
|
14 |
+
app = marimo.App(width="medium", app_title="Bernoulli Distribution")
|
15 |
+
|
16 |
+
|
17 |
+
@app.cell(hide_code=True)
|
18 |
+
def _(mo):
|
19 |
+
mo.md(
|
20 |
+
r"""
|
21 |
+
# Bernoulli Distribution
|
22 |
+
|
23 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/bernoulli/), by Stanford professor Chris Piech._
|
24 |
+
|
25 |
+
## Parametric Random Variables
|
26 |
+
|
27 |
+
There are many classic and commonly-seen random variable abstractions that show up in the world of probability. At this point, we'll learn about several of the most significant parametric discrete distributions.
|
28 |
+
|
29 |
+
When solving problems, if you can recognize that a random variable fits one of these formats, then you can use its pre-derived Probability Mass Function (PMF), expectation, variance, and other properties. Random variables of this sort are called **parametric random variables**. If you can argue that a random variable falls under one of the studied parametric types, you simply need to provide parameters.
|
30 |
+
|
31 |
+
> A good analogy is a `class` in programming. Creating a parametric random variable is very similar to calling a constructor with input parameters.
|
32 |
+
"""
|
33 |
+
)
|
34 |
+
return
|
35 |
+
|
36 |
+
|
37 |
+
@app.cell(hide_code=True)
|
38 |
+
def _(mo):
|
39 |
+
mo.md(
|
40 |
+
r"""
|
41 |
+
## Bernoulli Random Variables
|
42 |
+
|
43 |
+
A **Bernoulli random variable** (also called a boolean or indicator random variable) is the simplest kind of parametric random variable. It can take on two values: 1 and 0.
|
44 |
+
|
45 |
+
It takes on a 1 if an experiment with probability $p$ resulted in success and a 0 otherwise.
|
46 |
+
|
47 |
+
Some example uses include:
|
48 |
+
|
49 |
+
- A coin flip (heads = 1, tails = 0)
|
50 |
+
- A random binary digit
|
51 |
+
- Whether a disk drive crashed
|
52 |
+
- Whether someone likes a Netflix movie
|
53 |
+
|
54 |
+
Here $p$ is the parameter, but different instances of Bernoulli random variables might have different values of $p$.
|
55 |
+
"""
|
56 |
+
)
|
57 |
+
return
|
58 |
+
|
59 |
+
|
60 |
+
@app.cell(hide_code=True)
|
61 |
+
def _(mo):
|
62 |
+
mo.md(
|
63 |
+
r"""
|
64 |
+
## Key Properties of a Bernoulli Random Variable
|
65 |
+
|
66 |
+
If $X$ is declared to be a Bernoulli random variable with parameter $p$, denoted $X \sim \text{Bern}(p)$, it has the following properties:
|
67 |
+
"""
|
68 |
+
)
|
69 |
+
return
|
70 |
+
|
71 |
+
|
72 |
+
@app.cell
|
73 |
+
def _(stats):
|
74 |
+
# Define the Bernoulli distribution function
|
75 |
+
def Bern(p):
|
76 |
+
return stats.bernoulli(p)
|
77 |
+
return (Bern,)
|
78 |
+
|
79 |
+
|
80 |
+
@app.cell(hide_code=True)
|
81 |
+
def _(mo):
|
82 |
+
mo.md(
|
83 |
+
r"""
|
84 |
+
## Bernoulli Distribution Properties
|
85 |
+
|
86 |
+
$\begin{array}{lll}
|
87 |
+
\text{Notation:} & X \sim \text{Bern}(p) \\
|
88 |
+
\text{Description:} & \text{A boolean variable that is 1 with probability } p \\
|
89 |
+
\text{Parameters:} & p, \text{ the probability that } X = 1 \\
|
90 |
+
\text{Support:} & x \text{ is either 0 or 1} \\
|
91 |
+
\text{PMF equation:} & P(X = x) =
|
92 |
+
\begin{cases}
|
93 |
+
p & \text{if }x = 1\\
|
94 |
+
1-p & \text{if }x = 0
|
95 |
+
\end{cases} \\
|
96 |
+
\text{PMF (smooth):} & P(X = x) = p^x(1-p)^{1-x} \\
|
97 |
+
\text{Expectation:} & E[X] = p \\
|
98 |
+
\text{Variance:} & \text{Var}(X) = p(1-p) \\
|
99 |
+
\end{array}$
|
100 |
+
"""
|
101 |
+
)
|
102 |
+
return
|
103 |
+
|
104 |
+
|
105 |
+
@app.cell(hide_code=True)
|
106 |
+
def _(mo, p_slider):
|
107 |
+
# Visualization of the Bernoulli PMF
|
108 |
+
_p = p_slider.value
|
109 |
+
|
110 |
+
# Values for PMF
|
111 |
+
values = [0, 1]
|
112 |
+
probabilities = [1 - _p, _p]
|
113 |
+
|
114 |
+
# Relevant statistics
|
115 |
+
expected_value = _p
|
116 |
+
variance = _p * (1 - _p)
|
117 |
+
|
118 |
+
mo.md(f"""
|
119 |
+
## PMF Graph for Bernoulli($p={_p:.2f}$)
|
120 |
+
|
121 |
+
Parameter $p$: {p_slider}
|
122 |
+
|
123 |
+
Expected value: $E[X] = {expected_value:.2f}$
|
124 |
+
|
125 |
+
Variance: $\\text{{Var}}(X) = {variance:.2f}$
|
126 |
+
""")
|
127 |
+
return expected_value, probabilities, values, variance
|
128 |
+
|
129 |
+
|
130 |
+
@app.cell(hide_code=True)
|
131 |
+
def _(expected_value, p_slider, plt, probabilities, values, variance):
|
132 |
+
# PMF
|
133 |
+
_p = p_slider.value
|
134 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
135 |
+
|
136 |
+
# Bar plot for PMF
|
137 |
+
ax.bar(values, probabilities, width=0.4, color='blue', alpha=0.7)
|
138 |
+
|
139 |
+
ax.set_xlabel('Values that X can take on')
|
140 |
+
ax.set_ylabel('Probability')
|
141 |
+
ax.set_title(f'PMF of Bernoulli Distribution with p = {_p:.2f}')
|
142 |
+
|
143 |
+
# x-axis limit
|
144 |
+
ax.set_xticks([0, 1])
|
145 |
+
ax.set_xlim(-0.5, 1.5)
|
146 |
+
|
147 |
+
# y-axis w/ some padding
|
148 |
+
ax.set_ylim(0, max(probabilities) * 1.1)
|
149 |
+
|
150 |
+
# Add expectation as vertical line
|
151 |
+
ax.axvline(x=expected_value, color='red', linestyle='--',
|
152 |
+
label=f'E[X] = {expected_value:.2f}')
|
153 |
+
|
154 |
+
# Add variance annotation
|
155 |
+
ax.text(0.5, max(probabilities) * 0.8,
|
156 |
+
f'Var(X) = {variance:.3f}',
|
157 |
+
horizontalalignment='center',
|
158 |
+
bbox=dict(facecolor='white', alpha=0.7))
|
159 |
+
|
160 |
+
ax.legend()
|
161 |
+
plt.tight_layout()
|
162 |
+
plt.gca()
|
163 |
+
return ax, fig
|
164 |
+
|
165 |
+
|
166 |
+
@app.cell(hide_code=True)
|
167 |
+
def _(mo):
|
168 |
+
mo.md(
|
169 |
+
r"""
|
170 |
+
## Proof: Expectation of a Bernoulli
|
171 |
+
|
172 |
+
If $X$ is a Bernoulli with parameter $p$, $X \sim \text{Bern}(p)$:
|
173 |
+
|
174 |
+
\begin{align}
|
175 |
+
E[X] &= \sum_x x \cdot (X=x) && \text{Definition of expectation} \\
|
176 |
+
&= 1 \cdot p + 0 \cdot (1-p) &&
|
177 |
+
X \text{ can take on values 0 and 1} \\
|
178 |
+
&= p && \text{Remove the 0 term}
|
179 |
+
\end{align}
|
180 |
+
|
181 |
+
## Proof: Variance of a Bernoulli
|
182 |
+
|
183 |
+
If $X$ is a Bernoulli with parameter $p$, $X \sim \text{Bern}(p)$:
|
184 |
+
|
185 |
+
To compute variance, first compute $E[X^2]$:
|
186 |
+
|
187 |
+
\begin{align}
|
188 |
+
E[X^2]
|
189 |
+
&= \sum_x x^2 \cdot (X=x) &&\text{LOTUS}\\
|
190 |
+
&= 0^2 \cdot (1-p) + 1^2 \cdot p\\
|
191 |
+
&= p
|
192 |
+
\end{align}
|
193 |
+
|
194 |
+
\begin{align}
|
195 |
+
(X)
|
196 |
+
&= E[X^2] - E[X]^2&& \text{Def of variance} \\
|
197 |
+
&= p - p^2 && \text{Substitute }E[X^2]=p, E[X] = p \\
|
198 |
+
&= p (1-p) && \text{Factor out }p
|
199 |
+
\end{align}
|
200 |
+
"""
|
201 |
+
)
|
202 |
+
return
|
203 |
+
|
204 |
+
|
205 |
+
@app.cell(hide_code=True)
|
206 |
+
def _(mo):
|
207 |
+
mo.md(
|
208 |
+
r"""
|
209 |
+
## Indicator Random Variable
|
210 |
+
|
211 |
+
> **Definition**: An indicator variable is a Bernoulli random variable which takes on the value 1 if an **underlying event occurs**, and 0 _otherwise_.
|
212 |
+
|
213 |
+
Indicator random variables are a convenient way to convert the "true/false" outcome of an event into a number. That number may be easier to incorporate into an equation.
|
214 |
+
|
215 |
+
A random variable $I$ is an indicator variable for an event $A$ if $I = 1$ when $A$ occurs and $I = 0$ if $A$ does not occur. Indicator random variables are Bernoulli random variables, with $p = P(A)$. $I_A$ is a common choice of name for an indicator random variable.
|
216 |
+
|
217 |
+
Here are some properties of indicator random variables:
|
218 |
+
|
219 |
+
- $P(I=1)=P(A)$
|
220 |
+
- $E[I]=P(A)$
|
221 |
+
"""
|
222 |
+
)
|
223 |
+
return
|
224 |
+
|
225 |
+
|
226 |
+
@app.cell(hide_code=True)
|
227 |
+
def _(mo):
|
228 |
+
# Simulation of Bernoulli trials
|
229 |
+
mo.md(r"""
|
230 |
+
## Simulation of Bernoulli Trials
|
231 |
+
|
232 |
+
Let's simulate Bernoulli trials to see the law of large numbers in action. We'll flip a biased coin repeatedly and observe how the proportion of successes approaches the true probability $p$.
|
233 |
+
""")
|
234 |
+
|
235 |
+
# UI element for simulation parameters
|
236 |
+
num_trials_slider = mo.ui.slider(10, 10000, value=1000, step=10, label="Number of trials")
|
237 |
+
p_sim_slider = mo.ui.slider(0.01, 0.99, value=0.65, step=0.01, label="Success probability (p)")
|
238 |
+
return num_trials_slider, p_sim_slider
|
239 |
+
|
240 |
+
|
241 |
+
@app.cell(hide_code=True)
|
242 |
+
def _(mo):
|
243 |
+
mo.md(r"""## Simulation""")
|
244 |
+
return
|
245 |
+
|
246 |
+
|
247 |
+
@app.cell(hide_code=True)
|
248 |
+
def _(mo, num_trials_slider, p_sim_slider):
|
249 |
+
mo.hstack([num_trials_slider, p_sim_slider], justify='space-around')
|
250 |
+
return
|
251 |
+
|
252 |
+
|
253 |
+
@app.cell(hide_code=True)
|
254 |
+
def _(np, num_trials_slider, p_sim_slider, plt):
|
255 |
+
# Bernoulli trials
|
256 |
+
_num_trials = num_trials_slider.value
|
257 |
+
p = p_sim_slider.value
|
258 |
+
|
259 |
+
# Random Bernoulli trials
|
260 |
+
trials = np.random.binomial(1, p, size=_num_trials)
|
261 |
+
|
262 |
+
# Cumulative proportion of successes
|
263 |
+
cumulative_mean = np.cumsum(trials) / np.arange(1, _num_trials + 1)
|
264 |
+
|
265 |
+
# Results
|
266 |
+
plt.figure(figsize=(10, 6))
|
267 |
+
plt.plot(range(1, _num_trials + 1), cumulative_mean, label='Proportion of successes')
|
268 |
+
plt.axhline(y=p, color='r', linestyle='--', label=f'True probability (p={p})')
|
269 |
+
|
270 |
+
plt.xscale('log') # Use log scale for better visualization
|
271 |
+
plt.xlabel('Number of trials')
|
272 |
+
plt.ylabel('Proportion of successes')
|
273 |
+
plt.title('Convergence of Sample Proportion to True Probability')
|
274 |
+
plt.legend()
|
275 |
+
plt.grid(True, alpha=0.3)
|
276 |
+
|
277 |
+
# Add annotation
|
278 |
+
plt.annotate('As the number of trials increases,\nthe proportion approaches p',
|
279 |
+
xy=(_num_trials, cumulative_mean[-1]),
|
280 |
+
xytext=(_num_trials/5, p + 0.1),
|
281 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
282 |
+
|
283 |
+
plt.tight_layout()
|
284 |
+
plt.gca()
|
285 |
+
return cumulative_mean, p, trials
|
286 |
+
|
287 |
+
|
288 |
+
@app.cell(hide_code=True)
|
289 |
+
def _(mo, np, trials):
|
290 |
+
# Calculate statistics from the simulation
|
291 |
+
num_successes = np.sum(trials)
|
292 |
+
num_trials = len(trials)
|
293 |
+
proportion = num_successes / num_trials
|
294 |
+
|
295 |
+
# Display the results
|
296 |
+
mo.md(f"""
|
297 |
+
### Simulation Results
|
298 |
+
|
299 |
+
- Number of trials: {num_trials}
|
300 |
+
- Number of successes: {num_successes}
|
301 |
+
- Proportion of successes: {proportion:.4f}
|
302 |
+
|
303 |
+
This demonstrates how the sample proportion approaches the true probability $p$ as the number of trials increases.
|
304 |
+
""")
|
305 |
+
return num_successes, num_trials, proportion
|
306 |
+
|
307 |
+
|
308 |
+
@app.cell(hide_code=True)
|
309 |
+
def _(mo):
|
310 |
+
mo.md(
|
311 |
+
r"""
|
312 |
+
## 🤔 Test Your Understanding
|
313 |
+
|
314 |
+
Pick which of these statements about Bernoulli random variables you think are correct:
|
315 |
+
|
316 |
+
/// details | The variance of a Bernoulli random variable is always less than or equal to 0.25
|
317 |
+
✅ Correct! The variance $p(1-p)$ reaches its maximum value of 0.25 when $p = 0.5$.
|
318 |
+
///
|
319 |
+
|
320 |
+
/// details | The expected value of a Bernoulli random variable must be either 0 or 1
|
321 |
+
❌ Incorrect! The expected value is $p$, which can be any value between 0 and 1.
|
322 |
+
///
|
323 |
+
|
324 |
+
/// details | If $X \sim \text{Bern}(0.3)$ and $Y \sim \text{Bern}(0.7)$, then $X$ and $Y$ have the same variance
|
325 |
+
✅ Correct! $\text{Var}(X) = 0.3 \times 0.7 = 0.21$ and $\text{Var}(Y) = 0.7 \times 0.3 = 0.21$.
|
326 |
+
///
|
327 |
+
|
328 |
+
/// details | Two independent coin flips can be modeled as the sum of two Bernoulli random variables
|
329 |
+
✅ Correct! The sum would follow a Binomial distribution with $n=2$.
|
330 |
+
///
|
331 |
+
"""
|
332 |
+
)
|
333 |
+
return
|
334 |
+
|
335 |
+
|
336 |
+
@app.cell(hide_code=True)
|
337 |
+
def _(mo):
|
338 |
+
mo.md(
|
339 |
+
r"""
|
340 |
+
## Applications of Bernoulli Random Variables
|
341 |
+
|
342 |
+
Bernoulli random variables are used in many real-world scenarios:
|
343 |
+
|
344 |
+
1. **Quality Control**: Testing if a manufactured item is defective (1) or not (0)
|
345 |
+
|
346 |
+
2. **A/B Testing**: Determining if a user clicks (1) or doesn't click (0) on a website button
|
347 |
+
|
348 |
+
3. **Medical Testing**: Checking if a patient tests positive (1) or negative (0) for a disease
|
349 |
+
|
350 |
+
4. **Election Modeling**: Modeling if a particular voter votes for candidate A (1) or not (0)
|
351 |
+
|
352 |
+
5. **Financial Markets**: Modeling if a stock price goes up (1) or down (0) in a simplified model
|
353 |
+
|
354 |
+
Because Bernoulli random variables are parametric, as soon as you declare a random variable to be of type Bernoulli, you automatically know all of its pre-derived properties!
|
355 |
+
"""
|
356 |
+
)
|
357 |
+
return
|
358 |
+
|
359 |
+
|
360 |
+
@app.cell(hide_code=True)
|
361 |
+
def _(mo):
|
362 |
+
mo.md(
|
363 |
+
r"""
|
364 |
+
## Summary
|
365 |
+
|
366 |
+
And that's a wrap on Bernoulli distributions! We've learnt the simplest of all probability distributions — the one that only has two possible outcomes. Flip a coin, check if an email is spam, see if your blind date shows up — these are all Bernoulli trials with success probability $p$.
|
367 |
+
|
368 |
+
The beauty of Bernoulli is in its simplicity: just set $p$ (the probability of success) and you're good to go! The PMF gives us $P(X=1) = p$ and $P(X=0) = 1-p$, while expectation is simply $p$ and variance is $p(1-p)$. Oh, and when you're tracking whether specific events happen or not? That's an indicator random variable — just another Bernoulli in disguise!
|
369 |
+
|
370 |
+
Two key things to remember:
|
371 |
+
|
372 |
+
/// note
|
373 |
+
💡 **Maximum Variance**: A Bernoulli's variance $p(1-p)$ reaches its maximum at $p=0.5$, making a fair coin the most "unpredictable" Bernoulli random variable.
|
374 |
+
|
375 |
+
💡 **Instant Properties**: When you identify a random variable as Bernoulli, you instantly know all its properties—expectation, variance, PMF—without additional calculations.
|
376 |
+
///
|
377 |
+
|
378 |
+
Next up: Binomial distribution—where we'll see what happens when we let Bernoulli trials have a party and add themselves together!
|
379 |
+
"""
|
380 |
+
)
|
381 |
+
return
|
382 |
+
|
383 |
+
|
384 |
+
@app.cell(hide_code=True)
|
385 |
+
def _(mo):
|
386 |
+
mo.md(r"""#### Appendix (containing helper code for the notebook)""")
|
387 |
+
return
|
388 |
+
|
389 |
+
|
390 |
+
@app.cell
|
391 |
+
def _():
|
392 |
+
import marimo as mo
|
393 |
+
return (mo,)
|
394 |
+
|
395 |
+
|
396 |
+
@app.cell(hide_code=True)
|
397 |
+
def _():
|
398 |
+
from marimo import Html
|
399 |
+
return (Html,)
|
400 |
+
|
401 |
+
|
402 |
+
@app.cell(hide_code=True)
|
403 |
+
def _():
|
404 |
+
import numpy as np
|
405 |
+
import matplotlib.pyplot as plt
|
406 |
+
from scipy import stats
|
407 |
+
import math
|
408 |
+
|
409 |
+
# Set style for consistent visualizations
|
410 |
+
plt.style.use('seaborn-v0_8-whitegrid')
|
411 |
+
plt.rcParams['figure.figsize'] = [10, 6]
|
412 |
+
plt.rcParams['font.size'] = 12
|
413 |
+
|
414 |
+
# Set random seed for reproducibility
|
415 |
+
np.random.seed(42)
|
416 |
+
return math, np, plt, stats
|
417 |
+
|
418 |
+
|
419 |
+
@app.cell(hide_code=True)
|
420 |
+
def _(mo):
|
421 |
+
# Create a UI element for the parameter p
|
422 |
+
p_slider = mo.ui.slider(0.01, 0.99, value=0.65, step=0.01, label="Parameter p")
|
423 |
+
return (p_slider,)
|
424 |
+
|
425 |
+
|
426 |
+
if __name__ == "__main__":
|
427 |
+
app.run()
|
probability/14_binomial_distribution.py
ADDED
@@ -0,0 +1,545 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.0",
|
6 |
+
# "numpy==2.2.4",
|
7 |
+
# "scipy==1.15.2",
|
8 |
+
# "altair==5.2.0",
|
9 |
+
# "wigglystuff==0.1.10",
|
10 |
+
# "pandas==2.2.3",
|
11 |
+
# ]
|
12 |
+
# ///
|
13 |
+
|
14 |
+
import marimo
|
15 |
+
|
16 |
+
__generated_with = "0.11.24"
|
17 |
+
app = marimo.App(width="medium", app_title="Binomial Distribution")
|
18 |
+
|
19 |
+
|
20 |
+
@app.cell(hide_code=True)
|
21 |
+
def _(mo):
|
22 |
+
mo.md(
|
23 |
+
r"""
|
24 |
+
# Binomial Distribution
|
25 |
+
|
26 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/binomial/), by Stanford professor Chris Piech._
|
27 |
+
|
28 |
+
In this section, we will discuss the binomial distribution. To start, imagine the following example:
|
29 |
+
|
30 |
+
Consider $n$ independent trials of an experiment where each trial is a "success" with probability $p$. Let $X$ be the number of successes in $n$ trials.
|
31 |
+
|
32 |
+
This situation is truly common in the natural world, and as such, there has been a lot of research into such phenomena. Random variables like $X$ are called **binomial random variables**. If you can identify that a process fits this description, you can inherit many already proved properties such as the PMF formula, expectation, and variance!
|
33 |
+
"""
|
34 |
+
)
|
35 |
+
return
|
36 |
+
|
37 |
+
|
38 |
+
@app.cell(hide_code=True)
|
39 |
+
def _(mo):
|
40 |
+
mo.md(
|
41 |
+
r"""
|
42 |
+
## Binomial Random Variable Definition
|
43 |
+
|
44 |
+
$X \sim \text{Bin}(n, p)$ represents a binomial random variable where:
|
45 |
+
|
46 |
+
- $X$ is our random variable (number of successes)
|
47 |
+
- $\text{Bin}$ indicates it follows a binomial distribution
|
48 |
+
- $n$ is the number of trials
|
49 |
+
- $p$ is the probability of success in each trial
|
50 |
+
|
51 |
+
```
|
52 |
+
X ~ Bin(n, p)
|
53 |
+
↑ ↑ ↑
|
54 |
+
| | +-- Probability of
|
55 |
+
| | success on each
|
56 |
+
| | trial
|
57 |
+
| +-- Number of trials
|
58 |
+
|
|
59 |
+
Our random variable
|
60 |
+
is distributed
|
61 |
+
as a Binomial
|
62 |
+
```
|
63 |
+
|
64 |
+
Here are a few examples of binomial random variables:
|
65 |
+
|
66 |
+
- Number of heads in $n$ coin flips
|
67 |
+
- Number of 1's in randomly generated length $n$ bit string
|
68 |
+
- Number of disk drives crashed in 1000 computer cluster, assuming disks crash independently
|
69 |
+
"""
|
70 |
+
)
|
71 |
+
return
|
72 |
+
|
73 |
+
|
74 |
+
@app.cell(hide_code=True)
|
75 |
+
def _(mo):
|
76 |
+
mo.md(
|
77 |
+
r"""
|
78 |
+
## Properties of Binomial Distribution
|
79 |
+
|
80 |
+
| Property | Formula |
|
81 |
+
|----------|---------|
|
82 |
+
| Notation | $X \sim \text{Bin}(n, p)$ |
|
83 |
+
| Description | Number of "successes" in $n$ identical, independent experiments each with probability of success $p$ |
|
84 |
+
| Parameters | $n \in \{0, 1, \dots\}$, the number of experiments<br>$p \in [0, 1]$, the probability that a single experiment gives a "success" |
|
85 |
+
| Support | $x \in \{0, 1, \dots, n\}$ |
|
86 |
+
| PMF equation | $P(X=x) = {n \choose x}p^x(1-p)^{n-x}$ |
|
87 |
+
| Expectation | $E[X] = n \cdot p$ |
|
88 |
+
| Variance | $\text{Var}(X) = n \cdot p \cdot (1-p)$ |
|
89 |
+
|
90 |
+
Let's explore how the binomial distribution changes with different parameters.
|
91 |
+
"""
|
92 |
+
)
|
93 |
+
return
|
94 |
+
|
95 |
+
|
96 |
+
@app.cell(hide_code=True)
|
97 |
+
def _(TangleSlider, mo):
|
98 |
+
# Interactive elements using TangleSlider
|
99 |
+
n_slider = mo.ui.anywidget(TangleSlider(
|
100 |
+
amount=10,
|
101 |
+
min_value=1,
|
102 |
+
max_value=30,
|
103 |
+
step=1,
|
104 |
+
digits=0,
|
105 |
+
suffix=" trials"
|
106 |
+
))
|
107 |
+
|
108 |
+
p_slider = mo.ui.anywidget(TangleSlider(
|
109 |
+
amount=0.5,
|
110 |
+
min_value=0.01,
|
111 |
+
max_value=0.99,
|
112 |
+
step=0.01,
|
113 |
+
digits=2,
|
114 |
+
suffix=" probability"
|
115 |
+
))
|
116 |
+
|
117 |
+
# Grid layout for the interactive controls
|
118 |
+
controls = mo.vstack([
|
119 |
+
mo.md("### Adjust Parameters to See How Binomial Distribution Changes"),
|
120 |
+
mo.hstack([
|
121 |
+
mo.md("**Number of trials (n):** "),
|
122 |
+
n_slider
|
123 |
+
], justify="start"),
|
124 |
+
mo.hstack([
|
125 |
+
mo.md("**Probability of success (p):** "),
|
126 |
+
p_slider
|
127 |
+
], justify="start"),
|
128 |
+
])
|
129 |
+
return controls, n_slider, p_slider
|
130 |
+
|
131 |
+
|
132 |
+
@app.cell(hide_code=True)
|
133 |
+
def _(controls):
|
134 |
+
controls
|
135 |
+
return
|
136 |
+
|
137 |
+
|
138 |
+
@app.cell(hide_code=True)
|
139 |
+
def _(n_slider, np, p_slider, plt, stats):
|
140 |
+
# Parameters from sliders
|
141 |
+
_n = int(n_slider.amount)
|
142 |
+
_p = p_slider.amount
|
143 |
+
|
144 |
+
# Calculate PMF
|
145 |
+
_x = np.arange(0, _n + 1)
|
146 |
+
_pmf = stats.binom.pmf(_x, _n, _p)
|
147 |
+
|
148 |
+
# Relevant stats
|
149 |
+
_mean = _n * _p
|
150 |
+
_variance = _n * _p * (1 - _p)
|
151 |
+
_std_dev = np.sqrt(_variance)
|
152 |
+
|
153 |
+
_fig, _ax = plt.subplots(figsize=(10, 6))
|
154 |
+
|
155 |
+
# Plot PMF as bars
|
156 |
+
_ax.bar(_x, _pmf, color='royalblue', alpha=0.7, label=f'PMF: P(X=k)')
|
157 |
+
|
158 |
+
# Add a line
|
159 |
+
_ax.plot(_x, _pmf, 'ro-', alpha=0.6, label='PMF line')
|
160 |
+
|
161 |
+
# Add vertical lines
|
162 |
+
_ax.axvline(x=_mean, color='green', linestyle='--', linewidth=2,
|
163 |
+
label=f'Mean: {_mean:.2f}')
|
164 |
+
|
165 |
+
# Shade the stdev region
|
166 |
+
_ax.axvspan(_mean - _std_dev, _mean + _std_dev, alpha=0.2, color='green',
|
167 |
+
label=f'±1 Std Dev: {_std_dev:.2f}')
|
168 |
+
|
169 |
+
# Add labels and title
|
170 |
+
_ax.set_xlabel('Number of Successes (k)')
|
171 |
+
_ax.set_ylabel('Probability: P(X=k)')
|
172 |
+
_ax.set_title(f'Binomial Distribution with n={_n}, p={_p:.2f}')
|
173 |
+
|
174 |
+
# Annotations
|
175 |
+
_ax.annotate(f'E[X] = {_mean:.2f}',
|
176 |
+
xy=(_mean, stats.binom.pmf(int(_mean), _n, _p)),
|
177 |
+
xytext=(_mean + 1, max(_pmf) * 0.8),
|
178 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
179 |
+
|
180 |
+
_ax.annotate(f'Var(X) = {_variance:.2f}',
|
181 |
+
xy=(_mean, stats.binom.pmf(int(_mean), _n, _p) / 2),
|
182 |
+
xytext=(_mean + 1, max(_pmf) * 0.6),
|
183 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
184 |
+
|
185 |
+
# Grid and legend
|
186 |
+
_ax.grid(alpha=0.3)
|
187 |
+
_ax.legend()
|
188 |
+
|
189 |
+
plt.tight_layout()
|
190 |
+
plt.gca()
|
191 |
+
return
|
192 |
+
|
193 |
+
|
194 |
+
@app.cell(hide_code=True)
|
195 |
+
def _(mo):
|
196 |
+
mo.md(
|
197 |
+
r"""
|
198 |
+
## Relationship to Bernoulli Random Variables
|
199 |
+
|
200 |
+
One way to think of the binomial is as the sum of $n$ Bernoulli variables. Say that $Y_i$ is an indicator Bernoulli random variable which is 1 if experiment $i$ is a success. Then if $X$ is the total number of successes in $n$ experiments, $X \sim \text{Bin}(n, p)$:
|
201 |
+
|
202 |
+
$$X = \sum_{i=1}^n Y_i$$
|
203 |
+
|
204 |
+
Recall that the outcome of $Y_i$ will be 1 or 0, so one way to think of $X$ is as the sum of those 1s and 0s.
|
205 |
+
"""
|
206 |
+
)
|
207 |
+
return
|
208 |
+
|
209 |
+
|
210 |
+
@app.cell(hide_code=True)
|
211 |
+
def _(mo):
|
212 |
+
mo.md(
|
213 |
+
r"""
|
214 |
+
## Binomial Probability Mass Function (PMF)
|
215 |
+
|
216 |
+
The most important property to know about a binomial is its [Probability Mass Function](https://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/10_probability_mass_function.py):
|
217 |
+
|
218 |
+
$$P(X=k) = {n \choose k}p^k(1-p)^{n-k}$$
|
219 |
+
|
220 |
+
```
|
221 |
+
P(X = k) = (n) p^k(1-p)^(n-k)
|
222 |
+
↑ (k)
|
223 |
+
| ↑
|
224 |
+
| +-- Binomial coefficient:
|
225 |
+
| number of ways to choose
|
226 |
+
| k successes from n trials
|
227 |
+
|
|
228 |
+
Probability that our
|
229 |
+
variable takes on the
|
230 |
+
value k
|
231 |
+
```
|
232 |
+
|
233 |
+
Recall, we derived this formula in Part 1. There is a complete example on the probability of $k$ heads in $n$ coin flips, where each flip is heads with probability $p$.
|
234 |
+
|
235 |
+
To briefly review, if you think of each experiment as being distinct, then there are ${n \choose k}$ ways of permuting $k$ successes from $n$ experiments. For any of the mutually exclusive permutations, the probability of that permutation is $p^k \cdot (1-p)^{n-k}$.
|
236 |
+
|
237 |
+
The name binomial comes from the term ${n \choose k}$ which is formally called the binomial coefficient.
|
238 |
+
"""
|
239 |
+
)
|
240 |
+
return
|
241 |
+
|
242 |
+
|
243 |
+
@app.cell(hide_code=True)
|
244 |
+
def _(mo):
|
245 |
+
mo.md(
|
246 |
+
r"""
|
247 |
+
## Expectation of Binomial
|
248 |
+
|
249 |
+
There is an easy way to calculate the expectation of a binomial and a hard way. The easy way is to leverage the fact that a binomial is the sum of Bernoulli indicator random variables $X = \sum_{i=1}^{n} Y_i$ where $Y_i$ is an indicator of whether the $i$-th experiment was a success: $Y_i \sim \text{Bernoulli}(p)$.
|
250 |
+
|
251 |
+
Since the [expectation of the sum](http://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/11_expectation.py) of random variables is the sum of expectations, we can add the expectation, $E[Y_i] = p$, of each of the Bernoulli's:
|
252 |
+
|
253 |
+
\begin{align}
|
254 |
+
E[X] &= E\Big[\sum_{i=1}^{n} Y_i\Big] && \text{Since }X = \sum_{i=1}^{n} Y_i \\
|
255 |
+
&= \sum_{i=1}^{n}E[ Y_i] && \text{Expectation of sum} \\
|
256 |
+
&= \sum_{i=1}^{n}p && \text{Expectation of Bernoulli} \\
|
257 |
+
&= n \cdot p && \text{Sum $n$ times}
|
258 |
+
\end{align}
|
259 |
+
|
260 |
+
The hard way is to use the definition of expectation:
|
261 |
+
|
262 |
+
\begin{align}
|
263 |
+
E[X] &= \sum_{i=0}^n i \cdot P(X = i) && \text{Def of expectation} \\
|
264 |
+
&= \sum_{i=0}^n i \cdot {n \choose i} p^i(1-p)^{n-i} && \text{Sub in PMF} \\
|
265 |
+
& \cdots && \text{Many steps later} \\
|
266 |
+
&= n \cdot p
|
267 |
+
\end{align}
|
268 |
+
"""
|
269 |
+
)
|
270 |
+
return
|
271 |
+
|
272 |
+
|
273 |
+
@app.cell(hide_code=True)
|
274 |
+
def _(mo):
|
275 |
+
mo.md(
|
276 |
+
r"""
|
277 |
+
## Binomial Distribution in Python
|
278 |
+
|
279 |
+
As you might expect, you can use binomial distributions in code. The standardized library for binomials is `scipy.stats.binom`.
|
280 |
+
|
281 |
+
One of the most helpful methods that this package provides is a way to calculate the PMF. For example, say $n=5$, $p=0.6$ and you want to find $P(X=2)$, you could use the following code:
|
282 |
+
"""
|
283 |
+
)
|
284 |
+
return
|
285 |
+
|
286 |
+
|
287 |
+
@app.cell
|
288 |
+
def _(stats):
|
289 |
+
# define variables for x, n, and p
|
290 |
+
_n = 5 # Integer value for n
|
291 |
+
_p = 0.6
|
292 |
+
_x = 2
|
293 |
+
|
294 |
+
# use scipy to compute the pmf
|
295 |
+
p_x = stats.binom.pmf(_x, _n, _p)
|
296 |
+
|
297 |
+
# use the probability for future work
|
298 |
+
print(f'P(X = {_x}) = {p_x:.4f}')
|
299 |
+
return (p_x,)
|
300 |
+
|
301 |
+
|
302 |
+
@app.cell(hide_code=True)
|
303 |
+
def _(mo):
|
304 |
+
mo.md(r"""Another particularly helpful function is the ability to generate a random sample from a binomial. For example, say $X$ represents the number of requests to a website. We can draw 100 samples from this distribution using the following code:""")
|
305 |
+
return
|
306 |
+
|
307 |
+
|
308 |
+
@app.cell
|
309 |
+
def _(n, p, stats):
|
310 |
+
n_int = int(n)
|
311 |
+
|
312 |
+
# samples from the binomial distribution
|
313 |
+
samples = stats.binom.rvs(n_int, p, size=100)
|
314 |
+
|
315 |
+
# Print the samples
|
316 |
+
print(samples)
|
317 |
+
return n_int, samples
|
318 |
+
|
319 |
+
|
320 |
+
@app.cell(hide_code=True)
|
321 |
+
def _(n_int, np, p, plt, samples, stats):
|
322 |
+
# Plot histogram of samples
|
323 |
+
plt.figure(figsize=(10, 5))
|
324 |
+
plt.hist(samples, bins=np.arange(-0.5, n_int+1.5, 1), alpha=0.7, color='royalblue',
|
325 |
+
edgecolor='black', density=True)
|
326 |
+
|
327 |
+
# Overlay the PMF
|
328 |
+
x_values = np.arange(0, n_int+1)
|
329 |
+
pmf_values = stats.binom.pmf(x_values, n_int, p)
|
330 |
+
plt.plot(x_values, pmf_values, 'ro-', ms=8, label='Theoretical PMF')
|
331 |
+
|
332 |
+
# Add labels and title
|
333 |
+
plt.xlabel('Number of Successes')
|
334 |
+
plt.ylabel('Relative Frequency / Probability')
|
335 |
+
plt.title(f'Histogram of 100 Samples from Bin({n_int}, {p})')
|
336 |
+
plt.legend()
|
337 |
+
plt.grid(alpha=0.3)
|
338 |
+
|
339 |
+
# Annotate
|
340 |
+
plt.annotate('Sample mean: %.2f' % np.mean(samples),
|
341 |
+
xy=(0.7, 0.9), xycoords='axes fraction',
|
342 |
+
bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3))
|
343 |
+
plt.annotate('Theoretical mean: %.2f' % (n_int*p),
|
344 |
+
xy=(0.7, 0.8), xycoords='axes fraction',
|
345 |
+
bbox=dict(boxstyle='round,pad=0.5', fc='lightgreen', alpha=0.3))
|
346 |
+
|
347 |
+
plt.tight_layout()
|
348 |
+
plt.gca()
|
349 |
+
return pmf_values, x_values
|
350 |
+
|
351 |
+
|
352 |
+
@app.cell(hide_code=True)
|
353 |
+
def _(mo):
|
354 |
+
mo.md(
|
355 |
+
r"""
|
356 |
+
You might be wondering what a random sample is! A random sample is a randomly chosen assignment for our random variable. Above we have 100 such assignments. The probability that value $k$ is chosen is given by the PMF: $P(X=k)$.
|
357 |
+
|
358 |
+
There are also functions for getting the mean, the variance, and more. You can read the [scipy.stats.binom documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.binom.html), especially the list of methods.
|
359 |
+
"""
|
360 |
+
)
|
361 |
+
return
|
362 |
+
|
363 |
+
|
364 |
+
@app.cell(hide_code=True)
|
365 |
+
def _(mo):
|
366 |
+
mo.md(
|
367 |
+
r"""
|
368 |
+
## Interactive Exploration of Binomial vs. Negative Binomial
|
369 |
+
|
370 |
+
The standard binomial distribution is a special case of a broader family of distributions. One related distribution is the negative binomial, which can model count data with overdispersion (where the variance is larger than the mean).
|
371 |
+
|
372 |
+
Below, you can explore how the negative binomial distribution compares to a Poisson distribution (which can be seen as a limiting case of the binomial as $n$ gets large and $p$ gets small, with $np$ held constant).
|
373 |
+
|
374 |
+
Adjust the sliders to see how the parameters affect the distribution:
|
375 |
+
|
376 |
+
*Note: The interactive visualization in this section was inspired by work from [liquidcarbon on GitHub](https://github.com/liquidcarbon).*
|
377 |
+
"""
|
378 |
+
)
|
379 |
+
return
|
380 |
+
|
381 |
+
|
382 |
+
@app.cell(hide_code=True)
|
383 |
+
def _(alpha_slider, chart, equation, mo, mu_slider):
|
384 |
+
mo.vstack(
|
385 |
+
[
|
386 |
+
mo.md(f"## Negative Binomial Distribution (Poisson + Overdispersion)\n{equation}"),
|
387 |
+
mo.hstack([mu_slider, alpha_slider], justify="start"),
|
388 |
+
chart,
|
389 |
+
], justify='space-around'
|
390 |
+
).center()
|
391 |
+
return
|
392 |
+
|
393 |
+
|
394 |
+
@app.cell(hide_code=True)
|
395 |
+
def _(mo):
|
396 |
+
mo.md(
|
397 |
+
r"""
|
398 |
+
## 🤔 Test Your Understanding
|
399 |
+
Pick which of these statements about binomial distributions you think are correct:
|
400 |
+
|
401 |
+
/// details | The variance of a binomial distribution is always equal to its mean
|
402 |
+
❌ Incorrect! The variance is $np(1-p)$ while the mean is $np$. They're only equal when $p=1$ (which is a degenerate case).
|
403 |
+
///
|
404 |
+
|
405 |
+
/// details | If $X \sim \text{Bin}(n, p)$ and $Y \sim \text{Bin}(n, 1-p)$, then $X$ and $Y$ have the same variance
|
406 |
+
✅ Correct! $\text{Var}(X) = np(1-p)$ and $\text{Var}(Y) = n(1-p)p$, which are the same.
|
407 |
+
///
|
408 |
+
|
409 |
+
/// details | As the number of trials increases, the binomial distribution approaches a normal distribution
|
410 |
+
✅ Correct! For large $n$, the binomial distribution can be approximated by a normal distribution with the same mean and variance.
|
411 |
+
///
|
412 |
+
|
413 |
+
/// details | The PMF of a binomial distribution is symmetric when $p = 0.5$
|
414 |
+
✅ Correct! When $p = 0.5$, the PMF is symmetric around $n/2$.
|
415 |
+
///
|
416 |
+
|
417 |
+
/// details | The sum of two independent binomial random variables with the same $p$ is also a binomial random variable
|
418 |
+
✅ Correct! If $X \sim \text{Bin}(n_1, p)$ and $Y \sim \text{Bin}(n_2, p)$ are independent, then $X + Y \sim \text{Bin}(n_1 + n_2, p)$.
|
419 |
+
///
|
420 |
+
|
421 |
+
/// details | The maximum value of the PMF for $\text{Bin}(n,p)$ always occurs at $k = np$
|
422 |
+
❌ Incorrect! The mode (maximum value of PMF) is either $\lfloor (n+1)p \rfloor$ or $\lceil (n+1)p-1 \rceil$ depending on whether $(n+1)p$ is an integer.
|
423 |
+
///
|
424 |
+
"""
|
425 |
+
)
|
426 |
+
return
|
427 |
+
|
428 |
+
|
429 |
+
@app.cell(hide_code=True)
|
430 |
+
def _(mo):
|
431 |
+
mo.md(
|
432 |
+
r"""
|
433 |
+
## Summary
|
434 |
+
|
435 |
+
So we've explored the binomial distribution, and honestly, it's one of the most practical probability distributions you'll encounter. Think about it — anytime you're counting successes in a fixed number of trials (like those coin flips we discussed), this is your go-to distribution.
|
436 |
+
|
437 |
+
I find it fascinating how the expectation is simply $np$. Such a clean, intuitive formula! And remember that neat visualization we saw earlier? When we adjusted the parameters, you could actually see how the distribution shape changes—becoming more symmetric as $n$ increases.
|
438 |
+
|
439 |
+
The key things to take away:
|
440 |
+
|
441 |
+
- The binomial distribution models the number of successes in $n$ independent trials, each with probability $p$ of success
|
442 |
+
|
443 |
+
- Its PMF is given by the formula $P(X=k) = {n \choose k}p^k(1-p)^{n-k}$, which lets us calculate exactly how likely any specific number of successes is
|
444 |
+
|
445 |
+
- The expected value is $E[X] = np$ and the variance is $Var(X) = np(1-p)$
|
446 |
+
|
447 |
+
- It's related to other distributions: it's essentially a sum of Bernoulli random variables, and connects to both the negative binomial and Poisson distributions
|
448 |
+
|
449 |
+
- In Python, the `scipy.stats.binom` module makes working with binomial distributions straightforward—you can generate random samples and calculate probabilities with just a few lines of code
|
450 |
+
|
451 |
+
You'll see the binomial distribution pop up everywhere—from computer science to quality control, epidemiology, and data science. Any time you have scenarios with binary outcomes over multiple trials, this distribution has you covered.
|
452 |
+
"""
|
453 |
+
)
|
454 |
+
return
|
455 |
+
|
456 |
+
|
457 |
+
@app.cell(hide_code=True)
|
458 |
+
def _(mo):
|
459 |
+
mo.md(r"""Appendix code (helper functions, variables, etc.):""")
|
460 |
+
return
|
461 |
+
|
462 |
+
|
463 |
+
@app.cell
|
464 |
+
def _():
|
465 |
+
import marimo as mo
|
466 |
+
return (mo,)
|
467 |
+
|
468 |
+
|
469 |
+
@app.cell(hide_code=True)
|
470 |
+
def _():
|
471 |
+
import numpy as np
|
472 |
+
import matplotlib.pyplot as plt
|
473 |
+
import scipy.stats as stats
|
474 |
+
import pandas as pd
|
475 |
+
import altair as alt
|
476 |
+
from wigglystuff import TangleSlider
|
477 |
+
return TangleSlider, alt, np, pd, plt, stats
|
478 |
+
|
479 |
+
|
480 |
+
@app.cell(hide_code=True)
|
481 |
+
def _(mo):
|
482 |
+
alpha_slider = mo.ui.slider(
|
483 |
+
value=0.1,
|
484 |
+
steps=[0, 0.01, 0.02, 0.03, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 1],
|
485 |
+
label="α (overdispersion)",
|
486 |
+
show_value=True,
|
487 |
+
)
|
488 |
+
mu_slider = mo.ui.slider(
|
489 |
+
value=100, start=1, stop=100, step=1, label="μ (mean)", show_value=True
|
490 |
+
)
|
491 |
+
return alpha_slider, mu_slider
|
492 |
+
|
493 |
+
|
494 |
+
@app.cell(hide_code=True)
|
495 |
+
def _():
|
496 |
+
equation = """
|
497 |
+
$$
|
498 |
+
P(X = k) = \\frac{\\Gamma(k + \\frac{1}{\\alpha})}{\\Gamma(k + 1) \\Gamma(\\frac{1}{\\alpha})} \\left( \\frac{1}{\\mu \\alpha + 1} \\right)^{\\frac{1}{\\alpha}} \\left( \\frac{\\mu \\alpha}{\\mu \\alpha + 1} \\right)^k
|
499 |
+
$$
|
500 |
+
|
501 |
+
$$
|
502 |
+
\\sigma^2 = \\mu + \\alpha \\mu^2
|
503 |
+
$$
|
504 |
+
"""
|
505 |
+
return (equation,)
|
506 |
+
|
507 |
+
|
508 |
+
@app.cell(hide_code=True)
|
509 |
+
def _(alpha_slider, alt, mu_slider, np, pd, stats):
|
510 |
+
mu = mu_slider.value
|
511 |
+
alpha = alpha_slider.value
|
512 |
+
n = 1000 - mu if alpha == 0 else 1 / alpha
|
513 |
+
p = n / (mu + n)
|
514 |
+
x = np.arange(0, mu * 3 + 1, 1)
|
515 |
+
df = pd.DataFrame(
|
516 |
+
{
|
517 |
+
"x": x,
|
518 |
+
"y": stats.nbinom.pmf(x, n, p),
|
519 |
+
"y_poi": stats.nbinom.pmf(x, 1000 - mu, 1 - mu / 1000),
|
520 |
+
}
|
521 |
+
)
|
522 |
+
r1k = stats.nbinom.rvs(n, p, size=1000)
|
523 |
+
df["in 95% CI"] = df["x"].between(*np.percentile(r1k, q=[2.5, 97.5]))
|
524 |
+
base = alt.Chart(df)
|
525 |
+
|
526 |
+
chart_poi = base.mark_bar(
|
527 |
+
fillOpacity=0.25, width=100 / mu, fill="magenta"
|
528 |
+
).encode(
|
529 |
+
x=alt.X("x").scale(domain=(-0.4, x.max() + 0.4), nice=False),
|
530 |
+
y=alt.Y("y_poi").scale(domain=(0, df.y_poi.max() * 1.1)).title(None),
|
531 |
+
)
|
532 |
+
chart_nb = base.mark_bar(fillOpacity=0.75, width=100 / mu).encode(
|
533 |
+
x="x",
|
534 |
+
y="y",
|
535 |
+
fill=alt.Fill("in 95% CI")
|
536 |
+
.scale(domain=[False, True], range=["#aaa", "#7c7"])
|
537 |
+
.legend(orient="bottom-right"),
|
538 |
+
)
|
539 |
+
|
540 |
+
chart = (chart_poi + chart_nb).configure_view(continuousWidth=450)
|
541 |
+
return alpha, base, chart, chart_nb, chart_poi, df, mu, n, p, r1k, x
|
542 |
+
|
543 |
+
|
544 |
+
if __name__ == "__main__":
|
545 |
+
app.run()
|
probability/15_poisson_distribution.py
ADDED
@@ -0,0 +1,805 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# /// script
|
2 |
+
# requires-python = ">=3.10"
|
3 |
+
# dependencies = [
|
4 |
+
# "marimo",
|
5 |
+
# "matplotlib==3.10.0",
|
6 |
+
# "numpy==2.2.4",
|
7 |
+
# "scipy==1.15.2",
|
8 |
+
# "altair==5.2.0",
|
9 |
+
# "wigglystuff==0.1.10",
|
10 |
+
# "pandas==2.2.3",
|
11 |
+
# ]
|
12 |
+
# ///
|
13 |
+
|
14 |
+
import marimo
|
15 |
+
|
16 |
+
__generated_with = "0.11.25"
|
17 |
+
app = marimo.App(width="medium", app_title="Poisson Distribution")
|
18 |
+
|
19 |
+
|
20 |
+
@app.cell(hide_code=True)
|
21 |
+
def _(mo):
|
22 |
+
mo.md(
|
23 |
+
r"""
|
24 |
+
# Poisson Distribution
|
25 |
+
|
26 |
+
_This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/poisson/), by Stanford professor Chris Piech._
|
27 |
+
|
28 |
+
A Poisson random variable gives the probability of a given number of events in a fixed interval of time (or space). It makes the Poisson assumption that events occur with a known constant mean rate and independently of the time since the last event.
|
29 |
+
"""
|
30 |
+
)
|
31 |
+
return
|
32 |
+
|
33 |
+
|
34 |
+
@app.cell(hide_code=True)
|
35 |
+
def _(mo):
|
36 |
+
mo.md(
|
37 |
+
r"""
|
38 |
+
## Poisson Random Variable Definition
|
39 |
+
|
40 |
+
$X \sim \text{Poisson}(\lambda)$ represents a Poisson random variable where:
|
41 |
+
|
42 |
+
- $X$ is our random variable (number of events)
|
43 |
+
- $\text{Poisson}$ indicates it follows a Poisson distribution
|
44 |
+
- $\lambda$ is the rate parameter (average number of events per time interval)
|
45 |
+
|
46 |
+
```
|
47 |
+
X ~ Poisson(λ)
|
48 |
+
↑ ↑ ↑
|
49 |
+
| | +-- Rate parameter:
|
50 |
+
| | average number of
|
51 |
+
| | events per interval
|
52 |
+
| +-- Indicates Poisson
|
53 |
+
| distribution
|
54 |
+
|
|
55 |
+
Our random variable
|
56 |
+
counting number of events
|
57 |
+
```
|
58 |
+
|
59 |
+
The Poisson distribution is particularly useful when:
|
60 |
+
|
61 |
+
1. Events occur independently of each other
|
62 |
+
2. The average rate of occurrence is constant
|
63 |
+
3. Two events cannot occur at exactly the same instant
|
64 |
+
4. The probability of an event is proportional to the length of the time interval
|
65 |
+
"""
|
66 |
+
)
|
67 |
+
return
|
68 |
+
|
69 |
+
|
70 |
+
@app.cell(hide_code=True)
|
71 |
+
def _(mo):
|
72 |
+
mo.md(
|
73 |
+
r"""
|
74 |
+
## Properties of Poisson Distribution
|
75 |
+
|
76 |
+
| Property | Formula |
|
77 |
+
|----------|---------|
|
78 |
+
| Notation | $X \sim \text{Poisson}(\lambda)$ |
|
79 |
+
| Description | Number of events in a fixed time frame if (a) events occur with a constant mean rate and (b) they occur independently of time since last event |
|
80 |
+
| Parameters | $\lambda \in \mathbb{R}^{+}$, the constant average rate |
|
81 |
+
| Support | $x \in \{0, 1, \dots\}$ |
|
82 |
+
| PMF equation | $P(X=x) = \frac{\lambda^x e^{-\lambda}}{x!}$ |
|
83 |
+
| Expectation | $E[X] = \lambda$ |
|
84 |
+
| Variance | $\text{Var}(X) = \lambda$ |
|
85 |
+
|
86 |
+
Note that unlike many other distributions, the Poisson distribution's mean and variance are equal, both being $\lambda$.
|
87 |
+
|
88 |
+
Let's explore how the Poisson distribution changes with different rate parameters.
|
89 |
+
"""
|
90 |
+
)
|
91 |
+
return
|
92 |
+
|
93 |
+
|
94 |
+
@app.cell(hide_code=True)
|
95 |
+
def _(TangleSlider, mo):
|
96 |
+
# interactive elements using TangleSlider
|
97 |
+
lambda_slider = mo.ui.anywidget(TangleSlider(
|
98 |
+
amount=5,
|
99 |
+
min_value=0.1,
|
100 |
+
max_value=20,
|
101 |
+
step=0.1,
|
102 |
+
digits=1,
|
103 |
+
suffix=" events"
|
104 |
+
))
|
105 |
+
|
106 |
+
# interactive controls
|
107 |
+
_controls = mo.vstack([
|
108 |
+
mo.md("### Adjust the Rate Parameter to See How Poisson Distribution Changes"),
|
109 |
+
mo.hstack([
|
110 |
+
mo.md("**Rate parameter (λ):** "),
|
111 |
+
lambda_slider,
|
112 |
+
mo.md("**events per interval.** Higher values shift the distribution rightward and make it more spread out.")
|
113 |
+
], justify="start"),
|
114 |
+
])
|
115 |
+
_controls
|
116 |
+
return (lambda_slider,)
|
117 |
+
|
118 |
+
|
119 |
+
@app.cell(hide_code=True)
|
120 |
+
def _(lambda_slider, np, plt, stats):
|
121 |
+
def create_poisson_pmf_plot(lambda_value):
|
122 |
+
"""Create a visualization of Poisson PMF with annotations for mean and variance."""
|
123 |
+
# PMF for values
|
124 |
+
max_x = max(20, int(lambda_value * 3)) # Show at least up to 3*lambda
|
125 |
+
x = np.arange(0, max_x + 1)
|
126 |
+
pmf = stats.poisson.pmf(x, lambda_value)
|
127 |
+
|
128 |
+
# Relevant key statistics
|
129 |
+
mean = lambda_value # For Poisson, mean = lambda
|
130 |
+
variance = lambda_value # For Poisson, variance = lambda
|
131 |
+
std_dev = np.sqrt(variance)
|
132 |
+
|
133 |
+
# plot
|
134 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
135 |
+
|
136 |
+
# PMF as bars
|
137 |
+
ax.bar(x, pmf, color='royalblue', alpha=0.7, label=f'PMF: P(X=k)')
|
138 |
+
|
139 |
+
# for the PMF values
|
140 |
+
ax.plot(x, pmf, 'ro-', alpha=0.6, label='PMF line')
|
141 |
+
|
142 |
+
# Vertical lines - mean and key values
|
143 |
+
ax.axvline(x=mean, color='green', linestyle='--', linewidth=2,
|
144 |
+
label=f'Mean: {mean:.2f}')
|
145 |
+
|
146 |
+
# Stdev region
|
147 |
+
ax.axvspan(mean - std_dev, mean + std_dev, alpha=0.2, color='green',
|
148 |
+
label=f'±1 Std Dev: {std_dev:.2f}')
|
149 |
+
|
150 |
+
ax.set_xlabel('Number of Events (k)')
|
151 |
+
ax.set_ylabel('Probability: P(X=k)')
|
152 |
+
ax.set_title(f'Poisson Distribution with λ={lambda_value:.1f}')
|
153 |
+
|
154 |
+
# annotations
|
155 |
+
ax.annotate(f'E[X] = {mean:.2f}',
|
156 |
+
xy=(mean, stats.poisson.pmf(int(mean), lambda_value)),
|
157 |
+
xytext=(mean + 1, max(pmf) * 0.8),
|
158 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
159 |
+
|
160 |
+
ax.annotate(f'Var(X) = {variance:.2f}',
|
161 |
+
xy=(mean, stats.poisson.pmf(int(mean), lambda_value) / 2),
|
162 |
+
xytext=(mean + 1, max(pmf) * 0.6),
|
163 |
+
arrowprops=dict(facecolor='black', shrink=0.05, width=1))
|
164 |
+
|
165 |
+
ax.grid(alpha=0.3)
|
166 |
+
ax.legend()
|
167 |
+
|
168 |
+
plt.tight_layout()
|
169 |
+
return plt.gca()
|
170 |
+
|
171 |
+
# Get parameter from slider and create plot
|
172 |
+
_lambda = lambda_slider.amount
|
173 |
+
create_poisson_pmf_plot(_lambda)
|
174 |
+
return (create_poisson_pmf_plot,)
|
175 |
+
|
176 |
+
|
177 |
+
@app.cell(hide_code=True)
|
178 |
+
def _(mo):
|
179 |
+
mo.md(
|
180 |
+
r"""
|
181 |
+
## Poisson Intuition: Relation to Binomial Distribution
|
182 |
+
|
183 |
+
The Poisson distribution can be derived as a limiting case of the [binomial distribution](http://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/14_binomial_distribution.py).
|
184 |
+
|
185 |
+
Let's work on a practical example: predicting the number of ride-sharing requests in a specific area over a one-minute interval. From historical data, we know that the average number of requests per minute is $\lambda = 5$.
|
186 |
+
|
187 |
+
We could approximate this using a binomial distribution by dividing our minute into smaller intervals. For example, we can divide a minute into 60 seconds and treat each second as a [Bernoulli trial](http://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/13_bernoulli_distribution.py) - either there's a request (success) or there isn't (failure).
|
188 |
+
|
189 |
+
Let's visualize this concept:
|
190 |
+
"""
|
191 |
+
)
|
192 |
+
return
|
193 |
+
|
194 |
+
|
195 |
+
@app.cell(hide_code=True)
|
196 |
+
def _(fig_to_image, mo, plt):
|
197 |
+
def create_time_division_visualization():
|
198 |
+
# visualization of dividing a minute into 60 seconds
|
199 |
+
fig, ax = plt.subplots(figsize=(12, 2))
|
200 |
+
|
201 |
+
# Example events hardcoded at 2.75s and 7.12s
|
202 |
+
events = [2.75, 7.12]
|
203 |
+
|
204 |
+
# array of 60 rectangles
|
205 |
+
for i in range(60):
|
206 |
+
color = 'royalblue' if any(i <= e < i+1 for e in events) else 'lightgray'
|
207 |
+
ax.add_patch(plt.Rectangle((i, 0), 0.9, 1, color=color))
|
208 |
+
|
209 |
+
# markers for events
|
210 |
+
for e in events:
|
211 |
+
ax.plot(e, 0.5, 'ro', markersize=10)
|
212 |
+
|
213 |
+
# labels
|
214 |
+
ax.set_xlim(0, 60)
|
215 |
+
ax.set_ylim(0, 1)
|
216 |
+
ax.set_yticks([])
|
217 |
+
ax.set_xticks([0, 15, 30, 45, 60])
|
218 |
+
ax.set_xticklabels(['0s', '15s', '30s', '45s', '60s'])
|
219 |
+
ax.set_xlabel('Time (seconds)')
|
220 |
+
ax.set_title('One Minute Divided into 60 Second Intervals')
|
221 |
+
|
222 |
+
plt.tight_layout()
|
223 |
+
plt.gca()
|
224 |
+
return fig, events, i
|
225 |
+
|
226 |
+
# Create visualization and convert to image
|
227 |
+
_fig, _events, i = create_time_division_visualization()
|
228 |
+
_img = mo.image(fig_to_image(_fig), width="100%")
|
229 |
+
|
230 |
+
# explanation
|
231 |
+
_explanation = mo.md(
|
232 |
+
r"""
|
233 |
+
In this visualization:
|
234 |
+
|
235 |
+
- Each rectangle represents a 1-second interval
|
236 |
+
- Blue rectangles indicate intervals where an event occurred
|
237 |
+
- Red dots show the actual event times (2.75s and 7.12s)
|
238 |
+
|
239 |
+
If we treat this as a binomial experiment with 60 trials (seconds), we can calculate probabilities using the binomial PMF. But there's a problem: what if multiple events occur within the same second? To address this, we can divide our minute into smaller intervals.
|
240 |
+
"""
|
241 |
+
)
|
242 |
+
mo.vstack([_fig, _explanation])
|
243 |
+
return create_time_division_visualization, i
|
244 |
+
|
245 |
+
|
246 |
+
@app.cell(hide_code=True)
|
247 |
+
def _(mo):
|
248 |
+
mo.md(
|
249 |
+
r"""
|
250 |
+
The total number of requests received over the minute can be approximated as the sum of the sixty indicator variables, which conveniently matches the description of a binomial — a sum of Bernoullis.
|
251 |
+
|
252 |
+
Specifically, if we define $X$ to be the number of requests in a minute, $X$ is a binomial with $n=60$ trials. What is the probability, $p$, of a success on a single trial? To make the expectation of $X$ equal the observed historical average $\lambda$, we should choose $p$ so that:
|
253 |
+
|
254 |
+
\begin{align}
|
255 |
+
\lambda &= E[X] && \text{Expectation matches historical average} \\
|
256 |
+
\lambda &= n \cdot p && \text{Expectation of a Binomial is } n \cdot p \\
|
257 |
+
p &= \frac{\lambda}{n} && \text{Solving for $p$}
|
258 |
+
\end{align}
|
259 |
+
|
260 |
+
In this case, since $\lambda=5$ and $n=60$, we should choose $p=\frac{5}{60}=\frac{1}{12}$ and state that $X \sim \text{Bin}(n=60, p=\frac{5}{60})$. Now we can calculate the probability of different numbers of requests using the binomial PMF:
|
261 |
+
|
262 |
+
$P(X = x) = {n \choose x} p^x (1-p)^{n-x}$
|
263 |
+
|
264 |
+
For example:
|
265 |
+
|
266 |
+
\begin{align}
|
267 |
+
P(X=1) &= {60 \choose 1} (5/60)^1 (55/60)^{60-1} \approx 0.0295 \\
|
268 |
+
P(X=2) &= {60 \choose 2} (5/60)^2 (55/60)^{60-2} \approx 0.0790 \\
|
269 |
+
P(X=3) &= {60 \choose 3} (5/60)^3 (55/60)^{60-3} \approx 0.1389
|
270 |
+
\end{align}
|
271 |
+
|
272 |
+
This is a good approximation, but it doesn't account for the possibility of multiple events in a single second. One solution is to divide our minute into even more fine-grained intervals. Let's try 600 deciseconds (tenths of a second):
|
273 |
+
"""
|
274 |
+
)
|
275 |
+
return
|
276 |
+
|
277 |
+
|
278 |
+
@app.cell(hide_code=True)
|
279 |
+
def _(fig_to_image, mo, plt):
|
280 |
+
def create_decisecond_visualization(e_value):
|
281 |
+
# (Just showing the first 100 for clarity)
|
282 |
+
fig, ax = plt.subplots(figsize=(12, 2))
|
283 |
+
|
284 |
+
# Example events at 2.75s and 7.12s (convert to deciseconds)
|
285 |
+
events = [27.5, 71.2]
|
286 |
+
|
287 |
+
for i in range(100):
|
288 |
+
color = 'royalblue' if any(i <= event_val < i + 1 for event_val in events) else 'lightgray'
|
289 |
+
ax.add_patch(plt.Rectangle((i, 0), 0.9, 1, color=color))
|
290 |
+
|
291 |
+
# Markers for events
|
292 |
+
for event in events:
|
293 |
+
if event < 100: # Only show events in our visible range
|
294 |
+
ax.plot(event/10, 0.5, 'ro', markersize=10) # Divide by 10 to convert to deciseconds
|
295 |
+
|
296 |
+
# Add labels
|
297 |
+
ax.set_xlim(0, 100)
|
298 |
+
ax.set_ylim(0, 1)
|
299 |
+
ax.set_yticks([])
|
300 |
+
ax.set_xticks([0, 20, 40, 60, 80, 100])
|
301 |
+
ax.set_xticklabels(['0s', '2s', '4s', '6s', '8s', '10s'])
|
302 |
+
ax.set_xlabel('Time (first 10 seconds shown)')
|
303 |
+
ax.set_title('One Minute Divided into 600 Decisecond Intervals (first 100 shown)')
|
304 |
+
|
305 |
+
plt.tight_layout()
|
306 |
+
plt.gca()
|
307 |
+
return fig
|
308 |
+
|
309 |
+
# Create viz and convert to image
|
310 |
+
_fig = create_decisecond_visualization(e_value=5)
|
311 |
+
_img = mo.image(fig_to_image(_fig), width="100%")
|
312 |
+
|
313 |
+
# Explanation
|
314 |
+
_explanation = mo.md(
|
315 |
+
r"""
|
316 |
+
With $n=600$ and $p=\frac{5}{600}=\frac{1}{120}$, we can recalculate our probabilities:
|
317 |
+
|
318 |
+
\begin{align}
|
319 |
+
P(X=1) &= {600 \choose 1} (5/600)^1 (595/600)^{600-1} \approx 0.0333 \\
|
320 |
+
P(X=2) &= {600 \choose 2} (5/600)^2 (595/600)^{600-2} \approx 0.0837 \\
|
321 |
+
P(X=3) &= {600 \choose 3} (5/600)^3 (595/600)^{600-3} \approx 0.1402
|
322 |
+
\end{align}
|
323 |
+
|
324 |
+
As we make our intervals smaller (increasing $n$), our approximation becomes more accurate.
|
325 |
+
"""
|
326 |
+
)
|
327 |
+
mo.vstack([_fig, _explanation])
|
328 |
+
return (create_decisecond_visualization,)
|
329 |
+
|
330 |
+
|
331 |
+
@app.cell(hide_code=True)
|
332 |
+
def _(mo):
|
333 |
+
mo.md(
|
334 |
+
r"""
|
335 |
+
## The Binomial Distribution in the Limit
|
336 |
+
|
337 |
+
What happens if we continue dividing our time interval into smaller and smaller pieces? Let's explore how the probabilities change as we increase the number of intervals:
|
338 |
+
"""
|
339 |
+
)
|
340 |
+
return
|
341 |
+
|
342 |
+
|
343 |
+
@app.cell(hide_code=True)
|
344 |
+
def _(mo):
|
345 |
+
intervals_slider = mo.ui.slider(
|
346 |
+
start = 60,
|
347 |
+
stop = 10000,
|
348 |
+
step=100,
|
349 |
+
value=600,
|
350 |
+
label="Number of intervals to divide a minute")
|
351 |
+
return (intervals_slider,)
|
352 |
+
|
353 |
+
|
354 |
+
@app.cell(hide_code=True)
|
355 |
+
def _(intervals_slider):
|
356 |
+
intervals_slider
|
357 |
+
return
|
358 |
+
|
359 |
+
|
360 |
+
@app.cell(hide_code=True)
|
361 |
+
def _(intervals_slider, np, pd, plt, stats):
|
362 |
+
def create_comparison_plot(n, lambda_value):
|
363 |
+
# Calculate probability
|
364 |
+
p = lambda_value / n
|
365 |
+
|
366 |
+
# Binomial probabilities
|
367 |
+
x_values = np.arange(0, 15)
|
368 |
+
binom_pmf = stats.binom.pmf(x_values, n, p)
|
369 |
+
|
370 |
+
# True Poisson probabilities
|
371 |
+
poisson_pmf = stats.poisson.pmf(x_values, lambda_value)
|
372 |
+
|
373 |
+
# DF for comparison
|
374 |
+
df = pd.DataFrame({
|
375 |
+
'Events': x_values,
|
376 |
+
f'Binomial(n={n}, p={p:.6f})': binom_pmf,
|
377 |
+
f'Poisson(λ=5)': poisson_pmf,
|
378 |
+
'Difference': np.abs(binom_pmf - poisson_pmf)
|
379 |
+
})
|
380 |
+
|
381 |
+
# Plot both PMFs
|
382 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
383 |
+
|
384 |
+
# Bar plot for the binomial
|
385 |
+
ax.bar(x_values - 0.2, binom_pmf, width=0.4, alpha=0.7,
|
386 |
+
color='royalblue', label=f'Binomial(n={n}, p={p:.6f})')
|
387 |
+
|
388 |
+
# Bar plot for the Poisson
|
389 |
+
ax.bar(x_values + 0.2, poisson_pmf, width=0.4, alpha=0.7,
|
390 |
+
color='crimson', label='Poisson(λ=5)')
|
391 |
+
|
392 |
+
# Labels and title
|
393 |
+
ax.set_xlabel('Number of Events (k)')
|
394 |
+
ax.set_ylabel('Probability')
|
395 |
+
ax.set_title(f'Comparison of Binomial and Poisson PMFs with n={n}')
|
396 |
+
ax.legend()
|
397 |
+
ax.set_xticks(x_values)
|
398 |
+
ax.grid(alpha=0.3)
|
399 |
+
|
400 |
+
plt.tight_layout()
|
401 |
+
return df, fig, n, p
|
402 |
+
|
403 |
+
# Number of intervals from the slider
|
404 |
+
n = intervals_slider.value
|
405 |
+
_lambda = 5 # Fixed lambda for our example
|
406 |
+
|
407 |
+
# Cromparison plot
|
408 |
+
df, fig, n, p = create_comparison_plot(n, _lambda)
|
409 |
+
return create_comparison_plot, df, fig, n, p
|
410 |
+
|
411 |
+
|
412 |
+
@app.cell(hide_code=True)
|
413 |
+
def _(df, fig, fig_to_image, mo, n, p):
|
414 |
+
# table of values
|
415 |
+
_styled_df = df.style.format({
|
416 |
+
f'Binomial(n={n}, p={p:.6f})': '{:.6f}',
|
417 |
+
f'Poisson(λ=5)': '{:.6f}',
|
418 |
+
'Difference': '{:.6f}'
|
419 |
+
})
|
420 |
+
|
421 |
+
# Calculate the max absolute difference
|
422 |
+
_max_diff = df['Difference'].max()
|
423 |
+
|
424 |
+
# output
|
425 |
+
_chart = mo.image(fig_to_image(fig), width="100%")
|
426 |
+
_explanation = mo.md(f"**Maximum absolute difference between distributions: {_max_diff:.6f}**")
|
427 |
+
_table = mo.ui.table(df)
|
428 |
+
|
429 |
+
mo.vstack([_chart, _explanation, _table])
|
430 |
+
return
|
431 |
+
|
432 |
+
|
433 |
+
@app.cell(hide_code=True)
|
434 |
+
def _(mo):
|
435 |
+
mo.md(
|
436 |
+
r"""
|
437 |
+
As you can see from the interactive comparison above, as the number of intervals increases, the binomial distribution approaches the Poisson distribution! This is not a coincidence - the Poisson distribution is actually the limiting case of the binomial distribution when:
|
438 |
+
|
439 |
+
- The number of trials $n$ approaches infinity
|
440 |
+
- The probability of success $p$ approaches zero
|
441 |
+
- The product $np = \lambda$ remains constant
|
442 |
+
|
443 |
+
This relationship is why the Poisson distribution is so useful - it's easier to work with than a binomial with a very large number of trials and a very small probability of success.
|
444 |
+
|
445 |
+
## Derivation of the Poisson PMF
|
446 |
+
|
447 |
+
Let's derive the Poisson PMF by taking the limit of the binomial PMF as $n \to \infty$. We start with:
|
448 |
+
|
449 |
+
$P(X=x) = \lim_{n \rightarrow \infty} {n \choose x} (\lambda / n)^x(1-\lambda/n)^{n-x}$
|
450 |
+
|
451 |
+
While this expression looks intimidating, it simplifies nicely:
|
452 |
+
|
453 |
+
\begin{align}
|
454 |
+
P(X=x)
|
455 |
+
&= \lim_{n \rightarrow \infty} {n \choose x} (\lambda / n)^x(1-\lambda/n)^{n-x}
|
456 |
+
&& \text{Start: binomial in the limit}\\
|
457 |
+
&= \lim_{n \rightarrow \infty}
|
458 |
+
{n \choose x} \cdot
|
459 |
+
\frac{\lambda^x}{n^x} \cdot
|
460 |
+
\frac{(1-\lambda/n)^{n}}{(1-\lambda/n)^{x}}
|
461 |
+
&& \text{Expanding the power terms} \\
|
462 |
+
&= \lim_{n \rightarrow \infty}
|
463 |
+
\frac{n!}{(n-x)!x!} \cdot
|
464 |
+
\frac{\lambda^x}{n^x} \cdot
|
465 |
+
\frac{(1-\lambda/n)^{n}}{(1-\lambda/n)^{x}}
|
466 |
+
&& \text{Expanding the binomial term} \\
|
467 |
+
&= \lim_{n \rightarrow \infty}
|
468 |
+
\frac{n!}{(n-x)!x!} \cdot
|
469 |
+
\frac{\lambda^x}{n^x} \cdot
|
470 |
+
\frac{e^{-\lambda}}{(1-\lambda/n)^{x}}
|
471 |
+
&& \text{Using limit rule } \lim_{n \rightarrow \infty}(1-\lambda/n)^{n} = e^{-\lambda}\\
|
472 |
+
&= \lim_{n \rightarrow \infty}
|
473 |
+
\frac{n!}{(n-x)!x!} \cdot
|
474 |
+
\frac{\lambda^x}{n^x} \cdot
|
475 |
+
\frac{e^{-\lambda}}{1}
|
476 |
+
&& \text{As } n \to \infty \text{, } \lambda/n \to 0\\
|
477 |
+
&= \lim_{n \rightarrow \infty}
|
478 |
+
\frac{n!}{(n-x)!} \cdot
|
479 |
+
\frac{1}{x!} \cdot
|
480 |
+
\frac{\lambda^x}{n^x} \cdot
|
481 |
+
e^{-\lambda}
|
482 |
+
&& \text{Rearranging terms}\\
|
483 |
+
&= \lim_{n \rightarrow \infty}
|
484 |
+
\frac{n^x}{1} \cdot
|
485 |
+
\frac{1}{x!} \cdot
|
486 |
+
\frac{\lambda^x}{n^x} \cdot
|
487 |
+
e^{-\lambda}
|
488 |
+
&& \text{As } n \to \infty \text{, } \frac{n!}{(n-x)!} \approx n^x\\
|
489 |
+
&= \lim_{n \rightarrow \infty}
|
490 |
+
\frac{\lambda^x}{x!} \cdot
|
491 |
+
e^{-\lambda}
|
492 |
+
&& \text{Canceling } n^x\\
|
493 |
+
&=
|
494 |
+
\frac{\lambda^x \cdot e^{-\lambda}}{x!}
|
495 |
+
&& \text{Simplifying}\\
|
496 |
+
\end{align}
|
497 |
+
|
498 |
+
This gives us our elegant Poisson PMF formula: $P(X=x) = \frac{\lambda^x \cdot e^{-\lambda}}{x!}$
|
499 |
+
"""
|
500 |
+
)
|
501 |
+
return
|
502 |
+
|
503 |
+
|
504 |
+
@app.cell(hide_code=True)
|
505 |
+
def _(mo):
|
506 |
+
mo.md(
|
507 |
+
r"""
|
508 |
+
## Poisson Distribution in Python
|
509 |
+
|
510 |
+
Python's `scipy.stats` module provides functions to work with the Poisson distribution. Let's see how to calculate probabilities and generate random samples.
|
511 |
+
|
512 |
+
First, let's calculate some probabilities for our ride-sharing example with $\lambda = 5$:
|
513 |
+
"""
|
514 |
+
)
|
515 |
+
return
|
516 |
+
|
517 |
+
|
518 |
+
@app.cell
|
519 |
+
def _(stats):
|
520 |
+
_lambda = 5
|
521 |
+
|
522 |
+
# Calculate probabilities for X = 1, 2, 3
|
523 |
+
p_1 = stats.poisson.pmf(1, _lambda)
|
524 |
+
p_2 = stats.poisson.pmf(2, _lambda)
|
525 |
+
p_3 = stats.poisson.pmf(3, _lambda)
|
526 |
+
|
527 |
+
print(f"P(X=1) = {p_1:.5f}")
|
528 |
+
print(f"P(X=2) = {p_2:.5f}")
|
529 |
+
print(f"P(X=3) = {p_3:.5f}")
|
530 |
+
|
531 |
+
# Calculate cumulative probability P(X ≤ 3)
|
532 |
+
p_leq_3 = stats.poisson.cdf(3, _lambda)
|
533 |
+
print(f"P(X≤3) = {p_leq_3:.5f}")
|
534 |
+
|
535 |
+
# Calculate probability P(X > 10)
|
536 |
+
p_gt_10 = 1 - stats.poisson.cdf(10, _lambda)
|
537 |
+
print(f"P(X>10) = {p_gt_10:.5f}")
|
538 |
+
return p_1, p_2, p_3, p_gt_10, p_leq_3
|
539 |
+
|
540 |
+
|
541 |
+
@app.cell(hide_code=True)
|
542 |
+
def _(mo):
|
543 |
+
mo.md(r"""We can also generate random samples from a Poisson distribution and visualize their distribution:""")
|
544 |
+
return
|
545 |
+
|
546 |
+
|
547 |
+
@app.cell(hide_code=True)
|
548 |
+
def _(np, plt, stats):
|
549 |
+
def create_samples_plot(lambda_value, sample_size=1000):
|
550 |
+
# Random samples
|
551 |
+
samples = stats.poisson.rvs(lambda_value, size=sample_size)
|
552 |
+
|
553 |
+
# theoretical PMF
|
554 |
+
x_values = np.arange(0, max(samples) + 1)
|
555 |
+
pmf_values = stats.poisson.pmf(x_values, lambda_value)
|
556 |
+
|
557 |
+
# histograms to compare
|
558 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
559 |
+
|
560 |
+
# samples as a histogram
|
561 |
+
ax.hist(samples, bins=np.arange(-0.5, max(samples) + 1.5, 1),
|
562 |
+
alpha=0.7, density=True, label='Random Samples')
|
563 |
+
|
564 |
+
# theoretical PMF
|
565 |
+
ax.plot(x_values, pmf_values, 'ro-', label='Theoretical PMF')
|
566 |
+
|
567 |
+
# labels and title
|
568 |
+
ax.set_xlabel('Number of Events')
|
569 |
+
ax.set_ylabel('Relative Frequency / Probability')
|
570 |
+
ax.set_title(f'1000 Random Samples from Poisson(λ={lambda_value})')
|
571 |
+
ax.legend()
|
572 |
+
ax.grid(alpha=0.3)
|
573 |
+
|
574 |
+
# annotations
|
575 |
+
ax.annotate(f'Sample Mean: {np.mean(samples):.2f}',
|
576 |
+
xy=(0.7, 0.9), xycoords='axes fraction',
|
577 |
+
bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3))
|
578 |
+
ax.annotate(f'Theoretical Mean: {lambda_value:.2f}',
|
579 |
+
xy=(0.7, 0.8), xycoords='axes fraction',
|
580 |
+
bbox=dict(boxstyle='round,pad=0.5', fc='lightgreen', alpha=0.3))
|
581 |
+
|
582 |
+
plt.tight_layout()
|
583 |
+
return plt.gca()
|
584 |
+
|
585 |
+
# Use a lambda value of 5 for this example
|
586 |
+
_lambda = 5
|
587 |
+
create_samples_plot(_lambda)
|
588 |
+
return (create_samples_plot,)
|
589 |
+
|
590 |
+
|
591 |
+
@app.cell(hide_code=True)
|
592 |
+
def _(mo):
|
593 |
+
mo.md(
|
594 |
+
r"""
|
595 |
+
## Changing Time Frames
|
596 |
+
|
597 |
+
One important property of the Poisson distribution is that the rate parameter $\lambda$ scales linearly with the time interval. If events occur at a rate of $\lambda$ per unit time, then over a period of $t$ units, the rate parameter becomes $\lambda \cdot t$.
|
598 |
+
|
599 |
+
For example, if a website receives an average of 5 requests per minute, what is the distribution of requests over a 20-minute period?
|
600 |
+
|
601 |
+
The rate parameter for the 20-minute period would be $\lambda = 5 \cdot 20 = 100$ requests.
|
602 |
+
"""
|
603 |
+
)
|
604 |
+
return
|
605 |
+
|
606 |
+
|
607 |
+
@app.cell(hide_code=True)
|
608 |
+
def _(mo):
|
609 |
+
rate_slider = mo.ui.slider(
|
610 |
+
start = 0.1,
|
611 |
+
stop = 10,
|
612 |
+
step=0.1,
|
613 |
+
value=5,
|
614 |
+
label="Rate per unit time (λ)"
|
615 |
+
)
|
616 |
+
|
617 |
+
time_slider = mo.ui.slider(
|
618 |
+
start = 1,
|
619 |
+
stop = 60,
|
620 |
+
step=1,
|
621 |
+
value=20,
|
622 |
+
label="Time period (t units)"
|
623 |
+
)
|
624 |
+
|
625 |
+
controls = mo.vstack([
|
626 |
+
mo.md("### Adjust Parameters to See How Time Scaling Works"),
|
627 |
+
mo.hstack([rate_slider, time_slider], justify="space-between")
|
628 |
+
])
|
629 |
+
return controls, rate_slider, time_slider
|
630 |
+
|
631 |
+
|
632 |
+
@app.cell
|
633 |
+
def _(controls):
|
634 |
+
controls.center()
|
635 |
+
return
|
636 |
+
|
637 |
+
|
638 |
+
@app.cell(hide_code=True)
|
639 |
+
def _(mo, np, plt, rate_slider, stats, time_slider):
|
640 |
+
def create_time_scaling_plot(rate, time_period):
|
641 |
+
# scaled rate parameter
|
642 |
+
lambda_value = rate * time_period
|
643 |
+
|
644 |
+
# PMF for values
|
645 |
+
max_x = max(30, int(lambda_value * 1.5))
|
646 |
+
x = np.arange(0, max_x + 1)
|
647 |
+
pmf = stats.poisson.pmf(x, lambda_value)
|
648 |
+
|
649 |
+
# plot
|
650 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
651 |
+
|
652 |
+
# PMF as bars
|
653 |
+
ax.bar(x, pmf, color='royalblue', alpha=0.7,
|
654 |
+
label=f'PMF: Poisson(λ={lambda_value:.1f})')
|
655 |
+
|
656 |
+
# vertical line for mean
|
657 |
+
ax.axvline(x=lambda_value, color='red', linestyle='--', linewidth=2,
|
658 |
+
label=f'Mean = {lambda_value:.1f}')
|
659 |
+
|
660 |
+
# labels and title
|
661 |
+
ax.set_xlabel('Number of Events')
|
662 |
+
ax.set_ylabel('Probability')
|
663 |
+
ax.set_title(f'Poisson Distribution Over {time_period} Units (Rate = {rate}/unit)')
|
664 |
+
|
665 |
+
# better visualization if lambda is large
|
666 |
+
if lambda_value > 10:
|
667 |
+
ax.set_xlim(lambda_value - 4*np.sqrt(lambda_value),
|
668 |
+
lambda_value + 4*np.sqrt(lambda_value))
|
669 |
+
|
670 |
+
ax.legend()
|
671 |
+
ax.grid(alpha=0.3)
|
672 |
+
|
673 |
+
plt.tight_layout()
|
674 |
+
|
675 |
+
# Create relevant info markdown
|
676 |
+
info_text = f"""
|
677 |
+
When the rate is **{rate}** events per unit time and we observe for **{time_period}** units:
|
678 |
+
|
679 |
+
- The expected number of events is **{lambda_value:.1f}**
|
680 |
+
- The variance is also **{lambda_value:.1f}**
|
681 |
+
- The standard deviation is **{np.sqrt(lambda_value):.2f}**
|
682 |
+
- P(X=0) = {stats.poisson.pmf(0, lambda_value):.4f} (probability of no events)
|
683 |
+
- P(X≥10) = {1 - stats.poisson.cdf(9, lambda_value):.4f} (probability of 10 or more events)
|
684 |
+
"""
|
685 |
+
|
686 |
+
return plt.gca(), info_text
|
687 |
+
|
688 |
+
# parameters from sliders
|
689 |
+
_rate = rate_slider.value
|
690 |
+
_time = time_slider.value
|
691 |
+
|
692 |
+
# store
|
693 |
+
_plot, _info_text = create_time_scaling_plot(_rate, _time)
|
694 |
+
|
695 |
+
# Display info as markdown
|
696 |
+
info = mo.md(_info_text)
|
697 |
+
|
698 |
+
mo.vstack([_plot, info], justify="center")
|
699 |
+
return create_time_scaling_plot, info
|
700 |
+
|
701 |
+
|
702 |
+
@app.cell(hide_code=True)
|
703 |
+
def _(mo):
|
704 |
+
mo.md(
|
705 |
+
r"""
|
706 |
+
## 🤔 Test Your Understanding
|
707 |
+
Pick which of these statements about Poisson distributions you think are correct:
|
708 |
+
|
709 |
+
/// details | The variance of a Poisson distribution is always equal to its mean
|
710 |
+
✅ Correct! For a Poisson distribution with parameter $\lambda$, both the mean and variance equal $\lambda$.
|
711 |
+
///
|
712 |
+
|
713 |
+
/// details | The Poisson distribution can be used to model the number of successes in a fixed number of trials
|
714 |
+
❌ Incorrect! That's the binomial distribution. The Poisson distribution models the number of events in a fixed interval of time or space, not a fixed number of trials.
|
715 |
+
///
|
716 |
+
|
717 |
+
/// details | If $X \sim \text{Poisson}(\lambda_1)$ and $Y \sim \text{Poisson}(\lambda_2)$ are independent, then $X + Y \sim \text{Poisson}(\lambda_1 + \lambda_2)$
|
718 |
+
✅ Correct! The sum of independent Poisson random variables is also a Poisson random variable with parameter equal to the sum of the individual parameters.
|
719 |
+
///
|
720 |
+
|
721 |
+
/// details | As $\lambda$ increases, the Poisson distribution approaches a normal distribution
|
722 |
+
✅ Correct! For large values of $\lambda$ (generally $\lambda > 10$), the Poisson distribution is approximately normal with mean $\lambda$ and variance $\lambda$.
|
723 |
+
///
|
724 |
+
|
725 |
+
/// details | The probability of zero events in a Poisson process is always less than the probability of one event
|
726 |
+
❌ Incorrect! For $\lambda < 1$, the probability of zero events ($e^{-\lambda}$) is actually greater than the probability of one event ($\lambda e^{-\lambda}$).
|
727 |
+
///
|
728 |
+
|
729 |
+
/// details | The Poisson distribution has a single parameter $\lambda$, which always equals the average number of events per time period
|
730 |
+
✅ Correct! The parameter $\lambda$ represents the average rate of events, and it uniquely defines the distribution.
|
731 |
+
///
|
732 |
+
"""
|
733 |
+
)
|
734 |
+
return
|
735 |
+
|
736 |
+
|
737 |
+
@app.cell(hide_code=True)
|
738 |
+
def _(mo):
|
739 |
+
mo.md(
|
740 |
+
r"""
|
741 |
+
## Summary
|
742 |
+
|
743 |
+
The Poisson distribution is one of those incredibly useful tools that shows up all over the place. I've always found it fascinating how such a simple formula can model so many real-world phenomena - from website traffic to radioactive decay.
|
744 |
+
|
745 |
+
What makes the Poisson really cool is that it emerges naturally as we try to model rare events occurring over a continuous interval. Remember that visualization where we kept dividing time into smaller and smaller chunks? As we showed, when you take a binomial distribution and let the number of trials approach infinity while keeping the expected value constant, you end up with the elegant Poisson formula.
|
746 |
+
|
747 |
+
The key things to remember about the Poisson distribution:
|
748 |
+
|
749 |
+
- It models the number of events occurring in a fixed interval of time or space, assuming events happen at a constant average rate and independently of each other
|
750 |
+
|
751 |
+
- Its PMF is given by the elegantly simple formula $P(X=k) = \frac{\lambda^k e^{-\lambda}}{k!}$
|
752 |
+
|
753 |
+
- Both the mean and variance equal the parameter $\lambda$, which represents the average number of events per interval
|
754 |
+
|
755 |
+
- It's related to the binomial distribution as a limiting case when $n \to \infty$, $p \to 0$, and $np = \lambda$ remains constant
|
756 |
+
|
757 |
+
- The rate parameter scales linearly with the length of the interval - if events occur at rate $\lambda$ per unit time, then over $t$ units, the parameter becomes $\lambda t$
|
758 |
+
|
759 |
+
From modeling website traffic and customer arrivals to defects in manufacturing and radioactive decay, the Poisson distribution provides a powerful and mathematically elegant way to understand random occurrences in our world.
|
760 |
+
"""
|
761 |
+
)
|
762 |
+
return
|
763 |
+
|
764 |
+
|
765 |
+
@app.cell(hide_code=True)
|
766 |
+
def _(mo):
|
767 |
+
mo.md(r"""Appendix code (helper functions, variables, etc.):""")
|
768 |
+
return
|
769 |
+
|
770 |
+
|
771 |
+
@app.cell
|
772 |
+
def _():
|
773 |
+
import marimo as mo
|
774 |
+
return (mo,)
|
775 |
+
|
776 |
+
|
777 |
+
@app.cell(hide_code=True)
|
778 |
+
def _():
|
779 |
+
import numpy as np
|
780 |
+
import matplotlib.pyplot as plt
|
781 |
+
import scipy.stats as stats
|
782 |
+
import pandas as pd
|
783 |
+
import altair as alt
|
784 |
+
from wigglystuff import TangleSlider
|
785 |
+
return TangleSlider, alt, np, pd, plt, stats
|
786 |
+
|
787 |
+
|
788 |
+
@app.cell(hide_code=True)
|
789 |
+
def _():
|
790 |
+
import io
|
791 |
+
import base64
|
792 |
+
from matplotlib.figure import Figure
|
793 |
+
|
794 |
+
# Helper function to convert mpl figure to an image format mo.image can hopefully handle
|
795 |
+
def fig_to_image(fig):
|
796 |
+
buf = io.BytesIO()
|
797 |
+
fig.savefig(buf, format='png')
|
798 |
+
buf.seek(0)
|
799 |
+
data = f"data:image/png;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
|
800 |
+
return data
|
801 |
+
return Figure, base64, fig_to_image, io
|
802 |
+
|
803 |
+
|
804 |
+
if __name__ == "__main__":
|
805 |
+
app.run()
|
python/006_dictionaries.py
CHANGED
@@ -196,13 +196,13 @@ def _():
|
|
196 |
|
197 |
@app.cell
|
198 |
def _(mo, nested_data):
|
199 |
-
mo.md(f"Alice's age: {nested_data[
|
200 |
return
|
201 |
|
202 |
|
203 |
@app.cell
|
204 |
def _(mo, nested_data):
|
205 |
-
mo.md(f"Bob's interests: {nested_data[
|
206 |
return
|
207 |
|
208 |
|
|
|
196 |
|
197 |
@app.cell
|
198 |
def _(mo, nested_data):
|
199 |
+
mo.md(f"Alice's age: {nested_data['users']['alice']['age']}")
|
200 |
return
|
201 |
|
202 |
|
203 |
@app.cell
|
204 |
def _(mo, nested_data):
|
205 |
+
mo.md(f"Bob's interests: {nested_data['users']['bob']['interests']}")
|
206 |
return
|
207 |
|
208 |
|
scripts/build.py
ADDED
@@ -0,0 +1,1523 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import argparse
|
6 |
+
import json
|
7 |
+
from typing import List, Dict, Any
|
8 |
+
from pathlib import Path
|
9 |
+
import re
|
10 |
+
|
11 |
+
|
12 |
+
def export_html_wasm(notebook_path: str, output_dir: str, as_app: bool = False) -> bool:
|
13 |
+
"""Export a single marimo notebook to HTML format.
|
14 |
+
|
15 |
+
Returns:
|
16 |
+
bool: True if export succeeded, False otherwise
|
17 |
+
"""
|
18 |
+
output_path = notebook_path.replace(".py", ".html")
|
19 |
+
|
20 |
+
cmd = ["marimo", "export", "html-wasm"]
|
21 |
+
if as_app:
|
22 |
+
print(f"Exporting {notebook_path} to {output_path} as app")
|
23 |
+
cmd.extend(["--mode", "run", "--no-show-code"])
|
24 |
+
else:
|
25 |
+
print(f"Exporting {notebook_path} to {output_path} as notebook")
|
26 |
+
cmd.extend(["--mode", "edit"])
|
27 |
+
|
28 |
+
try:
|
29 |
+
output_file = os.path.join(output_dir, output_path)
|
30 |
+
os.makedirs(os.path.dirname(output_file), exist_ok=True)
|
31 |
+
|
32 |
+
cmd.extend([notebook_path, "-o", output_file])
|
33 |
+
print(f"Running command: {' '.join(cmd)}")
|
34 |
+
|
35 |
+
# Use Popen to handle interactive prompts
|
36 |
+
process = subprocess.Popen(
|
37 |
+
cmd,
|
38 |
+
stdin=subprocess.PIPE,
|
39 |
+
stdout=subprocess.PIPE,
|
40 |
+
stderr=subprocess.PIPE,
|
41 |
+
text=True
|
42 |
+
)
|
43 |
+
|
44 |
+
# Send 'Y' to the prompt
|
45 |
+
stdout, stderr = process.communicate(input="Y\n", timeout=60)
|
46 |
+
|
47 |
+
if process.returncode != 0:
|
48 |
+
print(f"Error exporting {notebook_path}:")
|
49 |
+
print(f"Command: {' '.join(cmd)}")
|
50 |
+
print(f"Return code: {process.returncode}")
|
51 |
+
print(f"Stdout: {stdout}")
|
52 |
+
print(f"Stderr: {stderr}")
|
53 |
+
return False
|
54 |
+
|
55 |
+
print(f"Successfully exported {notebook_path} to {output_file}")
|
56 |
+
return True
|
57 |
+
except subprocess.TimeoutExpired:
|
58 |
+
print(f"Timeout exporting {notebook_path} - command took too long to execute")
|
59 |
+
return False
|
60 |
+
except subprocess.CalledProcessError as e:
|
61 |
+
print(f"Error exporting {notebook_path}:")
|
62 |
+
print(e.stderr)
|
63 |
+
return False
|
64 |
+
except Exception as e:
|
65 |
+
print(f"Unexpected error exporting {notebook_path}: {e}")
|
66 |
+
return False
|
67 |
+
|
68 |
+
|
69 |
+
def get_course_metadata(course_dir: Path) -> Dict[str, Any]:
|
70 |
+
"""Extract metadata from a course directory."""
|
71 |
+
metadata = {
|
72 |
+
"id": course_dir.name,
|
73 |
+
"title": course_dir.name.replace("_", " ").title(),
|
74 |
+
"description": "",
|
75 |
+
"notebooks": []
|
76 |
+
}
|
77 |
+
|
78 |
+
# Try to read README.md for description
|
79 |
+
readme_path = course_dir / "README.md"
|
80 |
+
if readme_path.exists():
|
81 |
+
with open(readme_path, "r", encoding="utf-8") as f:
|
82 |
+
content = f.read()
|
83 |
+
# Extract first paragraph as description
|
84 |
+
if content:
|
85 |
+
lines = content.split("\n")
|
86 |
+
# Skip title line if it exists
|
87 |
+
start_idx = 1 if lines and lines[0].startswith("#") else 0
|
88 |
+
description_lines = []
|
89 |
+
for line in lines[start_idx:]:
|
90 |
+
if line.strip() and not line.startswith("#"):
|
91 |
+
description_lines.append(line)
|
92 |
+
elif description_lines: # Stop at the next heading
|
93 |
+
break
|
94 |
+
description = " ".join(description_lines).strip()
|
95 |
+
# Clean up the description
|
96 |
+
description = description.replace("_", "")
|
97 |
+
description = description.replace("[work in progress]", "")
|
98 |
+
description = description.replace("(https://github.com/marimo-team/learn/issues/51)", "")
|
99 |
+
# Remove any other GitHub issue links
|
100 |
+
description = re.sub(r'\[.*?\]\(https://github\.com/.*?/issues/\d+\)', '', description)
|
101 |
+
description = re.sub(r'https://github\.com/.*?/issues/\d+', '', description)
|
102 |
+
# Clean up any double spaces
|
103 |
+
description = re.sub(r'\s+', ' ', description).strip()
|
104 |
+
metadata["description"] = description
|
105 |
+
|
106 |
+
return metadata
|
107 |
+
|
108 |
+
|
109 |
+
def organize_notebooks_by_course(all_notebooks: List[str]) -> Dict[str, Dict[str, Any]]:
|
110 |
+
"""Organize notebooks by course."""
|
111 |
+
courses = {}
|
112 |
+
|
113 |
+
for notebook_path in all_notebooks:
|
114 |
+
path = Path(notebook_path)
|
115 |
+
course_id = path.parts[0]
|
116 |
+
|
117 |
+
if course_id not in courses:
|
118 |
+
course_dir = Path(course_id)
|
119 |
+
courses[course_id] = get_course_metadata(course_dir)
|
120 |
+
|
121 |
+
# Extract notebook info
|
122 |
+
filename = path.name
|
123 |
+
notebook_id = path.stem
|
124 |
+
|
125 |
+
# Try to extract order from filename (e.g., 001_numbers.py -> 1)
|
126 |
+
order = 999
|
127 |
+
if "_" in notebook_id:
|
128 |
+
try:
|
129 |
+
order_str = notebook_id.split("_")[0]
|
130 |
+
order = int(order_str)
|
131 |
+
except ValueError:
|
132 |
+
pass
|
133 |
+
|
134 |
+
# Create display name by removing order prefix and underscores
|
135 |
+
display_name = notebook_id
|
136 |
+
if "_" in notebook_id:
|
137 |
+
display_name = "_".join(notebook_id.split("_")[1:])
|
138 |
+
|
139 |
+
# Convert display name to title case, but handle italics properly
|
140 |
+
parts = display_name.split("_")
|
141 |
+
formatted_parts = []
|
142 |
+
|
143 |
+
i = 0
|
144 |
+
while i < len(parts):
|
145 |
+
if i + 1 < len(parts) and parts[i] == "" and parts[i+1] == "":
|
146 |
+
# Skip empty parts that might come from consecutive underscores
|
147 |
+
i += 2
|
148 |
+
continue
|
149 |
+
|
150 |
+
if i + 1 < len(parts) and (parts[i] == "" or parts[i+1] == ""):
|
151 |
+
# This is an italics marker
|
152 |
+
if parts[i] == "":
|
153 |
+
# Opening italics
|
154 |
+
text_part = parts[i+1].replace("_", " ").title()
|
155 |
+
formatted_parts.append(f"<em>{text_part}</em>")
|
156 |
+
i += 2
|
157 |
+
else:
|
158 |
+
# Text followed by italics marker
|
159 |
+
text_part = parts[i].replace("_", " ").title()
|
160 |
+
formatted_parts.append(text_part)
|
161 |
+
i += 1
|
162 |
+
else:
|
163 |
+
# Regular text
|
164 |
+
text_part = parts[i].replace("_", " ").title()
|
165 |
+
formatted_parts.append(text_part)
|
166 |
+
i += 1
|
167 |
+
|
168 |
+
display_name = " ".join(formatted_parts)
|
169 |
+
|
170 |
+
courses[course_id]["notebooks"].append({
|
171 |
+
"id": notebook_id,
|
172 |
+
"path": notebook_path,
|
173 |
+
"display_name": display_name,
|
174 |
+
"order": order,
|
175 |
+
"original_number": notebook_id.split("_")[0] if "_" in notebook_id else ""
|
176 |
+
})
|
177 |
+
|
178 |
+
# Sort notebooks by order
|
179 |
+
for course_id in courses:
|
180 |
+
courses[course_id]["notebooks"].sort(key=lambda x: x["order"])
|
181 |
+
|
182 |
+
return courses
|
183 |
+
|
184 |
+
|
185 |
+
def generate_eva_css() -> str:
|
186 |
+
"""Generate Neon Genesis Evangelion inspired CSS with light/dark mode support."""
|
187 |
+
return """
|
188 |
+
:root {
|
189 |
+
/* Light mode colors (default) */
|
190 |
+
--eva-purple: #7209b7;
|
191 |
+
--eva-green: #1c7361;
|
192 |
+
--eva-orange: #e65100;
|
193 |
+
--eva-blue: #0039cb;
|
194 |
+
--eva-red: #c62828;
|
195 |
+
--eva-black: #f5f5f5;
|
196 |
+
--eva-dark: #e0e0e0;
|
197 |
+
--eva-terminal-bg: rgba(255, 255, 255, 0.9);
|
198 |
+
--eva-text: #333333;
|
199 |
+
--eva-border-radius: 4px;
|
200 |
+
--eva-transition: all 0.3s ease;
|
201 |
+
}
|
202 |
+
|
203 |
+
/* Dark mode colors */
|
204 |
+
[data-theme="dark"] {
|
205 |
+
--eva-purple: #9a1eb3;
|
206 |
+
--eva-green: #1c7361;
|
207 |
+
--eva-orange: #ff6600;
|
208 |
+
--eva-blue: #0066ff;
|
209 |
+
--eva-red: #ff0000;
|
210 |
+
--eva-black: #111111;
|
211 |
+
--eva-dark: #222222;
|
212 |
+
--eva-terminal-bg: rgba(0, 0, 0, 0.85);
|
213 |
+
--eva-text: #e0e0e0;
|
214 |
+
}
|
215 |
+
|
216 |
+
body {
|
217 |
+
background-color: var(--eva-black);
|
218 |
+
color: var(--eva-text);
|
219 |
+
font-family: 'Courier New', monospace;
|
220 |
+
margin: 0;
|
221 |
+
padding: 0;
|
222 |
+
line-height: 1.6;
|
223 |
+
transition: background-color 0.3s ease, color 0.3s ease;
|
224 |
+
}
|
225 |
+
|
226 |
+
.eva-container {
|
227 |
+
max-width: 1200px;
|
228 |
+
margin: 0 auto;
|
229 |
+
padding: 2rem;
|
230 |
+
}
|
231 |
+
|
232 |
+
.eva-header {
|
233 |
+
border-bottom: 2px solid var(--eva-green);
|
234 |
+
padding-bottom: 1rem;
|
235 |
+
margin-bottom: 2rem;
|
236 |
+
display: flex;
|
237 |
+
justify-content: space-between;
|
238 |
+
align-items: center;
|
239 |
+
position: sticky;
|
240 |
+
top: 0;
|
241 |
+
background-color: var(--eva-black);
|
242 |
+
z-index: 100;
|
243 |
+
backdrop-filter: blur(5px);
|
244 |
+
padding-top: 1rem;
|
245 |
+
transition: background-color 0.3s ease;
|
246 |
+
}
|
247 |
+
|
248 |
+
[data-theme="light"] .eva-header {
|
249 |
+
background-color: rgba(245, 245, 245, 0.95);
|
250 |
+
}
|
251 |
+
|
252 |
+
.eva-logo {
|
253 |
+
font-size: 2.5rem;
|
254 |
+
font-weight: bold;
|
255 |
+
color: var(--eva-green);
|
256 |
+
text-transform: uppercase;
|
257 |
+
letter-spacing: 2px;
|
258 |
+
text-shadow: 0 0 10px rgba(28, 115, 97, 0.5);
|
259 |
+
}
|
260 |
+
|
261 |
+
[data-theme="light"] .eva-logo {
|
262 |
+
text-shadow: 0 0 10px rgba(28, 115, 97, 0.3);
|
263 |
+
}
|
264 |
+
|
265 |
+
.eva-nav {
|
266 |
+
display: flex;
|
267 |
+
gap: 1.5rem;
|
268 |
+
align-items: center;
|
269 |
+
}
|
270 |
+
|
271 |
+
.eva-nav a {
|
272 |
+
color: var(--eva-text);
|
273 |
+
text-decoration: none;
|
274 |
+
text-transform: uppercase;
|
275 |
+
font-size: 0.9rem;
|
276 |
+
letter-spacing: 1px;
|
277 |
+
transition: color 0.3s;
|
278 |
+
position: relative;
|
279 |
+
padding: 0.5rem 0;
|
280 |
+
}
|
281 |
+
|
282 |
+
.eva-nav a:hover {
|
283 |
+
color: var(--eva-green);
|
284 |
+
}
|
285 |
+
|
286 |
+
.eva-nav a:hover::after {
|
287 |
+
content: '';
|
288 |
+
position: absolute;
|
289 |
+
bottom: -5px;
|
290 |
+
left: 0;
|
291 |
+
width: 100%;
|
292 |
+
height: 2px;
|
293 |
+
background-color: var(--eva-green);
|
294 |
+
animation: scanline 1.5s linear infinite;
|
295 |
+
}
|
296 |
+
|
297 |
+
.theme-toggle {
|
298 |
+
background: none;
|
299 |
+
border: none;
|
300 |
+
color: var(--eva-text);
|
301 |
+
cursor: pointer;
|
302 |
+
font-size: 1.2rem;
|
303 |
+
padding: 0.5rem;
|
304 |
+
margin-left: 1rem;
|
305 |
+
transition: color 0.3s;
|
306 |
+
}
|
307 |
+
|
308 |
+
.theme-toggle:hover {
|
309 |
+
color: var(--eva-green);
|
310 |
+
}
|
311 |
+
|
312 |
+
.eva-hero {
|
313 |
+
background-color: var(--eva-terminal-bg);
|
314 |
+
border: 1px solid var(--eva-green);
|
315 |
+
padding: 3rem 2rem;
|
316 |
+
margin-bottom: 3rem;
|
317 |
+
position: relative;
|
318 |
+
overflow: hidden;
|
319 |
+
border-radius: var(--eva-border-radius);
|
320 |
+
display: flex;
|
321 |
+
flex-direction: column;
|
322 |
+
align-items: flex-start;
|
323 |
+
background-image: linear-gradient(45deg, rgba(0, 0, 0, 0.9), rgba(0, 0, 0, 0.7)), url('https://raw.githubusercontent.com/marimo-team/marimo/main/docs/_static/marimo-logotype-thick.svg');
|
324 |
+
background-size: cover;
|
325 |
+
background-position: center;
|
326 |
+
background-blend-mode: overlay;
|
327 |
+
transition: background-color 0.3s ease, border-color 0.3s ease;
|
328 |
+
}
|
329 |
+
|
330 |
+
[data-theme="light"] .eva-hero {
|
331 |
+
background-image: linear-gradient(45deg, rgba(255, 255, 255, 0.9), rgba(255, 255, 255, 0.7)), url('https://raw.githubusercontent.com/marimo-team/marimo/main/docs/_static/marimo-logotype-thick.svg');
|
332 |
+
}
|
333 |
+
|
334 |
+
.eva-hero::before {
|
335 |
+
content: '';
|
336 |
+
position: absolute;
|
337 |
+
top: 0;
|
338 |
+
left: 0;
|
339 |
+
width: 100%;
|
340 |
+
height: 2px;
|
341 |
+
background-color: var(--eva-green);
|
342 |
+
animation: scanline 3s linear infinite;
|
343 |
+
}
|
344 |
+
|
345 |
+
.eva-hero h1 {
|
346 |
+
font-size: 2.5rem;
|
347 |
+
margin-bottom: 1rem;
|
348 |
+
color: var(--eva-green);
|
349 |
+
text-transform: uppercase;
|
350 |
+
letter-spacing: 2px;
|
351 |
+
text-shadow: 0 0 10px rgba(28, 115, 97, 0.5);
|
352 |
+
}
|
353 |
+
|
354 |
+
[data-theme="light"] .eva-hero h1 {
|
355 |
+
text-shadow: 0 0 10px rgba(28, 115, 97, 0.3);
|
356 |
+
}
|
357 |
+
|
358 |
+
.eva-hero p {
|
359 |
+
font-size: 1.1rem;
|
360 |
+
max-width: 800px;
|
361 |
+
margin-bottom: 2rem;
|
362 |
+
line-height: 1.8;
|
363 |
+
}
|
364 |
+
|
365 |
+
.eva-features {
|
366 |
+
display: grid;
|
367 |
+
grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
|
368 |
+
gap: 2rem;
|
369 |
+
margin-bottom: 3rem;
|
370 |
+
}
|
371 |
+
|
372 |
+
.eva-feature {
|
373 |
+
background-color: var(--eva-terminal-bg);
|
374 |
+
border: 1px solid var(--eva-blue);
|
375 |
+
padding: 1.5rem;
|
376 |
+
border-radius: var(--eva-border-radius);
|
377 |
+
transition: var(--eva-transition);
|
378 |
+
position: relative;
|
379 |
+
overflow: hidden;
|
380 |
+
}
|
381 |
+
|
382 |
+
.eva-feature:hover {
|
383 |
+
transform: translateY(-5px);
|
384 |
+
box-shadow: 0 10px 20px rgba(0, 102, 255, 0.2);
|
385 |
+
}
|
386 |
+
|
387 |
+
.eva-feature-icon {
|
388 |
+
font-size: 2rem;
|
389 |
+
margin-bottom: 1rem;
|
390 |
+
color: var(--eva-blue);
|
391 |
+
}
|
392 |
+
|
393 |
+
.eva-feature h3 {
|
394 |
+
font-size: 1.3rem;
|
395 |
+
margin-bottom: 1rem;
|
396 |
+
color: var(--eva-blue);
|
397 |
+
}
|
398 |
+
|
399 |
+
.eva-section-title {
|
400 |
+
font-size: 2rem;
|
401 |
+
color: var(--eva-green);
|
402 |
+
margin-bottom: 2rem;
|
403 |
+
text-transform: uppercase;
|
404 |
+
letter-spacing: 2px;
|
405 |
+
text-align: center;
|
406 |
+
position: relative;
|
407 |
+
padding-bottom: 1rem;
|
408 |
+
}
|
409 |
+
|
410 |
+
.eva-section-title::after {
|
411 |
+
content: '';
|
412 |
+
position: absolute;
|
413 |
+
bottom: 0;
|
414 |
+
left: 50%;
|
415 |
+
transform: translateX(-50%);
|
416 |
+
width: 100px;
|
417 |
+
height: 2px;
|
418 |
+
background-color: var(--eva-green);
|
419 |
+
}
|
420 |
+
|
421 |
+
/* Flashcard view for courses */
|
422 |
+
.eva-courses {
|
423 |
+
display: grid;
|
424 |
+
grid-template-columns: repeat(auto-fill, minmax(350px, 1fr));
|
425 |
+
gap: 2rem;
|
426 |
+
}
|
427 |
+
|
428 |
+
.eva-course {
|
429 |
+
background-color: var(--eva-terminal-bg);
|
430 |
+
border: 1px solid var(--eva-purple);
|
431 |
+
border-radius: var(--eva-border-radius);
|
432 |
+
transition: var(--eva-transition), height 0.4s cubic-bezier(0.19, 1, 0.22, 1);
|
433 |
+
position: relative;
|
434 |
+
overflow: hidden;
|
435 |
+
height: 350px;
|
436 |
+
display: flex;
|
437 |
+
flex-direction: column;
|
438 |
+
}
|
439 |
+
|
440 |
+
.eva-course:hover {
|
441 |
+
transform: translateY(-5px);
|
442 |
+
box-shadow: 0 10px 20px rgba(154, 30, 179, 0.3);
|
443 |
+
}
|
444 |
+
|
445 |
+
.eva-course::after {
|
446 |
+
content: '';
|
447 |
+
position: absolute;
|
448 |
+
bottom: 0;
|
449 |
+
left: 0;
|
450 |
+
width: 100%;
|
451 |
+
height: 2px;
|
452 |
+
background-color: var(--eva-purple);
|
453 |
+
animation: scanline 2s linear infinite;
|
454 |
+
}
|
455 |
+
|
456 |
+
.eva-course-badge {
|
457 |
+
position: absolute;
|
458 |
+
top: 15px;
|
459 |
+
right: -40px;
|
460 |
+
background: linear-gradient(135deg, var(--eva-orange) 0%, #ff9500 100%);
|
461 |
+
color: var(--eva-black);
|
462 |
+
font-size: 0.65rem;
|
463 |
+
padding: 0.3rem 2.5rem;
|
464 |
+
text-transform: uppercase;
|
465 |
+
font-weight: bold;
|
466 |
+
z-index: 3;
|
467 |
+
letter-spacing: 1px;
|
468 |
+
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.3);
|
469 |
+
transform: rotate(45deg);
|
470 |
+
text-shadow: 0 1px 1px rgba(255, 255, 255, 0.2);
|
471 |
+
border-top: 1px solid rgba(255, 255, 255, 0.3);
|
472 |
+
border-bottom: 1px solid rgba(0, 0, 0, 0.2);
|
473 |
+
white-space: nowrap;
|
474 |
+
overflow: hidden;
|
475 |
+
}
|
476 |
+
|
477 |
+
.eva-course-badge i {
|
478 |
+
margin-right: 4px;
|
479 |
+
font-size: 0.7rem;
|
480 |
+
}
|
481 |
+
|
482 |
+
[data-theme="light"] .eva-course-badge {
|
483 |
+
box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);
|
484 |
+
text-shadow: 0 1px 1px rgba(255, 255, 255, 0.4);
|
485 |
+
}
|
486 |
+
|
487 |
+
.eva-course-badge::before {
|
488 |
+
content: '';
|
489 |
+
position: absolute;
|
490 |
+
left: 0;
|
491 |
+
top: 0;
|
492 |
+
width: 100%;
|
493 |
+
height: 100%;
|
494 |
+
background: linear-gradient(to right, transparent, rgba(255, 255, 255, 0.3), transparent);
|
495 |
+
animation: scanline 2s linear infinite;
|
496 |
+
}
|
497 |
+
|
498 |
+
.eva-course-header {
|
499 |
+
padding: 1rem 1.5rem;
|
500 |
+
cursor: pointer;
|
501 |
+
display: flex;
|
502 |
+
justify-content: space-between;
|
503 |
+
align-items: center;
|
504 |
+
border-bottom: 1px solid rgba(154, 30, 179, 0.3);
|
505 |
+
z-index: 2;
|
506 |
+
background-color: var(--eva-terminal-bg);
|
507 |
+
position: absolute;
|
508 |
+
top: 0;
|
509 |
+
left: 0;
|
510 |
+
width: 100%;
|
511 |
+
height: 3.5rem;
|
512 |
+
box-sizing: border-box;
|
513 |
+
}
|
514 |
+
|
515 |
+
.eva-course-title {
|
516 |
+
font-size: 1.3rem;
|
517 |
+
color: var(--eva-purple);
|
518 |
+
text-transform: uppercase;
|
519 |
+
letter-spacing: 1px;
|
520 |
+
margin: 0;
|
521 |
+
}
|
522 |
+
|
523 |
+
.eva-course-toggle {
|
524 |
+
color: var(--eva-purple);
|
525 |
+
font-size: 1.5rem;
|
526 |
+
transition: transform 0.4s cubic-bezier(0.19, 1, 0.22, 1);
|
527 |
+
}
|
528 |
+
|
529 |
+
.eva-course.active .eva-course-toggle {
|
530 |
+
transform: rotate(180deg);
|
531 |
+
}
|
532 |
+
|
533 |
+
.eva-course-front {
|
534 |
+
display: flex;
|
535 |
+
flex-direction: column;
|
536 |
+
justify-content: space-between;
|
537 |
+
padding: 1.5rem;
|
538 |
+
margin-top: 3.5rem;
|
539 |
+
transition: opacity 0.3s ease, transform 0.3s ease;
|
540 |
+
position: absolute;
|
541 |
+
top: 0;
|
542 |
+
left: 0;
|
543 |
+
width: 100%;
|
544 |
+
height: calc(100% - 3.5rem);
|
545 |
+
background-color: var(--eva-terminal-bg);
|
546 |
+
z-index: 1;
|
547 |
+
box-sizing: border-box;
|
548 |
+
}
|
549 |
+
|
550 |
+
.eva-course.active .eva-course-front {
|
551 |
+
opacity: 0;
|
552 |
+
transform: translateY(-10px);
|
553 |
+
pointer-events: none;
|
554 |
+
}
|
555 |
+
|
556 |
+
.eva-course-description {
|
557 |
+
margin-top: 0.5rem;
|
558 |
+
margin-bottom: 1.5rem;
|
559 |
+
font-size: 0.9rem;
|
560 |
+
line-height: 1.6;
|
561 |
+
flex-grow: 1;
|
562 |
+
overflow: hidden;
|
563 |
+
display: -webkit-box;
|
564 |
+
-webkit-line-clamp: 4;
|
565 |
+
-webkit-box-orient: vertical;
|
566 |
+
max-height: 150px;
|
567 |
+
}
|
568 |
+
|
569 |
+
.eva-course-stats {
|
570 |
+
display: flex;
|
571 |
+
justify-content: space-between;
|
572 |
+
font-size: 0.8rem;
|
573 |
+
color: var(--eva-text);
|
574 |
+
opacity: 0.7;
|
575 |
+
}
|
576 |
+
|
577 |
+
.eva-course-content {
|
578 |
+
position: absolute;
|
579 |
+
top: 3.5rem;
|
580 |
+
left: 0;
|
581 |
+
width: 100%;
|
582 |
+
height: calc(100% - 3.5rem);
|
583 |
+
padding: 1.5rem;
|
584 |
+
background-color: var(--eva-terminal-bg);
|
585 |
+
transition: opacity 0.3s ease, transform 0.3s ease;
|
586 |
+
opacity: 0;
|
587 |
+
transform: translateY(10px);
|
588 |
+
pointer-events: none;
|
589 |
+
overflow-y: auto;
|
590 |
+
z-index: 1;
|
591 |
+
box-sizing: border-box;
|
592 |
+
}
|
593 |
+
|
594 |
+
.eva-course.active .eva-course-content {
|
595 |
+
opacity: 1;
|
596 |
+
transform: translateY(0);
|
597 |
+
pointer-events: auto;
|
598 |
+
}
|
599 |
+
|
600 |
+
.eva-course.active {
|
601 |
+
height: auto;
|
602 |
+
min-height: 350px;
|
603 |
+
max-height: 800px;
|
604 |
+
transition: height 0.4s cubic-bezier(0.19, 1, 0.22, 1), transform 0.3s ease, box-shadow 0.3s ease;
|
605 |
+
}
|
606 |
+
|
607 |
+
.eva-notebooks {
|
608 |
+
margin-top: 1rem;
|
609 |
+
display: grid;
|
610 |
+
grid-template-columns: repeat(auto-fill, minmax(250px, 1fr));
|
611 |
+
gap: 0.75rem;
|
612 |
+
}
|
613 |
+
|
614 |
+
.eva-notebook {
|
615 |
+
margin-bottom: 0.5rem;
|
616 |
+
padding: 0.75rem;
|
617 |
+
border-left: 2px solid var(--eva-blue);
|
618 |
+
transition: all 0.25s ease;
|
619 |
+
display: flex;
|
620 |
+
align-items: center;
|
621 |
+
background-color: rgba(0, 0, 0, 0.2);
|
622 |
+
border-radius: 0 var(--eva-border-radius) var(--eva-border-radius) 0;
|
623 |
+
opacity: 1;
|
624 |
+
transform: translateX(0);
|
625 |
+
}
|
626 |
+
|
627 |
+
[data-theme="light"] .eva-notebook {
|
628 |
+
background-color: rgba(0, 0, 0, 0.05);
|
629 |
+
}
|
630 |
+
|
631 |
+
.eva-notebook:hover {
|
632 |
+
background-color: rgba(0, 102, 255, 0.1);
|
633 |
+
padding-left: 1rem;
|
634 |
+
transform: translateX(3px);
|
635 |
+
}
|
636 |
+
|
637 |
+
.eva-notebook a {
|
638 |
+
color: var(--eva-text);
|
639 |
+
text-decoration: none;
|
640 |
+
display: block;
|
641 |
+
font-size: 0.9rem;
|
642 |
+
flex-grow: 1;
|
643 |
+
}
|
644 |
+
|
645 |
+
.eva-notebook a:hover {
|
646 |
+
color: var(--eva-blue);
|
647 |
+
}
|
648 |
+
|
649 |
+
.eva-notebook-number {
|
650 |
+
color: var(--eva-blue);
|
651 |
+
font-size: 0.8rem;
|
652 |
+
margin-right: 0.75rem;
|
653 |
+
opacity: 0.7;
|
654 |
+
min-width: 24px;
|
655 |
+
font-weight: bold;
|
656 |
+
}
|
657 |
+
|
658 |
+
.eva-button {
|
659 |
+
display: inline-block;
|
660 |
+
background-color: transparent;
|
661 |
+
color: var(--eva-green);
|
662 |
+
border: 1px solid var(--eva-green);
|
663 |
+
padding: 0.7rem 1.5rem;
|
664 |
+
text-decoration: none;
|
665 |
+
text-transform: uppercase;
|
666 |
+
font-size: 0.9rem;
|
667 |
+
letter-spacing: 1px;
|
668 |
+
transition: var(--eva-transition);
|
669 |
+
cursor: pointer;
|
670 |
+
border-radius: var(--eva-border-radius);
|
671 |
+
position: relative;
|
672 |
+
overflow: hidden;
|
673 |
+
}
|
674 |
+
|
675 |
+
.eva-button:hover {
|
676 |
+
background-color: var(--eva-green);
|
677 |
+
color: var(--eva-black);
|
678 |
+
}
|
679 |
+
|
680 |
+
.eva-button::after {
|
681 |
+
content: '';
|
682 |
+
position: absolute;
|
683 |
+
top: 0;
|
684 |
+
left: -100%;
|
685 |
+
width: 100%;
|
686 |
+
height: 100%;
|
687 |
+
background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent);
|
688 |
+
transition: 0.5s;
|
689 |
+
}
|
690 |
+
|
691 |
+
.eva-button:hover::after {
|
692 |
+
left: 100%;
|
693 |
+
}
|
694 |
+
|
695 |
+
.eva-course-button {
|
696 |
+
margin-top: 1rem;
|
697 |
+
margin-bottom: 1rem;
|
698 |
+
align-self: center;
|
699 |
+
}
|
700 |
+
|
701 |
+
.eva-cta {
|
702 |
+
background-color: var(--eva-terminal-bg);
|
703 |
+
border: 1px solid var(--eva-orange);
|
704 |
+
padding: 3rem 2rem;
|
705 |
+
margin: 4rem 0;
|
706 |
+
text-align: center;
|
707 |
+
border-radius: var(--eva-border-radius);
|
708 |
+
position: relative;
|
709 |
+
overflow: hidden;
|
710 |
+
}
|
711 |
+
|
712 |
+
.eva-cta h2 {
|
713 |
+
font-size: 2rem;
|
714 |
+
color: var(--eva-orange);
|
715 |
+
margin-bottom: 1.5rem;
|
716 |
+
text-transform: uppercase;
|
717 |
+
}
|
718 |
+
|
719 |
+
.eva-cta p {
|
720 |
+
max-width: 600px;
|
721 |
+
margin: 0 auto 2rem;
|
722 |
+
font-size: 1.1rem;
|
723 |
+
}
|
724 |
+
|
725 |
+
.eva-cta .eva-button {
|
726 |
+
color: var(--eva-orange);
|
727 |
+
border-color: var(--eva-orange);
|
728 |
+
}
|
729 |
+
|
730 |
+
.eva-cta .eva-button:hover {
|
731 |
+
background-color: var(--eva-orange);
|
732 |
+
color: var(--eva-black);
|
733 |
+
}
|
734 |
+
|
735 |
+
.eva-footer {
|
736 |
+
margin-top: 4rem;
|
737 |
+
padding-top: 2rem;
|
738 |
+
border-top: 2px solid var(--eva-green);
|
739 |
+
display: flex;
|
740 |
+
flex-direction: column;
|
741 |
+
align-items: center;
|
742 |
+
gap: 2rem;
|
743 |
+
}
|
744 |
+
|
745 |
+
.eva-footer-logo {
|
746 |
+
max-width: 200px;
|
747 |
+
margin-bottom: 1rem;
|
748 |
+
}
|
749 |
+
|
750 |
+
.eva-footer-links {
|
751 |
+
display: flex;
|
752 |
+
gap: 1.5rem;
|
753 |
+
margin-bottom: 1.5rem;
|
754 |
+
}
|
755 |
+
|
756 |
+
.eva-footer-links a {
|
757 |
+
color: var(--eva-text);
|
758 |
+
text-decoration: none;
|
759 |
+
transition: var(--eva-transition);
|
760 |
+
}
|
761 |
+
|
762 |
+
.eva-footer-links a:hover {
|
763 |
+
color: var(--eva-green);
|
764 |
+
}
|
765 |
+
|
766 |
+
.eva-social-links {
|
767 |
+
display: flex;
|
768 |
+
gap: 1.5rem;
|
769 |
+
margin-bottom: 1.5rem;
|
770 |
+
}
|
771 |
+
|
772 |
+
.eva-social-links a {
|
773 |
+
color: var(--eva-text);
|
774 |
+
font-size: 1.5rem;
|
775 |
+
transition: var(--eva-transition);
|
776 |
+
}
|
777 |
+
|
778 |
+
.eva-social-links a:hover {
|
779 |
+
color: var(--eva-green);
|
780 |
+
transform: translateY(-3px);
|
781 |
+
}
|
782 |
+
|
783 |
+
.eva-footer-copyright {
|
784 |
+
font-size: 0.9rem;
|
785 |
+
text-align: center;
|
786 |
+
}
|
787 |
+
|
788 |
+
.eva-search {
|
789 |
+
position: relative;
|
790 |
+
margin-bottom: 3rem;
|
791 |
+
}
|
792 |
+
|
793 |
+
.eva-search input {
|
794 |
+
width: 100%;
|
795 |
+
padding: 1rem;
|
796 |
+
background-color: var(--eva-terminal-bg);
|
797 |
+
border: 1px solid var(--eva-green);
|
798 |
+
color: var(--eva-text);
|
799 |
+
font-family: 'Courier New', monospace;
|
800 |
+
font-size: 1rem;
|
801 |
+
border-radius: var(--eva-border-radius);
|
802 |
+
outline: none;
|
803 |
+
transition: var(--eva-transition);
|
804 |
+
}
|
805 |
+
|
806 |
+
.eva-search input:focus {
|
807 |
+
box-shadow: 0 0 10px rgba(28, 115, 97, 0.3);
|
808 |
+
}
|
809 |
+
|
810 |
+
.eva-search input::placeholder {
|
811 |
+
color: rgba(224, 224, 224, 0.5);
|
812 |
+
}
|
813 |
+
|
814 |
+
[data-theme="light"] .eva-search input::placeholder {
|
815 |
+
color: rgba(51, 51, 51, 0.5);
|
816 |
+
}
|
817 |
+
|
818 |
+
.eva-search-icon {
|
819 |
+
position: absolute;
|
820 |
+
right: 1rem;
|
821 |
+
top: 50%;
|
822 |
+
transform: translateY(-50%);
|
823 |
+
color: var(--eva-green);
|
824 |
+
font-size: 1.2rem;
|
825 |
+
}
|
826 |
+
|
827 |
+
@keyframes scanline {
|
828 |
+
0% {
|
829 |
+
transform: translateX(-100%);
|
830 |
+
}
|
831 |
+
100% {
|
832 |
+
transform: translateX(100%);
|
833 |
+
}
|
834 |
+
}
|
835 |
+
|
836 |
+
@keyframes blink {
|
837 |
+
0%, 100% {
|
838 |
+
opacity: 1;
|
839 |
+
}
|
840 |
+
50% {
|
841 |
+
opacity: 0;
|
842 |
+
}
|
843 |
+
}
|
844 |
+
|
845 |
+
.eva-cursor {
|
846 |
+
display: inline-block;
|
847 |
+
width: 10px;
|
848 |
+
height: 1.2em;
|
849 |
+
background-color: var(--eva-green);
|
850 |
+
margin-left: 2px;
|
851 |
+
animation: blink 1s infinite;
|
852 |
+
vertical-align: middle;
|
853 |
+
}
|
854 |
+
|
855 |
+
@media (max-width: 768px) {
|
856 |
+
.eva-courses {
|
857 |
+
grid-template-columns: 1fr;
|
858 |
+
}
|
859 |
+
|
860 |
+
.eva-header {
|
861 |
+
flex-direction: column;
|
862 |
+
align-items: flex-start;
|
863 |
+
padding: 1rem;
|
864 |
+
}
|
865 |
+
|
866 |
+
.eva-nav {
|
867 |
+
margin-top: 1rem;
|
868 |
+
flex-wrap: wrap;
|
869 |
+
}
|
870 |
+
|
871 |
+
.eva-hero {
|
872 |
+
padding: 2rem 1rem;
|
873 |
+
}
|
874 |
+
|
875 |
+
.eva-hero h1 {
|
876 |
+
font-size: 2rem;
|
877 |
+
}
|
878 |
+
|
879 |
+
.eva-features {
|
880 |
+
grid-template-columns: 1fr;
|
881 |
+
}
|
882 |
+
|
883 |
+
.eva-footer {
|
884 |
+
flex-direction: column;
|
885 |
+
align-items: center;
|
886 |
+
text-align: center;
|
887 |
+
}
|
888 |
+
|
889 |
+
.eva-notebooks {
|
890 |
+
grid-template-columns: 1fr;
|
891 |
+
}
|
892 |
+
}
|
893 |
+
|
894 |
+
.eva-course.closing .eva-course-content {
|
895 |
+
opacity: 0;
|
896 |
+
transform: translateY(10px);
|
897 |
+
transition: opacity 0.2s ease, transform 0.2s ease;
|
898 |
+
}
|
899 |
+
|
900 |
+
.eva-course.closing .eva-course-front {
|
901 |
+
opacity: 1;
|
902 |
+
transform: translateY(0);
|
903 |
+
transition: opacity 0.3s ease 0.1s, transform 0.3s ease 0.1s;
|
904 |
+
}
|
905 |
+
"""
|
906 |
+
|
907 |
+
|
908 |
+
def get_html_header():
|
909 |
+
"""Generate the HTML header with CSS and meta tags."""
|
910 |
+
return """<!DOCTYPE html>
|
911 |
+
<html lang="en" data-theme="light">
|
912 |
+
<head>
|
913 |
+
<meta charset="UTF-8">
|
914 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
915 |
+
<title>Marimo Learn - Interactive Educational Notebooks</title>
|
916 |
+
<link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
|
917 |
+
<style>
|
918 |
+
{css}
|
919 |
+
</style>
|
920 |
+
</head>
|
921 |
+
<body>
|
922 |
+
<div class="eva-container">
|
923 |
+
<header class="eva-header">
|
924 |
+
<div class="eva-logo">MARIMO LEARN</div>
|
925 |
+
<nav class="eva-nav">
|
926 |
+
<a href="#features">Features</a>
|
927 |
+
<a href="#courses">Courses</a>
|
928 |
+
<a href="#contribute">Contribute</a>
|
929 |
+
<a href="https://docs.marimo.io" target="_blank">Documentation</a>
|
930 |
+
<a href="https://github.com/marimo-team/learn" target="_blank">GitHub</a>
|
931 |
+
<button id="themeToggle" class="theme-toggle" aria-label="Toggle dark/light mode">
|
932 |
+
<i class="fas fa-moon"></i>
|
933 |
+
</button>
|
934 |
+
</nav>
|
935 |
+
</header>"""
|
936 |
+
|
937 |
+
|
938 |
+
def get_html_hero_section():
|
939 |
+
"""Generate the hero section of the page."""
|
940 |
+
return """
|
941 |
+
<section class="eva-hero">
|
942 |
+
<h1>Interactive Learning with Marimo<span class="eva-cursor"></span></h1>
|
943 |
+
<p>
|
944 |
+
A curated collection of educational notebooks covering computer science,
|
945 |
+
mathematics, data science, and more. Built with marimo - the reactive
|
946 |
+
Python notebook that makes data exploration delightful.
|
947 |
+
</p>
|
948 |
+
<a href="#courses" class="eva-button">Explore Courses</a>
|
949 |
+
</section>"""
|
950 |
+
|
951 |
+
|
952 |
+
def get_html_features_section():
|
953 |
+
"""Generate the features section of the page."""
|
954 |
+
return """
|
955 |
+
<section id="features">
|
956 |
+
<h2 class="eva-section-title">Why Marimo Learn?</h2>
|
957 |
+
<div class="eva-features">
|
958 |
+
<div class="eva-feature">
|
959 |
+
<div class="eva-feature-icon"><i class="fas fa-bolt"></i></div>
|
960 |
+
<h3>Reactive Notebooks</h3>
|
961 |
+
<p>Experience the power of reactive programming with marimo notebooks that automatically update when dependencies change.</p>
|
962 |
+
</div>
|
963 |
+
<div class="eva-feature">
|
964 |
+
<div class="eva-feature-icon"><i class="fas fa-code"></i></div>
|
965 |
+
<h3>Learn by Doing</h3>
|
966 |
+
<p>Interactive examples and exercises help you understand concepts through hands-on practice.</p>
|
967 |
+
</div>
|
968 |
+
<div class="eva-feature">
|
969 |
+
<div class="eva-feature-icon"><i class="fas fa-graduation-cap"></i></div>
|
970 |
+
<h3>Comprehensive Courses</h3>
|
971 |
+
<p>From Python basics to advanced optimization techniques, our courses cover a wide range of topics.</p>
|
972 |
+
</div>
|
973 |
+
</div>
|
974 |
+
</section>"""
|
975 |
+
|
976 |
+
|
977 |
+
def get_html_courses_start():
|
978 |
+
"""Generate the beginning of the courses section."""
|
979 |
+
return """
|
980 |
+
<section id="courses">
|
981 |
+
<h2 class="eva-section-title">Explore Courses</h2>
|
982 |
+
<div class="eva-search">
|
983 |
+
<input type="text" id="courseSearch" placeholder="Search courses and notebooks...">
|
984 |
+
<span class="eva-search-icon"><i class="fas fa-search"></i></span>
|
985 |
+
</div>
|
986 |
+
<div class="eva-courses">"""
|
987 |
+
|
988 |
+
|
989 |
+
def generate_course_card(course, notebook_count, is_wip):
|
990 |
+
"""Generate HTML for a single course card."""
|
991 |
+
html = f'<div class="eva-course" data-course-id="{course["id"]}">\n'
|
992 |
+
|
993 |
+
# Add WIP badge if needed
|
994 |
+
if is_wip:
|
995 |
+
html += ' <div class="eva-course-badge"><i class="fas fa-code-branch"></i> In Progress</div>\n'
|
996 |
+
|
997 |
+
html += f''' <div class="eva-course-header">
|
998 |
+
<h2 class="eva-course-title">{course["title"]}</h2>
|
999 |
+
<span class="eva-course-toggle"><i class="fas fa-chevron-down"></i></span>
|
1000 |
+
</div>
|
1001 |
+
<div class="eva-course-front">
|
1002 |
+
<p class="eva-course-description">{course["description"]}</p>
|
1003 |
+
<div class="eva-course-stats">
|
1004 |
+
<span><i class="fas fa-book"></i> {notebook_count} notebook{"s" if notebook_count != 1 else ""}</span>
|
1005 |
+
</div>
|
1006 |
+
<button class="eva-button eva-course-button">View Notebooks</button>
|
1007 |
+
</div>
|
1008 |
+
<div class="eva-course-content">
|
1009 |
+
<div class="eva-notebooks">
|
1010 |
+
'''
|
1011 |
+
|
1012 |
+
# Add notebooks
|
1013 |
+
for i, notebook in enumerate(course["notebooks"]):
|
1014 |
+
notebook_number = notebook.get("original_number", f"{i+1:02d}")
|
1015 |
+
html += f''' <div class="eva-notebook">
|
1016 |
+
<span class="eva-notebook-number">{notebook_number}</span>
|
1017 |
+
<a href="{notebook["path"].replace(".py", ".html")}" data-notebook-title="{notebook["display_name"]}">{notebook["display_name"]}</a>
|
1018 |
+
</div>
|
1019 |
+
'''
|
1020 |
+
|
1021 |
+
html += ''' </div>
|
1022 |
+
</div>
|
1023 |
+
</div>
|
1024 |
+
'''
|
1025 |
+
return html
|
1026 |
+
|
1027 |
+
|
1028 |
+
def generate_course_cards(courses):
|
1029 |
+
"""Generate HTML for all course cards."""
|
1030 |
+
html = ""
|
1031 |
+
|
1032 |
+
# Define the custom order for courses
|
1033 |
+
course_order = ["python", "probability", "polars", "optimization", "functional_programming"]
|
1034 |
+
|
1035 |
+
# Create a dictionary of courses by ID for easy lookup
|
1036 |
+
courses_by_id = {course["id"]: course for course in courses.values()}
|
1037 |
+
|
1038 |
+
# Determine which courses are "work in progress" based on description or notebook count
|
1039 |
+
work_in_progress = set()
|
1040 |
+
for course_id, course in courses_by_id.items():
|
1041 |
+
# Consider a course as "work in progress" if it has few notebooks or contains specific phrases
|
1042 |
+
if (len(course["notebooks"]) < 5 or
|
1043 |
+
"work in progress" in course["description"].lower() or
|
1044 |
+
"help us add" in course["description"].lower() or
|
1045 |
+
"check back later" in course["description"].lower()):
|
1046 |
+
work_in_progress.add(course_id)
|
1047 |
+
|
1048 |
+
# First output courses in the specified order
|
1049 |
+
for course_id in course_order:
|
1050 |
+
if course_id in courses_by_id:
|
1051 |
+
course = courses_by_id[course_id]
|
1052 |
+
|
1053 |
+
# Skip if no notebooks
|
1054 |
+
if not course["notebooks"]:
|
1055 |
+
continue
|
1056 |
+
|
1057 |
+
# Count notebooks
|
1058 |
+
notebook_count = len(course["notebooks"])
|
1059 |
+
|
1060 |
+
# Determine if this course is a work in progress
|
1061 |
+
is_wip = course_id in work_in_progress
|
1062 |
+
|
1063 |
+
html += generate_course_card(course, notebook_count, is_wip)
|
1064 |
+
|
1065 |
+
# Remove from the dictionary so we don't output it again
|
1066 |
+
del courses_by_id[course_id]
|
1067 |
+
|
1068 |
+
# Then output any remaining courses alphabetically
|
1069 |
+
sorted_remaining_courses = sorted(courses_by_id.values(), key=lambda x: x["title"])
|
1070 |
+
|
1071 |
+
for course in sorted_remaining_courses:
|
1072 |
+
# Skip if no notebooks
|
1073 |
+
if not course["notebooks"]:
|
1074 |
+
continue
|
1075 |
+
|
1076 |
+
# Count notebooks
|
1077 |
+
notebook_count = len(course["notebooks"])
|
1078 |
+
|
1079 |
+
# Determine if this course is a work in progress
|
1080 |
+
is_wip = course["id"] in work_in_progress
|
1081 |
+
|
1082 |
+
html += generate_course_card(course, notebook_count, is_wip)
|
1083 |
+
|
1084 |
+
return html
|
1085 |
+
|
1086 |
+
|
1087 |
+
def get_html_courses_end():
|
1088 |
+
"""Generate the end of the courses section."""
|
1089 |
+
return """ </div>
|
1090 |
+
</section>"""
|
1091 |
+
|
1092 |
+
|
1093 |
+
def get_html_contribute_section():
|
1094 |
+
"""Generate the contribute section."""
|
1095 |
+
return """
|
1096 |
+
<section id="contribute" class="eva-cta">
|
1097 |
+
<h2>Contribute to Marimo Learn</h2>
|
1098 |
+
<p>
|
1099 |
+
Help us expand our collection of educational notebooks. Whether you're an expert in machine learning,
|
1100 |
+
statistics, or any other field, your contributions are welcome!
|
1101 |
+
</p>
|
1102 |
+
<a href="https://github.com/marimo-team/learn" target="_blank" class="eva-button">
|
1103 |
+
<i class="fab fa-github"></i> Contribute on GitHub
|
1104 |
+
</a>
|
1105 |
+
</section>"""
|
1106 |
+
|
1107 |
+
|
1108 |
+
def get_html_footer():
|
1109 |
+
"""Generate the page footer."""
|
1110 |
+
return """
|
1111 |
+
<footer class="eva-footer">
|
1112 |
+
<div class="eva-footer-logo">
|
1113 |
+
<a href="https://marimo.io" target="_blank">
|
1114 |
+
<img src="https://marimo.io/logotype-wide.svg" alt="Marimo" width="200">
|
1115 |
+
</a>
|
1116 |
+
</div>
|
1117 |
+
<div class="eva-social-links">
|
1118 |
+
<a href="https://github.com/marimo-team" target="_blank" aria-label="GitHub"><i class="fab fa-github"></i></a>
|
1119 |
+
<a href="https://marimo.io/discord?ref=learn" target="_blank" aria-label="Discord"><i class="fab fa-discord"></i></a>
|
1120 |
+
<a href="https://twitter.com/marimo_io" target="_blank" aria-label="Twitter"><i class="fab fa-twitter"></i></a>
|
1121 |
+
<a href="https://www.youtube.com/@marimo-team" target="_blank" aria-label="YouTube"><i class="fab fa-youtube"></i></a>
|
1122 |
+
<a href="https://www.linkedin.com/company/marimo-io" target="_blank" aria-label="LinkedIn"><i class="fab fa-linkedin"></i></a>
|
1123 |
+
</div>
|
1124 |
+
<div class="eva-footer-links">
|
1125 |
+
<a href="https://marimo.io" target="_blank">Website</a>
|
1126 |
+
<a href="https://docs.marimo.io" target="_blank">Documentation</a>
|
1127 |
+
<a href="https://github.com/marimo-team/learn" target="_blank">GitHub</a>
|
1128 |
+
</div>
|
1129 |
+
<div class="eva-footer-copyright">
|
1130 |
+
© 2025 Marimo Inc. All rights reserved.
|
1131 |
+
</div>
|
1132 |
+
</footer>"""
|
1133 |
+
|
1134 |
+
|
1135 |
+
def get_html_scripts():
|
1136 |
+
"""Generate the JavaScript for the page."""
|
1137 |
+
return """
|
1138 |
+
<script>
|
1139 |
+
// Set light theme as default immediately
|
1140 |
+
document.documentElement.setAttribute('data-theme', 'light');
|
1141 |
+
|
1142 |
+
document.addEventListener('DOMContentLoaded', function() {
|
1143 |
+
// Theme toggle functionality
|
1144 |
+
const themeToggle = document.getElementById('themeToggle');
|
1145 |
+
const themeIcon = themeToggle.querySelector('i');
|
1146 |
+
|
1147 |
+
// Update theme icon based on current theme
|
1148 |
+
updateThemeIcon('light');
|
1149 |
+
|
1150 |
+
// Check localStorage for saved theme preference
|
1151 |
+
const savedTheme = localStorage.getItem('theme');
|
1152 |
+
if (savedTheme && savedTheme !== 'light') {
|
1153 |
+
document.documentElement.setAttribute('data-theme', savedTheme);
|
1154 |
+
updateThemeIcon(savedTheme);
|
1155 |
+
}
|
1156 |
+
|
1157 |
+
// Toggle theme when button is clicked
|
1158 |
+
themeToggle.addEventListener('click', () => {
|
1159 |
+
const currentTheme = document.documentElement.getAttribute('data-theme');
|
1160 |
+
const newTheme = currentTheme === 'dark' ? 'light' : 'dark';
|
1161 |
+
|
1162 |
+
document.documentElement.setAttribute('data-theme', newTheme);
|
1163 |
+
localStorage.setItem('theme', newTheme);
|
1164 |
+
updateThemeIcon(newTheme);
|
1165 |
+
});
|
1166 |
+
|
1167 |
+
function updateThemeIcon(theme) {
|
1168 |
+
if (theme === 'dark') {
|
1169 |
+
themeIcon.className = 'fas fa-sun';
|
1170 |
+
} else {
|
1171 |
+
themeIcon.className = 'fas fa-moon';
|
1172 |
+
}
|
1173 |
+
}
|
1174 |
+
|
1175 |
+
// Terminal typing effect for hero text
|
1176 |
+
const heroTitle = document.querySelector('.eva-hero h1');
|
1177 |
+
const heroText = document.querySelector('.eva-hero p');
|
1178 |
+
const cursor = document.querySelector('.eva-cursor');
|
1179 |
+
|
1180 |
+
const originalTitle = heroTitle.textContent;
|
1181 |
+
const originalText = heroText.textContent.trim();
|
1182 |
+
|
1183 |
+
heroTitle.textContent = '';
|
1184 |
+
heroText.textContent = '';
|
1185 |
+
|
1186 |
+
let titleIndex = 0;
|
1187 |
+
let textIndex = 0;
|
1188 |
+
|
1189 |
+
function typeTitle() {
|
1190 |
+
if (titleIndex < originalTitle.length) {
|
1191 |
+
heroTitle.textContent += originalTitle.charAt(titleIndex);
|
1192 |
+
titleIndex++;
|
1193 |
+
setTimeout(typeTitle, 50);
|
1194 |
+
} else {
|
1195 |
+
cursor.style.display = 'none';
|
1196 |
+
setTimeout(typeText, 500);
|
1197 |
+
}
|
1198 |
+
}
|
1199 |
+
|
1200 |
+
function typeText() {
|
1201 |
+
if (textIndex < originalText.length) {
|
1202 |
+
heroText.textContent += originalText.charAt(textIndex);
|
1203 |
+
textIndex++;
|
1204 |
+
setTimeout(typeText, 20);
|
1205 |
+
}
|
1206 |
+
}
|
1207 |
+
|
1208 |
+
typeTitle();
|
1209 |
+
|
1210 |
+
// Course toggle functionality - flashcard style
|
1211 |
+
const courseHeaders = document.querySelectorAll('.eva-course-header');
|
1212 |
+
const courseButtons = document.querySelectorAll('.eva-course-button');
|
1213 |
+
|
1214 |
+
// Function to toggle course
|
1215 |
+
function toggleCourse(course) {
|
1216 |
+
const isActive = course.classList.contains('active');
|
1217 |
+
|
1218 |
+
// First close all courses with a slight delay for better visual effect
|
1219 |
+
document.querySelectorAll('.eva-course.active').forEach(c => {
|
1220 |
+
if (c !== course) {
|
1221 |
+
// Add a closing class for animation
|
1222 |
+
c.classList.add('closing');
|
1223 |
+
// Remove active class after a short delay
|
1224 |
+
setTimeout(() => {
|
1225 |
+
c.classList.remove('active');
|
1226 |
+
c.classList.remove('closing');
|
1227 |
+
}, 300);
|
1228 |
+
}
|
1229 |
+
});
|
1230 |
+
|
1231 |
+
// Toggle the clicked course
|
1232 |
+
if (!isActive) {
|
1233 |
+
// Add a small delay before opening to allow others to close
|
1234 |
+
setTimeout(() => {
|
1235 |
+
course.classList.add('active');
|
1236 |
+
|
1237 |
+
// Check if the course has any notebooks
|
1238 |
+
const notebooks = course.querySelectorAll('.eva-notebook');
|
1239 |
+
const content = course.querySelector('.eva-course-content');
|
1240 |
+
|
1241 |
+
if (notebooks.length === 0 && !content.querySelector('.eva-empty-message')) {
|
1242 |
+
// If no notebooks, show a message
|
1243 |
+
const emptyMessage = document.createElement('p');
|
1244 |
+
emptyMessage.className = 'eva-empty-message';
|
1245 |
+
emptyMessage.textContent = 'No notebooks available in this course yet.';
|
1246 |
+
emptyMessage.style.color = 'var(--eva-text)';
|
1247 |
+
emptyMessage.style.fontStyle = 'italic';
|
1248 |
+
emptyMessage.style.opacity = '0.7';
|
1249 |
+
emptyMessage.style.textAlign = 'center';
|
1250 |
+
emptyMessage.style.padding = '1rem 0';
|
1251 |
+
content.appendChild(emptyMessage);
|
1252 |
+
}
|
1253 |
+
|
1254 |
+
// Animate notebooks to appear sequentially
|
1255 |
+
notebooks.forEach((notebook, index) => {
|
1256 |
+
notebook.style.opacity = '0';
|
1257 |
+
notebook.style.transform = 'translateX(-10px)';
|
1258 |
+
setTimeout(() => {
|
1259 |
+
notebook.style.opacity = '1';
|
1260 |
+
notebook.style.transform = 'translateX(0)';
|
1261 |
+
}, 50 + (index * 30)); // Stagger the animations
|
1262 |
+
});
|
1263 |
+
}, 100);
|
1264 |
+
}
|
1265 |
+
}
|
1266 |
+
|
1267 |
+
// Add click event to course headers
|
1268 |
+
courseHeaders.forEach(header => {
|
1269 |
+
header.addEventListener('click', function(e) {
|
1270 |
+
e.preventDefault();
|
1271 |
+
e.stopPropagation();
|
1272 |
+
|
1273 |
+
const currentCourse = this.closest('.eva-course');
|
1274 |
+
toggleCourse(currentCourse);
|
1275 |
+
});
|
1276 |
+
});
|
1277 |
+
|
1278 |
+
// Add click event to course buttons
|
1279 |
+
courseButtons.forEach(button => {
|
1280 |
+
button.addEventListener('click', function(e) {
|
1281 |
+
e.preventDefault();
|
1282 |
+
e.stopPropagation();
|
1283 |
+
|
1284 |
+
const currentCourse = this.closest('.eva-course');
|
1285 |
+
toggleCourse(currentCourse);
|
1286 |
+
});
|
1287 |
+
});
|
1288 |
+
|
1289 |
+
// Search functionality with improved matching
|
1290 |
+
const searchInput = document.getElementById('courseSearch');
|
1291 |
+
const courses = document.querySelectorAll('.eva-course');
|
1292 |
+
const notebooks = document.querySelectorAll('.eva-notebook');
|
1293 |
+
|
1294 |
+
searchInput.addEventListener('input', function() {
|
1295 |
+
const searchTerm = this.value.toLowerCase();
|
1296 |
+
|
1297 |
+
if (searchTerm === '') {
|
1298 |
+
// Reset all visibility
|
1299 |
+
courses.forEach(course => {
|
1300 |
+
course.style.display = 'block';
|
1301 |
+
course.classList.remove('active');
|
1302 |
+
});
|
1303 |
+
|
1304 |
+
notebooks.forEach(notebook => {
|
1305 |
+
notebook.style.display = 'flex';
|
1306 |
+
});
|
1307 |
+
|
1308 |
+
// Open the first course with notebooks by default when search is cleared
|
1309 |
+
for (let i = 0; i < courses.length; i++) {
|
1310 |
+
const courseNotebooks = courses[i].querySelectorAll('.eva-notebook');
|
1311 |
+
if (courseNotebooks.length > 0) {
|
1312 |
+
courses[i].classList.add('active');
|
1313 |
+
break;
|
1314 |
+
}
|
1315 |
+
}
|
1316 |
+
|
1317 |
+
return;
|
1318 |
+
}
|
1319 |
+
|
1320 |
+
// First hide all courses
|
1321 |
+
courses.forEach(course => {
|
1322 |
+
course.style.display = 'none';
|
1323 |
+
course.classList.remove('active');
|
1324 |
+
});
|
1325 |
+
|
1326 |
+
// Then show courses and notebooks that match the search
|
1327 |
+
let hasResults = false;
|
1328 |
+
|
1329 |
+
// Track which courses have matching notebooks
|
1330 |
+
const coursesWithMatchingNotebooks = new Set();
|
1331 |
+
|
1332 |
+
// First check notebooks
|
1333 |
+
notebooks.forEach(notebook => {
|
1334 |
+
const notebookTitle = notebook.querySelector('a').getAttribute('data-notebook-title').toLowerCase();
|
1335 |
+
const matchesSearch = notebookTitle.includes(searchTerm);
|
1336 |
+
|
1337 |
+
notebook.style.display = matchesSearch ? 'flex' : 'none';
|
1338 |
+
|
1339 |
+
if (matchesSearch) {
|
1340 |
+
const course = notebook.closest('.eva-course');
|
1341 |
+
coursesWithMatchingNotebooks.add(course.getAttribute('data-course-id'));
|
1342 |
+
hasResults = true;
|
1343 |
+
}
|
1344 |
+
});
|
1345 |
+
|
1346 |
+
// Then check course titles and descriptions
|
1347 |
+
courses.forEach(course => {
|
1348 |
+
const courseId = course.getAttribute('data-course-id');
|
1349 |
+
const courseTitle = course.querySelector('.eva-course-title').textContent.toLowerCase();
|
1350 |
+
const courseDescription = course.querySelector('.eva-course-description').textContent.toLowerCase();
|
1351 |
+
|
1352 |
+
const courseMatches = courseTitle.includes(searchTerm) || courseDescription.includes(searchTerm);
|
1353 |
+
|
1354 |
+
// Show course if it matches or has matching notebooks
|
1355 |
+
if (courseMatches || coursesWithMatchingNotebooks.has(courseId)) {
|
1356 |
+
course.style.display = 'block';
|
1357 |
+
course.classList.add('active');
|
1358 |
+
hasResults = true;
|
1359 |
+
|
1360 |
+
// If course matches but doesn't have matching notebooks, show all its notebooks
|
1361 |
+
if (courseMatches && !coursesWithMatchingNotebooks.has(courseId)) {
|
1362 |
+
course.querySelectorAll('.eva-notebook').forEach(nb => {
|
1363 |
+
nb.style.display = 'flex';
|
1364 |
+
});
|
1365 |
+
}
|
1366 |
+
}
|
1367 |
+
});
|
1368 |
+
});
|
1369 |
+
|
1370 |
+
// Open the first course with notebooks by default
|
1371 |
+
let firstCourseWithNotebooks = null;
|
1372 |
+
for (let i = 0; i < courses.length; i++) {
|
1373 |
+
const courseNotebooks = courses[i].querySelectorAll('.eva-notebook');
|
1374 |
+
if (courseNotebooks.length > 0) {
|
1375 |
+
firstCourseWithNotebooks = courses[i];
|
1376 |
+
break;
|
1377 |
+
}
|
1378 |
+
}
|
1379 |
+
|
1380 |
+
if (firstCourseWithNotebooks) {
|
1381 |
+
firstCourseWithNotebooks.classList.add('active');
|
1382 |
+
} else if (courses.length > 0) {
|
1383 |
+
// If no courses have notebooks, just open the first one
|
1384 |
+
courses[0].classList.add('active');
|
1385 |
+
}
|
1386 |
+
|
1387 |
+
// Smooth scrolling for anchor links
|
1388 |
+
document.querySelectorAll('a[href^="#"]').forEach(anchor => {
|
1389 |
+
anchor.addEventListener('click', function(e) {
|
1390 |
+
e.preventDefault();
|
1391 |
+
|
1392 |
+
const targetId = this.getAttribute('href');
|
1393 |
+
const targetElement = document.querySelector(targetId);
|
1394 |
+
|
1395 |
+
if (targetElement) {
|
1396 |
+
window.scrollTo({
|
1397 |
+
top: targetElement.offsetTop - 100,
|
1398 |
+
behavior: 'smooth'
|
1399 |
+
});
|
1400 |
+
}
|
1401 |
+
});
|
1402 |
+
});
|
1403 |
+
});
|
1404 |
+
</script>"""
|
1405 |
+
|
1406 |
+
|
1407 |
+
def get_html_footer_closing():
|
1408 |
+
"""Generate closing HTML tags."""
|
1409 |
+
return """
|
1410 |
+
</div>
|
1411 |
+
</body>
|
1412 |
+
</html>"""
|
1413 |
+
|
1414 |
+
|
1415 |
+
def generate_index(courses: Dict[str, Dict[str, Any]], output_dir: str) -> None:
|
1416 |
+
"""Generate the index.html file with Neon Genesis Evangelion aesthetics."""
|
1417 |
+
print("Generating index.html")
|
1418 |
+
|
1419 |
+
index_path = os.path.join(output_dir, "index.html")
|
1420 |
+
os.makedirs(output_dir, exist_ok=True)
|
1421 |
+
|
1422 |
+
try:
|
1423 |
+
with open(index_path, "w", encoding="utf-8") as f:
|
1424 |
+
# Build the page HTML from individual components
|
1425 |
+
header = get_html_header().format(css=generate_eva_css())
|
1426 |
+
hero = get_html_hero_section()
|
1427 |
+
features = get_html_features_section()
|
1428 |
+
courses_start = get_html_courses_start()
|
1429 |
+
course_cards = generate_course_cards(courses)
|
1430 |
+
courses_end = get_html_courses_end()
|
1431 |
+
contribute = get_html_contribute_section()
|
1432 |
+
footer = get_html_footer()
|
1433 |
+
scripts = get_html_scripts()
|
1434 |
+
closing = get_html_footer_closing()
|
1435 |
+
|
1436 |
+
# Write all elements to the file
|
1437 |
+
f.write(header)
|
1438 |
+
f.write(hero)
|
1439 |
+
f.write(features)
|
1440 |
+
f.write(courses_start)
|
1441 |
+
f.write(course_cards)
|
1442 |
+
f.write(courses_end)
|
1443 |
+
f.write(contribute)
|
1444 |
+
f.write(footer)
|
1445 |
+
f.write(scripts)
|
1446 |
+
f.write(closing)
|
1447 |
+
|
1448 |
+
except IOError as e:
|
1449 |
+
print(f"Error generating index.html: {e}")
|
1450 |
+
|
1451 |
+
|
1452 |
+
def main() -> None:
|
1453 |
+
parser = argparse.ArgumentParser(description="Build marimo notebooks")
|
1454 |
+
parser.add_argument(
|
1455 |
+
"--output-dir", default="_site", help="Output directory for built files"
|
1456 |
+
)
|
1457 |
+
parser.add_argument(
|
1458 |
+
"--course-dirs", nargs="+", default=None,
|
1459 |
+
help="Specific course directories to build (default: all directories with .py files)"
|
1460 |
+
)
|
1461 |
+
args = parser.parse_args()
|
1462 |
+
|
1463 |
+
# Find all course directories (directories containing .py files)
|
1464 |
+
all_notebooks: List[str] = []
|
1465 |
+
|
1466 |
+
# Directories to exclude from course detection
|
1467 |
+
excluded_dirs = ["scripts", "env", "__pycache__", ".git", ".github", "assets"]
|
1468 |
+
|
1469 |
+
if args.course_dirs:
|
1470 |
+
course_dirs = args.course_dirs
|
1471 |
+
else:
|
1472 |
+
# Automatically detect course directories (any directory with .py files)
|
1473 |
+
course_dirs = []
|
1474 |
+
for item in os.listdir("."):
|
1475 |
+
if (os.path.isdir(item) and
|
1476 |
+
not item.startswith(".") and
|
1477 |
+
not item.startswith("_") and
|
1478 |
+
item not in excluded_dirs):
|
1479 |
+
# Check if directory contains .py files
|
1480 |
+
if list(Path(item).glob("*.py")):
|
1481 |
+
course_dirs.append(item)
|
1482 |
+
|
1483 |
+
print(f"Found course directories: {', '.join(course_dirs)}")
|
1484 |
+
|
1485 |
+
for directory in course_dirs:
|
1486 |
+
dir_path = Path(directory)
|
1487 |
+
if not dir_path.exists():
|
1488 |
+
print(f"Warning: Directory not found: {dir_path}")
|
1489 |
+
continue
|
1490 |
+
|
1491 |
+
notebooks = [str(path) for path in dir_path.rglob("*.py")
|
1492 |
+
if not path.name.startswith("_") and "/__pycache__/" not in str(path)]
|
1493 |
+
all_notebooks.extend(notebooks)
|
1494 |
+
|
1495 |
+
if not all_notebooks:
|
1496 |
+
print("No notebooks found!")
|
1497 |
+
return
|
1498 |
+
|
1499 |
+
# Export notebooks sequentially
|
1500 |
+
successful_notebooks = []
|
1501 |
+
for nb in all_notebooks:
|
1502 |
+
# Determine if notebook should be exported as app or notebook
|
1503 |
+
# For now, export all as notebooks
|
1504 |
+
if export_html_wasm(nb, args.output_dir, as_app=False):
|
1505 |
+
successful_notebooks.append(nb)
|
1506 |
+
|
1507 |
+
# Organize notebooks by course (only include successfully exported notebooks)
|
1508 |
+
courses = organize_notebooks_by_course(successful_notebooks)
|
1509 |
+
|
1510 |
+
# Generate index with organized courses
|
1511 |
+
generate_index(courses, args.output_dir)
|
1512 |
+
|
1513 |
+
# Save course data as JSON for potential use by other tools
|
1514 |
+
courses_json_path = os.path.join(args.output_dir, "courses.json")
|
1515 |
+
with open(courses_json_path, "w", encoding="utf-8") as f:
|
1516 |
+
json.dump(courses, f, indent=2)
|
1517 |
+
|
1518 |
+
print(f"Build complete! Site generated in {args.output_dir}")
|
1519 |
+
print(f"Successfully exported {len(successful_notebooks)} out of {len(all_notebooks)} notebooks")
|
1520 |
+
|
1521 |
+
|
1522 |
+
if __name__ == "__main__":
|
1523 |
+
main()
|
scripts/preview.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3
|
2 |
+
|
3 |
+
import os
|
4 |
+
import subprocess
|
5 |
+
import argparse
|
6 |
+
import webbrowser
|
7 |
+
import time
|
8 |
+
import sys
|
9 |
+
from pathlib import Path
|
10 |
+
|
11 |
+
def main():
|
12 |
+
parser = argparse.ArgumentParser(description="Build and preview marimo notebooks site")
|
13 |
+
parser.add_argument(
|
14 |
+
"--port", default=8000, type=int, help="Port to run the server on"
|
15 |
+
)
|
16 |
+
parser.add_argument(
|
17 |
+
"--no-build", action="store_true", help="Skip building the site (just serve existing files)"
|
18 |
+
)
|
19 |
+
parser.add_argument(
|
20 |
+
"--output-dir", default="_site", help="Output directory for built files"
|
21 |
+
)
|
22 |
+
args = parser.parse_args()
|
23 |
+
|
24 |
+
# Store the current directory
|
25 |
+
original_dir = os.getcwd()
|
26 |
+
|
27 |
+
try:
|
28 |
+
# Build the site if not skipped
|
29 |
+
if not args.no_build:
|
30 |
+
print("Building site...")
|
31 |
+
build_script = Path("scripts/build.py")
|
32 |
+
if not build_script.exists():
|
33 |
+
print(f"Error: Build script not found at {build_script}")
|
34 |
+
return 1
|
35 |
+
|
36 |
+
result = subprocess.run(
|
37 |
+
[sys.executable, str(build_script), "--output-dir", args.output_dir],
|
38 |
+
check=False
|
39 |
+
)
|
40 |
+
if result.returncode != 0:
|
41 |
+
print("Warning: Build process completed with errors.")
|
42 |
+
|
43 |
+
# Check if the output directory exists
|
44 |
+
output_dir = Path(args.output_dir)
|
45 |
+
if not output_dir.exists():
|
46 |
+
print(f"Error: Output directory '{args.output_dir}' does not exist.")
|
47 |
+
return 1
|
48 |
+
|
49 |
+
# Change to the output directory
|
50 |
+
os.chdir(args.output_dir)
|
51 |
+
|
52 |
+
# Open the browser
|
53 |
+
url = f"http://localhost:{args.port}"
|
54 |
+
print(f"Opening {url} in your browser...")
|
55 |
+
webbrowser.open(url)
|
56 |
+
|
57 |
+
# Start the server
|
58 |
+
print(f"Starting server on port {args.port}...")
|
59 |
+
print("Press Ctrl+C to stop the server")
|
60 |
+
|
61 |
+
# Use the appropriate Python executable
|
62 |
+
subprocess.run([sys.executable, "-m", "http.server", str(args.port)])
|
63 |
+
|
64 |
+
return 0
|
65 |
+
except KeyboardInterrupt:
|
66 |
+
print("\nServer stopped.")
|
67 |
+
return 0
|
68 |
+
except Exception as e:
|
69 |
+
print(f"Error: {e}")
|
70 |
+
return 1
|
71 |
+
finally:
|
72 |
+
# Always return to the original directory
|
73 |
+
os.chdir(original_dir)
|
74 |
+
|
75 |
+
if __name__ == "__main__":
|
76 |
+
sys.exit(main())
|