etrotta commited on
Commit
5de8d31
·
2 Parent(s): 25e395f 1cd4542

Merge marimo-team/learn (add __marimo__ to gitignore)

Browse files
.github/workflows/deploy.yml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Deploy to GitHub Pages
2
+
3
+ on:
4
+ push:
5
+ branches: ['main']
6
+ workflow_dispatch:
7
+
8
+ concurrency:
9
+ group: 'pages'
10
+ cancel-in-progress: false
11
+
12
+ env:
13
+ UV_SYSTEM_PYTHON: 1
14
+
15
+ jobs:
16
+ build:
17
+ runs-on: ubuntu-latest
18
+ steps:
19
+ - uses: actions/checkout@v4
20
+
21
+ - name: 🚀 Install uv
22
+ uses: astral-sh/setup-uv@v4
23
+
24
+ - name: 🐍 Set up Python
25
+ uses: actions/setup-python@v5
26
+ with:
27
+ python-version: 3.12
28
+
29
+ - name: 📦 Install dependencies
30
+ run: |
31
+ uv pip install marimo
32
+
33
+ - name: 🛠️ Export notebooks
34
+ run: |
35
+ python scripts/build.py
36
+
37
+ - name: 📤 Upload artifact
38
+ uses: actions/upload-pages-artifact@v3
39
+ with:
40
+ path: _site
41
+
42
+ deploy:
43
+ needs: build
44
+
45
+ permissions:
46
+ pages: write
47
+ id-token: write
48
+
49
+ environment:
50
+ name: github-pages
51
+ url: ${{ steps.deployment.outputs.page_url }}
52
+ runs-on: ubuntu-latest
53
+ steps:
54
+ - name: 🚀 Deploy to GitHub Pages
55
+ id: deployment
56
+ uses: actions/deploy-pages@v4
.github/workflows/hf_sync.yml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Sync to Hugging Face hub
2
+ on:
3
+ push:
4
+ branches: [main]
5
+
6
+ # to run this workflow manually from the Actions tab
7
+ workflow_dispatch:
8
+
9
+ jobs:
10
+ sync-to-hub:
11
+ runs-on: ubuntu-latest
12
+ steps:
13
+ - uses: actions/checkout@v4
14
+ with:
15
+ fetch-depth: 0
16
+
17
+ - name: Configure Git
18
+ run: |
19
+ git config --global user.name "GitHub Action"
20
+ git config --global user.email "[email protected]"
21
+
22
+ - name: Prepend frontmatter to README
23
+ run: |
24
+ if [ -f README.md ] && ! grep -q "^---" README.md; then
25
+ FRONTMATTER="---
26
+ title: marimo learn
27
+ emoji: 🧠
28
+ colorFrom: blue
29
+ colorTo: indigo
30
+ sdk: docker
31
+ sdk_version: \"latest\"
32
+ app_file: app.py
33
+ pinned: false
34
+ ---
35
+
36
+ "
37
+ echo "$FRONTMATTER$(cat README.md)" > README.md
38
+ git add README.md
39
+ git commit -m "Add HF frontmatter to README" || echo "No changes to commit"
40
+ fi
41
+
42
+ - name: Push to hub
43
+ env:
44
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
45
+ run: git push -f https://mylessss:[email protected]/spaces/marimo-team/marimo-learn main
.github/workflows/typos.yaml CHANGED
@@ -13,4 +13,3 @@ jobs:
13
  uses: styfle/[email protected]
14
  - uses: actions/checkout@v4
15
  - uses: crate-ci/[email protected]
16
- name: Tests
 
13
  uses: styfle/[email protected]
14
  - uses: actions/checkout@v4
15
  - uses: crate-ci/[email protected]
 
.gitignore CHANGED
@@ -172,3 +172,6 @@ cython_debug/
172
 
173
  # Marimo specific
174
  __marimo__
 
 
 
 
172
 
173
  # Marimo specific
174
  __marimo__
175
+
176
+ # Generated site content
177
+ _site/
Dockerfile ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim
2
+
3
+ WORKDIR /app
4
+
5
+ # Create a non-root user
6
+ RUN useradd -m appuser
7
+
8
+ # Copy application files
9
+ COPY _server/main.py _server/main.py
10
+ COPY polars/ polars/
11
+ COPY duckdb/ duckdb/
12
+
13
+ # Set proper ownership
14
+ RUN chown -R appuser:appuser /app
15
+
16
+ # Switch to non-root user
17
+ USER appuser
18
+
19
+ # Create virtual environment and install dependencies
20
+ RUN uv venv
21
+ RUN uv export --script _server/main.py | uv pip install -r -
22
+
23
+ ENV PORT=7860
24
+ EXPOSE 7860
25
+
26
+ CMD ["uv", "run", "_server/main.py"]
README.md CHANGED
@@ -3,7 +3,7 @@
3
  </p>
4
 
5
  <p align="center">
6
- <em>A curated collection of educational <a href="https://github.com/marimo-team/marimo">marimo</a> notebooks</em>.
7
  </p>
8
 
9
  # 📚 Learn
@@ -29,6 +29,7 @@ notebooks for educators, students, and practitioners.
29
  - 📏 Linear algebra
30
  - ❄️ Polars
31
  - 🔥 Pytorch
 
32
  - 📈 Altair
33
  - 📈 Plotly
34
  - 📈 matplotlib
@@ -57,12 +58,27 @@ Here's a contribution checklist:
57
  If you aren't comfortable adding a new notebook or course, you can also request
58
  what you'd like to see by [filing an issue](https://github.com/marimo-team/learn/issues/new?template=example_request.yaml).
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  ## Community
61
 
62
  We're building a community. Come hang out with us!
63
 
64
  - 🌟 [Star us on GitHub](https://github.com/marimo-team/examples)
65
- - 💬 [Chat with us on Discord](https://marimo.io/discord?ref=readme)
66
  - 📧 [Subscribe to our Newsletter](https://marimo.io/newsletter)
67
  - ☁️ [Join our Cloud Waitlist](https://marimo.io/cloud)
68
  - ✏️ [Start a GitHub Discussion](https://github.com/marimo-team/marimo/discussions)
 
3
  </p>
4
 
5
  <p align="center">
6
+ <span><em>A curated collection of educational <a href="https://github.com/marimo-team/marimo">marimo</a> notebooks</em> || <a href="https://discord.gg/rT48v2Y9fe">💬 Discord</a></span>
7
  </p>
8
 
9
  # 📚 Learn
 
29
  - 📏 Linear algebra
30
  - ❄️ Polars
31
  - 🔥 Pytorch
32
+ - 🗄️ Duckdb
33
  - 📈 Altair
34
  - 📈 Plotly
35
  - 📈 matplotlib
 
58
  If you aren't comfortable adding a new notebook or course, you can also request
59
  what you'd like to see by [filing an issue](https://github.com/marimo-team/learn/issues/new?template=example_request.yaml).
60
 
61
+ ## Building and Previewing
62
+
63
+ The site is built using a Python script that exports marimo notebooks to HTML and generates an index page.
64
+
65
+ ```bash
66
+ # Build the site
67
+ python scripts/build.py --output-dir _site
68
+
69
+ # Preview the site (builds first)
70
+ python scripts/preview.py
71
+
72
+ # Preview without rebuilding
73
+ python scripts/preview.py --no-build
74
+ ```
75
+
76
  ## Community
77
 
78
  We're building a community. Come hang out with us!
79
 
80
  - 🌟 [Star us on GitHub](https://github.com/marimo-team/examples)
81
+ - 💬 [Chat with us on Discord](https://discord.gg/rT48v2Y9fe)
82
  - 📧 [Subscribe to our Newsletter](https://marimo.io/newsletter)
83
  - ☁️ [Join our Cloud Waitlist](https://marimo.io/cloud)
84
  - ✏️ [Start a GitHub Discussion](https://github.com/marimo-team/marimo/discussions)
_server/README.md ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # marimo learn server
2
+
3
+ This folder contains server code for hosting marimo apps.
4
+
5
+ ## Running the server
6
+
7
+ ```bash
8
+ cd _server
9
+ uv run --no-project main.py
10
+ ```
11
+
12
+ ## Building a Docker image
13
+
14
+ From the root directory, run:
15
+
16
+ ```bash
17
+ docker build -t marimo-learn .
18
+ ```
19
+
20
+ ## Running the Docker container
21
+
22
+ ```bash
23
+ docker run -p 7860:7860 marimo-learn
24
+ ```
_server/main.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.12"
3
+ # dependencies = [
4
+ # "fastapi",
5
+ # "marimo",
6
+ # "starlette",
7
+ # "python-dotenv",
8
+ # "pydantic",
9
+ # "duckdb",
10
+ # "altair==5.5.0",
11
+ # "beautifulsoup4==4.13.3",
12
+ # "httpx==0.28.1",
13
+ # "marimo",
14
+ # "nest-asyncio==1.6.0",
15
+ # "numba==0.61.0",
16
+ # "numpy==2.1.3",
17
+ # "polars==1.24.0",
18
+ # ]
19
+ # ///
20
+
21
+ import logging
22
+ import os
23
+ from pathlib import Path
24
+
25
+ import marimo
26
+ from dotenv import load_dotenv
27
+ from fastapi import FastAPI, Request
28
+ from fastapi.responses import HTMLResponse
29
+
30
+ # Load environment variables
31
+ load_dotenv()
32
+
33
+ # Set up logging
34
+ logging.basicConfig(level=logging.INFO)
35
+ logger = logging.getLogger(__name__)
36
+
37
+ # Get port from environment variable or use default
38
+ PORT = int(os.environ.get("PORT", 7860))
39
+
40
+ root_dir = Path(__file__).parent.parent
41
+
42
+ ROOTS = [
43
+ root_dir / "polars",
44
+ root_dir / "duckdb",
45
+ ]
46
+
47
+
48
+ server = marimo.create_asgi_app(include_code=True)
49
+ app_names: list[str] = []
50
+
51
+ for root in ROOTS:
52
+ for filename in root.iterdir():
53
+ if filename.is_file() and filename.suffix == ".py":
54
+ app_path = root.stem + "/" + filename.stem
55
+ server = server.with_app(path=f"/{app_path}", root=str(filename))
56
+ app_names.append(app_path)
57
+
58
+ # Create a FastAPI app
59
+ app = FastAPI()
60
+
61
+ logger.info(f"Mounting {len(app_names)} apps")
62
+ for app_name in app_names:
63
+ logger.info(f" /{app_name}")
64
+
65
+
66
+ @app.get("/")
67
+ async def home(request: Request):
68
+ html_content = """
69
+ <!DOCTYPE html>
70
+ <html>
71
+ <head>
72
+ <title>marimo learn</title>
73
+ </head>
74
+ <body>
75
+ <h1>Welcome to marimo learn!</h1>
76
+ <p>This is a collection of interactive tutorials for learning data science libraries with marimo.</p>
77
+ <p>Check out the <a href="https://github.com/marimo-team/learn">GitHub repository</a> for more information.</p>
78
+ </body>
79
+ </html>
80
+ """
81
+ return HTMLResponse(content=html_content)
82
+
83
+
84
+ app.mount("/", server.build())
85
+
86
+ # Run the server
87
+ if __name__ == "__main__":
88
+ import uvicorn
89
+
90
+ uvicorn.run(app, host="0.0.0.0", port=PORT, log_level="info")
duckdb/README.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Learn DuckDB
2
+
3
+ _🚧 This collection is a work in progress. Please help us add notebooks!_
4
+
5
+ This collection of marimo notebooks is designed to teach you the basics of
6
+ DuckDB, a fast in-memory OLAP engine that can interoperate with Dataframes.
7
+ These notebooks also show how marimo gives DuckDB superpowers.
8
+
9
+ **Help us build this course! ⚒️**
10
+
11
+ We're seeking contributors to help us build these notebooks. Every contributor
12
+ will be acknowledged as an author in this README and in their contributed
13
+ notebooks. Head over to the [tracking
14
+ issue](https://github.com/marimo-team/learn/issues/48) to sign up for a planned
15
+ notebook or propose your own.
16
+
17
+ **Running notebooks.** To run a notebook locally, use
18
+
19
+ ```bash
20
+ uvx marimo edit <file_url>
21
+ ```
22
+
23
+ You can also open notebooks in our online playground by appending marimo.app/ to a notebook's URL.
24
+
25
+
26
+ **Authors.**
27
+
28
+ Thanks to all our notebook authors!
functional_programming/05_functors.py ADDED
@@ -0,0 +1,1313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.9"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # ]
6
+ # ///
7
+
8
+ import marimo
9
+
10
+ __generated_with = "0.11.17"
11
+ app = marimo.App(app_title="Category Theory and Functors")
12
+
13
+
14
+ @app.cell(hide_code=True)
15
+ def _(mo):
16
+ mo.md(
17
+ """
18
+ # Category Theory and Functors
19
+
20
+ In this notebook, you will learn:
21
+
22
+ * Why `length` is a *functor* from the category of `list concatenation` to the category of `integer addition`
23
+ * How to *lift* an ordinary function into a specific *computational context*
24
+ * How to write an *adapter* between two categories
25
+
26
+ In short, a mathematical functor is a **mapping** between two categories in category theory. In practice, a functor represents a type that can be mapped over.
27
+
28
+ /// admonition | Intuitions
29
+
30
+ - A simple intuition is that a `Functor` represents a **container** of values, along with the ability to apply a function uniformly to every element in the container.
31
+ - Another intuition is that a `Functor` represents some sort of **computational context**.
32
+ - Mathematically, `Functors` generalize the idea of a container or a computational context.
33
+ ///
34
+
35
+ We will start with intuition, introduce the basics of category theory, and then examine functors from a categorical perspective.
36
+
37
+ /// details | Notebook metadata
38
+ type: info
39
+
40
+ version: 0.1.1 | last modified: 2025-03-16 | author: [métaboulie](https://github.com/metaboulie)<br/>
41
+ reviewer: [Haleshot](https://github.com/Haleshot)
42
+
43
+ ///
44
+ """
45
+ )
46
+ return
47
+
48
+
49
+ @app.cell(hide_code=True)
50
+ def _(mo):
51
+ mo.md(
52
+ """
53
+ # Functor as a Computational Context
54
+
55
+ A [**Functor**](https://wiki.haskell.org/Functor) is an abstraction that represents a computational context with the ability to apply a function to every value inside it without altering the structure of the context itself. This enables transformations while preserving the shape of the data.
56
+
57
+ To understand this, let's look at a simple example.
58
+
59
+ ## [The One-Way Wrapper Design Pattern](http://blog.sigfpe.com/2007/04/trivial-monad.html)
60
+
61
+ Often, we need to wrap data in some kind of context. However, when performing operations on wrapped data, we typically have to:
62
+
63
+ 1. Unwrap the data.
64
+ 2. Modify the unwrapped data.
65
+ 3. Rewrap the modified data.
66
+
67
+ This process is tedious and inefficient. Instead, we want to wrap data **once** and apply functions directly to the wrapped data without unwrapping it.
68
+
69
+ /// admonition | Rules for a One-Way Wrapper
70
+
71
+ 1. We can wrap values, but we cannot unwrap them.
72
+ 2. We should still be able to apply transformations to the wrapped data.
73
+ 3. Any operation that depends on wrapped data should itself return a wrapped result.
74
+ ///
75
+
76
+ Let's define such a `Wrapper` class:
77
+
78
+ ```python
79
+ from dataclasses import dataclass
80
+ from typing import Callable, Generic, TypeVar
81
+
82
+ A = TypeVar("A")
83
+ B = TypeVar("B")
84
+
85
+ @dataclass
86
+ class Wrapper(Generic[A]):
87
+ value: A
88
+ ```
89
+
90
+ Now, we can create an instance of wrapped data:
91
+
92
+ ```python
93
+ wrapped = Wrapper(1)
94
+ ```
95
+
96
+ ### Mapping Functions Over Wrapped Data
97
+
98
+ To modify wrapped data while keeping it wrapped, we define an `fmap` method:
99
+
100
+ ```python
101
+ @dataclass
102
+ class Wrapper(Functor, Generic[A]):
103
+ value: A
104
+
105
+ @classmethod
106
+ def fmap(cls, f: Callable[[A], B], a: "Wrapper[A]") -> "Wrapper[B]":
107
+ return Wrapper(f(a.value))
108
+ ```
109
+
110
+ Now, we can apply transformations without unwrapping:
111
+
112
+ ```python
113
+ >>> Wrapper.fmap(lambda x: x + 1, wrapper)
114
+ Wrapper(value=2)
115
+
116
+ >>> Wrapper.fmap(lambda x: [x], wrapper)
117
+ Wrapper(value=[1])
118
+ ```
119
+
120
+ > Try using the `Wrapper` in the cell below.
121
+ """
122
+ )
123
+ return
124
+
125
+
126
+ @app.cell
127
+ def _(A, B, Callable, Functor, Generic, dataclass, pp):
128
+ @dataclass
129
+ class Wrapper(Functor, Generic[A]):
130
+ value: A
131
+
132
+ @classmethod
133
+ def fmap(cls, f: Callable[[A], B], a: "Wrapper[A]") -> "Wrapper[B]":
134
+ return Wrapper(f(a.value))
135
+
136
+
137
+ wrapper = Wrapper(1)
138
+
139
+ pp(Wrapper.fmap(lambda x: x + 1, wrapper))
140
+ pp(Wrapper.fmap(lambda x: [x], wrapper))
141
+ return Wrapper, wrapper
142
+
143
+
144
+ @app.cell(hide_code=True)
145
+ def _(mo):
146
+ mo.md(
147
+ """
148
+ We can analyze the type signature of `fmap` for `Wrapper`:
149
+
150
+ * `f` is of type `Callable[[A], B]`
151
+ * `a` is of type `Wrapper[A]`
152
+ * The return value is of type `Wrapper[B]`
153
+
154
+ Thus, in Python's type system, we can express the type signature of `fmap` as:
155
+
156
+ ```python
157
+ fmap(f: Callable[[A], B], a: Wrapper[A]) -> Wrapper[B]:
158
+ ```
159
+
160
+ Essentially, `fmap`:
161
+
162
+ 1. Takes a function `Callable[[A], B]` and a `Wrapper[A]` instance as input.
163
+ 2. Applies the function to the value inside the wrapper.
164
+ 3. Returns a new `Wrapper[B]` instance with the transformed value, leaving the original wrapper and its internal data unmodified.
165
+
166
+ Now, let's examine `list` as a similar kind of wrapper.
167
+ """
168
+ )
169
+ return
170
+
171
+
172
+ @app.cell(hide_code=True)
173
+ def _(mo):
174
+ mo.md(
175
+ """
176
+ ## The List Wrapper
177
+
178
+ We can define a `List` class to represent a wrapped list that supports `fmap`:
179
+
180
+ ```python
181
+ @dataclass
182
+ class List(Functor, Generic[A]):
183
+ value: list[A]
184
+
185
+ @classmethod
186
+ def fmap(cls, f: Callable[[A], B], a: "List[A]") -> "List[B]":
187
+ return List([f(x) for x in a.value])
188
+ ```
189
+
190
+ Now, we can apply transformations:
191
+
192
+ ```python
193
+ >>> flist = List([1, 2, 3, 4])
194
+ >>> List.fmap(lambda x: x + 1, flist)
195
+ List(value=[2, 3, 4, 5])
196
+ >>> List.fmap(lambda x: [x], flist)
197
+ List(value=[[1], [2], [3], [4]])
198
+ ```
199
+ """
200
+ )
201
+ return
202
+
203
+
204
+ @app.cell
205
+ def _(A, B, Callable, Functor, Generic, dataclass, pp):
206
+ @dataclass
207
+ class List(Functor, Generic[A]):
208
+ value: list[A]
209
+
210
+ @classmethod
211
+ def fmap(cls, f: Callable[[A], B], a: "List[A]") -> "List[B]":
212
+ return List([f(x) for x in a.value])
213
+
214
+
215
+ flist = List([1, 2, 3, 4])
216
+ pp(List.fmap(lambda x: x + 1, flist))
217
+ pp(List.fmap(lambda x: [x], flist))
218
+ return List, flist
219
+
220
+
221
+ @app.cell(hide_code=True)
222
+ def _(mo):
223
+ mo.md(
224
+ """
225
+ ### Extracting the Type of `fmap`
226
+
227
+ The type signature of `fmap` for `List` is:
228
+
229
+ ```python
230
+ fmap(f: Callable[[A], B], a: List[A]) -> List[B]
231
+ ```
232
+
233
+ Similarly, for `Wrapper`:
234
+
235
+ ```python
236
+ fmap(f: Callable[[A], B], a: Wrapper[A]) -> Wrapper[B]
237
+ ```
238
+
239
+ Both follow the same pattern, which we can generalize as:
240
+
241
+ ```python
242
+ fmap(f: Callable[[A], B], a: Functor[A]) -> Functor[B]
243
+ ```
244
+
245
+ where `Functor` can be `Wrapper`, `List`, or any other wrapper type that follows the same structure.
246
+
247
+ ### Functors in Haskell (optional)
248
+
249
+ In Haskell, the type of `fmap` is:
250
+
251
+ ```haskell
252
+ fmap :: Functor f => (a -> b) -> f a -> f b
253
+ ```
254
+
255
+ or equivalently:
256
+
257
+ ```haskell
258
+ fmap :: Functor f => (a -> b) -> (f a -> f b)
259
+ ```
260
+
261
+ This means that `fmap` **lifts** an ordinary function into the **functor world**, allowing it to operate within a computational context.
262
+
263
+ Now, let's define an abstract class for `Functor`.
264
+ """
265
+ )
266
+ return
267
+
268
+
269
+ @app.cell(hide_code=True)
270
+ def _(mo):
271
+ mo.md(
272
+ """
273
+ ## Defining Functor
274
+
275
+ Recall that, a **Functor** is an abstraction that allows us to apply a function to values inside a computational context while preserving its structure.
276
+
277
+ To define `Functor` in Python, we use an abstract base class:
278
+
279
+ ```python
280
+ from dataclasses import dataclass
281
+ from typing import Callable, Generic, TypeVar
282
+ from abc import ABC, abstractmethod
283
+
284
+ A = TypeVar("A")
285
+ B = TypeVar("B")
286
+
287
+ @dataclass
288
+ class Functor(ABC, Generic[A]):
289
+ @classmethod
290
+ @abstractmethod
291
+ def fmap(f: Callable[[A], B], a: "Functor[A]") -> "Functor[B]":
292
+ raise NotImplementedError
293
+ ```
294
+
295
+ We can now extend custom wrappers, containers, or computation contexts with this `Functor` base class, implement the `fmap` method, and apply any function.
296
+
297
+ Next, let's implement a more complex data structure: [RoseTree](https://en.wikipedia.org/wiki/Rose_tree).
298
+ """
299
+ )
300
+ return
301
+
302
+
303
+ @app.cell(hide_code=True)
304
+ def _(mo):
305
+ mo.md(
306
+ """
307
+ ## Case Study: RoseTree
308
+
309
+ A **RoseTree** is a tree where:
310
+
311
+ - Each node holds a **value**.
312
+ - Each node has a **list of child nodes** (which are also RoseTrees).
313
+
314
+ This structure is useful for representing hierarchical data, such as:
315
+
316
+ - Abstract Syntax Trees (ASTs)
317
+ - File system directories
318
+ - Recursive computations
319
+
320
+ We can implement `RoseTree` by extending the `Functor` class:
321
+
322
+ ```python
323
+ from dataclasses import dataclass
324
+ from typing import Callable, Generic, TypeVar
325
+
326
+ A = TypeVar("A")
327
+ B = TypeVar("B")
328
+
329
+ @dataclass
330
+ class RoseTree(Functor, Generic[a]):
331
+
332
+ value: A
333
+ children: list["RoseTree[A]"]
334
+
335
+ @classmethod
336
+ def fmap(cls, f: Callable[[A], B], a: "RoseTree[A]") -> "RoseTree[B]":
337
+ return RoseTree(
338
+ f(a.value), [cls.fmap(f, child) for child in a.children]
339
+ )
340
+
341
+ def __repr__(self) -> str:
342
+ return f"Node: {self.value}, Children: {self.children}"
343
+ ```
344
+
345
+ - The function is applied **recursively** to each node's value.
346
+ - The tree structure **remains unchanged**.
347
+ - Only the values inside the tree are modified.
348
+
349
+ > Try using `RoseTree` in the cell below.
350
+ """
351
+ )
352
+ return
353
+
354
+
355
+ @app.cell(hide_code=True)
356
+ def _(A, B, Callable, Functor, Generic, dataclass, mo):
357
+ @dataclass
358
+ class RoseTree(Functor, Generic[A]):
359
+ """
360
+ ### Doc: RoseTree
361
+
362
+ A Functor implementation of `RoseTree`, allowing transformation of values while preserving the tree structure.
363
+
364
+ **Attributes**
365
+
366
+ - `value (A)`: The value stored in the node.
367
+ - `children (list[RoseTree[A]])`: A list of child nodes forming the tree structure.
368
+
369
+ **Methods:**
370
+
371
+ - `fmap(f: Callable[[A], B], a: "RoseTree[A]") -> "RoseTree[B]"`
372
+
373
+ Applies a function to each value in the tree, producing a new `RoseTree[b]` with transformed values.
374
+
375
+ **Implementation logic:**
376
+
377
+ - The function `f` is applied to the root node's `value`.
378
+ - Each child in `children` recursively calls `fmap`, ensuring all values in the tree are mapped.
379
+ - The overall tree structure remains unchanged.
380
+ """
381
+
382
+ value: A
383
+ children: list["RoseTree[A]"]
384
+
385
+ @classmethod
386
+ def fmap(cls, f: Callable[[A], B], a: "RoseTree[A]") -> "RoseTree[B]":
387
+ return RoseTree(
388
+ f(a.value), [cls.fmap(f, child) for child in a.children]
389
+ )
390
+
391
+ def __repr__(self) -> str:
392
+ return f"Node: {self.value}, Children: {self.children}"
393
+
394
+
395
+ mo.md(RoseTree.__doc__)
396
+ return (RoseTree,)
397
+
398
+
399
+ @app.cell
400
+ def _(RoseTree, pp):
401
+ rosetree = RoseTree(1, [RoseTree(2, []), RoseTree(3, [RoseTree(4, [])])])
402
+
403
+ pp(rosetree)
404
+ pp(RoseTree.fmap(lambda x: [x], rosetree))
405
+ pp(RoseTree.fmap(lambda x: RoseTree(x, []), rosetree))
406
+ return (rosetree,)
407
+
408
+
409
+ @app.cell(hide_code=True)
410
+ def _(mo):
411
+ mo.md(
412
+ """
413
+ ## Generic Functions that can be Used with Any Functor
414
+
415
+ One of the powerful features of functors is that we can write **generic functions** that can work with any functor.
416
+
417
+ Remember that in Haskell, the type of `fmap` can be written as:
418
+
419
+ ```haskell
420
+ fmap :: Functor f => (a -> b) -> (f a -> f b)
421
+ ```
422
+
423
+ Translating to Python, we get:
424
+
425
+ ```python
426
+ def fmap(func: Callable[[A], B]) -> Callable[[Functor[A]], Functor[B]]
427
+ ```
428
+
429
+ This means that `fmap`:
430
+
431
+ - Takes an **ordinary function** `Callable[[A], B]` as input.
432
+ - Outputs a function that:
433
+ - Takes a **functor** of type `Functor[A]` as input.
434
+ - Outputs a **functor** of type `Functor[B]`.
435
+
436
+ We can implement a similar idea in Python:
437
+
438
+ ```python
439
+ fmap = lambda f, functor: functor.__class__.fmap(f, functor)
440
+ inc = lambda functor: fmap(lambda x: x + 1, functor)
441
+ ```
442
+
443
+ - **`fmap`**: Lifts an ordinary function (`f`) to the functor world, allowing the function to operate on the wrapped value inside the functor.
444
+ - **`inc`**: A specific instance of `fmap` that operates on any functor. It takes a functor, applies the function `lambda x: x + 1` to every value inside it, and returns a new functor with the updated values.
445
+
446
+ Thus, **`fmap`** transforms an ordinary function into a **function that operates on functors**, and **`inc`** is a specific case where it increments the value inside the functor.
447
+
448
+ ### Applying the `inc` Function to Various Functors
449
+
450
+ You can now apply `inc` to any functor like `Wrapper`, `List`, or `RoseTree`:
451
+
452
+ ```python
453
+ # Applying `inc` to a Wrapper
454
+ wrapper = Wrapper(5)
455
+ inc(wrapper) # Wrapper(value=6)
456
+
457
+ # Applying `inc` to a List
458
+ list_wrapper = List([1, 2, 3])
459
+ inc(list_wrapper) # List(value=[2, 3, 4])
460
+
461
+ # Applying `inc` to a RoseTree
462
+ tree = RoseTree(1, [RoseTree(2, []), RoseTree(3, [])])
463
+ inc(tree) # RoseTree(value=2, children=[RoseTree(value=3, children=[]), RoseTree(value=4, children=[])])
464
+ ```
465
+
466
+ > Try using `fmap` in the cell below.
467
+ """
468
+ )
469
+ return
470
+
471
+
472
+ @app.cell
473
+ def _(flist, pp, rosetree, wrapper):
474
+ fmap = lambda f, functor: functor.__class__.fmap(f, functor)
475
+ inc = lambda functor: fmap(lambda x: x + 1, functor)
476
+
477
+ pp(inc(wrapper))
478
+ pp(inc(flist))
479
+ pp(inc(rosetree))
480
+ return fmap, inc
481
+
482
+
483
+ @app.cell(hide_code=True)
484
+ def _(mo):
485
+ mo.md(
486
+ """
487
+ ## Functor laws
488
+
489
+ In addition to providing a function `fmap` of the specified type, functors are also required to satisfy two equational laws:
490
+
491
+ ```haskell
492
+ fmap id = id -- fmap preserves identity
493
+ fmap (g . h) = fmap g . fmap h -- fmap distributes over composition
494
+ ```
495
+
496
+ 1. `fmap` should preserve the **identity function**, in the sense that applying `fmap` to this function returns the same function as the result.
497
+ 2. `fmap` should also preserve **function composition**. Applying two composed functions `g` and `h` to a functor via `fmap` should give the same result as first applying `fmap` to `g` and then applying `fmap` to `h`.
498
+
499
+ /// admonition |
500
+ - Any `Functor` instance satisfying the first law `(fmap id = id)` will automatically satisfy the [second law](https://github.com/quchen/articles/blob/master/second_functor_law.mo) as well.
501
+ ///
502
+
503
+ ### Functor Law Verification
504
+
505
+ We can define `id` and `compose` in `Python` as below:
506
+
507
+ ```python
508
+ id = lambda x: x
509
+ compose = lambda f, g: lambda x: f(g(x))
510
+ ```
511
+
512
+ We can add a helper function `check_functor_law` to verify that an instance satisfies the functor laws.
513
+
514
+ ```Python
515
+ check_functor_law = lambda functor: repr(fmap(id, functor)) == repr(functor)
516
+ ```
517
+
518
+ We can verify the functor we've defined.
519
+ """
520
+ )
521
+ return
522
+
523
+
524
+ @app.cell
525
+ def _():
526
+ id = lambda x: x
527
+ compose = lambda f, g: lambda x: f(g(x))
528
+ return compose, id
529
+
530
+
531
+ @app.cell
532
+ def _(fmap, id):
533
+ check_functor_law = lambda functor: repr(fmap(id, functor)) == repr(functor)
534
+ return (check_functor_law,)
535
+
536
+
537
+ @app.cell
538
+ def _(check_functor_law, flist, pp, rosetree, wrapper):
539
+ for functor in (wrapper, flist, rosetree):
540
+ pp(check_functor_law(functor))
541
+ return (functor,)
542
+
543
+
544
+ @app.cell(hide_code=True)
545
+ def _(mo):
546
+ mo.md(
547
+ """
548
+ And here is an `EvilFunctor`. We can verify it's not a valid `Functor`.
549
+
550
+ ```python
551
+ @dataclass
552
+ class EvilFunctor(Functor, Generic[A]):
553
+ value: list[A]
554
+
555
+ @classmethod
556
+ def fmap(cls, f: Callable[[A], B], a: "EvilFunctor[A]") -> "EvilFunctor[B]":
557
+ return (
558
+ cls([a.value[0]] * 2 + list(map(f, a.value[1:])))
559
+ if a.value
560
+ else []
561
+ )
562
+ ```
563
+ """
564
+ )
565
+ return
566
+
567
+
568
+ @app.cell
569
+ def _(A, B, Callable, Functor, Generic, check_functor_law, dataclass, pp):
570
+ @dataclass
571
+ class EvilFunctor(Functor, Generic[A]):
572
+ value: list[A]
573
+
574
+ @classmethod
575
+ def fmap(
576
+ cls, f: Callable[[A], B], a: "EvilFunctor[A]"
577
+ ) -> "EvilFunctor[B]":
578
+ return (
579
+ cls([a.value[0]] * 2 + [f(x) for x in a.value[1:]])
580
+ if a.value
581
+ else []
582
+ )
583
+
584
+
585
+ pp(check_functor_law(EvilFunctor([1, 2, 3, 4])))
586
+ return (EvilFunctor,)
587
+
588
+
589
+ @app.cell(hide_code=True)
590
+ def _(mo):
591
+ mo.md(
592
+ """
593
+ ## Final definition of Functor
594
+
595
+ We can now draft the final definition of `Functor` with some utility functions.
596
+
597
+ ```Python
598
+ @classmethod
599
+ @abstractmethod
600
+ def fmap(cls, f: Callable[[A], B], a: "Functor[A]") -> "Functor[B]":
601
+ return NotImplementedError
602
+
603
+ @classmethod
604
+ def const_fmap(cls, a: "Functor[A]", b: B) -> "Functor[B]":
605
+ return cls.fmap(lambda _: b, a)
606
+
607
+ @classmethod
608
+ def void(cls, a: "Functor[A]") -> "Functor[None]":
609
+ return cls.const_fmap(a, None)
610
+ ```
611
+ """
612
+ )
613
+ return
614
+
615
+
616
+ @app.cell(hide_code=True)
617
+ def _(A, ABC, B, Callable, Generic, abstractmethod, dataclass, mo):
618
+ @dataclass
619
+ class Functor(ABC, Generic[A]):
620
+ """
621
+ ### Doc: Functor
622
+
623
+ A generic interface for types that support mapping over their values.
624
+
625
+ **Methods:**
626
+
627
+ - `fmap(f: Callable[[A], B], a: Functor[A]) -> Functor[B]`
628
+ Abstract method to apply a function to all values inside a functor.
629
+
630
+ - `const_fmap(a: "Functor[A]", b: B) -> Functor[B]`
631
+ Replaces all values inside a functor with a constant `b`, preserving the original structure.
632
+
633
+ - `void(a: "Functor[A]") -> Functor[None]`
634
+ Equivalent to `const_fmap(a, None)`, transforming all values in a functor into `None`.
635
+ """
636
+
637
+ @classmethod
638
+ @abstractmethod
639
+ def fmap(cls, f: Callable[[A], B], a: "Functor[A]") -> "Functor[B]":
640
+ return NotImplementedError
641
+
642
+ @classmethod
643
+ def const_fmap(cls, a: "Functor[A]", b: B) -> "Functor[B]":
644
+ return cls.fmap(lambda _: b, a)
645
+
646
+ @classmethod
647
+ def void(cls, a: "Functor[A]") -> "Functor[None]":
648
+ return cls.const_fmap(a, None)
649
+
650
+
651
+ mo.md(Functor.__doc__)
652
+ return (Functor,)
653
+
654
+
655
+ @app.cell(hide_code=True)
656
+ def _(mo):
657
+ mo.md("""> Try with utility functions in the cell below""")
658
+ return
659
+
660
+
661
+ @app.cell
662
+ def _(List, RoseTree, flist, pp, rosetree):
663
+ pp(RoseTree.const_fmap(rosetree, "λ"))
664
+ pp(RoseTree.void(rosetree))
665
+ pp(List.const_fmap(flist, "λ"))
666
+ pp(List.void(flist))
667
+ return
668
+
669
+
670
+ @app.cell(hide_code=True)
671
+ def _(mo):
672
+ mo.md(
673
+ """
674
+ ## Functors for Non-Iterable Types
675
+
676
+ In the previous examples, we implemented functors for **iterables**, like `List` and `RoseTree`, which are inherently **iterable types**. This is a natural fit for functors, as iterables can be mapped over.
677
+
678
+ However, **functors are not limited to iterables**. There are cases where we want to apply the concept of functors to types that are not inherently iterable, such as types that represent optional values, computations, or other data structures.
679
+
680
+ ### The Maybe Functor
681
+
682
+ One example is the **`Maybe`** type from Haskell, which is used to represent computations that can either result in a value or no value (`Nothing`).
683
+
684
+ We can define the `Maybe` functor as below:
685
+
686
+ ```python
687
+ @dataclass
688
+ class Maybe(Functor, Generic[A]):
689
+ value: None | A
690
+
691
+ @classmethod
692
+ def fmap(cls, f: Callable[[A], B], a: "Maybe[A]") -> "Maybe[B]":
693
+ return (
694
+ cls(None) if a.value is None else cls(f(a.value))
695
+ )
696
+
697
+ def __repr__(self):
698
+ return "Nothing" if self.value is None else repr(self.value)
699
+ ```
700
+ """
701
+ )
702
+ return
703
+
704
+
705
+ @app.cell
706
+ def _(A, B, Callable, Functor, Generic, dataclass):
707
+ @dataclass
708
+ class Maybe(Functor, Generic[A]):
709
+ value: None | A
710
+
711
+ @classmethod
712
+ def fmap(cls, f: Callable[[A], B], a: "Maybe[A]") -> "Maybe[B]":
713
+ return cls(None) if a.value is None else cls(f(a.value))
714
+
715
+ def __repr__(self):
716
+ return "Nothing" if self.value is None else repr(self.value)
717
+ return (Maybe,)
718
+
719
+
720
+ @app.cell(hide_code=True)
721
+ def _(mo):
722
+ mo.md(
723
+ """
724
+ **`Maybe`** is a functor that can either hold a value or be `Nothing` (equivalent to `None` in Python). The `fmap` method applies a function to the value inside the functor, if it exists. If the value is `None` (representing `Nothing`), `fmap` simply returns `None`.
725
+
726
+ By using `Maybe` as a functor, we gain the ability to apply transformations (`fmap`) to potentially absent values, without having to explicitly handle the `None` case every time.
727
+
728
+ > Try using `Maybe` in the cell below.
729
+ """
730
+ )
731
+ return
732
+
733
+
734
+ @app.cell
735
+ def _(Maybe, pp):
736
+ mint = Maybe(1)
737
+ mnone = Maybe(None)
738
+
739
+ pp(Maybe.fmap(lambda x: x + 1, mint))
740
+ pp(Maybe.fmap(lambda x: x + 1, mnone))
741
+ return mint, mnone
742
+
743
+
744
+ @app.cell(hide_code=True)
745
+ def _(mo):
746
+ mo.md(
747
+ """
748
+ ## Limitations of Functor
749
+
750
+ Functors abstract the idea of mapping a function over each element of a structure. Suppose now that we wish to generalise this idea to allow functions with any number of arguments to be mapped, rather than being restricted to functions with a single argument. More precisely, suppose that we wish to define a hierarchy of `fmap` functions with the following types:
751
+
752
+ ```haskell
753
+ fmap0 :: a -> f a
754
+
755
+ fmap1 :: (a -> b) -> f a -> f b
756
+
757
+ fmap2 :: (a -> b -> c) -> f a -> f b -> f c
758
+
759
+ fmap3 :: (a -> b -> c -> d) -> f a -> f b -> f c -> f d
760
+ ```
761
+
762
+ And we have to declare a special version of the functor class for each case.
763
+
764
+ We will learn how to resolve this problem in the next notebook on `Applicatives`.
765
+ """
766
+ )
767
+ return
768
+
769
+
770
+ @app.cell(hide_code=True)
771
+ def _(mo):
772
+ mo.md(
773
+ """
774
+ # Introduction to Categories
775
+
776
+ A [category](https://en.wikibooks.org/wiki/Haskell/Category_theory#Introduction_to_categories) is, in essence, a simple collection. It has three components:
777
+
778
+ - A collection of **objects**.
779
+ - A collection of **morphisms**, each of which ties two objects (a _source object_ and a _target object_) together. If $f$ is a morphism with source object $C$ and target object $B$, we write $f : C → B$.
780
+ - A notion of **composition** of these morphisms. If $g : A → B$ and $f : B → C$ are two morphisms, they can be composed, resulting in a morphism $f ∘ g : A → C$.
781
+
782
+ ## Category laws
783
+
784
+ There are three laws that categories need to follow.
785
+
786
+ 1. The composition of morphisms needs to be **associative**. Symbolically, $f ∘ (g ∘ h) = (f ∘ g) ∘ h$
787
+
788
+ - Morphisms are applied right to left, so with $f ∘ g$ first $g$ is applied, then $f$.
789
+
790
+ 2. The category needs to be **closed** under the composition operation. So if $f : B → C$ and $g : A → B$, then there must be some morphism $h : A → C$ in the category such that $h = f ∘ g$.
791
+
792
+ 3. Given a category $C$ there needs to be for every object $A$ an **identity** morphism, $id_A : A → A$ that is an identity of composition with other morphisms. Put precisely, for every morphism $g : A → B$: $g ∘ id_A = id_B ∘ g = g$
793
+
794
+ /// attention | The definition of a category does not define:
795
+
796
+ - what `∘` is,
797
+ - what `id` is, or
798
+ - what `f`, `g`, and `h` might be.
799
+
800
+ Instead, category theory leaves it up to us to discover what they might be.
801
+ ///
802
+ """
803
+ )
804
+ return
805
+
806
+
807
+ @app.cell(hide_code=True)
808
+ def _(mo):
809
+ mo.md(
810
+ """
811
+ ## The Python category
812
+
813
+ The main category we'll be concerning ourselves with in this part is the Python category, or we can give it a shorter name: `Py`. `Py` treats Python types as objects and Python functions as morphisms. A function `def f(a: A) -> B` for types A and B is a morphism in Python.
814
+
815
+ Remember that we defined the `id` and `compose` function above as:
816
+
817
+ ```Python
818
+ def id(x: Generic[A]) -> Generic[A]:
819
+ return x
820
+
821
+ def compose(f: Callable[[B], C], g: Callable[[A], B]) -> Callable[[A], C]:
822
+ return lambda x: f(g(x))
823
+ ```
824
+
825
+ We can check second law easily.
826
+
827
+ For the first law, we have:
828
+
829
+ ```python
830
+ # compose(f, g) = lambda x: f(g(x))
831
+ f ∘ (g ∘ h)
832
+ = compose(f, compose(g, h))
833
+ = lambda x: f(compose(g, h)(x))
834
+ = lambda x: f(lambda y: g(h(y))(x))
835
+ = lambda x: f(g(h(x)))
836
+
837
+ (f ∘ g) ∘ h
838
+ = compose(compose(f, g), h)
839
+ = lambda x: compose(f, g)(h(x))
840
+ = lambda x: lambda y: f(g(y))(h(x))
841
+ = lambda x: f(g(h(x)))
842
+ ```
843
+
844
+ For the third law, we have:
845
+
846
+ ```python
847
+ g ∘ id_A
848
+ = compose(g: Callable[[a], b], id: Callable[[a], a]) -> Callable[[a], b]
849
+ = lambda x: g(id(x))
850
+ = lambda x: g(x) # id(x) = x
851
+ = g
852
+ ```
853
+ the similar proof can be applied to $id_B ∘ g =g$.
854
+
855
+ Thus `Py` is a valid category.
856
+ """
857
+ )
858
+ return
859
+
860
+
861
+ @app.cell(hide_code=True)
862
+ def _(mo):
863
+ mo.md(
864
+ """
865
+ # Functors, again
866
+
867
+ A functor is essentially a transformation between categories, so given categories $C$ and $D$, a functor $F : C → D$:
868
+
869
+ - Maps any object $A$ in $C$ to $F ( A )$, in $D$.
870
+ - Maps morphisms $f : A → B$ in $C$ to $F ( f ) : F ( A ) → F ( B )$ in $D$.
871
+
872
+ /// admonition |
873
+
874
+ Endofunctors are functors from a category to itself.
875
+
876
+ ///
877
+ """
878
+ )
879
+ return
880
+
881
+
882
+ @app.cell(hide_code=True)
883
+ def _(mo):
884
+ mo.md(
885
+ """
886
+ ## Functors on the category of Python
887
+
888
+ Remember that a functor has two parts: it maps objects in one category to objects in another and morphisms in the first category to morphisms in the second.
889
+
890
+ Functors in Python are from `Py` to `func`, where `func` is the subcategory of `Py` defined on just that functor's types. E.g. the RoseTree functor goes from `Py` to `RoseTree`, where `RoseTree` is the category containing only RoseTree types, that is, `RoseTree[T]` for any type `T`. The morphisms in `RoseTree` are functions defined on RoseTree types, that is, functions `Callable[[RoseTree[T]], RoseTree[U]]` for types `T`, `U`.
891
+
892
+ Recall the definition of `Functor`:
893
+
894
+ ```Python
895
+ @dataclass
896
+ class Functor(ABC, Generic[A])
897
+ ```
898
+
899
+ And RoseTree:
900
+
901
+ ```Python
902
+ @dataclass
903
+ class RoseTree(Functor, Generic[A])
904
+ ```
905
+
906
+ **Here's the key part:** the _type constructor_ `RoseTree` takes any type `T` to a new type, `RoseTree[T]`. Also, `fmap` restricted to `RoseTree` types takes a function `Callable[[A], B]` to a function `Callable[[RoseTree[A]], RoseTree[B]]`.
907
+
908
+ But that's it. We've defined two parts, something that takes objects in `Py` to objects in another category (that of `RoseTree` types and functions defined on `RoseTree` types), and something that takes morphisms in `Py` to morphisms in this category. So `RoseTree` is a functor.
909
+
910
+ To sum up:
911
+
912
+ - We work in the category **Py** and its subcategories.
913
+ - **Objects** are types (e.g., `int`, `str`, `list`).
914
+ - **Morphisms** are functions (`Callable[[A], B]`).
915
+ - **Things that take a type and return another type** are type constructors (`RoseTree[T]`).
916
+ - **Things that take a function and return another function** are higher-order functions (`Callable[[Callable[[A], B]], Callable[[C], D]]`).
917
+ - **Abstract base classes (ABC)** and duck typing provide a way to express polymorphism, capturing the idea that in category theory, structures are often defined over multiple objects at once.
918
+ """
919
+ )
920
+ return
921
+
922
+
923
+ @app.cell(hide_code=True)
924
+ def _(mo):
925
+ mo.md(
926
+ """
927
+ ## Functor laws, again
928
+
929
+ Once again there are a few axioms that functors have to obey.
930
+
931
+ 1. Given an identity morphism $id_A$ on an object $A$, $F ( id_A )$ must be the identity morphism on $F ( A )$, i.e.: ${\displaystyle F(\operatorname {id} _{A})=\operatorname {id} _{F(A)}}$
932
+ 2. Functors must distribute over morphism composition, i.e. ${\displaystyle F(f\circ g)=F(f)\circ F(g)}$
933
+ """
934
+ )
935
+ return
936
+
937
+
938
+ @app.cell(hide_code=True)
939
+ def _(mo):
940
+ mo.md(
941
+ """
942
+ Remember that we defined the `fmap`, `id` and `compose` as
943
+ ```python
944
+ fmap = lambda f, functor: functor.__class__.fmap(f, functor)
945
+ id = lambda x: x
946
+ compose = lambda f, g: lambda x: f(g(x))
947
+ ```
948
+
949
+ Let's prove that `fmap` is a functor.
950
+
951
+ First, let's define a `Category` for a specific `Functor`. We choose to define the `Category` for the `Wrapper` as `WrapperCategory` here for simplicity, but remember that `Wrapper` can be any `Functor`(i.e. `List`, `RoseTree`, `Maybe` and more):
952
+
953
+ **Notice that** in this case, we can actually view `fmap` as:
954
+ ```python
955
+ fmap = lambda f, functor: functor.fmap(f, functor)
956
+ ```
957
+
958
+ We define `WrapperCategory` as:
959
+
960
+ ```python
961
+ @dataclass
962
+ class WrapperCategory:
963
+ @staticmethod
964
+ def id(wrapper: Wrapper[A]) -> Wrapper[A]:
965
+ return Wrapper(wrapper.value)
966
+
967
+ @staticmethod
968
+ def compose(
969
+ f: Callable[[Wrapper[B]], Wrapper[C]],
970
+ g: Callable[[Wrapper[A]], Wrapper[B]],
971
+ wrapper: Wrapper[A]
972
+ ) -> Callable[[Wrapper[A]], Wrapper[C]]:
973
+ return f(g(Wrapper(wrapper.value)))
974
+ ```
975
+
976
+ And `Wrapper` is:
977
+
978
+ ```Python
979
+ @dataclass
980
+ class Wrapper(Functor, Generic[A]):
981
+ value: A
982
+
983
+ @classmethod
984
+ def fmap(cls, f: Callable[[A], B], a: "Wrapper[A]") -> "Wrapper[B]":
985
+ return Wrapper(f(a.value))
986
+ ```
987
+ """
988
+ )
989
+ return
990
+
991
+
992
+ @app.cell(hide_code=True)
993
+ def _(mo):
994
+ mo.md(
995
+ """
996
+ We can prove that:
997
+
998
+ ```python
999
+ fmap(id, wrapper)
1000
+ = Wrapper.fmap(id, wrapper)
1001
+ = Wrapper(id(wrapper.value))
1002
+ = Wrapper(wrapper.value)
1003
+ = WrapperCategory.id(wrapper)
1004
+ ```
1005
+ and:
1006
+ ```python
1007
+ fmap(compose(f, g), wrapper)
1008
+ = Wrapper.fmap(compose(f, g), wrapper)
1009
+ = Wrapper(compose(f, g)(wrapper.value))
1010
+ = Wrapper(f(g(wrapper.value)))
1011
+
1012
+ WrapperCategory.compose(fmap(f, wrapper), fmap(g, wrapper), wrapper)
1013
+ = fmap(f, wrapper)(fmap(g, wrapper)(wrapper))
1014
+ = fmap(f, wrapper)(Wrapper.fmap(g, wrapper))
1015
+ = fmap(f, wrapper)(Wrapper(g(wrapper.value)))
1016
+ = Wrapper.fmap(f, Wrapper(g(wrapper.value)))
1017
+ = Wrapper(f(Wrapper(g(wrapper.value)).value))
1018
+ = Wrapper(f(g(wrapper.value))) # Wrapper(g(wrapper.value)).value = g(wrapper.value)
1019
+ ```
1020
+
1021
+ So our `Wrapper` is a valid `Functor`.
1022
+
1023
+ > Try validating functor laws for `Wrapper` below.
1024
+ """
1025
+ )
1026
+ return
1027
+
1028
+
1029
+ @app.cell
1030
+ def _(A, B, C, Callable, Wrapper, dataclass):
1031
+ @dataclass
1032
+ class WrapperCategory:
1033
+ @staticmethod
1034
+ def id(wrapper: Wrapper[A]) -> Wrapper[A]:
1035
+ return Wrapper(wrapper.value)
1036
+
1037
+ @staticmethod
1038
+ def compose(
1039
+ f: Callable[[Wrapper[B]], Wrapper[C]],
1040
+ g: Callable[[Wrapper[A]], Wrapper[B]],
1041
+ wrapper: Wrapper[A],
1042
+ ) -> Callable[[Wrapper[A]], Wrapper[C]]:
1043
+ return f(g(Wrapper(wrapper.value)))
1044
+ return (WrapperCategory,)
1045
+
1046
+
1047
+ @app.cell
1048
+ def _(WrapperCategory, fmap, id, pp, wrapper):
1049
+ pp(fmap(id, wrapper) == WrapperCategory.id(wrapper))
1050
+ return
1051
+
1052
+
1053
+ @app.cell(hide_code=True)
1054
+ def _(mo):
1055
+ mo.md(
1056
+ """
1057
+ ## Length as a Functor
1058
+
1059
+ Remember that a functor is a transformation between two categories. It is not only limited to a functor from `Py` to `func`, but also includes transformations between other mathematical structures.
1060
+
1061
+ Let’s prove that **`length`** can be viewed as a functor. Specifically, we will demonstrate that `length` is a functor from the **category of list concatenation** to the **category of integer addition**.
1062
+
1063
+ ### Category of List Concatenation
1064
+
1065
+ First, let’s define the category of list concatenation:
1066
+
1067
+ ```python
1068
+ @dataclass
1069
+ class ListConcatenation(Generic[A]):
1070
+ value: list[A]
1071
+
1072
+ @staticmethod
1073
+ def id() -> "ListConcatenation[A]":
1074
+ return ListConcatenation([])
1075
+
1076
+ @staticmethod
1077
+ def compose(
1078
+ this: "ListConcatenation[A]", other: "ListConcatenation[A]"
1079
+ ) -> "ListConcatenation[a]":
1080
+ return ListConcatenation(this.value + other.value)
1081
+ ```
1082
+ """
1083
+ )
1084
+ return
1085
+
1086
+
1087
+ @app.cell
1088
+ def _(A, Generic, dataclass):
1089
+ @dataclass
1090
+ class ListConcatenation(Generic[A]):
1091
+ value: list[A]
1092
+
1093
+ @staticmethod
1094
+ def id() -> "ListConcatenation[A]":
1095
+ return ListConcatenation([])
1096
+
1097
+ @staticmethod
1098
+ def compose(
1099
+ this: "ListConcatenation[A]", other: "ListConcatenation[A]"
1100
+ ) -> "ListConcatenation[a]":
1101
+ return ListConcatenation(this.value + other.value)
1102
+ return (ListConcatenation,)
1103
+
1104
+
1105
+ @app.cell(hide_code=True)
1106
+ def _(mo):
1107
+ mo.md(
1108
+ """
1109
+ - **Identity**: The identity element is an empty list (`ListConcatenation([])`).
1110
+ - **Composition**: The composition of two lists is their concatenation (`this.value + other.value`).
1111
+ """
1112
+ )
1113
+ return
1114
+
1115
+
1116
+ @app.cell(hide_code=True)
1117
+ def _(mo):
1118
+ mo.md(
1119
+ """
1120
+ ### Category of Integer Addition
1121
+
1122
+ Now, let's define the category of integer addition:
1123
+
1124
+ ```python
1125
+ @dataclass
1126
+ class IntAddition:
1127
+ value: int
1128
+
1129
+ @staticmethod
1130
+ def id() -> "IntAddition":
1131
+ return IntAddition(0)
1132
+
1133
+ @staticmethod
1134
+ def compose(this: "IntAddition", other: "IntAddition") -> "IntAddition":
1135
+ return IntAddition(this.value + other.value)
1136
+ ```
1137
+ """
1138
+ )
1139
+ return
1140
+
1141
+
1142
+ @app.cell
1143
+ def _(dataclass):
1144
+ @dataclass
1145
+ class IntAddition:
1146
+ value: int
1147
+
1148
+ @staticmethod
1149
+ def id() -> "IntAddition":
1150
+ return IntAddition(0)
1151
+
1152
+ @staticmethod
1153
+ def compose(this: "IntAddition", other: "IntAddition") -> "IntAddition":
1154
+ return IntAddition(this.value + other.value)
1155
+ return (IntAddition,)
1156
+
1157
+
1158
+ @app.cell(hide_code=True)
1159
+ def _(mo):
1160
+ mo.md(
1161
+ """
1162
+ - **Identity**: The identity element is `IntAddition(0)` (the additive identity).
1163
+ - **Composition**: The composition of two integers is their sum (`this.value + other.value`).
1164
+ """
1165
+ )
1166
+ return
1167
+
1168
+
1169
+ @app.cell(hide_code=True)
1170
+ def _(mo):
1171
+ mo.md(
1172
+ """
1173
+ ### Defining the Length Functor
1174
+
1175
+ We now define the `length` function as a functor, mapping from the category of list concatenation to the category of integer addition:
1176
+
1177
+ ```python
1178
+ length = lambda l: IntAddition(len(l.value))
1179
+ ```
1180
+ """
1181
+ )
1182
+ return
1183
+
1184
+
1185
+ @app.cell(hide_code=True)
1186
+ def _(IntAddition):
1187
+ length = lambda l: IntAddition(len(l.value))
1188
+ return (length,)
1189
+
1190
+
1191
+ @app.cell(hide_code=True)
1192
+ def _(mo):
1193
+ mo.md("""This function takes an instance of `ListConcatenation`, computes its length, and returns an `IntAddition` instance with the computed length.""")
1194
+ return
1195
+
1196
+
1197
+ @app.cell(hide_code=True)
1198
+ def _(mo):
1199
+ mo.md(
1200
+ """
1201
+ ### Verifying Functor Laws
1202
+
1203
+ Now, let’s verify that `length` satisfies the two functor laws.
1204
+
1205
+ #### 1. **Identity Law**:
1206
+ The identity law states that applying the functor to the identity element of one category should give the identity element of the other category.
1207
+
1208
+ ```python
1209
+ > length(ListConcatenation.id()) == IntAddition.id()
1210
+ True
1211
+ ```
1212
+ """
1213
+ )
1214
+ return
1215
+
1216
+
1217
+ @app.cell(hide_code=True)
1218
+ def _(mo):
1219
+ mo.md("""This ensures that the length of an empty list (identity in the `ListConcatenation` category) is `0` (identity in the `IntAddition` category).""")
1220
+ return
1221
+
1222
+
1223
+ @app.cell(hide_code=True)
1224
+ def _(mo):
1225
+ mo.md(
1226
+ """
1227
+ #### 2. **Composition Law**:
1228
+ The composition law states that the functor should preserve composition. Applying the functor to a composed element should be the same as composing the functor applied to the individual elements.
1229
+
1230
+ ```python
1231
+ > lista = ListConcatenation([1, 2])
1232
+ > listb = ListConcatenation([3, 4])
1233
+ > length(ListConcatenation.compose(lista, listb)) == IntAddition.compose(
1234
+ > length(lista), length(listb)
1235
+ > )
1236
+ True
1237
+ ```
1238
+ """
1239
+ )
1240
+ return
1241
+
1242
+
1243
+ @app.cell(hide_code=True)
1244
+ def _(mo):
1245
+ mo.md("""This ensures that the length of the concatenation of two lists is the same as the sum of the lengths of the individual lists.""")
1246
+ return
1247
+
1248
+
1249
+ @app.cell
1250
+ def _(IntAddition, ListConcatenation, length, pp):
1251
+ pp(length(ListConcatenation.id()) == IntAddition.id())
1252
+ lista = ListConcatenation([1, 2])
1253
+ listb = ListConcatenation([3, 4])
1254
+ pp(
1255
+ length(ListConcatenation.compose(lista, listb))
1256
+ == IntAddition.compose(length(lista), length(listb))
1257
+ )
1258
+ return lista, listb
1259
+
1260
+
1261
+ @app.cell(hide_code=True)
1262
+ def _(mo):
1263
+ mo.md(
1264
+ """
1265
+ # Further reading
1266
+
1267
+ - [The Trivial Monad](http://blog.sigfpe.com/2007/04/trivial-monad.html)
1268
+ - [Haskellwiki. Category Theory](https://en.wikibooks.org/wiki/Haskell/Category_theory)
1269
+ - [Haskellforall. The Category Design Pattern](https://www.haskellforall.com/2012/08/the-category-design-pattern.html)
1270
+ - [Haskellforall. The Functor Design Pattern](https://www.haskellforall.com/2012/09/the-functor-design-pattern.html)
1271
+
1272
+ /// attention | ATTENTION
1273
+ The functor design pattern doesn't work at all if you aren't using categories in the first place. This is why you should structure your tools using the compositional category design pattern so that you can take advantage of functors to easily mix your tools together.
1274
+ ///
1275
+
1276
+ - [Haskellwiki. Functor](https://wiki.haskell.org/index.php?title=Functor)
1277
+ - [Haskellwiki. Typeclassopedia#Functor](https://wiki.haskell.org/index.php?title=Typeclassopedia#Functor)
1278
+ - [Haskellwiki. Typeclassopedia#Category](https://wiki.haskell.org/index.php?title=Typeclassopedia#Category)
1279
+ """
1280
+ )
1281
+ return
1282
+
1283
+
1284
+ @app.cell(hide_code=True)
1285
+ def _():
1286
+ import marimo as mo
1287
+ return (mo,)
1288
+
1289
+
1290
+ @app.cell(hide_code=True)
1291
+ def _():
1292
+ from abc import abstractmethod, ABC
1293
+ return ABC, abstractmethod
1294
+
1295
+
1296
+ @app.cell(hide_code=True)
1297
+ def _():
1298
+ from dataclasses import dataclass
1299
+ from typing import Callable, Generic, TypeVar
1300
+ from pprint import pp
1301
+ return Callable, Generic, TypeVar, dataclass, pp
1302
+
1303
+
1304
+ @app.cell(hide_code=True)
1305
+ def _(TypeVar):
1306
+ A = TypeVar("A")
1307
+ B = TypeVar("B")
1308
+ C = TypeVar("C")
1309
+ return A, B, C
1310
+
1311
+
1312
+ if __name__ == "__main__":
1313
+ app.run()
functional_programming/CHANGELOG.md ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Changelog of the functional-programming course
2
+
3
+ ## 2025-03-11
4
+
5
+ * Demo version of notebook `05_functors.py`
6
+
7
+ ## 2025-03-13
8
+
9
+ * `0.1.0` version of notebook `05_functors`
10
+
11
+ Thank [Akshay](https://github.com/akshayka) and [Haleshot](https://github.com/Haleshot) for reviewing
12
+
13
+ ## 2025-03-16
14
+
15
+ + Use uppercased letters for `Generic` types, e.g. `A = TypeVar("A")`
16
+ + Refactor the `Functor` class, changing `fmap` and utility methods to `classmethod`
17
+
18
+ For example:
19
+
20
+ ```python
21
+ @dataclass
22
+ class Wrapper(Functor, Generic[A]):
23
+ value: A
24
+
25
+ @classmethod
26
+ def fmap(cls, f: Callable[[A], B], a: "Wrapper[A]") -> "Wrapper[B]":
27
+ return Wrapper(f(a.value))
28
+
29
+ >>> Wrapper.fmap(lambda x: x + 1, wrapper)
30
+ Wrapper(value=2)
31
+ ```
32
+
33
+ + Move the `check_functor_law` method from `Functor` class to a standard function
34
+ - Rename `ListWrapper` to `List` for simplicity
35
+ - Remove the `Just` class
36
+ + Rewrite proofs
functional_programming/README.md ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Learn Functional Programming
2
+
3
+ _🚧 This collection is a
4
+ [work in progress](https://github.com/marimo-team/learn/issues/51)._
5
+
6
+ This series of marimo notebooks introduces the powerful paradigm of functional
7
+ programming through Python. Taking inspiration from Haskell and Category Theory,
8
+ we'll build a strong foundation in FP concepts that can transform how you
9
+ approach software development.
10
+
11
+ ## What You'll Learn
12
+
13
+ **Using only Python's standard library**, we'll construct functional programming
14
+ concepts from first principles.
15
+
16
+ Topics include:
17
+
18
+ - Recursion and higher-order functions
19
+ - Category theory fundamentals
20
+ - Functors, applicatives, and monads
21
+ - Composable abstractions for robust code
22
+
23
+ ## Timeline & Collaboration
24
+
25
+ I'm currently studying functional programming and Haskell, estimating about 2
26
+ months or even longer to complete this series. The structure may evolve as the
27
+ project develops.
28
+
29
+ If you're interested in collaborating or have questions, please reach out to me
30
+ on Discord (@eugene.hs).
31
+
32
+ **Running notebooks.** To run a notebook locally, use
33
+
34
+ ```bash
35
+ uvx marimo edit <URL>
36
+ ```
37
+
38
+ For example, run the `Functor` tutorial with
39
+
40
+ ```bash
41
+ uvx marimo edit https://github.com/marimo-team/learn/blob/main/Functional_programming/05_functors.py
42
+ ```
43
+
44
+ You can also open notebooks in our online playground by appending `marimo.app/`
45
+ to a notebook's URL:
46
+ [marimo.app/github.com/marimo-team/learn/blob/main/functional_programming/05_functors.py](https://marimo.app/https://github.com/marimo-team/learn/blob/main/functional_programming/05_functors.py).
47
+
48
+ # Description of notebooks
49
+
50
+ Check [here](https://github.com/marimo-team/learn/issues/51) for current series
51
+ structure.
52
+
53
+ | Notebook | Description | References |
54
+ | ----------------------------------------------------------------------------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
55
+ | [05. Category and Functors](https://github.com/marimo-team/learn/blob/main/Functional_programming/05_functors.py) | Learn why `len` is a _Functor_ from `list concatenation` to `integer addition`, how to _lift_ an ordinary function into a _computation context_, and how to write an _adapter_ between two categories. | - [The Trivial Monad](http://blog.sigfpe.com/2007/04/trivial-monad.html) <br> - [Haskellwiki. Category Theory](https://en.wikibooks.org/wiki/Haskell/Category_theory) <br> - [Haskellforall. The Category Design Pattern](https://www.haskellforall.com/2012/08/the-category-design-pattern.html) <br> - [Haskellforall. The Functor Design Pattern](https://www.haskellforall.com/2012/09/the-functor-design-pattern.html) <br> - [Haskellwiki. Functor](https://wiki.haskell.org/index.php?title=Functor) <br> - [Haskellwiki. Typeclassopedia#Functor](https://wiki.haskell.org/index.php?title=Typeclassopedia#Functor) <br> - [Haskellwiki. Typeclassopedia#Category](https://wiki.haskell.org/index.php?title=Typeclassopedia#Category) |
56
+
57
+ **Authors.**
58
+
59
+ Thanks to all our notebook authors!
60
+
61
+ - [métaboulie](https://github.com/metaboulie)
optimization/05_portfolio_optimization.py CHANGED
@@ -47,7 +47,7 @@ def _(mo):
47
  r"""
48
  ## Asset returns and risk
49
 
50
- We will only model investments held for one period. The initial prices are $p_i > 0$. The end of period prices are $p_i^+ >0$. The asset (fractional) returns are $r_i = (p_i^+-p_i)/p_i$. The porfolio (fractional) return is $R = r^Tw$.
51
 
52
  A common model is that $r$ is a random variable with mean ${\bf E}r = \mu$ and covariance ${\bf E{(r-\mu)(r-\mu)^T}} = \Sigma$.
53
  It follows that $R$ is a random variable with ${\bf E}R = \mu^T w$ and ${\bf var}(R) = w^T\Sigma w$. In real-world applications, $\mu$ and $\Sigma$ are estimated from data and models, and $w$ is chosen using a library like CVXPY.
 
47
  r"""
48
  ## Asset returns and risk
49
 
50
+ We will only model investments held for one period. The initial prices are $p_i > 0$. The end of period prices are $p_i^+ >0$. The asset (fractional) returns are $r_i = (p_i^+-p_i)/p_i$. The portfolio (fractional) return is $R = r^Tw$.
51
 
52
  A common model is that $r$ is a random variable with mean ${\bf E}r = \mu$ and covariance ${\bf E{(r-\mu)(r-\mu)^T}} = \Sigma$.
53
  It follows that $R$ is a random variable with ${\bf E}R = \mu^T w$ and ${\bf var}(R) = w^T\Sigma w$. In real-world applications, $\mu$ and $\Sigma$ are estimated from data and models, and $w$ is chosen using a library like CVXPY.
polars/04_basic_operations.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.13"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "polars==1.23.0",
6
+ # ]
7
+ # ///
8
+
9
+ import marimo
10
+
11
+ __generated_with = "0.11.13"
12
+ app = marimo.App(width="medium")
13
+
14
+
15
+ @app.cell
16
+ def _():
17
+ import marimo as mo
18
+ return (mo,)
19
+
20
+
21
+ @app.cell(hide_code=True)
22
+ def _(mo):
23
+ mo.md(
24
+ r"""
25
+ # Basic operations on data
26
+ _By [Joram Mutenge](https://www.udemy.com/user/joram-mutenge/)._
27
+
28
+ In this notebook, you'll learn how to perform arithmetic operations, comparisons, and conditionals on a Polars dataframe. We'll work with a DataFrame that tracks software usage by year, categorized as either Vintage (old) or Modern (new).
29
+ """
30
+ )
31
+ return
32
+
33
+
34
+ @app.cell
35
+ def _():
36
+ import polars as pl
37
+
38
+ df = pl.DataFrame(
39
+ {
40
+ "software": [
41
+ "Lotus-123",
42
+ "WordStar",
43
+ "dBase III",
44
+ "VisiCalc",
45
+ "WinZip",
46
+ "MS-DOS",
47
+ "HyperCard",
48
+ "WordPerfect",
49
+ "Excel",
50
+ "Photoshop",
51
+ "Visual Studio",
52
+ "Slack",
53
+ "Zoom",
54
+ "Notion",
55
+ "Figma",
56
+ "Spotify",
57
+ "VSCode",
58
+ "Docker",
59
+ ],
60
+ "users": [
61
+ 10000,
62
+ 4500,
63
+ 2500,
64
+ 3000,
65
+ 1800,
66
+ 17000,
67
+ 2200,
68
+ 1900,
69
+ 500000,
70
+ 12000000,
71
+ 1500000,
72
+ 3000000,
73
+ 4000000,
74
+ 2000000,
75
+ 2500000,
76
+ 4500000,
77
+ 6000000,
78
+ 3500000,
79
+ ],
80
+ "category": ["Vintage"] * 8 + ["Modern"] * 10,
81
+ "year": [
82
+ 1985,
83
+ 1980,
84
+ 1984,
85
+ 1979,
86
+ 1991,
87
+ 1981,
88
+ 1987,
89
+ 1982,
90
+ 1987,
91
+ 1990,
92
+ 1997,
93
+ 2013,
94
+ 2011,
95
+ 2016,
96
+ 2016,
97
+ 2008,
98
+ 2015,
99
+ 2013,
100
+ ],
101
+ }
102
+ )
103
+
104
+ df
105
+ return df, pl
106
+
107
+
108
+ @app.cell(hide_code=True)
109
+ def _(mo):
110
+ mo.md(
111
+ r"""
112
+ ## Arithmetic
113
+ ### Addition
114
+ Let's add 42 users to each piece of software. This means adding 42 to each value under **users**.
115
+ """
116
+ )
117
+ return
118
+
119
+
120
+ @app.cell
121
+ def _(df, pl):
122
+ df.with_columns(pl.col("users") + 42)
123
+ return
124
+
125
+
126
+ @app.cell(hide_code=True)
127
+ def _(mo):
128
+ mo.md(r"""Another way to perform the above operation is using the built-in function.""")
129
+ return
130
+
131
+
132
+ @app.cell
133
+ def _(df, pl):
134
+ df.with_columns(pl.col("users").add(42))
135
+ return
136
+
137
+
138
+ @app.cell(hide_code=True)
139
+ def _(mo):
140
+ mo.md(
141
+ r"""
142
+ ### Subtraction
143
+ Let's subtract 42 users to each piece of software.
144
+ """
145
+ )
146
+ return
147
+
148
+
149
+ @app.cell
150
+ def _(df, pl):
151
+ df.with_columns(pl.col("users") - 42)
152
+ return
153
+
154
+
155
+ @app.cell(hide_code=True)
156
+ def _(mo):
157
+ mo.md(r"""Alternatively, you could subtract like this:""")
158
+ return
159
+
160
+
161
+ @app.cell
162
+ def _(df, pl):
163
+ df.with_columns(pl.col("users").sub(42))
164
+ return
165
+
166
+
167
+ @app.cell(hide_code=True)
168
+ def _(mo):
169
+ mo.md(
170
+ r"""
171
+ ### Division
172
+ Suppose the **users** values are inflated, we can reduce them by dividing by 1000. Here's how to do it.
173
+ """
174
+ )
175
+ return
176
+
177
+
178
+ @app.cell
179
+ def _(df, pl):
180
+ df.with_columns(pl.col("users") / 1000)
181
+ return
182
+
183
+
184
+ @app.cell(hide_code=True)
185
+ def _(mo):
186
+ mo.md(r"""Or we could do it with a built-in expression.""")
187
+ return
188
+
189
+
190
+ @app.cell
191
+ def _(df, pl):
192
+ df.with_columns(pl.col("users").truediv(1000))
193
+ return
194
+
195
+
196
+ @app.cell(hide_code=True)
197
+ def _(mo):
198
+ mo.md(r"""If we didn't care about the remainder after division (i.e remove numbers after decimal point) we could do it like this.""")
199
+ return
200
+
201
+
202
+ @app.cell
203
+ def _(df, pl):
204
+ df.with_columns(pl.col("users").floordiv(1000))
205
+ return
206
+
207
+
208
+ @app.cell(hide_code=True)
209
+ def _(mo):
210
+ mo.md(
211
+ r"""
212
+ ### Multiplication
213
+ Let's pretend the *user* values are deflated and increase them by multiplying by 100.
214
+ """
215
+ )
216
+ return
217
+
218
+
219
+ @app.cell
220
+ def _(df, pl):
221
+ (df.with_columns(pl.col("users") * 100))
222
+ return
223
+
224
+
225
+ @app.cell(hide_code=True)
226
+ def _(mo):
227
+ mo.md(r"""Polars also has a built-in function for multiplication.""")
228
+ return
229
+
230
+
231
+ @app.cell
232
+ def _(df, pl):
233
+ df.with_columns(pl.col("users").mul(100))
234
+ return
235
+
236
+
237
+ @app.cell(hide_code=True)
238
+ def _(mo):
239
+ mo.md(r"""So far, we've only modified the values in an existing column. Let's create a column **decade** that will represent the years as decades. Thus 1985 will be 1980 and 2008 will be 2000.""")
240
+ return
241
+
242
+
243
+ @app.cell
244
+ def _(df, pl):
245
+ (df.with_columns(decade=pl.col("year").floordiv(10).mul(10)))
246
+ return
247
+
248
+
249
+ @app.cell(hide_code=True)
250
+ def _(mo):
251
+ mo.md(r"""We could create a new column another way as follows:""")
252
+ return
253
+
254
+
255
+ @app.cell
256
+ def _(df, pl):
257
+ df.with_columns((pl.col("year").floordiv(10).mul(10)).alias("decade"))
258
+ return
259
+
260
+
261
+ @app.cell(hide_code=True)
262
+ def _(mo):
263
+ mo.md(
264
+ r"""
265
+ **Tip**
266
+ Polars encounrages you to perform your operations as a chain. This enables you to take advantage of the query optimizer. We'll build upon the above code as a chain.
267
+
268
+ ## Comparison
269
+ ### Equal
270
+ Let's get all the software categorized as Vintage.
271
+ """
272
+ )
273
+ return
274
+
275
+
276
+ @app.cell
277
+ def _(df, pl):
278
+ (
279
+ df.with_columns(decade=pl.col("year").floordiv(10).mul(10))
280
+ .filter(pl.col("category") == "Vintage")
281
+ )
282
+ return
283
+
284
+
285
+ @app.cell(hide_code=True)
286
+ def _(mo):
287
+ mo.md(r"""We could also do a double comparison. VisiCal is the only software that's vintage and in the decade 1970s. Let's perform this comparison operation.""")
288
+ return
289
+
290
+
291
+ @app.cell
292
+ def _(df, pl):
293
+ (
294
+ df.with_columns(decade=pl.col("year").floordiv(10).mul(10))
295
+ .filter(pl.col("category") == "Vintage")
296
+ .filter(pl.col("decade") == 1970)
297
+ )
298
+ return
299
+
300
+
301
+ @app.cell(hide_code=True)
302
+ def _(mo):
303
+ mo.md(
304
+ r"""
305
+ We could also do this comparison in one line, if readability is not a concern
306
+
307
+ **Notice** that we must enclose the two expressions between the `&` with parenthesis.
308
+ """
309
+ )
310
+ return
311
+
312
+
313
+ @app.cell
314
+ def _(df, pl):
315
+ (
316
+ df.with_columns(decade=pl.col("year").floordiv(10).mul(10))
317
+ .filter((pl.col("category") == "Vintage") & (pl.col("decade") == 1970))
318
+ )
319
+ return
320
+
321
+
322
+ @app.cell(hide_code=True)
323
+ def _(mo):
324
+ mo.md(r"""We can also use the built-in function for equal to comparisons.""")
325
+ return
326
+
327
+
328
+ @app.cell
329
+ def _(df, pl):
330
+ (df
331
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
332
+ .filter(pl.col('category').eq('Vintage'))
333
+ )
334
+ return
335
+
336
+
337
+ @app.cell(hide_code=True)
338
+ def _(mo):
339
+ mo.md(
340
+ r"""
341
+ ### Not equal
342
+ We can also compare if something is `not` equal to something. In this case, category is not vintage.
343
+ """
344
+ )
345
+ return
346
+
347
+
348
+ @app.cell
349
+ def _(df, pl):
350
+ (df
351
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
352
+ .filter(pl.col('category') != 'Vintage')
353
+ )
354
+ return
355
+
356
+
357
+ @app.cell(hide_code=True)
358
+ def _(mo):
359
+ mo.md(r"""Or with the built-in function.""")
360
+ return
361
+
362
+
363
+ @app.cell
364
+ def _(df, pl):
365
+ (df
366
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
367
+ .filter(pl.col('category').ne('Vintage'))
368
+ )
369
+ return
370
+
371
+
372
+ @app.cell(hide_code=True)
373
+ def _(mo):
374
+ mo.md(r"""Or if you want to be extra clever, you can use the negation symbol `~` used in logic.""")
375
+ return
376
+
377
+
378
+ @app.cell
379
+ def _(df, pl):
380
+ (df
381
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
382
+ .filter(~pl.col('category').eq('Vintage'))
383
+ )
384
+ return
385
+
386
+
387
+ @app.cell(hide_code=True)
388
+ def _(mo):
389
+ mo.md(
390
+ r"""
391
+ ### Greater than
392
+ Let's get the software where the year is greater than 2008 from the above dataframe.
393
+ """
394
+ )
395
+ return
396
+
397
+
398
+ @app.cell
399
+ def _(df, pl):
400
+ (df
401
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
402
+ .filter(~pl.col('category').eq('Vintage'))
403
+ .filter(pl.col('year') > 2008)
404
+ )
405
+ return
406
+
407
+
408
+ @app.cell(hide_code=True)
409
+ def _(mo):
410
+ mo.md(r"""Or if we wanted the year 2008 to be included, we could use great or equal to.""")
411
+ return
412
+
413
+
414
+ @app.cell
415
+ def _(df, pl):
416
+ (df
417
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
418
+ .filter(~pl.col('category').eq('Vintage'))
419
+ .filter(pl.col('year') >= 2008)
420
+ )
421
+ return
422
+
423
+
424
+ @app.cell(hide_code=True)
425
+ def _(mo):
426
+ mo.md(r"""We could do the previous two operations with built-in functions. Here's with greater than.""")
427
+ return
428
+
429
+
430
+ @app.cell
431
+ def _(df, pl):
432
+ (df
433
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
434
+ .filter(~pl.col('category').eq('Vintage'))
435
+ .filter(pl.col('year').gt(2008))
436
+ )
437
+ return
438
+
439
+
440
+ @app.cell(hide_code=True)
441
+ def _(mo):
442
+ mo.md(r"""And here's with greater or equal to""")
443
+ return
444
+
445
+
446
+ @app.cell
447
+ def _(df, pl):
448
+ (df
449
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
450
+ .filter(~pl.col('category').eq('Vintage'))
451
+ .filter(pl.col('year').ge(2008))
452
+ )
453
+ return
454
+
455
+
456
+ @app.cell(hide_code=True)
457
+ def _(mo):
458
+ mo.md(
459
+ r"""
460
+ **Note**: For "less than", and "less or equal to" you can use the operators `<` or `<=`. Alternatively, you can use built-in functions `lt` or `le` respectively.
461
+
462
+ ### Is between
463
+ Polars also allows us to filter between a range of values. Let's get the modern software were the year is between 2013 and 2016. This is inclusive on both ends (i.e. both years are part of the result).
464
+ """
465
+ )
466
+ return
467
+
468
+
469
+ @app.cell
470
+ def _(df, pl):
471
+ (df
472
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
473
+ .filter(pl.col('category').eq('Modern'))
474
+ .filter(pl.col('year').is_between(2013, 2016))
475
+ )
476
+ return
477
+
478
+
479
+ @app.cell(hide_code=True)
480
+ def _(mo):
481
+ mo.md(
482
+ r"""
483
+ ### Or operator
484
+ If we only want either one of the conditions in the comparison to be met, we could use `|`, which is the `or` operator.
485
+
486
+ Let's get software that is either modern or used in the decade 1980s.
487
+ """
488
+ )
489
+ return
490
+
491
+
492
+ @app.cell
493
+ def _(df, pl):
494
+ (df
495
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
496
+ .filter((pl.col('category') == 'Modern') | (pl.col('decade') == 1980))
497
+ )
498
+ return
499
+
500
+
501
+ @app.cell(hide_code=True)
502
+ def _(mo):
503
+ mo.md(
504
+ r"""
505
+ ## Conditionals
506
+ Polars also allows you create new columns based on a condition. Let's create a column *status* that will indicate if the software is "discontinued" or "in use".
507
+
508
+ Here's a list of products that are no longer in use.
509
+ """
510
+ )
511
+ return
512
+
513
+
514
+ @app.cell
515
+ def _():
516
+ discontinued_list = ['Lotus-123', 'WordStar', 'dBase III', 'VisiCalc', 'MS-DOS', 'HyperCard']
517
+ return (discontinued_list,)
518
+
519
+
520
+ @app.cell(hide_code=True)
521
+ def _(mo):
522
+ mo.md(r"""Here's how we can get a dataframe of the products that are discontinued.""")
523
+ return
524
+
525
+
526
+ @app.cell
527
+ def _(df, discontinued_list, pl):
528
+ (df
529
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
530
+ .filter(pl.col('software').is_in(discontinued_list))
531
+ )
532
+ return
533
+
534
+
535
+ @app.cell(hide_code=True)
536
+ def _(mo):
537
+ mo.md(r"""Now, let's create the **status** column.""")
538
+ return
539
+
540
+
541
+ @app.cell
542
+ def _(df, discontinued_list, pl):
543
+ (df
544
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
545
+ .with_columns(pl.when(pl.col('software').is_in(discontinued_list))
546
+ .then(pl.lit('Discontinued'))
547
+ .otherwise(pl.lit('In use'))
548
+ .alias('status')
549
+ )
550
+ )
551
+ return
552
+
553
+
554
+ @app.cell(hide_code=True)
555
+ def _(mo):
556
+ mo.md(
557
+ r"""
558
+ ## Unique counts
559
+ Sometimes you may want to see only the unique values in a column. Let's check the unique decades we have in our DataFrame.
560
+ """
561
+ )
562
+ return
563
+
564
+
565
+ @app.cell
566
+ def _(df, discontinued_list, pl):
567
+ (df
568
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
569
+ .with_columns(pl.when(pl.col('software').is_in(discontinued_list))
570
+ .then(pl.lit('Discontinued'))
571
+ .otherwise(pl.lit('In use'))
572
+ .alias('status')
573
+ )
574
+ .select('decade').unique()
575
+ )
576
+ return
577
+
578
+
579
+ @app.cell(hide_code=True)
580
+ def _(mo):
581
+ mo.md(r"""Finally, let's find out the number of software used in each decade.""")
582
+ return
583
+
584
+
585
+ @app.cell
586
+ def _(df, discontinued_list, pl):
587
+ (df
588
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
589
+ .with_columns(pl.when(pl.col('software').is_in(discontinued_list))
590
+ .then(pl.lit('Discontinued'))
591
+ .otherwise(pl.lit('In use'))
592
+ .alias('status')
593
+ )
594
+ ['decade'].value_counts()
595
+ )
596
+ return
597
+
598
+
599
+ @app.cell(hide_code=True)
600
+ def _(mo):
601
+ mo.md(r"""We could also rewrite the above code as follows:""")
602
+ return
603
+
604
+
605
+ @app.cell
606
+ def _(df, discontinued_list, pl):
607
+ (df
608
+ .with_columns(decade=pl.col('year').floordiv(10).mul(10))
609
+ .with_columns(pl.when(pl.col('software').is_in(discontinued_list))
610
+ .then(pl.lit('Discontinued'))
611
+ .otherwise(pl.lit('In use'))
612
+ .alias('status')
613
+ )
614
+ .select('decade').to_series().value_counts()
615
+ )
616
+ return
617
+
618
+
619
+ @app.cell(hide_code=True)
620
+ def _(mo):
621
+ mo.md(r"""Hopefully, we've picked your interest to try out Polars the next time you analyze your data.""")
622
+ return
623
+
624
+
625
+ @app.cell
626
+ def _():
627
+ return
628
+
629
+
630
+ if __name__ == "__main__":
631
+ app.run()
polars/10_strings.py ADDED
@@ -0,0 +1,1004 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.12"
3
+ # dependencies = [
4
+ # "altair==5.5.0",
5
+ # "marimo",
6
+ # "numpy==2.2.3",
7
+ # "polars==1.24.0",
8
+ # ]
9
+ # ///
10
+
11
+ import marimo
12
+
13
+ __generated_with = "0.11.17"
14
+ app = marimo.App(width="medium")
15
+
16
+
17
+ @app.cell(hide_code=True)
18
+ def _(mo):
19
+ mo.md(
20
+ r"""
21
+ # Strings
22
+
23
+ _By [Péter Ferenc Gyarmati](http://github.com/peter-gy)_.
24
+
25
+ In this chapter we're going to dig into string manipulation. For a fun twist, we'll be mostly playing around with a dataset that every Polars user has bumped into without really thinking about it—the source code of the `polars` module itself. More precisely, we'll use a dataframe that pulls together all the Polars expressions and their docstrings, giving us a cool, hands-on way to explore the expression API in a truly data-driven manner.
26
+
27
+ We'll cover parsing, length calculation, case conversion, and much more, with practical examples and visualizations. Finally, we will combine various techniques you learned in prior chapters to build a fully interactive playground in which you can execute the official code examples of Polars expressions.
28
+ """
29
+ )
30
+ return
31
+
32
+
33
+ @app.cell(hide_code=True)
34
+ def _(mo):
35
+ mo.md(
36
+ r"""
37
+ ## 🛠️ Parsing & Conversion
38
+
39
+ Let's warm up with one of the most frequent use cases: parsing raw strings into various formats.
40
+ We'll take a tiny dataframe with metadata about Python packages represented as raw JSON strings and we'll use Polars string expressions to parse the attributes into their true data types.
41
+ """
42
+ )
43
+ return
44
+
45
+
46
+ @app.cell
47
+ def _(pl):
48
+ pip_metadata_raw_df = pl.DataFrame(
49
+ [
50
+ '{"package": "polars", "version": "1.24.0", "released_at": "2025-03-02T20:31:12+0000", "size_mb": "30.9"}',
51
+ '{"package": "marimo", "version": "0.11.14", "released_at": "2025-03-04T00:28:57+0000", "size_mb": "10.7"}',
52
+ ],
53
+ schema={"raw_json": pl.String},
54
+ )
55
+ pip_metadata_raw_df
56
+ return (pip_metadata_raw_df,)
57
+
58
+
59
+ @app.cell(hide_code=True)
60
+ def _(mo):
61
+ mo.md(r"""We can use the [`json_decode`](https://docs.pola.rs/api/python/stable/reference/series/api/polars.Series.str.json_decode.html) expression to parse the raw JSON strings into Polars-native structs and we can use the [unnest](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.unnest.html) dataframe operation to have a dedicated column per parsed attribute.""")
62
+ return
63
+
64
+
65
+ @app.cell
66
+ def _(pip_metadata_raw_df, pl):
67
+ pip_metadata_df = pip_metadata_raw_df.select(json=pl.col('raw_json').str.json_decode()).unnest('json')
68
+ pip_metadata_df
69
+ return (pip_metadata_df,)
70
+
71
+
72
+ @app.cell(hide_code=True)
73
+ def _(mo):
74
+ mo.md(r"""This is already a much friendlier representation of the data we started out with, but note that since the JSON entries had only string attributes, all values are strings, even the temporal `released_at` and numerical `size_mb` columns.""")
75
+ return
76
+
77
+
78
+ @app.cell(hide_code=True)
79
+ def _(mo):
80
+ mo.md(r"""As we know that the `size_mb` column should have a decimal representation, we go ahead and use [`to_decimal`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_decimal.html#polars.Expr.str.to_decimal) to perform the conversion.""")
81
+ return
82
+
83
+
84
+ @app.cell
85
+ def _(pip_metadata_df, pl):
86
+ pip_metadata_df.select(
87
+ 'package',
88
+ 'version',
89
+ pl.col('size_mb').str.to_decimal(),
90
+ )
91
+ return
92
+
93
+
94
+ @app.cell(hide_code=True)
95
+ def _(mo):
96
+ mo.md(
97
+ r"""
98
+ Moving on to the `released_at` attribute which indicates the exact time when a given Python package got released, we have a bit more options to consider. We can convert to `Date`, `DateTime`, and `Time` types based on the desired temporal granularity. The [`to_date`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_date.html), [`to_datetime`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_datetime.html), and [`to_time`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_time.html) expressions are here to help us with the conversion, all we need is to provide the desired format string.
99
+
100
+ Since Polars uses Rust under the hood to implement all its expressions, we need to consult the [`chrono::format`](https://docs.rs/chrono/latest/chrono/format/strftime/index.html) reference to come up with appropriate format strings.
101
+
102
+ Here's a quick reference:
103
+
104
+ | Specifier | Meaning |
105
+ |-----------|--------------------|
106
+ | `%Y` | Year (e.g., 2025) |
107
+ | `%m` | Month (01-12) |
108
+ | `%d` | Day (01-31) |
109
+ | `%H` | Hour (00-23) |
110
+ | `%z` | UTC offset |
111
+
112
+ The raw strings we are working with look like `"2025-03-02T20:31:12+0000"`. We can match this using the `"%Y-%m-%dT%H:%M:%S%z"` format string.
113
+ """
114
+ )
115
+ return
116
+
117
+
118
+ @app.cell
119
+ def _(pip_metadata_df, pl):
120
+ pip_metadata_df.select(
121
+ 'package',
122
+ 'version',
123
+ release_date=pl.col('released_at').str.to_date('%Y-%m-%dT%H:%M:%S%z'),
124
+ release_datetime=pl.col('released_at').str.to_datetime('%Y-%m-%dT%H:%M:%S%z'),
125
+ release_time=pl.col('released_at').str.to_time('%Y-%m-%dT%H:%M:%S%z'),
126
+ )
127
+ return
128
+
129
+
130
+ @app.cell(hide_code=True)
131
+ def _(mo):
132
+ mo.md(r"""Alternatively, instead of using three different functions to perform the conversion to date, we can use a single one, [`strptime`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strptime.html) which takes the desired temporal data type as its first parameter.""")
133
+ return
134
+
135
+
136
+ @app.cell
137
+ def _(pip_metadata_df, pl):
138
+ pip_metadata_df.select(
139
+ 'package',
140
+ 'version',
141
+ release_date=pl.col('released_at').str.strptime(pl.Date, '%Y-%m-%dT%H:%M:%S%z'),
142
+ release_datetime=pl.col('released_at').str.strptime(pl.Datetime, '%Y-%m-%dT%H:%M:%S%z'),
143
+ release_time=pl.col('released_at').str.strptime(pl.Time, '%Y-%m-%dT%H:%M:%S%z'),
144
+ )
145
+ return
146
+
147
+
148
+ @app.cell(hide_code=True)
149
+ def _(mo):
150
+ mo.md(r"""And to wrap up this section on parsing and conversion, let's consider a final scenario. What if we don't want to parse the entire raw JSON string, because we only need a subset of its attributes? Well, in this case we can leverage the [`json_path_match`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.json_path_match.html) expression to extract only the desired attributes using standard [JSONPath](https://goessner.net/articles/JsonPath/) syntax.""")
151
+ return
152
+
153
+
154
+ @app.cell
155
+ def _(pip_metadata_raw_df, pl):
156
+ pip_metadata_raw_df.select(
157
+ package=pl.col("raw_json").str.json_path_match("$.package"),
158
+ version=pl.col("raw_json").str.json_path_match("$.version"),
159
+ release_date=pl.col("raw_json")
160
+ .str.json_path_match("$.size_mb")
161
+ .str.to_decimal(),
162
+ )
163
+ return
164
+
165
+
166
+ @app.cell(hide_code=True)
167
+ def _(mo):
168
+ mo.md(
169
+ r"""
170
+ ## 📊 Dataset Overview
171
+
172
+ Now that we got our hands dirty, let's consider a somewhat wilder dataset for the subsequent sections: a dataframe of metadata about every single expression in your current Polars module.
173
+
174
+ At the risk of stating the obvious, in the previous section, when we typed `pl.col('raw_json').str.json_decode()`, we accessed the `json_decode` member of the `str` expression namespace through the `pl.col('raw_json')` expression *instance*. Under the hood, deep inside the Polars source code, there is a corresponding `def json_decode(...)` method with a carefully authored docstring explaining the purpose and signature of the member.
175
+
176
+ Since Python makes module introspection simple, we can easily enumerate all Polars expressions and organize their metadata in `expressions_df`, to be used for all the upcoming string manipulation examples.
177
+ """
178
+ )
179
+ return
180
+
181
+
182
+ @app.cell(hide_code=True)
183
+ def _(pl):
184
+ def list_members(expr, namespace) -> list[dict]:
185
+ """Iterates through the attributes of `expr` and returns their metadata"""
186
+ members = []
187
+ for attrname in expr.__dir__():
188
+ is_namespace = attrname in pl.Expr._accessors
189
+ is_private = attrname.startswith("_")
190
+ if is_namespace or is_private:
191
+ continue
192
+
193
+ attr = getattr(expr, attrname)
194
+ members.append(
195
+ {
196
+ "namespace": namespace,
197
+ "member": attrname,
198
+ "docstring": attr.__doc__,
199
+ }
200
+ )
201
+ return members
202
+
203
+
204
+ def list_expr_meta() -> list[dict]:
205
+ # Dummy expression instance to 'crawl'
206
+ expr = pl.lit("")
207
+ root_members = list_members(expr, "root")
208
+ namespaced_members: list[list[dict]] = [
209
+ list_members(getattr(expr, namespace), namespace)
210
+ for namespace in pl.Expr._accessors
211
+ ]
212
+ return sum(namespaced_members, root_members)
213
+
214
+
215
+ expressions_df = pl.from_dicts(list_expr_meta(), infer_schema_length=None).sort('namespace', 'member')
216
+ expressions_df
217
+ return expressions_df, list_expr_meta, list_members
218
+
219
+
220
+ @app.cell(hide_code=True)
221
+ def _(mo):
222
+ mo.md(r"""As the following visualization shows, `str` is one of the richest Polars expression namespaces with multiple dozens of functions in it.""")
223
+ return
224
+
225
+
226
+ @app.cell(hide_code=True)
227
+ def _(alt, expressions_df):
228
+ expressions_df.plot.bar(
229
+ x=alt.X("count(member):Q", title='Count of Expressions'),
230
+ y=alt.Y("namespace:N", title='Namespace').sort("-x"),
231
+ )
232
+ return
233
+
234
+
235
+ @app.cell(hide_code=True)
236
+ def _(mo):
237
+ mo.md(
238
+ r"""
239
+ ## 📏 Length Calculation
240
+
241
+ A common use case is to compute the length of a string. Most people associate string length exclusively with the number of characters the said string consists of; however, in certain scenarios it is useful to also know how much memory is required for storing, so how many bytes are required to represent the textual data.
242
+
243
+ The expressions [`len_chars`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.len_chars.html) and [`len_bytes`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.len_bytes.html) are here to help us with these calculations.
244
+
245
+ Below, we compute `docstring_len_chars` and `docstring_len_bytes` columns to see how many characters and bytes the documentation of each expression is made up of.
246
+ """
247
+ )
248
+ return
249
+
250
+
251
+ @app.cell
252
+ def _(expressions_df, pl):
253
+ docstring_length_df = expressions_df.select(
254
+ 'namespace',
255
+ 'member',
256
+ docstring_len_chars=pl.col("docstring").str.len_chars(),
257
+ docstring_len_bytes=pl.col("docstring").str.len_bytes(),
258
+ )
259
+ docstring_length_df
260
+ return (docstring_length_df,)
261
+
262
+
263
+ @app.cell(hide_code=True)
264
+ def _(mo):
265
+ mo.md(r"""As the dataframe preview above and the scatterplot below show, the docstring length measured in bytes is almost always bigger than the length expressed in characters. This is due to the fact that the docstrings include characters which require more than a single byte to represent, such as "╞" for displaying dataframe header and body separators.""")
266
+ return
267
+
268
+
269
+ @app.cell
270
+ def _(alt, docstring_length_df):
271
+ docstring_length_df.plot.point(
272
+ x=alt.X('docstring_len_chars', title='Docstring Length (Chars)'),
273
+ y=alt.Y('docstring_len_bytes', title='Docstring Length (Bytes)'),
274
+ tooltip=['namespace', 'member', 'docstring_len_chars', 'docstring_len_bytes'],
275
+ )
276
+ return
277
+
278
+
279
+ @app.cell(hide_code=True)
280
+ def _(mo):
281
+ mo.md(
282
+ r"""
283
+ ## 🔠 Case Conversion
284
+
285
+ Another frequent string transformation is lowercasing, uppercasing, and titlecasing. We can use [`to_lowercase`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_lowercase.html), [`to_uppercase`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_lowercase.html) and [`to_titlecase`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.to_titlecase.html) for doing so.
286
+ """
287
+ )
288
+ return
289
+
290
+
291
+ @app.cell
292
+ def _(expressions_df, pl):
293
+ expressions_df.select(
294
+ member_lower=pl.col('member').str.to_lowercase(),
295
+ member_upper=pl.col('member').str.to_uppercase(),
296
+ member_title=pl.col('member').str.to_titlecase(),
297
+ )
298
+ return
299
+
300
+
301
+ @app.cell(hide_code=True)
302
+ def _(mo):
303
+ mo.md(
304
+ r"""
305
+ ## ➕ Padding
306
+
307
+ Sometimes we need to ensure that strings have a fixed-size character length. [`pad_start`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.pad_start.html) and [`pad_end`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.pad_end.html) can be used to fill the "front" or "back" of a string with a supplied character, while [`zfill`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.zfill.html) is a utility for padding the start of a string with `"0"` until it reaches a particular length. In other words, `zfill` is a more specific version of `pad_start`, where the `fill_char` parameter is explicitly set to `"0"`.
308
+
309
+ In the example below we take the unique Polars expression namespaces and pad them so that they have a uniform length which you can control via a slider.
310
+ """
311
+ )
312
+ return
313
+
314
+
315
+ @app.cell(hide_code=True)
316
+ def _(mo):
317
+ padding = mo.ui.slider(0, 16, step=1, value=8, label="Padding Size")
318
+ return (padding,)
319
+
320
+
321
+ @app.cell
322
+ def _(expressions_df, padding, pl):
323
+ padded_df = expressions_df.select("namespace").unique().select(
324
+ "namespace",
325
+ namespace_front_padded=pl.col("namespace").str.pad_start(padding.value, "_"),
326
+ namespace_back_padded=pl.col("namespace").str.pad_end(padding.value, "_"),
327
+ namespace_zfilled=pl.col("namespace").str.zfill(padding.value),
328
+ )
329
+ return (padded_df,)
330
+
331
+
332
+ @app.cell(hide_code=True)
333
+ def _(mo, padded_df, padding):
334
+ mo.vstack([
335
+ padding,
336
+ padded_df,
337
+ ])
338
+ return
339
+
340
+
341
+ @app.cell(hide_code=True)
342
+ def _(mo):
343
+ mo.md(
344
+ r"""
345
+ ## 🔄 Replacing
346
+
347
+ Let's say we want to convert from `snake_case` API member names to `kebab-case`, that is, we need to replace the underscore character with a hyphen. For operations like that, we can use [`replace`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.replace.html) and [`replace_all`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.replace_all.html).
348
+
349
+ As the example below demonstrates, `replace` stops after the first occurrence of the to-be-replaced pattern, while `replace_all` goes all the way through and changes all underscores to hyphens resulting in the `kebab-case` representation we were looking for.
350
+ """
351
+ )
352
+ return
353
+
354
+
355
+ @app.cell
356
+ def _(expressions_df, pl):
357
+ expressions_df.select(
358
+ "member",
359
+ member_kebab_case_partial=pl.col("member").str.replace("_", "-"),
360
+ member_kebab_case=pl.col("member").str.replace_all("_", "-"),
361
+ ).sort(pl.col("member").str.len_chars(), descending=True)
362
+ return
363
+
364
+
365
+ @app.cell(hide_code=True)
366
+ def _(mo):
367
+ mo.md(
368
+ r"""
369
+ A related expression is [`replace_many`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.replace_many.html), which accepts *many* pairs of to-be-matched patterns and corresponding replacements and uses the [Aho–Corasick algorithm](https://en.wikipedia.org/wiki/Aho%E2%80%93Corasick_algorithm) to carry out the operation with great performance.
370
+
371
+ In the example below we replace all instances of `"min"` with `"minimum"` and `"max"` with `"maximum"` using a single expression.
372
+ """
373
+ )
374
+ return
375
+
376
+
377
+ @app.cell
378
+ def _(expressions_df, pl):
379
+ expressions_df.select(
380
+ "member",
381
+ member_modified=pl.col("member").str.replace_many(
382
+ {
383
+ "min": "minimum",
384
+ "max": "maximum",
385
+ }
386
+ ),
387
+ )
388
+ return
389
+
390
+
391
+ @app.cell(hide_code=True)
392
+ def _(mo):
393
+ mo.md(
394
+ r"""
395
+ ## 🔍 Searching & Matching
396
+
397
+ A common need when working with strings is to determine whether their content satisfies some condition: whether it starts or ends with a particular substring or contains a certain pattern.
398
+
399
+ Let's suppose we want to determine whether a member of the Polars expression API is a "converter", such as `to_decimal`, identified by its `"to_"` prefix. We can use [`starts_with`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.starts_with.html) to perform this check.
400
+ """
401
+ )
402
+ return
403
+
404
+
405
+ @app.cell
406
+ def _(expressions_df, pl):
407
+ expressions_df.select(
408
+ "namespace",
409
+ "member",
410
+ is_converter=pl.col("member").str.starts_with("to_"),
411
+ ).sort(-pl.col("is_converter").cast(pl.Int8))
412
+ return
413
+
414
+
415
+ @app.cell(hide_code=True)
416
+ def _(mo):
417
+ mo.md(
418
+ r"""
419
+ Throughout this course as you have gained familiarity with the expression API you might have noticed that some members end with an underscore such as `or_`, since their "body" is a reserved Python keyword.
420
+
421
+ Let's use [`ends_with`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.ends_with.html) to find all the members which are named after such keywords.
422
+ """
423
+ )
424
+ return
425
+
426
+
427
+ @app.cell
428
+ def _(expressions_df, pl):
429
+ expressions_df.select(
430
+ "namespace",
431
+ "member",
432
+ is_escaped_keyword=pl.col("member").str.ends_with("_"),
433
+ ).sort(-pl.col("is_escaped_keyword").cast(pl.Int8))
434
+ return
435
+
436
+
437
+ @app.cell(hide_code=True)
438
+ def _(mo):
439
+ mo.md(
440
+ r"""
441
+ Now let's move on to analyzing the docstrings in a bit more detail. Based on their content we can determine whether a member is deprecated, accepts parameters, comes with examples, or references external URL(s) & related members.
442
+
443
+ As demonstrated below, we can compute all these boolean attributes using [`contains`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.contains.html) to check whether the docstring includes a particular substring.
444
+ """
445
+ )
446
+ return
447
+
448
+
449
+ @app.cell
450
+ def _(expressions_df, pl):
451
+ expressions_df.select(
452
+ 'namespace',
453
+ 'member',
454
+ is_deprecated=pl.col('docstring').str.contains('.. deprecated', literal=True),
455
+ has_parameters=pl.col('docstring').str.contains('Parameters'),
456
+ has_examples=pl.col('docstring').str.contains('Examples'),
457
+ has_related_members=pl.col('docstring').str.contains('See Also'),
458
+ has_url=pl.col('docstring').str.contains('https?://'),
459
+ )
460
+ return
461
+
462
+
463
+ @app.cell(hide_code=True)
464
+ def _(mo):
465
+ mo.md(r"""For scenarios where we want to combine multiple substrings to check for, we can use the [`contains`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.contains.html) expression to check for the presence of various patterns.""")
466
+ return
467
+
468
+
469
+ @app.cell
470
+ def _(expressions_df, pl):
471
+ expressions_df.select(
472
+ 'namespace',
473
+ 'member',
474
+ has_reference=pl.col('docstring').str.contains_any(['See Also', 'https://'])
475
+ )
476
+ return
477
+
478
+
479
+ @app.cell(hide_code=True)
480
+ def _(mo):
481
+ mo.md(
482
+ r"""
483
+ From the above analysis we could see that almost all the members come with code examples. It would be interesting to know how many variable assignments are going on within each of these examples, right? That's not as simple as checking for a pre-defined literal string containment though, because variables can have arbitrary names - any valid Python identifier is allowed. While the `contains` function supports checking for regular expressions instead of literal strings too, it would not suffice for this exercise because it only tells us whether there is at least a single occurrence of the sought pattern rather than telling us the exact number of matches.
484
+
485
+ Fortunately, we can take advantage of [`count_matches`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.count_matches.html) to achieve exactly what we want. We specify the regular expression `r'[a-zA-Z_][a-zA-Z0-9_]* = '` according to the [`regex` Rust crate](https://docs.rs/regex/latest/regex/) to match Python identifiers and we leave the rest to Polars.
486
+
487
+ In `count_matches(r'[a-zA-Z_][a-zA-Z0-9_]* = ')`:
488
+
489
+ - `[a-zA-Z_]` matches a letter or underscore (start of a Python identifier).
490
+ - `[a-zA-Z0-9_]*` matches zero or more letters, digits, or underscores.
491
+ - ` = ` matches a space, equals sign, and space (indicating assignment).
492
+
493
+ This finds variable assignments like `x = ` or `df_result = ` in docstrings.
494
+ """
495
+ )
496
+ return
497
+
498
+
499
+ @app.cell
500
+ def _(expressions_df, pl):
501
+ expressions_df.select(
502
+ 'namespace',
503
+ 'member',
504
+ variable_assignment_count=pl.col('docstring').str.count_matches(r'[a-zA-Z_][a-zA-Z0-9_]* = '),
505
+ )
506
+ return
507
+
508
+
509
+ @app.cell(hide_code=True)
510
+ def _(mo):
511
+ mo.md(r"""A related application example is to *find* the first index where a particular pattern is present, so that it can be used for downstream processing such as slicing. Below we use the [`find`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.find.html) expression to determine the index at which a code example starts in the docstring - identified by the Python shell substring `">>>"`.""")
512
+ return
513
+
514
+
515
+ @app.cell
516
+ def _(expressions_df, pl):
517
+ expressions_df.select(
518
+ 'namespace',
519
+ 'member',
520
+ code_example_start=pl.col('docstring').str.find('>>>'),
521
+ )
522
+ return
523
+
524
+
525
+ @app.cell(hide_code=True)
526
+ def _(mo):
527
+ mo.md(
528
+ r"""
529
+ ## ✂️ Slicing and Substrings
530
+
531
+ Sometimes we are only interested in a particular substring. We can use [`head`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.head.html), [`tail`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.tail.html) and [`slice`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.slice.html) to extract a substring from the start, end, or between arbitrary indices.
532
+ """
533
+ )
534
+ return
535
+
536
+
537
+ @app.cell
538
+ def _(mo):
539
+ slice = mo.ui.slider(1, 50, step=1, value=25, label="Slice Size")
540
+ return (slice,)
541
+
542
+
543
+ @app.cell
544
+ def _(expressions_df, pl, slice):
545
+ sliced_df = expressions_df.select(
546
+ # First 25 chars
547
+ docstring_head=pl.col("docstring").str.head(slice.value),
548
+ # 50 chars after the first 25 chars
549
+ docstring_slice=pl.col("docstring").str.slice(slice.value, 2*slice.value),
550
+ # Last 25 chars
551
+ docstring_tail=pl.col("docstring").str.tail(slice.value),
552
+ )
553
+ return (sliced_df,)
554
+
555
+
556
+ @app.cell
557
+ def _(mo, slice, sliced_df):
558
+ mo.vstack([
559
+ slice,
560
+ sliced_df,
561
+ ])
562
+ return
563
+
564
+
565
+ @app.cell(hide_code=True)
566
+ def _(mo):
567
+ mo.md(
568
+ r"""
569
+ ## ➗ Splitting
570
+
571
+ Certain strings follow a well-defined structure and we might be only interested in some parts of them. For example, when dealing with `snake_cased_expression` member names we might be curious to get only the first, second, or $n^{\text{th}}$ word before an underscore. We would need to *split* the string at a particular pattern for downstream processing.
572
+
573
+ The [`split`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.split.html), [`split_exact`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.split_exact.html) and [`splitn`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.splitn.html) expressions enable us to achieve this.
574
+
575
+ The primary difference between these string splitting utilities is that `split` produces a list of variadic length based on the number of resulting segments, `splitn` returns a struct with at least `0` and at most `n` fields while `split_exact` returns a struct of exactly `n` fields.
576
+ """
577
+ )
578
+ return
579
+
580
+
581
+ @app.cell
582
+ def _(expressions_df, pl):
583
+ expressions_df.select(
584
+ 'member',
585
+ member_name_parts=pl.col('member').str.split('_'),
586
+ member_name_parts_n=pl.col('member').str.splitn('_', n=2),
587
+ member_name_parts_exact=pl.col('member').str.split_exact('_', n=2),
588
+ )
589
+ return
590
+
591
+
592
+ @app.cell(hide_code=True)
593
+ def _(mo):
594
+ mo.md(r"""As a more practical example, we can use the `split` expression with some aggregation to count the number of times a particular word occurs in member names across all namespaces. This enables us to create a word cloud of the API members' constituents!""")
595
+ return
596
+
597
+
598
+ @app.cell(hide_code=True)
599
+ def _(mo, wordcloud, wordcloud_height, wordcloud_width):
600
+ mo.vstack([
601
+ wordcloud_width,
602
+ wordcloud_height,
603
+ wordcloud,
604
+ ])
605
+ return
606
+
607
+
608
+ @app.cell(hide_code=True)
609
+ def _(mo):
610
+ wordcloud_width = mo.ui.slider(0, 64, step=1, value=32, label="Word Cloud Width")
611
+ wordcloud_height = mo.ui.slider(0, 32, step=1, value=16, label="Word Cloud Height")
612
+ return wordcloud_height, wordcloud_width
613
+
614
+
615
+ @app.cell(hide_code=True)
616
+ def _(alt, expressions_df, pl, random, wordcloud_height, wordcloud_width):
617
+ wordcloud_df = (
618
+ expressions_df.select(pl.col("member").str.split("_"))
619
+ .explode("member")
620
+ .group_by("member")
621
+ .agg(pl.len())
622
+ # Generating random x and y coordinates to distribute the words in the 2D space
623
+ .with_columns(
624
+ x=pl.col("member").map_elements(
625
+ lambda e: random.randint(0, wordcloud_width.value),
626
+ return_dtype=pl.UInt8,
627
+ ),
628
+ y=pl.col("member").map_elements(
629
+ lambda e: random.randint(0, wordcloud_height.value),
630
+ return_dtype=pl.UInt8,
631
+ ),
632
+ )
633
+ )
634
+
635
+ wordcloud = alt.Chart(wordcloud_df).mark_text(baseline="middle").encode(
636
+ x=alt.X("x:O", axis=None),
637
+ y=alt.Y("y:O", axis=None),
638
+ text="member:N",
639
+ color=alt.Color("len:Q", scale=alt.Scale(scheme="bluepurple")),
640
+ size=alt.Size("len:Q", legend=None),
641
+ tooltip=["member", "len"],
642
+ ).configure_view(strokeWidth=0)
643
+ return wordcloud, wordcloud_df
644
+
645
+
646
+ @app.cell(hide_code=True)
647
+ def _(mo):
648
+ mo.md(
649
+ r"""
650
+ ## 🔗 Concatenation & Joining
651
+
652
+ Often we would like to create longer strings from strings we already have. We might want to create a formatted, sentence-like string or join multiple existing strings in our dataframe into a single one.
653
+
654
+ The top-level [`concat_str`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.concat_str.html) expression enables us to combine strings *horizontally* in a dataframe. As the example below shows, we can take the `member` and `namespace` column of each row and construct a `description` column in which each row will correspond to the value ``f"- Expression `{member}` belongs to namespace `{namespace}`"``.
655
+ """
656
+ )
657
+ return
658
+
659
+
660
+ @app.cell
661
+ def _(expressions_df, pl):
662
+ descriptions_df = expressions_df.sample(5).select(
663
+ description=pl.concat_str(
664
+ [
665
+ pl.lit("- Expression "),
666
+ pl.lit("`"),
667
+ "member",
668
+ pl.lit("`"),
669
+ pl.lit(" belongs to namespace "),
670
+ pl.lit("`"),
671
+ "namespace",
672
+ pl.lit("`"),
673
+ ],
674
+ )
675
+ )
676
+ descriptions_df
677
+ return (descriptions_df,)
678
+
679
+
680
+ @app.cell(hide_code=True)
681
+ def _(mo):
682
+ mo.md(
683
+ r"""
684
+ Now that we have constructed these bullet points through *horizontal* concatenation of strings, we can perform a *vertical* one so that we end up with a single string in which we have a bullet point on each line.
685
+
686
+ We will use the [`join`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.join.html) expression to do so.
687
+ """
688
+ )
689
+ return
690
+
691
+
692
+ @app.cell
693
+ def _(descriptions_df, pl):
694
+ descriptions_df.select(pl.col('description').str.join('\n'))
695
+ return
696
+
697
+
698
+ @app.cell(hide_code=True)
699
+ def _(descriptions_df, mo, pl):
700
+ mo.md(f"""In fact, since the string we constructed dynamically is valid markdown, we can display it dynamically using Marimo's `mo.md` utility!
701
+
702
+ ---
703
+
704
+ {descriptions_df.select(pl.col('description').str.join('\n')).to_numpy().squeeze().tolist()}
705
+ """)
706
+ return
707
+
708
+
709
+ @app.cell(hide_code=True)
710
+ def _(mo):
711
+ mo.md(
712
+ r"""
713
+ ## 🔍 Pattern-based Extraction
714
+
715
+ In the vast majority of the cases, when dealing with unstructured text data, all we really want is to extract something structured from it. A common use case is to extract URLs from text to get a better understanding of related content.
716
+
717
+ In the example below that's exactly what we do. We scan the `docstring` of each API member and extract URLs from them using [`extract`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.extract.html) and [`extract_all`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.extract_all.html) using a simple regular expression to match http and https URLs.
718
+
719
+ Note that `extract` stops after a first match and returns a scalar result (or `null` if there was no match) while `extract_all` returns a - potentially empty - list of matches.
720
+ """
721
+ )
722
+ return
723
+
724
+
725
+ @app.cell
726
+ def _(expressions_df, pl):
727
+ url_pattern = r'(https?://[^\s>]+)'
728
+ expressions_df.select(
729
+ 'namespace',
730
+ 'member',
731
+ url_match=pl.col('docstring').str.extract(url_pattern),
732
+ url_matches=pl.col('docstring').str.extract_all(url_pattern),
733
+ ).filter(pl.col('url_match').is_not_null())
734
+ return (url_pattern,)
735
+
736
+
737
+ @app.cell(hide_code=True)
738
+ def _(mo):
739
+ mo.md(
740
+ r"""
741
+ Note that in each `docstring` where a code example involving dataframes is present, we will see an output such as "shape: (5, 2)" indicating the number of rows and columns of the dataframe produced by the sample code. Let's say we would like to *capture* this information in a structured way.
742
+
743
+ [`extract_groups`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.extract_groups.html) is a really powerful expression allowing us to achieve exactly that.
744
+
745
+ Below we define the regular expression `r"shape:\s*\((?<height>\S+),\s*(?<width>\S+)\)"` with two capture groups, named `height` and `width` and pass it as the parameter of `extract_groups`. After execution, for each `docstring`, we end up with fully structured data we can further process downstream!
746
+ """
747
+ )
748
+ return
749
+
750
+
751
+ @app.cell
752
+ def _(expressions_df, pl):
753
+ expressions_df.select(
754
+ 'namespace',
755
+ 'member',
756
+ example_df_shape=pl.col('docstring').str.extract_groups(r"shape:\s*\((?<height>\S+),\s*(?<width>\S+)\)"),
757
+ )
758
+ return
759
+
760
+
761
+ @app.cell(hide_code=True)
762
+ def _(mo):
763
+ mo.md(
764
+ r"""
765
+ ## 🧹 Stripping
766
+
767
+ Strings might require some cleaning before further processing, such as the removal of some characters from the beginning or end of the text. [`strip_chars_start`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_chars_start.html), [`strip_chars_end`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_chars_end.html) and [`strip_chars`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_chars.html) are here to facilitate this.
768
+
769
+ All we need to do is to specify a set of characters we would like to get rid of and Polars handles the rest for us.
770
+ """
771
+ )
772
+ return
773
+
774
+
775
+ @app.cell
776
+ def _(expressions_df, pl):
777
+ expressions_df.select(
778
+ "member",
779
+ member_front_stripped=pl.col("member").str.strip_chars_start("a"),
780
+ member_back_stripped=pl.col("member").str.strip_chars_end("n"),
781
+ member_fully_stripped=pl.col("member").str.strip_chars("na"),
782
+ )
783
+ return
784
+
785
+
786
+ @app.cell(hide_code=True)
787
+ def _(mo):
788
+ mo.md(
789
+ r"""
790
+ Note that when using the above expressions, the specified characters do not need to form a sequence; they are handled as a set. However, in certain use cases we only want to strip complete substrings, so we would need our input to be strictly treated as a sequence rather than as a set.
791
+
792
+ That's exactly the rationale behind [`strip_prefix`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_prefix.html) and [`strip_suffix`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.strip_suffix.html).
793
+
794
+ Below we use these to remove the `"to_"` prefixes and `"_with"` suffixes from each member name.
795
+ """
796
+ )
797
+ return
798
+
799
+
800
+ @app.cell
801
+ def _(expressions_df, pl):
802
+ expressions_df.select(
803
+ "member",
804
+ member_prefix_stripped=pl.col("member").str.strip_prefix("to_"),
805
+ member_suffix_stripped=pl.col("member").str.strip_suffix("_with"),
806
+ ).slice(20)
807
+ return
808
+
809
+
810
+ @app.cell(hide_code=True)
811
+ def _(mo):
812
+ mo.md(
813
+ r"""
814
+ ## 🔑 Encoding & Decoding
815
+
816
+ Should you find yourself in the need of encoding your strings into [base64](https://en.wikipedia.org/wiki/Base64) or [hexadecimal](https://en.wikipedia.org/wiki/Hexadecimal) format, then Polars has your back with its [`encode`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.encode.html) expression.
817
+ """
818
+ )
819
+ return
820
+
821
+
822
+ @app.cell
823
+ def _(expressions_df, pl):
824
+ encoded_df = expressions_df.select(
825
+ "member",
826
+ member_base64=pl.col('member').str.encode('base64'),
827
+ member_hex=pl.col('member').str.encode('hex'),
828
+ )
829
+ encoded_df
830
+ return (encoded_df,)
831
+
832
+
833
+ @app.cell(hide_code=True)
834
+ def _(mo):
835
+ mo.md(r"""And of course, you can convert back into a human-readable representation using the [`decode`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.str.decode.html) expression.""")
836
+ return
837
+
838
+
839
+ @app.cell
840
+ def _(encoded_df, pl):
841
+ encoded_df.with_columns(
842
+ member_base64_decoded=pl.col('member_base64').str.decode('base64').cast(pl.String),
843
+ member_hex_decoded=pl.col('member_hex').str.decode('hex').cast(pl.String),
844
+ )
845
+ return
846
+
847
+
848
+ @app.cell(hide_code=True)
849
+ def _(mo):
850
+ mo.md(
851
+ r"""
852
+ ## 🚀 Application: Dynamic Execution of Polars Examples
853
+
854
+ Now that we are familiar with string expressions, we can combine them with other Polars operations to build a fully interactive playground where code examples of Polars expressions can be explored.
855
+
856
+ We make use of string expressions to extract the raw Python source code of examples from the docstrings and we leverage the interactive Marimo environment to enable the selection of expressions via a searchable dropdown and a fully functional code editor whose output is rendered with Marimo's rich display utilities.
857
+
858
+ In other words, we will use Polars to execute Polars. ❄️ How cool is that?
859
+
860
+ ---
861
+ """
862
+ )
863
+ return
864
+
865
+
866
+ @app.cell(hide_code=True)
867
+ def _(
868
+ example_editor,
869
+ execution_result,
870
+ expression,
871
+ expression_description,
872
+ expression_docs_link,
873
+ mo,
874
+ ):
875
+ mo.vstack(
876
+ [
877
+ mo.md(f'### {expression.value}'),
878
+ expression,
879
+ mo.hstack([expression_description, expression_docs_link]),
880
+ example_editor,
881
+ execution_result,
882
+ ]
883
+ )
884
+ return
885
+
886
+
887
+ @app.cell(hide_code=True)
888
+ def _(mo, selected_expression_record):
889
+ expression_description = mo.md(selected_expression_record["description"])
890
+ expression_docs_link = mo.md(
891
+ f"🐻‍❄️ [Official Docs](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.{selected_expression_record['expr']}.html)"
892
+ )
893
+ return expression_description, expression_docs_link
894
+
895
+
896
+ @app.cell(hide_code=True)
897
+ def _(example_editor, execute_code):
898
+ execution_result = execute_code(example_editor.value)
899
+ return (execution_result,)
900
+
901
+
902
+ @app.cell(hide_code=True)
903
+ def _(code_df, mo):
904
+ expression = mo.ui.dropdown(code_df.get_column('expr'), value='arr.all', searchable=True)
905
+ return (expression,)
906
+
907
+
908
+ @app.cell(hide_code=True)
909
+ def _(code_df, expression):
910
+ selected_expression_record = code_df.filter(expr=expression.value).to_dicts()[0]
911
+ return (selected_expression_record,)
912
+
913
+
914
+ @app.cell(hide_code=True)
915
+ def _(mo, selected_expression_record):
916
+ example_editor = mo.ui.code_editor(value=selected_expression_record["code"])
917
+ return (example_editor,)
918
+
919
+
920
+ @app.cell(hide_code=True)
921
+ def _(expressions_df, pl):
922
+ code_df = (
923
+ expressions_df.select(
924
+ expr=pl.when(pl.col("namespace") == "root")
925
+ .then("member")
926
+ .otherwise(pl.concat_str(["namespace", "member"], separator=".")),
927
+ description=pl.col("docstring")
928
+ .str.split("\n\n")
929
+ .list.get(0)
930
+ .str.slice(9),
931
+ docstring_lines=pl.col("docstring").str.split("\n"),
932
+ )
933
+ .with_row_index()
934
+ .explode("docstring_lines")
935
+ .rename({"docstring_lines": "docstring_line"})
936
+ .with_columns(pl.col("docstring_line").str.strip_chars(" "))
937
+ .filter(pl.col("docstring_line").str.contains_any([">>> ", "... "]))
938
+ .with_columns(pl.col("docstring_line").str.slice(4))
939
+ .group_by(pl.exclude("docstring_line"), maintain_order=True)
940
+ .agg(code=pl.col("docstring_line").str.join("\n"))
941
+ .drop("index")
942
+ )
943
+ return (code_df,)
944
+
945
+
946
+ @app.cell(hide_code=True)
947
+ def _():
948
+ def execute_code(code: str):
949
+ import ast
950
+
951
+ # Create a new local namespace for execution
952
+ local_namespace = {}
953
+
954
+ # Parse the code into an AST to identify the last expression
955
+ parsed_code = ast.parse(code)
956
+
957
+ # Check if there's at least one statement
958
+ if not parsed_code.body:
959
+ return None
960
+
961
+ # If the last statement is an expression, we'll need to get its value
962
+ last_is_expr = isinstance(parsed_code.body[-1], ast.Expr)
963
+
964
+ if last_is_expr:
965
+ # Split the code: everything except the last statement, and the last statement
966
+ last_expr = ast.Expression(parsed_code.body[-1].value)
967
+
968
+ # Remove the last statement from the parsed code
969
+ parsed_code.body = parsed_code.body[:-1]
970
+
971
+ # Execute everything except the last statement
972
+ if parsed_code.body:
973
+ exec(
974
+ compile(parsed_code, "<string>", "exec"),
975
+ globals(),
976
+ local_namespace,
977
+ )
978
+
979
+ # Execute the last statement and get its value
980
+ result = eval(
981
+ compile(last_expr, "<string>", "eval"), globals(), local_namespace
982
+ )
983
+ return result
984
+ else:
985
+ # If the last statement is not an expression (e.g., an assignment),
986
+ # execute the entire code and return None
987
+ exec(code, globals(), local_namespace)
988
+ return None
989
+ return (execute_code,)
990
+
991
+
992
+ @app.cell(hide_code=True)
993
+ def _():
994
+ import polars as pl
995
+ import marimo as mo
996
+ import altair as alt
997
+ import random
998
+
999
+ random.seed(42)
1000
+ return alt, mo, pl, random
1001
+
1002
+
1003
+ if __name__ == "__main__":
1004
+ app.run()
polars/12_aggregations.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.13"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "polars==1.23.0",
6
+ # ]
7
+ # ///
8
+
9
+ import marimo
10
+
11
+ __generated_with = "0.11.14"
12
+ app = marimo.App(width="medium")
13
+
14
+
15
+ @app.cell
16
+ def _():
17
+ import marimo as mo
18
+ return (mo,)
19
+
20
+
21
+ @app.cell(hide_code=True)
22
+ def _(mo):
23
+ mo.md(
24
+ r"""
25
+ # Aggregations
26
+ _By [Joram Mutenge](https://www.udemy.com/user/joram-mutenge/)._
27
+
28
+ In this notebook, you'll learn how to perform different types of aggregations in Polars, including grouping by categories and time. We'll analyze sales data from a clothing store, focusing on three product categories: hats, socks, and sweaters.
29
+ """
30
+ )
31
+ return
32
+
33
+
34
+ @app.cell
35
+ def _():
36
+ import polars as pl
37
+
38
+ df = (pl.read_csv('https://raw.githubusercontent.com/jorammutenge/learn-rust/refs/heads/main/sample_sales.csv', try_parse_dates=True)
39
+ .rename(lambda col: col.replace(' ','_').lower())
40
+ )
41
+ df
42
+ return df, pl
43
+
44
+
45
+ @app.cell(hide_code=True)
46
+ def _(mo):
47
+ mo.md(
48
+ r"""
49
+ ## Grouping by category
50
+ ### With single category
51
+ Let's find out how many of each product category we sold.
52
+ """
53
+ )
54
+ return
55
+
56
+
57
+ @app.cell
58
+ def _(df, pl):
59
+ (df
60
+ .group_by('category')
61
+ .agg(pl.sum('quantity'))
62
+ )
63
+ return
64
+
65
+
66
+ @app.cell(hide_code=True)
67
+ def _(mo):
68
+ mo.md(
69
+ r"""
70
+ It looks like we sold more sweaters. Maybe this was a winter season.
71
+
72
+ Let's add another aggregate to see how much was spent on the total units for each product.
73
+ """
74
+ )
75
+ return
76
+
77
+
78
+ @app.cell
79
+ def _(df, pl):
80
+ (df
81
+ .group_by('category')
82
+ .agg(pl.sum('quantity'),
83
+ pl.sum('ext_price'))
84
+ )
85
+ return
86
+
87
+
88
+ @app.cell(hide_code=True)
89
+ def _(mo):
90
+ mo.md(r"""We could also write aggregate code for the two columns as a single line.""")
91
+ return
92
+
93
+
94
+ @app.cell
95
+ def _(df, pl):
96
+ (df
97
+ .group_by('category')
98
+ .agg(pl.sum('quantity','ext_price'))
99
+ )
100
+ return
101
+
102
+
103
+ @app.cell(hide_code=True)
104
+ def _(mo):
105
+ mo.md(r"""Actually, the way we've been writing the aggregate lines is syntactic sugar. Here's a longer way of doing it as shown in the [Polars documentation](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.dataframe.group_by.GroupBy.agg.html).""")
106
+ return
107
+
108
+
109
+ @app.cell
110
+ def _(df, pl):
111
+ (df
112
+ .group_by('category')
113
+ .agg(pl.col('quantity').sum(),
114
+ pl.col('ext_price').sum())
115
+ )
116
+ return
117
+
118
+
119
+ @app.cell(hide_code=True)
120
+ def _(mo):
121
+ mo.md(
122
+ r"""
123
+ ### With multiple categories
124
+ We can also group by multiple categories. Let's find out how many items we sold in each product category for each SKU. This more detailed aggregation will produce more rows than the previous DataFrame.
125
+ """
126
+ )
127
+ return
128
+
129
+
130
+ @app.cell
131
+ def _(df, pl):
132
+ (df
133
+ .group_by('category','sku')
134
+ .agg(pl.sum('quantity'))
135
+ )
136
+ return
137
+
138
+
139
+ @app.cell(hide_code=True)
140
+ def _(mo):
141
+ mo.md(
142
+ r"""
143
+ Aggregations when grouping data are not limited to sums. You can also use functions like [`max`, `min`, `median`, `first`, and `last`](https://docs.pola.rs/user-guide/expressions/aggregation/#basic-aggregations).
144
+
145
+ Let's find the largest sale quantity for each product category.
146
+ """
147
+ )
148
+ return
149
+
150
+
151
+ @app.cell
152
+ def _(df, pl):
153
+ (df
154
+ .group_by('category')
155
+ .agg(pl.max('quantity'))
156
+ )
157
+ return
158
+
159
+
160
+ @app.cell(hide_code=True)
161
+ def _(mo):
162
+ mo.md(
163
+ r"""
164
+ Let's make the aggregation more interesting. We'll identify the first customer to purchase each item, along with the quantity they bought and the amount they spent.
165
+
166
+ **Note:** To make this work, we'll have to sort the date from earliest to latest.
167
+ """
168
+ )
169
+ return
170
+
171
+
172
+ @app.cell
173
+ def _(df, pl):
174
+ (df
175
+ .sort('date')
176
+ .group_by('category')
177
+ .agg(pl.first('account_name','quantity','ext_price'))
178
+ )
179
+ return
180
+
181
+
182
+ @app.cell(hide_code=True)
183
+ def _(mo):
184
+ mo.md(
185
+ r"""
186
+ ## Grouping by time
187
+ Since `datetime` is a special data type in Polars, we can perform various group-by aggregations on it.
188
+
189
+ Our dataset spans a two-year period. Let's calculate the total dollar sales for each year. We'll do it the naive way first so you can appreciate grouping with time.
190
+ """
191
+ )
192
+ return
193
+
194
+
195
+ @app.cell
196
+ def _(df, pl):
197
+ (df
198
+ .with_columns(year=pl.col('date').dt.year())
199
+ .group_by('year')
200
+ .agg(pl.sum('ext_price').round(2))
201
+ )
202
+ return
203
+
204
+
205
+ @app.cell(hide_code=True)
206
+ def _(mo):
207
+ mo.md(
208
+ r"""
209
+ We had more sales in 2014.
210
+
211
+ Now let's perform the above operation by groupin with time. This requires sorting the dataframe first.
212
+ """
213
+ )
214
+ return
215
+
216
+
217
+ @app.cell
218
+ def _(df, pl):
219
+ (df
220
+ .sort('date')
221
+ .group_by_dynamic('date', every='1y')
222
+ .agg(pl.sum('ext_price'))
223
+ )
224
+ return
225
+
226
+
227
+ @app.cell(hide_code=True)
228
+ def _(mo):
229
+ mo.md(
230
+ r"""
231
+ The beauty of grouping with time is that it allows us to resample the data by selecting whatever time interval we want.
232
+
233
+ Let's find out what the quarterly sales were for 2014
234
+ """
235
+ )
236
+ return
237
+
238
+
239
+ @app.cell
240
+ def _(df, pl):
241
+ (df
242
+ .filter(pl.col('date').dt.year() == 2014)
243
+ .sort('date')
244
+ .group_by_dynamic('date', every='1q')
245
+ .agg(pl.sum('ext_price'))
246
+ )
247
+ return
248
+
249
+
250
+ @app.cell(hide_code=True)
251
+ def _(mo):
252
+ mo.md(
253
+ r"""
254
+ Here's an interesting question we can answer that takes advantage of grouping by time.
255
+
256
+ Let's find the hour of the day where we had the most sales in dollars.
257
+ """
258
+ )
259
+ return
260
+
261
+
262
+ @app.cell
263
+ def _(df, pl):
264
+ (df
265
+ .sort('date')
266
+ .group_by_dynamic('date', every='1h')
267
+ .agg(pl.max('ext_price'))
268
+ .filter(pl.col('ext_price') == pl.col('ext_price').max())
269
+ )
270
+ return
271
+
272
+
273
+ @app.cell(hide_code=True)
274
+ def _(mo):
275
+ mo.md(r"""Just for fun, let's find the median number of items sold in each SKU and the total dollar amount in each SKU every six days.""")
276
+ return
277
+
278
+
279
+ @app.cell
280
+ def _(df, pl):
281
+ (df
282
+ .sort('date')
283
+ .group_by_dynamic('date', every='6d')
284
+ .agg(pl.first('sku'),
285
+ pl.median('quantity'),
286
+ pl.sum('ext_price'))
287
+ )
288
+ return
289
+
290
+
291
+ @app.cell(hide_code=True)
292
+ def _(mo):
293
+ mo.md(r"""Let's rename the columns to clearly indicate the type of aggregation performed. This will help us identify the aggregation method used on a column without needing to check the code.""")
294
+ return
295
+
296
+
297
+ @app.cell
298
+ def _(df, pl):
299
+ (df
300
+ .sort('date')
301
+ .group_by_dynamic('date', every='6d')
302
+ .agg(pl.first('sku'),
303
+ pl.median('quantity').alias('median_qty'),
304
+ pl.sum('ext_price').alias('total_dollars'))
305
+ )
306
+ return
307
+
308
+
309
+ @app.cell(hide_code=True)
310
+ def _(mo):
311
+ mo.md(
312
+ r"""
313
+ ## Grouping with over
314
+
315
+ Sometimes, we may want to perform an aggregation but also keep all the columns and rows of the dataframe.
316
+
317
+ Let's assign a value to indicate the number of times each customer visited and bought something.
318
+ """
319
+ )
320
+ return
321
+
322
+
323
+ @app.cell
324
+ def _(df, pl):
325
+ (df
326
+ .with_columns(buy_freq=pl.col('account_name').len().over('account_name'))
327
+ )
328
+ return
329
+
330
+
331
+ @app.cell(hide_code=True)
332
+ def _(mo):
333
+ mo.md(r"""Finally, let's determine which customers visited the store the most and bought something.""")
334
+ return
335
+
336
+
337
+ @app.cell
338
+ def _(df, pl):
339
+ (df
340
+ .with_columns(buy_freq=pl.col('account_name').len().over('account_name'))
341
+ .filter(pl.col('buy_freq') == pl.col('buy_freq').max())
342
+ .select('account_name','buy_freq')
343
+ .unique()
344
+ )
345
+ return
346
+
347
+
348
+ @app.cell(hide_code=True)
349
+ def _(mo):
350
+ mo.md(r"""There's more you can do with aggregations in Polars such as [sorting with aggregations](https://docs.pola.rs/user-guide/expressions/aggregation/#sorting). We hope that in this notebook, we've armed you with the tools to get started.""")
351
+ return
352
+
353
+
354
+ if __name__ == "__main__":
355
+ app.run()
polars/14_user_defined_functions.py ADDED
@@ -0,0 +1,946 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.12"
3
+ # dependencies = [
4
+ # "altair==5.5.0",
5
+ # "beautifulsoup4==4.13.3",
6
+ # "httpx==0.28.1",
7
+ # "marimo",
8
+ # "nest-asyncio==1.6.0",
9
+ # "numba==0.61.0",
10
+ # "numpy==2.1.3",
11
+ # "polars==1.24.0",
12
+ # ]
13
+ # ///
14
+
15
+ import marimo
16
+
17
+ __generated_with = "0.11.17"
18
+ app = marimo.App(width="medium")
19
+
20
+
21
+ @app.cell(hide_code=True)
22
+ def _(mo):
23
+ mo.md(
24
+ r"""
25
+ # User-Defined Functions
26
+
27
+ _By [Péter Ferenc Gyarmati](http://github.com/peter-gy)_.
28
+
29
+ Throughout the previous chapters, you've seen how Polars provides a comprehensive set of built-in expressions for flexible data transformation. But what happens when you need something *more*? Perhaps your project has unique requirements, or you need to integrate functionality from an external Python library. This is where User-Defined Functions (UDFs) come into play, allowing you to extend Polars with your own custom logic.
30
+
31
+ In this chapter, we'll weigh the performance trade-offs of UDFs, pinpoint situations where they're truly beneficial, and explore different ways to effectively incorporate them into your Polars workflows. We'll walk through a complete, practical example.
32
+ """
33
+ )
34
+ return
35
+
36
+
37
+ @app.cell(hide_code=True)
38
+ def _(mo):
39
+ mo.md(
40
+ r"""
41
+ ## ⚖️ The Cost of UDFs
42
+
43
+ > Performance vs. Flexibility
44
+
45
+ Polars' built-in expressions are highly optimized for speed and parallel processing. User-defined functions (UDFs), however, introduce a significant performance overhead because they rely on standard Python code, which often runs in a single thread and bypasses Polars' logical optimizations. Therefore, always prioritize native Polars operations *whenever possible*.
46
+
47
+ However, UDFs become inevitable when you need to:
48
+
49
+ - **Integrate external libraries:** Use functionality not directly available in Polars.
50
+ - **Implement custom logic:** Handle complex transformations that can't be easily expressed with Polars' built-in functions.
51
+
52
+ Let's dive into a real-world project where UDFs were the only way to get the job done, demonstrating a scenario where native Polars expressions simply weren't sufficient.
53
+ """
54
+ )
55
+ return
56
+
57
+
58
+ @app.cell(hide_code=True)
59
+ def _(mo):
60
+ mo.md(
61
+ r"""
62
+ ## 📊 Project Overview
63
+
64
+ > Scraping and Analyzing Observable Notebook Statistics
65
+
66
+ If you're into data visualization, you've probably seen [D3.js](https://d3js.org/) and [Observable Plot](https://observablehq.com/plot/). Both have extensive galleries showcasing amazing visualizations. Each gallery item is a standalone [Observable notebook](https://observablehq.com/documentation/notebooks/), with metrics like stars, comments, and forks – indicators of popularity. But getting and analyzing these statistics directly isn't straightforward. We'll need to scrape the web.
67
+ """
68
+ )
69
+ return
70
+
71
+
72
+ @app.cell(hide_code=True)
73
+ def _(mo):
74
+ mo.hstack(
75
+ [
76
+ mo.image(
77
+ "https://minio.peter.gy/static/assets/marimo/learn/polars/14_d3-gallery.png?0",
78
+ width=600,
79
+ caption="Screenshot of https://observablehq.com/@d3/gallery",
80
+ ),
81
+ mo.image(
82
+ "https://minio.peter.gy/static/assets/marimo/learn/polars/14_plot-gallery.png?0",
83
+ width=600,
84
+ caption="Screenshot of https://observablehq.com/@observablehq/plot-gallery",
85
+ ),
86
+ ]
87
+ )
88
+ return
89
+
90
+
91
+ @app.cell(hide_code=True)
92
+ def _(mo):
93
+ mo.md(r"""Our goal is to use Polars UDFs to fetch the HTML content of these gallery pages. Then, we'll use the `BeautifulSoup` Python library to parse the HTML and extract the relevant metadata. After some data wrangling with native Polars expressions, we'll have a DataFrame listing each visualization notebook. Then, we'll use another UDF to retrieve the number of likes, forks, and comments for each notebook. Finally, we will create our own high-performance UDF to implement a custom notebook ranking scheme. This will involve multiple steps, showcasing different UDF approaches.""")
94
+ return
95
+
96
+
97
+ @app.cell(hide_code=True)
98
+ def _(mo):
99
+ mo.mermaid('''
100
+ graph LR;
101
+ url_df --> |"UDF: Fetch HTML"| html_df
102
+ html_df --> |"UDF: Parse with BeautifulSoup"| parsed_html_df
103
+ parsed_html_df --> |"Native Polars: Extract Data"| notebooks_df
104
+ notebooks_df --> |"UDF: Get Notebook Stats"| notebook_stats_df
105
+ notebook_stats_df --> |"Numba UDF: Compute Popularity"| notebook_popularity_df
106
+ ''')
107
+ return
108
+
109
+
110
+ @app.cell(hide_code=True)
111
+ def _(mo):
112
+ mo.md(r"""Our starting point, `url_df`, is a simple DataFrame with a single `url` column containing the URLs of the D3 and Observable Plot gallery notebooks.""")
113
+ return
114
+
115
+
116
+ @app.cell(hide_code=True)
117
+ def _(pl):
118
+ url_df = pl.from_dict(
119
+ {
120
+ "url": [
121
+ "https://observablehq.com/@d3/gallery",
122
+ "https://observablehq.com/@observablehq/plot-gallery",
123
+ ]
124
+ }
125
+ )
126
+ url_df
127
+ return (url_df,)
128
+
129
+
130
+ @app.cell(hide_code=True)
131
+ def _(mo):
132
+ mo.md(
133
+ r"""
134
+ ## 🔂 Element-Wise UDFs
135
+
136
+ > Processing Value by Value
137
+
138
+ The most common way to use UDFs is to apply them element-wise. This means our custom function will execute for *each individual row* in a specified column. Our first task is to fetch the HTML content for each URL in `url_df`.
139
+
140
+ We'll define a Python function that takes a `url` (a string) as input, uses the `httpx` library (an HTTP client) to fetch the content, and returns the HTML as a string. We then integrate this function into Polars using the [`map_elements`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.map_elements.html) expression.
141
+
142
+ You'll notice we have to explicitly specify the `return_dtype`. This is *crucial*. Polars doesn't automatically know what our custom function will return. We're responsible for defining the function's logic and, therefore, its output type. By providing the `return_dtype`, we help Polars maintain its internal representation of the DataFrame's schema, enabling query optimization. Think of it as giving Polars a "heads-up" about the data type it should expect.
143
+ """
144
+ )
145
+ return
146
+
147
+
148
+ @app.cell(hide_code=True)
149
+ def _(httpx, pl, url_df):
150
+ html_df = url_df.with_columns(
151
+ html=pl.col("url").map_elements(
152
+ lambda url: httpx.get(url).text,
153
+ return_dtype=pl.String,
154
+ )
155
+ )
156
+ html_df
157
+ return (html_df,)
158
+
159
+
160
+ @app.cell(hide_code=True)
161
+ def _(mo):
162
+ mo.md(
163
+ r"""
164
+ Now, `html_df` holds the HTML for each URL. We need to parse it. Again, a UDF is the way to go. Parsing HTML with native Polars expressions would be a nightmare! Instead, we'll use the [`beautifulsoup4`](https://pypi.org/project/beautifulsoup4/) library, a standard tool for this.
165
+
166
+ These Observable pages are built with [Next.js](https://nextjs.org/), which helpfully serializes page properties as JSON within the HTML. This simplifies our UDF: we'll extract the raw JSON from the `<script id="__NEXT_DATA__" type="application/json">` tag. We'll use [`map_elements`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.map_elements.html) again. For clarity, we'll define this UDF as a named function, `extract_nextjs_data`, since it's a bit more complex than a simple HTTP request.
167
+ """
168
+ )
169
+ return
170
+
171
+
172
+ @app.cell(hide_code=True)
173
+ def _(BeautifulSoup):
174
+ def extract_nextjs_data(html: str) -> str:
175
+ soup = BeautifulSoup(html, "html.parser")
176
+ script_tag = soup.find("script", id="__NEXT_DATA__")
177
+ return script_tag.text
178
+ return (extract_nextjs_data,)
179
+
180
+
181
+ @app.cell(hide_code=True)
182
+ def _(extract_nextjs_data, html_df, pl):
183
+ parsed_html_df = html_df.select(
184
+ "url",
185
+ next_data=pl.col("html").map_elements(
186
+ extract_nextjs_data,
187
+ return_dtype=pl.String,
188
+ ),
189
+ )
190
+ parsed_html_df
191
+ return (parsed_html_df,)
192
+
193
+
194
+ @app.cell(hide_code=True)
195
+ def _(mo):
196
+ mo.md(r"""With some data wrangling of the raw JSON (using *native* Polars expressions!), we get `notebooks_df`, containing the metadata for each notebook.""")
197
+ return
198
+
199
+
200
+ @app.cell(hide_code=True)
201
+ def _(parsed_html_df, pl):
202
+ notebooks_df = (
203
+ parsed_html_df.select(
204
+ "url",
205
+ # We extract the content of every cell present in the gallery notebooks
206
+ cell=pl.col("next_data")
207
+ .str.json_path_match("$.props.pageProps.initialNotebook.nodes")
208
+ .str.json_decode()
209
+ .list.eval(pl.element().struct.field("value")),
210
+ )
211
+ # We want one row per cell
212
+ .explode("cell")
213
+ # Only keep categorized notebook listing cells starting with H3
214
+ .filter(pl.col("cell").str.starts_with("### "))
215
+ # Split up the cells into [heading, description, config] sections
216
+ .with_columns(pl.col("cell").str.split("\n\n"))
217
+ .select(
218
+ gallery_url="url",
219
+ # Text after the '### ' heading, ignore '<!--' comments'
220
+ category=pl.col("cell").list.get(0).str.extract(r"###\s+(.*?)(?:\s+<!--.*?-->|$)"),
221
+ # Paragraph after heading
222
+ description=pl.col("cell")
223
+ .list.get(1)
224
+ .str.strip_chars(" ")
225
+ .str.replace_all("](/", "](https://observablehq.com/", literal=True),
226
+ # Parsed notebook config from ${preview([{...}])}
227
+ notebooks=pl.col("cell")
228
+ .list.get(2)
229
+ .str.strip_prefix("${previews([")
230
+ .str.strip_suffix("]})}")
231
+ .str.strip_chars(" \n")
232
+ .str.split("},")
233
+ # Simple regex-based attribute extraction from JS/JSON objects like
234
+ # ```js
235
+ # {
236
+ # path: "@d3/spilhaus-shoreline-map",
237
+ # "thumbnail": "66a87355e205d820...",
238
+ # title: "Spilhaus shoreline map",
239
+ # "author": "D3"
240
+ # }
241
+ # ```
242
+ .list.eval(
243
+ pl.struct(
244
+ *(
245
+ pl.element()
246
+ .str.extract(f'(?:"{key}"|{key})\s*:\s*"([^"]*)"')
247
+ .alias(key)
248
+ for key in ["path", "thumbnail", "title"]
249
+ )
250
+ )
251
+ ),
252
+ )
253
+ .explode("notebooks")
254
+ .unnest("notebooks")
255
+ .filter(pl.col("path").is_not_null())
256
+ # Final projection to end up with directly usable values
257
+ .select(
258
+ pl.concat_str(
259
+ [
260
+ pl.lit("https://static.observableusercontent.com/thumbnail/"),
261
+ "thumbnail",
262
+ pl.lit(".jpg"),
263
+ ],
264
+ ).alias("notebook_thumbnail_src"),
265
+ "category",
266
+ "title",
267
+ "description",
268
+ pl.concat_str(
269
+ [pl.lit("https://observablehq.com"), "path"], separator="/"
270
+ ).alias("notebook_url"),
271
+ )
272
+ )
273
+ notebooks_df
274
+ return (notebooks_df,)
275
+
276
+
277
+ @app.cell(hide_code=True)
278
+ def _(mo):
279
+ mo.md(
280
+ r"""
281
+ ## 📦 Batch-Wise UDFs
282
+
283
+ > Processing Entire Series
284
+
285
+ `map_elements` calls the UDF for *each row*. Fine for our tiny, two-rows-tall `url_df`. But `notebooks_df` has almost 400 rows! Individual HTTP requests for each would be painfully slow.
286
+
287
+ We want stats for each notebook in `notebooks_df`. To avoid sequential requests, we'll use Polars' [`map_batches`](https://docs.pola.rs/api/python/stable/reference/expressions/api/polars.Expr.map_batches.html). This lets us process an *entire Series* (a column) at once.
288
+
289
+ Our UDF, `fetch_html_batch`, will take a *Series* of URLs and use `asyncio` to make concurrent requests – a huge performance boost.
290
+ """
291
+ )
292
+ return
293
+
294
+
295
+ @app.cell(hide_code=True)
296
+ def _(Iterable, asyncio, httpx, mo):
297
+ async def _fetch_html_batch(urls: Iterable[str]) -> tuple[str, ...]:
298
+ async with httpx.AsyncClient(timeout=15) as client:
299
+ res = await asyncio.gather(*(client.get(url) for url in urls))
300
+ return tuple((r.text for r in res))
301
+
302
+
303
+ @mo.cache
304
+ def fetch_html_batch(urls: Iterable[str]) -> tuple[str, ...]:
305
+ return asyncio.run(_fetch_html_batch(urls))
306
+ return (fetch_html_batch,)
307
+
308
+
309
+ @app.cell(hide_code=True)
310
+ def _(mo):
311
+ mo.callout(
312
+ mo.md("""
313
+ Since `fetch_html_batch` is a pure Python function and performs multiple network requests, it's a good candidate for caching. We use [`mo.cache`](https://docs.marimo.io/api/caching/#marimo.cache) to avoid redundant requests to the same URL. This is a simple way to improve performance without modifying the core logic.
314
+ """
315
+ ),
316
+ kind="info",
317
+ )
318
+ return
319
+
320
+
321
+ @app.cell(hide_code=True)
322
+ def _(mo, notebooks_df):
323
+ category = mo.ui.dropdown(
324
+ notebooks_df.sort("category").get_column("category"),
325
+ value="Maps",
326
+ )
327
+ return (category,)
328
+
329
+
330
+ @app.cell(hide_code=True)
331
+ def _(category, extract_nextjs_data, fetch_html_batch, notebooks_df, pl):
332
+ notebook_stats_df = (
333
+ # Setting filter upstream to limit number of concurrent HTTP requests
334
+ notebooks_df.filter(category=category.value)
335
+ .with_columns(
336
+ notebook_html=pl.col("notebook_url")
337
+ .map_batches(fetch_html_batch, return_dtype=pl.List(pl.String))
338
+ .explode()
339
+ )
340
+ .with_columns(
341
+ notebook_data=pl.col("notebook_html")
342
+ .map_elements(
343
+ extract_nextjs_data,
344
+ return_dtype=pl.String,
345
+ )
346
+ .str.json_path_match("$.props.pageProps.initialNotebook")
347
+ .str.json_decode()
348
+ )
349
+ .drop("notebook_html")
350
+ .with_columns(
351
+ *[
352
+ pl.col("notebook_data").struct.field(key).alias(key)
353
+ for key in ["likes", "forks", "comments", "license"]
354
+ ]
355
+ )
356
+ .drop("notebook_data")
357
+ .with_columns(pl.col("comments").list.len())
358
+ .select(
359
+ pl.exclude("description", "notebook_url"),
360
+ "description",
361
+ "notebook_url",
362
+ )
363
+ .sort("likes", descending=True)
364
+ )
365
+ return (notebook_stats_df,)
366
+
367
+
368
+ @app.cell(hide_code=True)
369
+ def _(mo, notebook_stats_df):
370
+ notebooks = mo.ui.table(notebook_stats_df, selection='single', initial_selection=[2], page_size=5)
371
+ notebook_height = mo.ui.slider(start=400, stop=2000, value=825, step=25, show_value=True, label='Notebook Height')
372
+ return notebook_height, notebooks
373
+
374
+
375
+ @app.cell(hide_code=True)
376
+ def _():
377
+ def nb_iframe(notebook_url: str, height=825) -> str:
378
+ embed_url = notebook_url.replace(
379
+ "https://observablehq.com", "https://observablehq.com/embed"
380
+ )
381
+ return f'<iframe width="100%" height="{height}" frameborder="0" src="{embed_url}?cell=*"></iframe>'
382
+ return (nb_iframe,)
383
+
384
+
385
+ @app.cell(hide_code=True)
386
+ def _(mo):
387
+ mo.md(r"""Now that we have access to notebook-level statistics, we can rank the visualizations by the number of likes they received & display them interactively.""")
388
+ return
389
+
390
+
391
+ @app.cell(hide_code=True)
392
+ def _(mo):
393
+ mo.callout("💡 Explore the visualizations by paging through the table below and selecting any of its rows.")
394
+ return
395
+
396
+
397
+ @app.cell(hide_code=True)
398
+ def _(category, mo, nb_iframe, notebook_height, notebooks):
399
+ notebook = notebooks.value.to_dicts()[0]
400
+ mo.vstack(
401
+ [
402
+ mo.hstack([category, notebook_height]),
403
+ notebooks,
404
+ mo.md(f"{notebook['description']}"),
405
+ mo.md('---'),
406
+ mo.md(nb_iframe(notebook["notebook_url"], notebook_height.value)),
407
+ ]
408
+ )
409
+ return (notebook,)
410
+
411
+
412
+ @app.cell(hide_code=True)
413
+ def _(mo):
414
+ mo.md(
415
+ r"""
416
+ ## ⚙️ Row-Wise UDFs
417
+
418
+ > Accessing All Columns at Once
419
+
420
+ Sometimes, you need to work with *all* columns of a row at once. This is where [`map_rows`](https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.map_rows.html) comes in. It operates directly on the DataFrame, passing each row to your UDF *as a tuple*.
421
+
422
+ Below, `create_notebook_summary` takes a row from `notebook_stats_df` (as a tuple) and returns a formatted Markdown string summarizing the notebook's key stats. We're essentially reducing the DataFrame to a single column. While this *could* be done with native Polars expressions, it would be much more cumbersome. This example demonstrates a case where a row-wise UDF simplifies the code, even if the underlying operation isn't inherently complex.
423
+ """
424
+ )
425
+ return
426
+
427
+
428
+ @app.cell(hide_code=True)
429
+ def _():
430
+ def create_notebook_summary(row: tuple) -> str:
431
+ (
432
+ thumbnail_src,
433
+ category,
434
+ title,
435
+ likes,
436
+ forks,
437
+ comments,
438
+ license,
439
+ description,
440
+ notebook_url,
441
+ ) = row
442
+ return (
443
+ f"""
444
+ ### [{title}]({notebook_url})
445
+
446
+ <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 12px; margin: 12px 0;">
447
+ <div>⭐ <strong>Likes:</strong> {likes}</div>
448
+ <div>↗️ <strong>Forks:</strong> {forks}</div>
449
+ <div>💬 <strong>Comments:</strong> {comments}</div>
450
+ <div>⚖️ <strong>License:</strong> {license}</div>
451
+ </div>
452
+
453
+ <a href="{notebook_url}" target="_blank">
454
+ <img src="{thumbnail_src}" style="height: 300px;" />
455
+ <a/>
456
+ """.strip('\n')
457
+ )
458
+ return (create_notebook_summary,)
459
+
460
+
461
+ @app.cell(hide_code=True)
462
+ def _(create_notebook_summary, notebook_stats_df, pl):
463
+ notebook_summary_df = notebook_stats_df.map_rows(
464
+ create_notebook_summary,
465
+ return_dtype=pl.String,
466
+ ).rename({"map": "summary"})
467
+ notebook_summary_df.head(1)
468
+ return (notebook_summary_df,)
469
+
470
+
471
+ @app.cell(hide_code=True)
472
+ def _(mo):
473
+ mo.callout("💡 You can explore individual notebook statistics through the carousel. Discover the visualization's source code by clicking the notebook title or the thumbnail.")
474
+ return
475
+
476
+
477
+ @app.cell(hide_code=True)
478
+ def _(mo, notebook_summary_df):
479
+ mo.carousel(
480
+ [
481
+ mo.lazy(mo.md(summary))
482
+ for summary in notebook_summary_df.get_column("summary")
483
+ ]
484
+ )
485
+ return
486
+
487
+
488
+ @app.cell(hide_code=True)
489
+ def _(mo):
490
+ mo.md(
491
+ r"""
492
+ ## 🚀 Higher-performance UDFs
493
+
494
+ > Leveraging Numba to Make Python Fast
495
+
496
+ Python code doesn't *always* mean slow code. While UDFs *often* introduce performance overhead, there are exceptions. NumPy's universal functions ([`ufuncs`](https://numpy.org/doc/stable/reference/ufuncs.html)) and generalized universal functions ([`gufuncs`](https://numpy.org/neps/nep-0005-generalized-ufuncs.html)) provide high-performance operations on NumPy arrays, thanks to low-level implementations.
497
+
498
+ But NumPy's built-in functions are predefined. We can't easily use them for *custom* logic. Enter [`numba`](https://numba.pydata.org/). Numba is a just-in-time (JIT) compiler that translates Python functions into optimized machine code *at runtime*. It provides decorators like [`numba.guvectorize`](https://numba.readthedocs.io/en/stable/user/vectorize.html#the-guvectorize-decorator) that let us create our *own* high-performance `gufuncs` – *without* writing low-level code!
499
+ """
500
+ )
501
+ return
502
+
503
+
504
+ @app.cell(hide_code=True)
505
+ def _(mo):
506
+ mo.md(
507
+ r"""
508
+ Let's create a custom popularity metric to rank notebooks, considering likes, forks, *and* comments (not just likes). We'll define `weighted_popularity_numba`, decorated with `@numba.guvectorize`. The decorator arguments specify that we're taking three integer vectors of length `n` and returning a float vector of length `n`.
509
+
510
+ The weighted popularity score for each notebook is calculated using the following formula:
511
+
512
+ $$
513
+ \begin{equation}
514
+ \text{score}_i = w_l \cdot l_i^{f} + w_f \cdot f_i^{f} + w_c \cdot c_i^{f}
515
+ \end{equation}
516
+ $$
517
+
518
+ with:
519
+ """
520
+ )
521
+ return
522
+
523
+
524
+ @app.cell(hide_code=True)
525
+ def _(mo, non_linear_factor, weight_comments, weight_forks, weight_likes):
526
+ mo.md(rf"""
527
+ | Symbol | Description |
528
+ |--------|-------------|
529
+ | $\text{{score}}_i$ | Popularity score for the *i*-th notebook |
530
+ | $w_l = {weight_likes.value}$ | Weight for likes |
531
+ | $l_i$ | Number of likes for the *i*-th notebook |
532
+ | $w_f = {weight_forks.value}$ | Weight for forks |
533
+ | $f_i$ | Number of forks for the *i*-th notebook |
534
+ | $w_c = {weight_comments.value}$ | Weight for comments |
535
+ | $c_i$ | Number of comments for the *i*-th notebook |
536
+ | $f = {non_linear_factor.value}$ | Non-linear factor (exponent) |
537
+ """)
538
+ return
539
+
540
+
541
+ @app.cell(hide_code=True)
542
+ def _(mo):
543
+ weight_likes = mo.ui.slider(
544
+ start=0.1,
545
+ stop=1,
546
+ value=0.5,
547
+ step=0.1,
548
+ show_value=True,
549
+ label="⭐ Weight for Likes",
550
+ )
551
+ weight_forks = mo.ui.slider(
552
+ start=0.1,
553
+ stop=1,
554
+ value=0.3,
555
+ step=0.1,
556
+ show_value=True,
557
+ label="↗️ Weight for Forks",
558
+ )
559
+ weight_comments = mo.ui.slider(
560
+ start=0.1,
561
+ stop=1,
562
+ value=0.5,
563
+ step=0.1,
564
+ show_value=True,
565
+ label="💬 Weight for Comments",
566
+ )
567
+ non_linear_factor = mo.ui.slider(
568
+ start=1,
569
+ stop=2,
570
+ value=1.2,
571
+ step=0.1,
572
+ show_value=True,
573
+ label="🎢 Non-Linear Factor",
574
+ )
575
+ return non_linear_factor, weight_comments, weight_forks, weight_likes
576
+
577
+
578
+ @app.cell(hide_code=True)
579
+ def _(
580
+ non_linear_factor,
581
+ np,
582
+ numba,
583
+ weight_comments,
584
+ weight_forks,
585
+ weight_likes,
586
+ ):
587
+ w_l = weight_likes.value
588
+ w_f = weight_forks.value
589
+ w_c = weight_comments.value
590
+ nlf = non_linear_factor.value
591
+
592
+
593
+ @numba.guvectorize(
594
+ [(numba.int64[:], numba.int64[:], numba.int64[:], numba.float64[:])],
595
+ "(n), (n), (n) -> (n)",
596
+ )
597
+ def weighted_popularity_numba(
598
+ likes: np.ndarray,
599
+ forks: np.ndarray,
600
+ comments: np.ndarray,
601
+ out: np.ndarray,
602
+ ):
603
+ for i in range(likes.shape[0]):
604
+ out[i] = (
605
+ w_l * (likes[i] ** nlf)
606
+ + w_f * (forks[i] ** nlf)
607
+ + w_c * (comments[i] ** nlf)
608
+ )
609
+ return nlf, w_c, w_f, w_l, weighted_popularity_numba
610
+
611
+
612
+ @app.cell(hide_code=True)
613
+ def _(mo):
614
+ mo.md(r"""We apply our JIT-compiled UDF using `map_batches`, as before. The key is that we're passing entire columns directly to `weighted_popularity_numba`. Polars and Numba handle the conversion to NumPy arrays behind the scenes. This direct integration is a major benefit of using `guvectorize`.""")
615
+ return
616
+
617
+
618
+ @app.cell(hide_code=True)
619
+ def _(notebook_stats_df, pl, weighted_popularity_numba):
620
+ notebook_popularity_df = (
621
+ notebook_stats_df.select(
622
+ pl.col("notebook_thumbnail_src").alias("thumbnail"),
623
+ "title",
624
+ "likes",
625
+ "forks",
626
+ "comments",
627
+ popularity=pl.struct(["likes", "forks", "comments"]).map_batches(
628
+ lambda obj: weighted_popularity_numba(
629
+ obj.struct.field("likes"),
630
+ obj.struct.field("forks"),
631
+ obj.struct.field("comments"),
632
+ ),
633
+ return_dtype=pl.Float64,
634
+ ),
635
+ url="notebook_url",
636
+ )
637
+ )
638
+ return (notebook_popularity_df,)
639
+
640
+
641
+ @app.cell(hide_code=True)
642
+ def _(mo):
643
+ mo.callout("💡 Adjust the hyperparameters of the popularity ranking UDF. How do the weights and non-linear factor affect the notebook rankings?")
644
+ return
645
+
646
+
647
+ @app.cell(hide_code=True)
648
+ def _(
649
+ mo,
650
+ non_linear_factor,
651
+ notebook_popularity_df,
652
+ weight_comments,
653
+ weight_forks,
654
+ weight_likes,
655
+ ):
656
+ mo.vstack(
657
+ [
658
+ mo.hstack([weight_likes, weight_forks]),
659
+ mo.hstack([weight_comments, non_linear_factor]),
660
+ notebook_popularity_df,
661
+ ]
662
+ )
663
+ return
664
+
665
+
666
+ @app.cell(hide_code=True)
667
+ def _(mo):
668
+ mo.md(r"""As the slope chart below demonstrates, this new ranking strategy significantly changes the notebook order, as it considers forks and comments, not just likes.""")
669
+ return
670
+
671
+
672
+ @app.cell(hide_code=True)
673
+ def _(alt, notebook_popularity_df, pl):
674
+ notebook_ranks_df = (
675
+ notebook_popularity_df.sort("likes", descending=True)
676
+ .with_row_index("rank_by_likes")
677
+ .with_columns(pl.col("rank_by_likes") + 1)
678
+ .sort("popularity", descending=True)
679
+ .with_row_index("rank_by_popularity")
680
+ .with_columns(pl.col("rank_by_popularity") + 1)
681
+ .select("thumbnail", "title", "rank_by_popularity", "rank_by_likes")
682
+ .unpivot(
683
+ ["rank_by_popularity", "rank_by_likes"],
684
+ index="title",
685
+ variable_name="strategy",
686
+ value_name="rank",
687
+ )
688
+ )
689
+
690
+ # Slope chart to visualize rank differences by strategy
691
+ lines = notebook_ranks_df.plot.line(
692
+ x="strategy:O",
693
+ y="rank:Q",
694
+ color="title:N",
695
+ )
696
+ points = notebook_ranks_df.plot.point(
697
+ x="strategy:O",
698
+ y="rank:Q",
699
+ color=alt.Color("title:N", legend=None),
700
+ fill="title:N",
701
+ )
702
+ (points + lines).properties(width=400)
703
+ return lines, notebook_ranks_df, points
704
+
705
+
706
+ @app.cell(hide_code=True)
707
+ def _(mo):
708
+ mo.md(
709
+ r"""
710
+ ## ⏱️ Quantifying the Overhead
711
+
712
+ > UDF Performance Comparison
713
+
714
+ To truly understand the performance implications of using UDFs, let's conduct a benchmark. We'll create a DataFrame with random numbers and perform the same numerical operation using four different methods:
715
+
716
+ 1. **Native Polars:** Using Polars' built-in expressions.
717
+ 2. **`map_elements`:** Applying a Python function element-wise.
718
+ 3. **`map_batches`:** **Applying** a Python function to the entire Series.
719
+ 4. **`map_batches` with Numba:** Applying a JIT-compiled function to batches, similar to a generalized universal function.
720
+
721
+ We'll use a simple, but non-trivial, calculation: `result = (x * 2.5 + 5) / (x + 1)`. This involves multiplication, addition, and division, giving us a realistic representation of a common numerical operation. We'll use the `timeit` module, to accurately measure execution times over multiple trials.
722
+ """
723
+ )
724
+ return
725
+
726
+
727
+ @app.cell(hide_code=True)
728
+ def _(mo):
729
+ mo.callout("💡 Tweak the benchmark parameters to explore how execution times change with different sample sizes and trial counts. Do you notice anything surprising as you decrease the number of samples?")
730
+ return
731
+
732
+
733
+ @app.cell(hide_code=True)
734
+ def _(benchmark_plot, mo, num_samples, num_trials):
735
+ mo.vstack(
736
+ [
737
+ mo.hstack([num_samples, num_trials]),
738
+ mo.md(
739
+ f"""---
740
+ Performance comparison over **{num_trials.value:,} trials** with **{num_samples.value:,} samples**.
741
+
742
+ > Lower execution times are better.
743
+ """
744
+ ),
745
+ benchmark_plot,
746
+ ]
747
+ )
748
+ return
749
+
750
+
751
+ @app.cell(hide_code=True)
752
+ def _(mo):
753
+ mo.md(
754
+ r"""
755
+ As anticipated, the `Batch-Wise UDF (Python)` and `Element-Wise UDF` exhibit significantly worse performance, essentially acting as pure-Python for-each loops.
756
+
757
+ However, when Python serves as an interface to lower-level, high-performance libraries, we observe substantial improvements. The `Batch-Wise UDF (NumPy)` lags behind both `Batch-Wise UDF (Numba)` and `Native Polars`, but it still represents a considerable improvement over pure-Python UDFs due to its vectorized computations.
758
+
759
+ Numba's Just-In-Time (JIT) compilation delivers a dramatic performance boost, achieving speeds comparable to native Polars expressions. This demonstrates that UDFs, particularly when combined with tools like Numba, don't inevitably lead to bottlenecks in numerical computations.
760
+ """
761
+ )
762
+ return
763
+
764
+
765
+ @app.cell(hide_code=True)
766
+ def _(mo):
767
+ num_samples = mo.ui.slider(
768
+ start=1_000,
769
+ stop=1_000_000,
770
+ value=250_000,
771
+ step=1000,
772
+ show_value=True,
773
+ debounce=True,
774
+ label="Number of Samples",
775
+ )
776
+ num_trials = mo.ui.slider(
777
+ start=50,
778
+ stop=1_000,
779
+ value=100,
780
+ step=50,
781
+ show_value=True,
782
+ debounce=True,
783
+ label="Number of Trials",
784
+ )
785
+ return num_samples, num_trials
786
+
787
+
788
+ @app.cell(hide_code=True)
789
+ def _(np, num_samples, pl):
790
+ rng = np.random.default_rng(42)
791
+ sample_df = pl.from_dict({"x": rng.random(num_samples.value)})
792
+ return rng, sample_df
793
+
794
+
795
+ @app.cell(hide_code=True)
796
+ def _(np, num_trials, numba, pl, sample_df, timeit):
797
+ def run_native():
798
+ sample_df.with_columns(
799
+ result_native=(pl.col("x") * 2.5 + 5) / (pl.col("x") + 1)
800
+ )
801
+
802
+
803
+ def _calculate_elementwise(x: float) -> float:
804
+ return (x * 2.5 + 5) / (x + 1)
805
+
806
+
807
+ def run_map_elements():
808
+ sample_df.with_columns(
809
+ result_map_elements=pl.col("x").map_elements(
810
+ _calculate_elementwise,
811
+ return_dtype=pl.Float64,
812
+ )
813
+ )
814
+
815
+
816
+ def _calculate_batchwise_numpy(x_series: pl.Series) -> pl.Series:
817
+ x_array = x_series.to_numpy()
818
+ result_array = (x_array * 2.5 + 5) / (x_array + 1)
819
+ return pl.Series(result_array)
820
+
821
+
822
+ def run_map_batches_numpy():
823
+ sample_df.with_columns(
824
+ result_map_batches_numpy=pl.col("x").map_batches(
825
+ _calculate_batchwise_numpy,
826
+ return_dtype=pl.Float64,
827
+ )
828
+ )
829
+
830
+
831
+ def _calculate_batchwise_python(x_series: pl.Series) -> pl.Series:
832
+ x_array = x_series.to_list()
833
+ result_array = [_calculate_elementwise(x) for x in x_array]
834
+ return pl.Series(result_array)
835
+
836
+
837
+ def run_map_batches_python():
838
+ sample_df.with_columns(
839
+ result_map_batches_python=pl.col("x").map_batches(
840
+ _calculate_batchwise_python,
841
+ return_dtype=pl.Float64,
842
+ )
843
+ )
844
+
845
+
846
+ @numba.guvectorize([(numba.float64[:], numba.float64[:])], "(n) -> (n)")
847
+ def _calculate_batchwise_numba(x: np.ndarray, out: np.ndarray):
848
+ for i in range(x.shape[0]):
849
+ out[i] = (x[i] * 2.5 + 5) / (x[i] + 1)
850
+
851
+
852
+ def run_map_batches_numba():
853
+ sample_df.with_columns(
854
+ result_map_batches_numba=pl.col("x").map_batches(
855
+ _calculate_batchwise_numba,
856
+ return_dtype=pl.Float64,
857
+ )
858
+ )
859
+
860
+
861
+ def time_method(callable_name: str, number=num_trials.value) -> float:
862
+ fn = globals()[callable_name]
863
+ return timeit.timeit(fn, number=number)
864
+ return (
865
+ run_map_batches_numba,
866
+ run_map_batches_numpy,
867
+ run_map_batches_python,
868
+ run_map_elements,
869
+ run_native,
870
+ time_method,
871
+ )
872
+
873
+
874
+ @app.cell(hide_code=True)
875
+ def _(alt, pl, time_method):
876
+ benchmark_df = pl.from_dicts(
877
+ [
878
+ {
879
+ "title": "Native Polars",
880
+ "callable_name": "run_native",
881
+ },
882
+ {
883
+ "title": "Element-Wise UDF",
884
+ "callable_name": "run_map_elements",
885
+ },
886
+ {
887
+ "title": "Batch-Wise UDF (NumPy)",
888
+ "callable_name": "run_map_batches_numpy",
889
+ },
890
+ {
891
+ "title": "Batch-Wise UDF (Python)",
892
+ "callable_name": "run_map_batches_python",
893
+ },
894
+ {
895
+ "title": "Batch-Wise UDF (Numba)",
896
+ "callable_name": "run_map_batches_numba",
897
+ },
898
+ ]
899
+ ).with_columns(
900
+ time=pl.col("callable_name").map_elements(
901
+ time_method, return_dtype=pl.Float64
902
+ )
903
+ )
904
+
905
+ benchmark_plot = benchmark_df.plot.bar(
906
+ x=alt.X("title:N", title="Method", sort="-y"),
907
+ y=alt.Y("time:Q", title="Execution Time (s)", axis=alt.Axis(format=".3f")),
908
+ ).properties(width=400)
909
+ return benchmark_df, benchmark_plot
910
+
911
+
912
+ @app.cell(hide_code=True)
913
+ def _():
914
+ import asyncio
915
+ import timeit
916
+ from typing import Iterable
917
+
918
+ import altair as alt
919
+ import httpx
920
+ import marimo as mo
921
+ import nest_asyncio
922
+ import numba
923
+ import numpy as np
924
+ from bs4 import BeautifulSoup
925
+
926
+ import polars as pl
927
+
928
+ # Fixes RuntimeError: asyncio.run() cannot be called from a running event loop
929
+ nest_asyncio.apply()
930
+ return (
931
+ BeautifulSoup,
932
+ Iterable,
933
+ alt,
934
+ asyncio,
935
+ httpx,
936
+ mo,
937
+ nest_asyncio,
938
+ np,
939
+ numba,
940
+ pl,
941
+ timeit,
942
+ )
943
+
944
+
945
+ if __name__ == "__main__":
946
+ app.run()
polars/README.md CHANGED
@@ -23,3 +23,4 @@ You can also open notebooks in our online playground by appending marimo.app/ to
23
  Thanks to all our notebook authors!
24
 
25
  * [Koushik Khan](https://github.com/koushikkhan)
 
 
23
  Thanks to all our notebook authors!
24
 
25
  * [Koushik Khan](https://github.com/koushikkhan)
26
+ * [Péter Gyarmati](https://github.com/peter-gy)
probability/08_bayes_theorem.py CHANGED
@@ -307,7 +307,7 @@ def _(mo):
307
  mo.md(
308
  r"""
309
 
310
- _This interactive exmaple was made with [marimo](https://github.com/marimo-team/marimo/blob/main/examples/misc/bayes_theorem.py), and is [based on an explanation of Bayes' Theorem by Grant Sanderson](https://www.youtube.com/watch?v=HZGCoVF3YvM&list=PLzq7odmtfKQw2KIbQq0rzWrqgifHKkPG1&index=1&t=3s)_.
311
 
312
  Bayes theorem provides a convenient way to calculate the probability
313
  of a hypothesis event $H$ given evidence $E$:
 
307
  mo.md(
308
  r"""
309
 
310
+ _This interactive example was made with [marimo](https://github.com/marimo-team/marimo/blob/main/examples/misc/bayes_theorem.py), and is [based on an explanation of Bayes' Theorem by Grant Sanderson](https://www.youtube.com/watch?v=HZGCoVF3YvM&list=PLzq7odmtfKQw2KIbQq0rzWrqgifHKkPG1&index=1&t=3s)_.
311
 
312
  Bayes theorem provides a convenient way to calculate the probability
313
  of a hypothesis event $H$ given evidence $E$:
probability/10_probability_mass_function.py ADDED
@@ -0,0 +1,711 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.0",
6
+ # "numpy==2.2.3",
7
+ # "scipy==1.15.2",
8
+ # ]
9
+ # ///
10
+
11
+ import marimo
12
+
13
+ __generated_with = "0.11.17"
14
+ app = marimo.App(width="medium", app_title="Probability Mass Functions")
15
+
16
+
17
+ @app.cell(hide_code=True)
18
+ def _(mo):
19
+ mo.md(
20
+ r"""
21
+ # Probability Mass Functions
22
+
23
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/pmf/), by Stanford professor Chris Piech._
24
+
25
+ For a random variable, the most important thing to know is: how likely is each outcome? For a discrete random variable, this information is called the "**Probability Mass Function**". The probability mass function (PMF) provides the "mass" (i.e. amount) of "probability" for each possible assignment of the random variable.
26
+
27
+ Formally, the Probability Mass Function is a mapping between the values that the random variable could take on and the probability of the random variable taking on said value. In mathematics, we call these associations functions. There are many different ways of representing functions: you can write an equation, you can make a graph, you can even store many samples in a list.
28
+ """
29
+ )
30
+ return
31
+
32
+
33
+ @app.cell(hide_code=True)
34
+ def _(mo):
35
+ mo.md(
36
+ r"""
37
+ ## Properties of a PMF
38
+
39
+ For a function $p_X(x)$ to be a valid PMF, it must satisfy:
40
+
41
+ 1. **Non-negativity**: $p_X(x) \geq 0$ for all $x$
42
+ 2. **Unit total probability**: $\sum_x p_X(x) = 1$
43
+
44
+ ### Probabilities Must Sum to 1
45
+
46
+ For a variable (call it $X$) to be a proper random variable, it must be the case that if you summed up the values of $P(X=x)$ for all possible values $x$ that $X$ can take on, the result must be 1:
47
+
48
+ $$\sum_x P(X=x) = 1$$
49
+
50
+ This is because a random variable taking on a value is an event (for example $X=3$). Each of those events is mutually exclusive because a random variable will take on exactly one value. Those mutually exclusive cases define an entire sample space. Why? Because $X$ must take on some value.
51
+ """
52
+ )
53
+ return
54
+
55
+
56
+ @app.cell(hide_code=True)
57
+ def _(mo):
58
+ mo.md(
59
+ r"""
60
+ ## PMFs as Graphs
61
+
62
+ Let's start by looking at PMFs as graphs where the $x$-axis is the values that the random variable could take on and the $y$-axis is the probability of the random variable taking on said value.
63
+
64
+ In the following example, we show two PMFs:
65
+
66
+ - On the left: PMF for the random variable $X$ = the value of a single six-sided die roll
67
+ - On the right: PMF for the random variable $Y$ = value of the sum of two dice rolls
68
+ """
69
+ )
70
+ return
71
+
72
+
73
+ @app.cell(hide_code=True)
74
+ def _(np, plt):
75
+ # Single die PMF
76
+ single_die_values = np.arange(1, 7)
77
+ single_die_probs = np.ones(6) / 6
78
+
79
+ # Two dice sum PMF
80
+ two_dice_values = np.arange(2, 13)
81
+ two_dice_probs = []
82
+
83
+ for dice_sum in two_dice_values:
84
+ if dice_sum <= 7:
85
+ dice_prob = (dice_sum-1) / 36
86
+ else:
87
+ dice_prob = (13-dice_sum) / 36
88
+ two_dice_probs.append(dice_prob)
89
+
90
+ # Create side-by-side plots
91
+ dice_fig, (dice_ax1, dice_ax2) = plt.subplots(1, 2, figsize=(12, 4))
92
+
93
+ # Single die plot
94
+ dice_ax1.bar(single_die_values, single_die_probs, width=0.4)
95
+ dice_ax1.set_xticks(single_die_values)
96
+ dice_ax1.set_xlabel('Value of die roll (x)')
97
+ dice_ax1.set_ylabel('Probability: P(X = x)')
98
+ dice_ax1.set_title('PMF of a Single Die Roll')
99
+ dice_ax1.grid(alpha=0.3)
100
+
101
+ # Two dice sum plot
102
+ dice_ax2.bar(two_dice_values, two_dice_probs, width=0.4)
103
+ dice_ax2.set_xticks(two_dice_values)
104
+ dice_ax2.set_xlabel('Sum of two dice (y)')
105
+ dice_ax2.set_ylabel('Probability: P(Y = y)')
106
+ dice_ax2.set_title('PMF of Sum of Two Dice')
107
+ dice_ax2.grid(alpha=0.3)
108
+
109
+ plt.tight_layout()
110
+ plt.gca()
111
+ return (
112
+ dice_ax1,
113
+ dice_ax2,
114
+ dice_fig,
115
+ dice_prob,
116
+ dice_sum,
117
+ single_die_probs,
118
+ single_die_values,
119
+ two_dice_probs,
120
+ two_dice_values,
121
+ )
122
+
123
+
124
+ @app.cell(hide_code=True)
125
+ def _(mo):
126
+ mo.md(
127
+ r"""
128
+ The information provided in these graphs shows the likelihood of a random variable taking on different values.
129
+
130
+ In the graph on the right, the value "6" on the $x$-axis is associated with the probability $\frac{5}{36}$ on the $y$-axis. This $x$-axis refers to the event "the sum of two dice is 6" or $Y = 6$. The $y$-axis tells us that the probability of that event is $\frac{5}{36}$. In full: $P(Y = 6) = \frac{5}{36}$.
131
+
132
+ The value "2" is associated with "$\frac{1}{36}$" which tells us that, $P(Y = 2) = \frac{1}{36}$, the probability that two dice sum to 2 is $\frac{1}{36}$. There is no value associated with "1" because the sum of two dice cannot be 1.
133
+ """
134
+ )
135
+ return
136
+
137
+
138
+ @app.cell(hide_code=True)
139
+ def _(mo):
140
+ mo.md(
141
+ r"""
142
+ ## PMFs as Equations
143
+
144
+ Here is the exact same information in equation form:
145
+
146
+ For a single die roll $X$:
147
+ $$P(X=x) = \frac{1}{6} \quad \text{ if } 1 \leq x \leq 6$$
148
+
149
+ For the sum of two dice $Y$:
150
+ $$P(Y=y) = \begin{cases}
151
+ \frac{(y-1)}{36} & \text{ if } 2 \leq y \leq 7\\
152
+ \frac{(13-y)}{36} & \text{ if } 8 \leq y \leq 12
153
+ \end{cases}$$
154
+
155
+ Let's implement the PMF for $Y$, the sum of two dice, in Python code:
156
+ """
157
+ )
158
+ return
159
+
160
+
161
+ @app.cell
162
+ def _():
163
+ def pmf_sum_two_dice(y_val):
164
+ """Returns the probability that the sum of two dice is y"""
165
+ if y_val < 2 or y_val > 12:
166
+ return 0
167
+ if y_val <= 7:
168
+ return (y_val-1) / 36
169
+ else:
170
+ return (13-y_val) / 36
171
+
172
+ # Test the function for a few values
173
+ test_values = [1, 2, 7, 12, 13]
174
+ for test_y in test_values:
175
+ print(f"P(Y = {test_y}) = {pmf_sum_two_dice(test_y)}")
176
+ return pmf_sum_two_dice, test_values, test_y
177
+
178
+
179
+ @app.cell(hide_code=True)
180
+ def _(mo):
181
+ mo.md(r"""Now, let's verify that our PMF satisfies the property that the sum of all probabilities equals 1:""")
182
+ return
183
+
184
+
185
+ @app.cell
186
+ def _(pmf_sum_two_dice):
187
+ # Verify that probabilities sum to 1
188
+ verify_total_prob = sum(pmf_sum_two_dice(y_val) for y_val in range(2, 13))
189
+ # Round to 10 decimal places to handle floating-point precision
190
+ verify_total_prob_rounded = round(verify_total_prob, 10)
191
+ print(f"Sum of all probabilities: {verify_total_prob_rounded}")
192
+ return verify_total_prob, verify_total_prob_rounded
193
+
194
+
195
+ @app.cell(hide_code=True)
196
+ def _(plt, pmf_sum_two_dice):
197
+ # Create a visual verification
198
+ verify_y_values = list(range(2, 13))
199
+ verify_probabilities = [pmf_sum_two_dice(y_val) for y_val in verify_y_values]
200
+
201
+ plt.figure(figsize=(10, 4))
202
+ plt.bar(verify_y_values, verify_probabilities, width=0.4)
203
+ plt.xticks(verify_y_values)
204
+ plt.xlabel('Sum of two dice (y)')
205
+ plt.ylabel('Probability: P(Y = y)')
206
+ plt.title('PMF of Sum of Two Dice (Total Probability = 1)')
207
+ plt.grid(alpha=0.3)
208
+
209
+ # Add probability values on top of bars
210
+ for verify_i, verify_prob in enumerate(verify_probabilities):
211
+ plt.text(verify_y_values[verify_i], verify_prob + 0.001, f'{verify_prob:.3f}', ha='center')
212
+
213
+ plt.gca() # Return the current axes to ensure proper display
214
+ return verify_i, verify_prob, verify_probabilities, verify_y_values
215
+
216
+
217
+ @app.cell(hide_code=True)
218
+ def _(mo):
219
+ mo.md(
220
+ r"""
221
+ ## Data to Histograms to Probability Mass Functions
222
+
223
+ One surprising way to store a likelihood function (recall that a PMF is the name of the likelihood function for discrete random variables) is simply a list of data. Let's simulate summing two dice many times to create an empirical PMF:
224
+ """
225
+ )
226
+ return
227
+
228
+
229
+ @app.cell
230
+ def _(np):
231
+ # Simulate rolling two dice many times
232
+ sim_num_trials = 10000
233
+ np.random.seed(42) # For reproducibility
234
+
235
+ # Generate random dice rolls
236
+ sim_die1 = np.random.randint(1, 7, size=sim_num_trials)
237
+ sim_die2 = np.random.randint(1, 7, size=sim_num_trials)
238
+
239
+ # Calculate the sum
240
+ sim_dice_sums = sim_die1 + sim_die2
241
+
242
+ # Display a small sample of the data
243
+ print(f"First 20 dice sums: {sim_dice_sums[:20]}")
244
+ print(f"Total number of trials: {sim_num_trials}")
245
+ return sim_dice_sums, sim_die1, sim_die2, sim_num_trials
246
+
247
+
248
+ @app.cell(hide_code=True)
249
+ def _(collections, np, plt, sim_dice_sums):
250
+ # Count the frequency of each sum
251
+ sim_counter = collections.Counter(sim_dice_sums)
252
+
253
+ # Sort the values
254
+ sim_sorted_values = sorted(sim_counter.keys())
255
+
256
+ # Calculate the empirical PMF
257
+ sim_empirical_pmf = [sim_counter[x] / len(sim_dice_sums) for x in sim_sorted_values]
258
+
259
+ # Calculate the theoretical PMF
260
+ sim_theoretical_values = np.arange(2, 13)
261
+ sim_theoretical_pmf = []
262
+ for sim_y in sim_theoretical_values:
263
+ if sim_y <= 7:
264
+ sim_prob = (sim_y-1) / 36
265
+ else:
266
+ sim_prob = (13-sim_y) / 36
267
+ sim_theoretical_pmf.append(sim_prob)
268
+
269
+ # Create a comparison plot
270
+ sim_fig, (sim_ax1, sim_ax2) = plt.subplots(1, 2, figsize=(12, 4))
271
+
272
+ # Empirical PMF (normalized histogram)
273
+ sim_ax1.bar(sim_sorted_values, sim_empirical_pmf, width=0.4)
274
+ sim_ax1.set_xticks(sim_sorted_values)
275
+ sim_ax1.set_xlabel('Sum of two dice')
276
+ sim_ax1.set_ylabel('Empirical Probability')
277
+ sim_ax1.set_title(f'Empirical PMF from {len(sim_dice_sums)} Trials')
278
+ sim_ax1.grid(alpha=0.3)
279
+
280
+ # Theoretical PMF
281
+ sim_ax2.bar(sim_theoretical_values, sim_theoretical_pmf, width=0.4)
282
+ sim_ax2.set_xticks(sim_theoretical_values)
283
+ sim_ax2.set_xlabel('Sum of two dice')
284
+ sim_ax2.set_ylabel('Theoretical Probability')
285
+ sim_ax2.set_title('Theoretical PMF')
286
+ sim_ax2.grid(alpha=0.3)
287
+
288
+ plt.tight_layout()
289
+
290
+ # Let's also look at the raw counts (histogram)
291
+ plt.figure(figsize=(10, 4))
292
+ sim_counts = [sim_counter[x] for x in sim_sorted_values]
293
+ plt.bar(sim_sorted_values, sim_counts, width=0.4)
294
+ plt.xticks(sim_sorted_values)
295
+ plt.xlabel('Sum of two dice')
296
+ plt.ylabel('Frequency')
297
+ plt.title('Histogram of Dice Sum Frequencies')
298
+ plt.grid(alpha=0.3)
299
+
300
+ # Add count values on top of bars
301
+ for sim_i, sim_count in enumerate(sim_counts):
302
+ plt.text(sim_sorted_values[sim_i], sim_count + 19, str(sim_count), ha='center')
303
+
304
+ plt.gca() # Return the current axes to ensure proper display
305
+ return (
306
+ sim_ax1,
307
+ sim_ax2,
308
+ sim_count,
309
+ sim_counter,
310
+ sim_counts,
311
+ sim_empirical_pmf,
312
+ sim_fig,
313
+ sim_i,
314
+ sim_prob,
315
+ sim_sorted_values,
316
+ sim_theoretical_pmf,
317
+ sim_theoretical_values,
318
+ sim_y,
319
+ )
320
+
321
+
322
+ @app.cell(hide_code=True)
323
+ def _(mo):
324
+ mo.md(
325
+ r"""
326
+ A normalized histogram (where each value is divided by the length of your data list) is an approximation of the PMF. For a dataset of discrete numbers, a histogram shows the count of each value. By the definition of probability, if you divide this count by the number of experiments run, you arrive at an approximation of the probability of the event $P(Y=y)$.
327
+
328
+ Let's look at a specific example. If we want to approximate $P(Y=3)$ (the probability that the sum of two dice is 3), we can count the number of times that "3" occurs in our data and divide by the total number of trials:
329
+ """
330
+ )
331
+ return
332
+
333
+
334
+ @app.cell
335
+ def _(sim_counter, sim_dice_sums):
336
+ # Calculate P(Y=3) empirically
337
+ sim_count_of_3 = sim_counter[3]
338
+ sim_empirical_prob = sim_count_of_3 / len(sim_dice_sums)
339
+
340
+ # Calculate P(Y=3) theoretically
341
+ sim_theoretical_prob = 2/36 # There are 2 ways to get a sum of 3 out of 36 possible outcomes
342
+
343
+ print(f"Count of sum=3: {sim_count_of_3}")
344
+ print(f"Empirical P(Y=3): {sim_count_of_3}/{len(sim_dice_sums)} = {sim_empirical_prob:.4f}")
345
+ print(f"Theoretical P(Y=3): 2/36 = {sim_theoretical_prob:.4f}")
346
+ print(f"Difference: {abs(sim_empirical_prob - sim_theoretical_prob):.4f}")
347
+ return sim_count_of_3, sim_empirical_prob, sim_theoretical_prob
348
+
349
+
350
+ @app.cell(hide_code=True)
351
+ def _(mo):
352
+ mo.md(
353
+ r"""
354
+ As we can see, with a large number of trials, the empirical PMF becomes a very good approximation of the theoretical PMF. This is an example of the [Law of Large Numbers](https://en.wikipedia.org/wiki/Law_of_large_numbers) in action.
355
+
356
+ ## Interactive Example: Exploring PMFs
357
+
358
+ Let's create an interactive tool to explore different PMFs:
359
+ """
360
+ )
361
+ return
362
+
363
+
364
+ @app.cell
365
+ def _(dist_param1, dist_param2, dist_selection, mo):
366
+ mo.hstack([dist_selection, dist_param1, dist_param2], justify="space-around")
367
+ return
368
+
369
+
370
+ @app.cell(hide_code=True)
371
+ def _(mo):
372
+ dist_selection = mo.ui.dropdown(
373
+ options=[
374
+ "bernoulli",
375
+ "binomial",
376
+ "geometric",
377
+ "poisson"
378
+ ],
379
+ value="bernoulli",
380
+ label="Select a distribution"
381
+ )
382
+
383
+ # Parameters for different distributions
384
+ dist_param1 = mo.ui.slider(
385
+ start=0.05,
386
+ stop=0.95,
387
+ step=0.05,
388
+ value=0.5,
389
+ label="p (success probability)"
390
+ )
391
+
392
+ dist_param2 = mo.ui.slider(
393
+ start=1,
394
+ stop=20,
395
+ step=1,
396
+ value=10,
397
+ label="n (trials) or λ (rate)"
398
+ )
399
+ return dist_param1, dist_param2, dist_selection
400
+
401
+
402
+ @app.cell(hide_code=True)
403
+ def _(dist_param1, dist_param2, dist_selection, np, plt, stats):
404
+ # Set up the plot based on the selected distribution
405
+ if dist_selection.value == "bernoulli":
406
+ # Bernoulli distribution
407
+ dist_p = dist_param1.value
408
+ dist_x_values = np.array([0, 1])
409
+ dist_pmf_values = [1-dist_p, dist_p]
410
+ dist_title = f"Bernoulli PMF (p = {dist_p:.2f})"
411
+ dist_x_label = "Outcome (0 = Failure, 1 = Success)"
412
+ dist_max_x = 1
413
+
414
+ elif dist_selection.value == "binomial":
415
+ # Binomial distribution
416
+ dist_n = int(dist_param2.value)
417
+ dist_p = dist_param1.value
418
+ dist_x_values = np.arange(0, dist_n+1)
419
+ dist_pmf_values = stats.binom.pmf(dist_x_values, dist_n, dist_p)
420
+ dist_title = f"Binomial PMF (n = {dist_n}, p = {dist_p:.2f})"
421
+ dist_x_label = "Number of Successes"
422
+ dist_max_x = dist_n
423
+
424
+ elif dist_selection.value == "geometric":
425
+ # Geometric distribution
426
+ dist_p = dist_param1.value
427
+ dist_max_x = min(int(5/dist_p), 50) # Limit the range for visualization
428
+ dist_x_values = np.arange(1, dist_max_x+1)
429
+ dist_pmf_values = stats.geom.pmf(dist_x_values, dist_p)
430
+ dist_title = f"Geometric PMF (p = {dist_p:.2f})"
431
+ dist_x_label = "Number of Trials Until First Success"
432
+
433
+ else: # Poisson
434
+ # Poisson distribution
435
+ dist_lam = dist_param2.value
436
+ dist_max_x = int(dist_lam*3) + 1 # Reasonable range for visualization
437
+ dist_x_values = np.arange(0, dist_max_x)
438
+ dist_pmf_values = stats.poisson.pmf(dist_x_values, dist_lam)
439
+ dist_title = f"Poisson PMF (λ = {dist_lam})"
440
+ dist_x_label = "Number of Events"
441
+
442
+ # Create the plot
443
+ plt.figure(figsize=(10, 5))
444
+
445
+ # For discrete distributions, use stem plot for clarity
446
+ dist_markerline, dist_stemlines, dist_baseline = plt.stem(
447
+ dist_x_values, dist_pmf_values, markerfmt='o', basefmt=' '
448
+ )
449
+ plt.setp(dist_markerline, markersize=6)
450
+ plt.setp(dist_stemlines, linewidth=1.5)
451
+
452
+ # Add a bar plot for better visibility
453
+ plt.bar(dist_x_values, dist_pmf_values, alpha=0.3, width=0.4)
454
+
455
+ plt.xlabel(dist_x_label)
456
+ plt.ylabel("Probability: P(X = x)")
457
+ plt.title(dist_title)
458
+ plt.grid(alpha=0.3)
459
+
460
+ # Calculate and display expected value and variance
461
+ if dist_selection.value == "bernoulli":
462
+ dist_mean = dist_p
463
+ dist_variance = dist_p * (1-dist_p)
464
+ elif dist_selection.value == "binomial":
465
+ dist_mean = dist_n * dist_p
466
+ dist_variance = dist_n * dist_p * (1-dist_p)
467
+ elif dist_selection.value == "geometric":
468
+ dist_mean = 1/dist_p
469
+ dist_variance = (1-dist_p)/(dist_p**2)
470
+ else: # Poisson
471
+ dist_mean = dist_lam
472
+ dist_variance = dist_lam
473
+
474
+ dist_std_dev = np.sqrt(dist_variance)
475
+
476
+ # Add text with distribution properties
477
+ dist_props_text = (
478
+ f"Mean: {dist_mean:.3f}\n"
479
+ f"Variance: {dist_variance:.3f}\n"
480
+ f"Std Dev: {dist_std_dev:.3f}\n"
481
+ f"Sum of probabilities: {sum(dist_pmf_values):.6f}"
482
+ )
483
+
484
+ plt.text(0.95, 0.95, dist_props_text,
485
+ transform=plt.gca().transAxes,
486
+ verticalalignment='top',
487
+ horizontalalignment='right',
488
+ bbox=dict(boxstyle='round', facecolor='white', alpha=0.8))
489
+
490
+ plt.gca() # Return the current axes to ensure proper display
491
+ return (
492
+ dist_baseline,
493
+ dist_lam,
494
+ dist_markerline,
495
+ dist_max_x,
496
+ dist_mean,
497
+ dist_n,
498
+ dist_p,
499
+ dist_pmf_values,
500
+ dist_props_text,
501
+ dist_std_dev,
502
+ dist_stemlines,
503
+ dist_title,
504
+ dist_variance,
505
+ dist_x_label,
506
+ dist_x_values,
507
+ )
508
+
509
+
510
+ @app.cell(hide_code=True)
511
+ def _(mo):
512
+ mo.md(
513
+ r"""
514
+ ## Expected Value from a PMF
515
+
516
+ The expected value (or mean) of a discrete random variable is calculated using its PMF:
517
+
518
+ $$E[X] = \sum_x x \cdot p_X(x)$$
519
+
520
+ This represents the long-run average value of the random variable.
521
+ """
522
+ )
523
+ return
524
+
525
+
526
+ @app.cell
527
+ def _(dist_pmf_values, dist_x_values):
528
+ def calc_expected_value(x_values, pmf_values):
529
+ """Calculate the expected value of a discrete random variable."""
530
+ return sum(x * p for x, p in zip(x_values, pmf_values))
531
+
532
+ # Calculate expected value for the current distribution
533
+ ev_dist_mean = calc_expected_value(dist_x_values, dist_pmf_values)
534
+
535
+ print(f"Expected value: {ev_dist_mean:.4f}")
536
+ return calc_expected_value, ev_dist_mean
537
+
538
+
539
+ @app.cell(hide_code=True)
540
+ def _(mo):
541
+ mo.md(
542
+ r"""
543
+ ## Variance from a PMF
544
+
545
+ The variance measures the spread or dispersion of a random variable around its mean:
546
+
547
+ $$\text{Var}(X) = E[(X - E[X])^2] = \sum_x (x - E[X])^2 \cdot p_X(x)$$
548
+
549
+ An alternative formula is:
550
+
551
+ $$\text{Var}(X) = E[X^2] - (E[X])^2 = \sum_x x^2 \cdot p_X(x) - \left(\sum_x x \cdot p_X(x)\right)^2$$
552
+ """
553
+ )
554
+ return
555
+
556
+
557
+ @app.cell
558
+ def _(dist_pmf_values, dist_x_values, ev_dist_mean, np):
559
+ def calc_variance(x_values, pmf_values, mean_value):
560
+ """Calculate the variance of a discrete random variable."""
561
+ return sum((x - mean_value)**2 * p for x, p in zip(x_values, pmf_values))
562
+
563
+ # Calculate variance for the current distribution
564
+ var_dist_var = calc_variance(dist_x_values, dist_pmf_values, ev_dist_mean)
565
+ var_dist_std_dev = np.sqrt(var_dist_var)
566
+
567
+ print(f"Variance: {var_dist_var:.4f}")
568
+ print(f"Standard deviation: {var_dist_std_dev:.4f}")
569
+ return calc_variance, var_dist_std_dev, var_dist_var
570
+
571
+
572
+ @app.cell(hide_code=True)
573
+ def _(mo):
574
+ mo.md(
575
+ r"""
576
+ ## PMF vs. CDF
577
+
578
+ The **Cumulative Distribution Function (CDF)** is related to the PMF but gives the probability that the random variable $X$ is less than or equal to a value $x$:
579
+
580
+ $$F_X(x) = P(X \leq x) = \sum_{k \leq x} p_X(k)$$
581
+
582
+ While the PMF gives the probability mass at each point, the CDF accumulates these probabilities.
583
+ """
584
+ )
585
+ return
586
+
587
+
588
+ @app.cell(hide_code=True)
589
+ def _(dist_pmf_values, dist_x_values, np, plt):
590
+ # Calculate the CDF from the PMF
591
+ cdf_dist_values = np.cumsum(dist_pmf_values)
592
+
593
+ # Create a plot comparing PMF and CDF
594
+ cdf_fig, (cdf_ax1, cdf_ax2) = plt.subplots(1, 2, figsize=(12, 4))
595
+
596
+ # PMF plot
597
+ cdf_ax1.bar(dist_x_values, dist_pmf_values, width=0.4, alpha=0.7)
598
+ cdf_ax1.set_xlabel('x')
599
+ cdf_ax1.set_ylabel('P(X = x)')
600
+ cdf_ax1.set_title('Probability Mass Function (PMF)')
601
+ cdf_ax1.grid(alpha=0.3)
602
+
603
+ # CDF plot - using step function with 'post' style for proper discrete representation
604
+ cdf_ax2.step(dist_x_values, cdf_dist_values, where='post', linewidth=2, color='blue')
605
+ cdf_ax2.scatter(dist_x_values, cdf_dist_values, s=50, color='blue')
606
+
607
+ # Set appropriate limits for better visualization
608
+ if len(dist_x_values) > 0:
609
+ x_min = min(dist_x_values) - 0.5
610
+ x_max = max(dist_x_values) + 0.5
611
+ cdf_ax2.set_xlim(x_min, x_max)
612
+ cdf_ax2.set_ylim(0, 1.05) # CDF goes from 0 to 1
613
+
614
+ cdf_ax2.set_xlabel('x')
615
+ cdf_ax2.set_ylabel('P(X ≤ x)')
616
+ cdf_ax2.set_title('Cumulative Distribution Function (CDF)')
617
+ cdf_ax2.grid(alpha=0.3)
618
+
619
+ plt.tight_layout()
620
+ plt.gca() # Return the current axes to ensure proper display
621
+ return cdf_ax1, cdf_ax2, cdf_dist_values, cdf_fig, x_max, x_min
622
+
623
+
624
+ @app.cell(hide_code=True)
625
+ def _(mo):
626
+ mo.md(
627
+ r"""
628
+ The graphs above illustrate the key difference between PMF and CDF:
629
+
630
+ - **PMF (left)**: Shows the probability of the random variable taking each specific value: P(X = x)
631
+ - **CDF (right)**: Shows the probability of the random variable being less than or equal to each value: P(X ≤ x)
632
+
633
+ The CDF at any point is the sum of all PMF values up to and including that point. This is why the CDF is always non-decreasing and eventually reaches 1. For discrete distributions like this one, the CDF forms a step function that jumps at each value in the support of the random variable.
634
+ """
635
+ )
636
+ return
637
+
638
+
639
+ @app.cell(hide_code=True)
640
+ def _(mo):
641
+ mo.md(
642
+ r"""
643
+ ## Test Your Understanding
644
+
645
+ Choose what you believe are the correct options in the questions below:
646
+
647
+ <details>
648
+ <summary>If X is a discrete random variable with PMF p(x), then p(x) must always be less than 1</summary>
649
+ ❌ False! While most values in a PMF are typically less than 1, a PMF can have p(x) = 1 for a specific value if the random variable always takes that value (with 100% probability).
650
+ </details>
651
+
652
+ <details>
653
+ <summary>The sum of all probabilities in a PMF must equal exactly 1</summary>
654
+ ✅ True! This is a fundamental property of any valid PMF. The total probability across all possible values must be 1, as the random variable must take some value.
655
+ </details>
656
+
657
+ <details>
658
+ <summary>A PMF can be estimated from data by creating a normalized histogram</summary>
659
+ ✅ True! Counting the frequency of each value and dividing by the total number of observations gives an empirical PMF.
660
+ </details>
661
+
662
+ <details>
663
+ <summary>The expected value of a discrete random variable is always one of the possible values of the variable</summary>
664
+ ❌ False! The expected value is a weighted average and may not be a value the random variable can actually take. For example, the expected value of a fair die roll is 3.5, which is not a possible outcome.
665
+ </details>
666
+ """
667
+ )
668
+ return
669
+
670
+
671
+ @app.cell(hide_code=True)
672
+ def _(mo):
673
+ mo.md(
674
+ r"""
675
+ ## Practical Applications of PMFs
676
+
677
+ PMFs pop up everywhere - network engineers use them to model traffic patterns, reliability teams predict equipment failures, and marketers analyze purchase behavior. In finance, they help price options; in gaming, they're behind every dice roll. Machine learning algorithms like Naive Bayes rely on them, and they're essential for modeling rare events like genetic mutations or system failures.
678
+ """
679
+ )
680
+ return
681
+
682
+
683
+ @app.cell(hide_code=True)
684
+ def _(mo):
685
+ mo.md(
686
+ r"""
687
+ ## Key Takeaways
688
+
689
+ PMFs give us the probability picture for discrete random variables - they tell us how likely each value is, must be non-negative, and always sum to 1. We can write them as equations, draw them as graphs, or estimate them from data. They're the foundation for calculating expected values and variances, which we'll explore in our next notebook on Expectation, where we'll learn how to summarize random variables with a single, most "expected" value.
690
+ """
691
+ )
692
+ return
693
+
694
+
695
+ @app.cell
696
+ def _():
697
+ import marimo as mo
698
+ return (mo,)
699
+
700
+
701
+ @app.cell
702
+ def _():
703
+ import matplotlib.pyplot as plt
704
+ import numpy as np
705
+ from scipy import stats
706
+ import collections
707
+ return collections, np, plt, stats
708
+
709
+
710
+ if __name__ == "__main__":
711
+ app.run()
probability/11_expectation.py ADDED
@@ -0,0 +1,860 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.0",
6
+ # "numpy==2.2.3",
7
+ # "scipy==1.15.2",
8
+ # ]
9
+ # ///
10
+
11
+ import marimo
12
+
13
+ __generated_with = "0.11.19"
14
+ app = marimo.App(width="medium", app_title="Expectation")
15
+
16
+
17
+ @app.cell(hide_code=True)
18
+ def _(mo):
19
+ mo.md(
20
+ r"""
21
+ # Expectation
22
+
23
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/expectation/), by Stanford professor Chris Piech._
24
+
25
+ A random variable is fully represented by its Probability Mass Function (PMF), which describes each value the random variable can take on and the corresponding probabilities. However, a PMF can contain a lot of information. Sometimes it's useful to summarize a random variable with a single value!
26
+
27
+ The most common, and arguably the most useful, summary of a random variable is its **Expectation** (also called the expected value or mean).
28
+ """
29
+ )
30
+ return
31
+
32
+
33
+ @app.cell(hide_code=True)
34
+ def _(mo):
35
+ mo.md(
36
+ r"""
37
+ ## Definition of Expectation
38
+
39
+ The expectation of a random variable $X$, written $E[X]$, is the average of all the values the random variable can take on, each weighted by the probability that the random variable will take on that value.
40
+
41
+ $$E[X] = \sum_x x \cdot P(X=x)$$
42
+
43
+ Expectation goes by many other names: Mean, Weighted Average, Center of Mass, 1st Moment. All of these are calculated using the same formula.
44
+ """
45
+ )
46
+ return
47
+
48
+
49
+ @app.cell(hide_code=True)
50
+ def _(mo):
51
+ mo.md(
52
+ r"""
53
+ ## Intuition Behind Expectation
54
+
55
+ The expected value represents the long-run average value of a random variable over many independent repetitions of an experiment.
56
+
57
+ For example, if you roll a fair six-sided die many times and calculate the average of all rolls, that average will approach the expected value of 3.5 as the number of rolls increases.
58
+
59
+ Let's visualize this concept:
60
+ """
61
+ )
62
+ return
63
+
64
+
65
+ @app.cell(hide_code=True)
66
+ def _(np, plt):
67
+ # Set random seed for reproducibility
68
+ np.random.seed(42)
69
+
70
+ # Simulate rolling a die many times
71
+ exp_num_rolls = 1000
72
+ exp_die_rolls = np.random.randint(1, 7, size=exp_num_rolls)
73
+
74
+ # Calculate the running average
75
+ exp_running_avg = np.cumsum(exp_die_rolls) / np.arange(1, exp_num_rolls + 1)
76
+
77
+ # Create the plot
78
+ plt.figure(figsize=(10, 5))
79
+ plt.plot(range(1, exp_num_rolls + 1), exp_running_avg, label='Running Average')
80
+ plt.axhline(y=3.5, color='r', linestyle='--', label='Expected Value (3.5)')
81
+ plt.xlabel('Number of Rolls')
82
+ plt.ylabel('Average Value')
83
+ plt.title('Running Average of Die Rolls Approaching Expected Value')
84
+ plt.legend()
85
+ plt.grid(alpha=0.3)
86
+ plt.xscale('log') # Log scale to better see convergence
87
+
88
+ # Add annotations
89
+ plt.annotate('As the number of rolls increases,\nthe average approaches the expected value',
90
+ xy=(exp_num_rolls, exp_running_avg[-1]), xytext=(exp_num_rolls/3, 4),
91
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1.5))
92
+
93
+ plt.gca()
94
+ return exp_die_rolls, exp_num_rolls, exp_running_avg
95
+
96
+
97
+ @app.cell(hide_code=True)
98
+ def _(mo):
99
+ mo.md(r"""## Properties of Expectation""")
100
+ return
101
+
102
+
103
+ @app.cell(hide_code=True)
104
+ def _(mo):
105
+ mo.accordion(
106
+ {
107
+ "1. Linearity of Expectation": mo.md(
108
+ r"""
109
+ $$E[aX + b] = a \cdot E[X] + b$$
110
+
111
+ Where $a$ and $b$ are constants (not random variables).
112
+
113
+ This means that if you multiply a random variable by a constant, the expectation is multiplied by that constant. And if you add a constant to a random variable, the expectation increases by that constant.
114
+ """
115
+ ),
116
+ "2. Expectation of the Sum of Random Variables": mo.md(
117
+ r"""
118
+ $$E[X + Y] = E[X] + E[Y]$$
119
+
120
+ This is true regardless of the relationship between $X$ and $Y$. They can be dependent, and they can have different distributions. This also applies with more than two random variables:
121
+
122
+ $$E\left[\sum_{i=1}^n X_i\right] = \sum_{i=1}^n E[X_i]$$
123
+ """
124
+ ),
125
+ "3. Law of the Unconscious Statistician (LOTUS)": mo.md(
126
+ r"""
127
+ $$E[g(X)] = \sum_x g(x) \cdot P(X=x)$$
128
+
129
+ This allows us to calculate the expected value of a function $g(X)$ of a random variable $X$ when we know the probability distribution of $X$ but don't explicitly know the distribution of $g(X)$.
130
+
131
+ This theorem has the humorous name "Law of the Unconscious Statistician" (LOTUS) because it's so useful that you should be able to employ it unconsciously.
132
+ """
133
+ ),
134
+ "4. Expectation of a Constant": mo.md(
135
+ r"""
136
+ $$E[a] = a$$
137
+
138
+ Sometimes in proofs, you'll end up with the expectation of a constant (rather than a random variable). Since a constant doesn't change, its expected value is just the constant itself.
139
+ """
140
+ ),
141
+ }
142
+ )
143
+ return
144
+
145
+
146
+ @app.cell(hide_code=True)
147
+ def _(mo):
148
+ mo.md(
149
+ r"""
150
+ ## Calculating Expectation
151
+
152
+ Let's calculate the expected value for some common examples:
153
+
154
+ ### Example 1: Fair Die Roll
155
+
156
+ For a fair six-sided die, the PMF is:
157
+
158
+ $$P(X=x) = \frac{1}{6} \text{ for } x \in \{1, 2, 3, 4, 5, 6\}$$
159
+
160
+ The expected value is:
161
+
162
+ $$E[X] = 1 \cdot \frac{1}{6} + 2 \cdot \frac{1}{6} + 3 \cdot \frac{1}{6} + 4 \cdot \frac{1}{6} + 5 \cdot \frac{1}{6} + 6 \cdot \frac{1}{6} = \frac{21}{6} = 3.5$$
163
+
164
+ Let's implement this calculation in Python:
165
+ """
166
+ )
167
+ return
168
+
169
+
170
+ @app.cell
171
+ def _():
172
+ def calc_expectation_die():
173
+ """Calculate the expected value of a fair six-sided die roll."""
174
+ exp_die_values = range(1, 7)
175
+ exp_die_probs = [1/6] * 6
176
+
177
+ exp_die_expected = sum(x * p for x, p in zip(exp_die_values, exp_die_probs))
178
+ return exp_die_expected
179
+
180
+ exp_die_result = calc_expectation_die()
181
+ print(f"Expected value of a fair die roll: {exp_die_result}")
182
+ return calc_expectation_die, exp_die_result
183
+
184
+
185
+ @app.cell(hide_code=True)
186
+ def _(mo):
187
+ mo.md(
188
+ r"""
189
+ ### Example 2: Sum of Two Dice
190
+
191
+ Now let's calculate the expected value for the sum of two fair dice. First, we need the PMF:
192
+ """
193
+ )
194
+ return
195
+
196
+
197
+ @app.cell
198
+ def _():
199
+ def pmf_sum_two_dice(y_val):
200
+ """Returns the probability that the sum of two dice is y."""
201
+ # Count the number of ways to get sum y
202
+ exp_count = 0
203
+ for dice1 in range(1, 7):
204
+ for dice2 in range(1, 7):
205
+ if dice1 + dice2 == y_val:
206
+ exp_count += 1
207
+ return exp_count / 36 # There are 36 possible outcomes (6×6)
208
+
209
+ # Test the function for a few values
210
+ exp_test_values = [2, 7, 12]
211
+ for exp_test_y in exp_test_values:
212
+ print(f"P(Y = {exp_test_y}) = {pmf_sum_two_dice(exp_test_y)}")
213
+ return exp_test_values, exp_test_y, pmf_sum_two_dice
214
+
215
+
216
+ @app.cell
217
+ def _(pmf_sum_two_dice):
218
+ def calc_expectation_sum_two_dice():
219
+ """Calculate the expected value of the sum of two dice."""
220
+ exp_sum_two_dice = 0
221
+ # Sum of dice can take on the values 2 through 12
222
+ for exp_x in range(2, 13):
223
+ exp_pr_x = pmf_sum_two_dice(exp_x) # PMF gives P(sum is x)
224
+ exp_sum_two_dice += exp_x * exp_pr_x
225
+ return exp_sum_two_dice
226
+
227
+ exp_sum_result = calc_expectation_sum_two_dice()
228
+
229
+ # Round to 2 decimal places for display
230
+ exp_sum_result_rounded = round(exp_sum_result, 2)
231
+
232
+ print(f"Expected value of the sum of two dice: {exp_sum_result_rounded}")
233
+
234
+ # Let's also verify this with a direct calculation
235
+ exp_direct_calc = sum(x * pmf_sum_two_dice(x) for x in range(2, 13))
236
+ exp_direct_calc_rounded = round(exp_direct_calc, 2)
237
+
238
+ print(f"Direct calculation: {exp_direct_calc_rounded}")
239
+
240
+ # Verify that this equals 7
241
+ print(f"Is the expected value exactly 7? {abs(exp_sum_result - 7) < 1e-10}")
242
+ return (
243
+ calc_expectation_sum_two_dice,
244
+ exp_direct_calc,
245
+ exp_direct_calc_rounded,
246
+ exp_sum_result,
247
+ exp_sum_result_rounded,
248
+ )
249
+
250
+
251
+ @app.cell(hide_code=True)
252
+ def _(mo):
253
+ mo.md(
254
+ r"""
255
+ ### Visualizing Expectation
256
+
257
+ Let's visualize the expectation for the sum of two dice. The expected value is the "center of mass" of the PMF:
258
+ """
259
+ )
260
+ return
261
+
262
+
263
+ @app.cell(hide_code=True)
264
+ def _(plt, pmf_sum_two_dice):
265
+ # Create the visualization
266
+ exp_y_values = list(range(2, 13))
267
+ exp_probabilities = [pmf_sum_two_dice(y) for y in exp_y_values]
268
+
269
+ dice_fig, dice_ax = plt.subplots(figsize=(10, 5))
270
+ dice_ax.bar(exp_y_values, exp_probabilities, width=0.4)
271
+ dice_ax.axvline(x=7, color='r', linestyle='--', linewidth=2, label='Expected Value (7)')
272
+
273
+ dice_ax.set_xticks(exp_y_values)
274
+ dice_ax.set_xlabel('Sum of two dice (y)')
275
+ dice_ax.set_ylabel('Probability: P(Y = y)')
276
+ dice_ax.set_title('PMF of Sum of Two Dice with Expected Value')
277
+ dice_ax.grid(alpha=0.3)
278
+ dice_ax.legend()
279
+
280
+ # Add probability values on top of bars
281
+ for exp_i, exp_prob in enumerate(exp_probabilities):
282
+ dice_ax.text(exp_y_values[exp_i], exp_prob + 0.001, f'{exp_prob:.3f}', ha='center')
283
+
284
+ plt.tight_layout()
285
+ plt.gca()
286
+ return dice_ax, dice_fig, exp_i, exp_prob, exp_probabilities, exp_y_values
287
+
288
+
289
+ @app.cell(hide_code=True)
290
+ def _(mo):
291
+ mo.md(
292
+ r"""
293
+ ## Demonstrating the Properties of Expectation
294
+
295
+ Let's demonstrate some of these properties with examples:
296
+ """
297
+ )
298
+ return
299
+
300
+
301
+ @app.cell
302
+ def _(exp_die_result):
303
+ # Demonstrate linearity of expectation (1)
304
+ # E[aX + b] = a*E[X] + b
305
+
306
+ # For a die roll X with E[X] = 3.5
307
+ prop_a = 2
308
+ prop_b = 10
309
+
310
+ # Calculate E[2X + 10] using the property
311
+ prop_expected_using_property = prop_a * exp_die_result + prop_b
312
+ prop_expected_using_property_rounded = round(prop_expected_using_property, 2)
313
+
314
+ print(f"Using linearity property: E[{prop_a}X + {prop_b}] = {prop_a} * E[X] + {prop_b} = {prop_expected_using_property_rounded}")
315
+
316
+ # Calculate E[2X + 10] directly
317
+ prop_expected_direct = sum((prop_a * x + prop_b) * (1/6) for x in range(1, 7))
318
+ prop_expected_direct_rounded = round(prop_expected_direct, 2)
319
+
320
+ print(f"Direct calculation: E[{prop_a}X + {prop_b}] = {prop_expected_direct_rounded}")
321
+
322
+ # Verify they match
323
+ print(f"Do they match? {abs(prop_expected_using_property - prop_expected_direct) < 1e-10}")
324
+ return (
325
+ prop_a,
326
+ prop_b,
327
+ prop_expected_direct,
328
+ prop_expected_direct_rounded,
329
+ prop_expected_using_property,
330
+ prop_expected_using_property_rounded,
331
+ )
332
+
333
+
334
+ @app.cell(hide_code=True)
335
+ def _(mo):
336
+ mo.md(
337
+ r"""
338
+ ### Law of the Unconscious Statistician (LOTUS)
339
+
340
+ Let's use LOTUS to calculate $E[X^2]$ for a die roll, which will be useful when we study variance:
341
+ """
342
+ )
343
+ return
344
+
345
+
346
+ @app.cell
347
+ def _():
348
+ # Calculate E[X^2] for a die roll using LOTUS (3)
349
+ lotus_die_values = range(1, 7)
350
+ lotus_die_probs = [1/6] * 6
351
+
352
+ # Using LOTUS: E[X^2] = sum(x^2 * P(X=x))
353
+ lotus_expected_x_squared = sum(x**2 * p for x, p in zip(lotus_die_values, lotus_die_probs))
354
+ lotus_expected_x_squared_rounded = round(lotus_expected_x_squared, 2)
355
+
356
+ expected_x_squared = 3.5**2
357
+ expected_x_squared_rounded = round(expected_x_squared, 2)
358
+
359
+ print(f"E[X^2] for a die roll = {lotus_expected_x_squared_rounded}")
360
+ print(f"(E[X])^2 for a die roll = {expected_x_squared_rounded}")
361
+ return (
362
+ expected_x_squared,
363
+ expected_x_squared_rounded,
364
+ lotus_die_probs,
365
+ lotus_die_values,
366
+ lotus_expected_x_squared,
367
+ lotus_expected_x_squared_rounded,
368
+ )
369
+
370
+
371
+ @app.cell(hide_code=True)
372
+ def _(mo):
373
+ mo.md(
374
+ r"""
375
+ /// Note
376
+ Note that E[X^2] != (E[X])^2
377
+ """
378
+ )
379
+ return
380
+
381
+
382
+ @app.cell(hide_code=True)
383
+ def _(mo):
384
+ mo.md(
385
+ r"""
386
+ ## Interactive Example
387
+
388
+ Let's explore how the expected value changes as we adjust the parameters of common probability distributions. This interactive visualization focuses specifically on the relationship between distribution parameters and expected values.
389
+
390
+ Use the controls below to select a distribution and adjust its parameters. The graph will show how the expected value changes across a range of parameter values.
391
+ """
392
+ )
393
+ return
394
+
395
+
396
+ @app.cell(hide_code=True)
397
+ def _(mo):
398
+ # Create UI elements for distribution selection
399
+ dist_selection = mo.ui.dropdown(
400
+ options=[
401
+ "bernoulli",
402
+ "binomial",
403
+ "geometric",
404
+ "poisson"
405
+ ],
406
+ value="bernoulli",
407
+ label="Select a distribution"
408
+ )
409
+ return (dist_selection,)
410
+
411
+
412
+ @app.cell(hide_code=True)
413
+ def _(dist_selection):
414
+ dist_selection.center()
415
+ return
416
+
417
+
418
+ @app.cell(hide_code=True)
419
+ def _(dist_description):
420
+ dist_description
421
+ return
422
+
423
+
424
+ @app.cell(hide_code=True)
425
+ def _(mo):
426
+ mo.md("""### Adjust Parameters""")
427
+ return
428
+
429
+
430
+ @app.cell(hide_code=True)
431
+ def _(controls):
432
+ controls
433
+ return
434
+
435
+
436
+ @app.cell(hide_code=True)
437
+ def _(
438
+ dist_selection,
439
+ lambda_range,
440
+ np,
441
+ param_lambda,
442
+ param_n,
443
+ param_p,
444
+ param_range,
445
+ plt,
446
+ ):
447
+ # Calculate expected values based on the selected distribution
448
+ if dist_selection.value == "bernoulli":
449
+ # Get parameter range for visualization
450
+ p_min, p_max = param_range.value
451
+ param_values = np.linspace(p_min, p_max, 100)
452
+
453
+ # E[X] = p for Bernoulli
454
+ expected_values = param_values
455
+ current_param = param_p.value
456
+ current_expected = round(current_param, 2)
457
+ x_label = "p (probability of success)"
458
+ title = "Expected Value of Bernoulli Distribution"
459
+ formula = "E[X] = p"
460
+
461
+ elif dist_selection.value == "binomial":
462
+ p_min, p_max = param_range.value
463
+ param_values = np.linspace(p_min, p_max, 100)
464
+
465
+ # E[X] = np for Binomial
466
+ n = int(param_n.value)
467
+ expected_values = [n * p for p in param_values]
468
+ current_param = param_p.value
469
+ current_expected = round(n * current_param, 2)
470
+ x_label = "p (probability of success)"
471
+ title = f"Expected Value of Binomial Distribution (n={n})"
472
+ formula = f"E[X] = n × p = {n} × p"
473
+
474
+ elif dist_selection.value == "geometric":
475
+ p_min, p_max = param_range.value
476
+ # Ensure p is not 0 for geometric distribution
477
+ p_min = max(0.01, p_min)
478
+ param_values = np.linspace(p_min, p_max, 100)
479
+
480
+ # E[X] = 1/p for Geometric
481
+ expected_values = [1/p for p in param_values]
482
+ current_param = param_p.value
483
+ current_expected = round(1 / current_param, 2)
484
+ x_label = "p (probability of success)"
485
+ title = "Expected Value of Geometric Distribution"
486
+ formula = "E[X] = 1/p"
487
+
488
+ else: # Poisson
489
+ lambda_min, lambda_max = lambda_range.value
490
+ param_values = np.linspace(lambda_min, lambda_max, 100)
491
+
492
+ # E[X] = lambda for Poisson
493
+ expected_values = param_values
494
+ current_param = param_lambda.value
495
+ current_expected = round(current_param, 2)
496
+ x_label = "λ (rate parameter)"
497
+ title = "Expected Value of Poisson Distribution"
498
+ formula = "E[X] = λ"
499
+
500
+ # Create the plot
501
+ dist_fig, dist_ax = plt.subplots(figsize=(10, 6))
502
+
503
+ # Plot the expected value function
504
+ dist_ax.plot(param_values, expected_values, 'b-', linewidth=2, label="Expected Value Function")
505
+
506
+ dist_ax.plot(current_param, current_expected, 'ro', markersize=10, label=f"Current Value: E[X] = {current_expected}")
507
+
508
+ dist_ax.hlines(current_expected, param_values[0], current_param, colors='r', linestyles='dashed')
509
+
510
+ dist_ax.vlines(current_param, 0, current_expected, colors='r', linestyles='dashed')
511
+
512
+ dist_ax.fill_between(param_values, 0, expected_values, alpha=0.2, color='blue')
513
+
514
+ dist_ax.set_xlabel(x_label, fontsize=12)
515
+ dist_ax.set_ylabel("Expected Value: E[X]", fontsize=12)
516
+ dist_ax.set_title(title, fontsize=14, fontweight='bold')
517
+ dist_ax.grid(True, alpha=0.3)
518
+
519
+ # Move legend to lower right to avoid overlap with formula
520
+ dist_ax.legend(loc='lower right', fontsize=10)
521
+
522
+ # Add formula text box in upper left
523
+ dist_props = dict(boxstyle='round', facecolor='white', alpha=0.8)
524
+ dist_ax.text(0.02, 0.95, formula, transform=dist_ax.transAxes, fontsize=12,
525
+ verticalalignment='top', bbox=dist_props)
526
+
527
+ if dist_selection.value == "geometric":
528
+ max_y = min(50, 2/max(0.01, param_values[0]))
529
+ dist_ax.set_ylim(0, max_y)
530
+ elif dist_selection.value == "binomial":
531
+ dist_ax.set_ylim(0, int(param_n.value) + 1)
532
+ else:
533
+ dist_ax.set_ylim(0, max(expected_values) * 1.1)
534
+
535
+ annotation_x = current_param + (param_values[-1] - param_values[0]) * 0.05
536
+ annotation_y = current_expected
537
+
538
+ # Adjust annotation position if it would go off the chart
539
+ if annotation_x > param_values[-1] * 0.9:
540
+ annotation_x = current_param - (param_values[-1] - param_values[0]) * 0.2
541
+
542
+ dist_ax.annotate(
543
+ f"Parameter: {current_param:.2f}\nE[X] = {current_expected}",
544
+ xy=(current_param, current_expected),
545
+ xytext=(annotation_x, annotation_y),
546
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, alpha=0.7),
547
+ bbox=dist_props
548
+ )
549
+
550
+ plt.tight_layout()
551
+ plt.gca()
552
+ return (
553
+ annotation_x,
554
+ annotation_y,
555
+ current_expected,
556
+ current_param,
557
+ dist_ax,
558
+ dist_fig,
559
+ dist_props,
560
+ expected_values,
561
+ formula,
562
+ lambda_max,
563
+ lambda_min,
564
+ max_y,
565
+ n,
566
+ p_max,
567
+ p_min,
568
+ param_values,
569
+ title,
570
+ x_label,
571
+ )
572
+
573
+
574
+ @app.cell(hide_code=True)
575
+ def _(mo):
576
+ mo.md(
577
+ r"""
578
+ ## Expectation vs. Mode
579
+
580
+ The expected value (mean) of a random variable is not always the same as its most likely value (mode). Let's explore this with an example:
581
+ """
582
+ )
583
+ return
584
+
585
+
586
+ @app.cell(hide_code=True)
587
+ def _(np, plt, stats):
588
+ # Create a skewed distribution
589
+ skew_n = 10
590
+ skew_p = 0.25
591
+
592
+ # Binomial PMF
593
+ skew_x_values = np.arange(0, skew_n+1)
594
+ skew_pmf_values = stats.binom.pmf(skew_x_values, skew_n, skew_p)
595
+
596
+ # Find the mode (most likely value)
597
+ skew_mode = skew_x_values[np.argmax(skew_pmf_values)]
598
+
599
+ # Calculate the expected value
600
+ skew_expected = skew_n * skew_p
601
+ skew_expected_rounded = round(skew_expected, 2)
602
+
603
+ skew_fig, skew_ax = plt.subplots(figsize=(10, 5))
604
+ skew_ax.bar(skew_x_values, skew_pmf_values, alpha=0.7, width=0.4)
605
+
606
+ # Add vertical lines for mode and expected value
607
+ skew_ax.axvline(x=skew_mode, color='g', linestyle='--', linewidth=2,
608
+ label=f'Mode = {skew_mode} (Most likely value)')
609
+ skew_ax.axvline(x=skew_expected, color='r', linestyle='--', linewidth=2,
610
+ label=f'Expected Value = {skew_expected_rounded} (Mean)')
611
+
612
+ skew_ax.annotate('Mode', xy=(skew_mode, 0.05), xytext=(skew_mode-2.0, 0.1),
613
+ arrowprops=dict(facecolor='green', shrink=0.05, width=1.5), color='green')
614
+ skew_ax.annotate('Expected Value', xy=(skew_expected, 0.05), xytext=(skew_expected+1, 0.15),
615
+ arrowprops=dict(facecolor='red', shrink=0.05, width=1.5), color='red')
616
+
617
+ if skew_mode != int(skew_expected):
618
+ min_x = min(skew_mode, skew_expected)
619
+ max_x = max(skew_mode, skew_expected)
620
+ skew_ax.axvspan(min_x, max_x, alpha=0.2, color='purple')
621
+
622
+ # Add text explaining the difference
623
+ mid_x = (skew_mode + skew_expected) / 2
624
+ skew_ax.text(mid_x, max(skew_pmf_values) * 0.5,
625
+ f"Difference: {abs(skew_mode - skew_expected_rounded):.2f}",
626
+ ha='center', va='center', bbox=dict(facecolor='white', alpha=0.7))
627
+
628
+ skew_ax.set_xlabel('Number of Successes')
629
+ skew_ax.set_ylabel('Probability')
630
+ skew_ax.set_title(f'Binomial Distribution (n={skew_n}, p={skew_p})')
631
+ skew_ax.grid(alpha=0.3)
632
+ skew_ax.legend()
633
+
634
+ plt.tight_layout()
635
+ plt.gca()
636
+ return (
637
+ max_x,
638
+ mid_x,
639
+ min_x,
640
+ skew_ax,
641
+ skew_expected,
642
+ skew_expected_rounded,
643
+ skew_fig,
644
+ skew_mode,
645
+ skew_n,
646
+ skew_p,
647
+ skew_pmf_values,
648
+ skew_x_values,
649
+ )
650
+
651
+
652
+ @app.cell(hide_code=True)
653
+ def _(mo):
654
+ mo.md(
655
+ r"""
656
+ /// NOTE
657
+ For the sum of two dice we calculated earlier, we found the expected value to be exactly 7. In that case, 7 also happens to be the mode (most likely outcome) of the distribution. However, this is just a coincidence for this particular example!
658
+
659
+ As we can see from the binomial distribution above, the expected value (2.50) and the mode (2) are often different values (this is common in skewed distributions). The expected value represents the "center of mass" of the distribution, while the mode represents the most likely single outcome.
660
+ """
661
+ )
662
+ return
663
+
664
+
665
+ @app.cell(hide_code=True)
666
+ def _(mo):
667
+ mo.md(
668
+ r"""
669
+ ## 🤔 Test Your Understanding
670
+
671
+ Choose what you believe are the correct options in the questions below:
672
+
673
+ <details>
674
+ <summary>The expected value of a random variable is always one of the possible values the random variable can take.</summary>
675
+ ❌ False! The expected value is a weighted average and may not be a value the random variable can actually take. For example, the expected value of a fair die roll is 3.5, which is not a possible outcome.
676
+ </details>
677
+
678
+ <details>
679
+ <summary>If X and Y are independent random variables, then E[X·Y] = E[X]·E[Y].</summary>
680
+ ✅ True! For independent random variables, the expectation of their product equals the product of their expectations.
681
+ </details>
682
+
683
+ <details>
684
+ <summary>The expected value of a constant random variable (one that always takes the same value) is that constant.</summary>
685
+ ✅ True! If X = c with probability 1, then E[X] = c.
686
+ </details>
687
+
688
+ <details>
689
+ <summary>The expected value of the sum of two random variables is always the sum of their expected values, regardless of whether they are independent.</summary>
690
+ ✅ True! This is the linearity of expectation property: E[X + Y] = E[X] + E[Y], which holds regardless of dependence.
691
+ </details>
692
+ """
693
+ )
694
+ return
695
+
696
+
697
+ @app.cell(hide_code=True)
698
+ def _(mo):
699
+ mo.md(
700
+ r"""
701
+ ## Practical Applications of Expectation
702
+
703
+ Expected values show up everywhere - from investment decisions and insurance pricing to machine learning algorithms and game design. Engineers use them to predict system reliability, data scientists to understand customer behavior, and economists to model market outcomes. They're essential for risk assessment in project management and for optimizing resource allocation in operations research.
704
+ """
705
+ )
706
+ return
707
+
708
+
709
+ @app.cell(hide_code=True)
710
+ def _(mo):
711
+ mo.md(
712
+ r"""
713
+ ## Key Takeaways
714
+
715
+ Expectation gives us a single value that summarizes a random variable's central tendency - it's the weighted average of all possible outcomes, where the weights are probabilities. The linearity property makes expectations easy to work with, even for complex combinations of random variables. While a PMF gives the complete probability picture, expectation provides an essential summary that helps us make decisions under uncertainty. In our next notebook, we'll explore variance, which measures how spread out a random variable's values are around its expectation.
716
+ """
717
+ )
718
+ return
719
+
720
+
721
+ @app.cell(hide_code=True)
722
+ def _(mo):
723
+ mo.md(r"""#### Appendix (containing helper code)""")
724
+ return
725
+
726
+
727
+ @app.cell(hide_code=True)
728
+ def _():
729
+ import marimo as mo
730
+ return (mo,)
731
+
732
+
733
+ @app.cell(hide_code=True)
734
+ def _():
735
+ import matplotlib.pyplot as plt
736
+ import numpy as np
737
+ from scipy import stats
738
+ import collections
739
+ return collections, np, plt, stats
740
+
741
+
742
+ @app.cell(hide_code=True)
743
+ def _(dist_selection, mo):
744
+ # Parameter controls for probability-based distributions
745
+ param_p = mo.ui.slider(
746
+ start=0.01,
747
+ stop=0.99,
748
+ step=0.01,
749
+ value=0.5,
750
+ label="p (probability of success)",
751
+ full_width=True
752
+ )
753
+
754
+ # Parameter control for binomial distribution
755
+ param_n = mo.ui.slider(
756
+ start=1,
757
+ stop=50,
758
+ step=1,
759
+ value=10,
760
+ label="n (number of trials)",
761
+ full_width=True
762
+ )
763
+
764
+ # Parameter control for Poisson distribution
765
+ param_lambda = mo.ui.slider(
766
+ start=0.1,
767
+ stop=20,
768
+ step=0.1,
769
+ value=5,
770
+ label="λ (rate parameter)",
771
+ full_width=True
772
+ )
773
+
774
+ # Parameter range sliders for visualization
775
+ param_range = mo.ui.range_slider(
776
+ start=0,
777
+ stop=1,
778
+ step=0.01,
779
+ value=[0, 1],
780
+ label="Parameter range to visualize",
781
+ full_width=True
782
+ )
783
+
784
+ lambda_range = mo.ui.range_slider(
785
+ start=0,
786
+ stop=20,
787
+ step=0.1,
788
+ value=[0, 20],
789
+ label="λ range to visualize",
790
+ full_width=True
791
+ )
792
+
793
+ # Display appropriate controls based on the selected distribution
794
+ if dist_selection.value == "bernoulli":
795
+ controls = mo.hstack([param_p, param_range], justify="space-around")
796
+ elif dist_selection.value == "binomial":
797
+ controls = mo.hstack([param_p, param_n, param_range], justify="space-around")
798
+ elif dist_selection.value == "geometric":
799
+ controls = mo.hstack([param_p, param_range], justify="space-around")
800
+ else: # poisson
801
+ controls = mo.hstack([param_lambda, lambda_range], justify="space-around")
802
+ return controls, lambda_range, param_lambda, param_n, param_p, param_range
803
+
804
+
805
+ @app.cell(hide_code=True)
806
+ def _(dist_selection, mo):
807
+ # Create distribution descriptions based on selection
808
+ if dist_selection.value == "bernoulli":
809
+ dist_description = mo.md(
810
+ r"""
811
+ **Bernoulli Distribution**
812
+
813
+ A Bernoulli distribution models a single trial with two possible outcomes: success (1) or failure (0).
814
+
815
+ - Parameter: $p$ = probability of success
816
+ - Expected Value: $E[X] = p$
817
+ - Example: Flipping a coin once (p = 0.5 for a fair coin)
818
+ """
819
+ )
820
+ elif dist_selection.value == "binomial":
821
+ dist_description = mo.md(
822
+ r"""
823
+ **Binomial Distribution**
824
+
825
+ A Binomial distribution models the number of successes in $n$ independent trials.
826
+
827
+ - Parameters: $n$ = number of trials, $p$ = probability of success
828
+ - Expected Value: $E[X] = np$
829
+ - Example: Number of heads in 10 coin flips
830
+ """
831
+ )
832
+ elif dist_selection.value == "geometric":
833
+ dist_description = mo.md(
834
+ r"""
835
+ **Geometric Distribution**
836
+
837
+ A Geometric distribution models the number of trials until the first success.
838
+
839
+ - Parameter: $p$ = probability of success
840
+ - Expected Value: $E[X] = \frac{1}{p}$
841
+ - Example: Number of coin flips until first heads
842
+ """
843
+ )
844
+ else: # poisson
845
+ dist_description = mo.md(
846
+ r"""
847
+ **Poisson Distribution**
848
+
849
+ A Poisson distribution models the number of events occurring in a fixed interval.
850
+
851
+ - Parameter: $\lambda$ = average rate of events
852
+ - Expected Value: $E[X] = \lambda$
853
+ - Example: Number of emails received per hour
854
+ """
855
+ )
856
+ return (dist_description,)
857
+
858
+
859
+ if __name__ == "__main__":
860
+ app.run()
probability/12_variance.py ADDED
@@ -0,0 +1,631 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.0",
6
+ # "numpy==2.2.3",
7
+ # "scipy==1.15.2",
8
+ # "wigglystuff==0.1.10",
9
+ # ]
10
+ # ///
11
+
12
+ import marimo
13
+
14
+ __generated_with = "0.11.20"
15
+ app = marimo.App(width="medium", app_title="Variance")
16
+
17
+
18
+ @app.cell(hide_code=True)
19
+ def _(mo):
20
+ mo.md(
21
+ r"""
22
+ # Variance
23
+
24
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/variance/), by Stanford professor Chris Piech._
25
+
26
+ In our previous exploration of random variables, we learned about expectation - a measure of central tendency. However, knowing the average value alone doesn't tell us everything about a distribution. Consider these questions:
27
+
28
+ - How spread out are the values around the mean?
29
+ - How reliable is the expectation as a predictor of individual outcomes?
30
+ - How much do individual samples typically deviate from the average?
31
+
32
+ This is where **variance** comes in - it measures the spread or dispersion of a random variable around its expected value.
33
+ """
34
+ )
35
+ return
36
+
37
+
38
+ @app.cell(hide_code=True)
39
+ def _(mo):
40
+ mo.md(
41
+ r"""
42
+ ## Definition of Variance
43
+
44
+ The variance of a random variable $X$ with expected value $\mu = E[X]$ is defined as:
45
+
46
+ $$\text{Var}(X) = E[(X-\mu)^2]$$
47
+
48
+ This definition captures the average squared deviation from the mean. There's also an equivalent, often more convenient formula:
49
+
50
+ $$\text{Var}(X) = E[X^2] - (E[X])^2$$
51
+
52
+ /// tip
53
+ The second formula is usually easier to compute, as it only requires calculating $E[X^2]$ and $E[X]$, rather than working with deviations from the mean.
54
+ """
55
+ )
56
+ return
57
+
58
+
59
+ @app.cell(hide_code=True)
60
+ def _(mo):
61
+ mo.md(
62
+ r"""
63
+ ## Intuition Through Example
64
+
65
+ Let's look at a real-world example that illustrates why variance is important. Consider three different groups of graders evaluating assignments in a massive online course. Each grader has their own "grading distribution" - their pattern of assigning scores to work that deserves a 70/100.
66
+
67
+ The visualization below shows the probability distributions for three types of graders. Try clicking and dragging the blue numbers to adjust the parameters and see how they affect the variance.
68
+ """
69
+ )
70
+ return
71
+
72
+
73
+ @app.cell(hide_code=True)
74
+ def _(mo):
75
+ mo.md(
76
+ r"""
77
+ /// TIP
78
+ Try adjusting the blue numbers above to see how:
79
+
80
+ - Increasing spread increases variance
81
+ - The mixture ratio affects how many outliers appear in Grader C's distribution
82
+ - Changing the true grade shifts all distributions but maintains their relative variances
83
+ """
84
+ )
85
+ return
86
+
87
+
88
+ @app.cell(hide_code=True)
89
+ def _(controls):
90
+ controls
91
+ return
92
+
93
+
94
+ @app.cell(hide_code=True)
95
+ def _(
96
+ grader_a_spread,
97
+ grader_b_spread,
98
+ grader_c_mix,
99
+ np,
100
+ plt,
101
+ stats,
102
+ true_grade,
103
+ ):
104
+ # Create data for three grader distributions
105
+ _grader_x = np.linspace(40, 100, 200)
106
+
107
+ # Calculate actual variances
108
+ var_a = grader_a_spread.amount**2
109
+ var_b = grader_b_spread.amount**2
110
+ var_c = (1-grader_c_mix.amount) * 3**2 + grader_c_mix.amount * 8**2 + \
111
+ grader_c_mix.amount * (1-grader_c_mix.amount) * (8-3)**2 # Mixture variance formula
112
+
113
+ # Grader A: Wide spread around true grade
114
+ grader_a = stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=grader_a_spread.amount)
115
+
116
+ # Grader B: Narrow spread around true grade
117
+ grader_b = stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=grader_b_spread.amount)
118
+
119
+ # Grader C: Mixture of distributions
120
+ grader_c = (1-grader_c_mix.amount) * stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=3) + \
121
+ grader_c_mix.amount * stats.norm.pdf(_grader_x, loc=true_grade.amount, scale=8)
122
+
123
+ grader_fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(15, 5))
124
+
125
+ # Plot each distribution
126
+ ax1.fill_between(_grader_x, grader_a, alpha=0.3, color='green', label=f'Var ≈ {var_a:.2f}')
127
+ ax1.axvline(x=true_grade.amount, color='black', linestyle='--', label='True Grade')
128
+ ax1.set_title('Grader A: High Variance')
129
+ ax1.set_xlabel('Grade')
130
+ ax1.set_ylabel('Pr(G = g)')
131
+ ax1.set_ylim(0, max(grader_a)*1.1)
132
+
133
+ ax2.fill_between(_grader_x, grader_b, alpha=0.3, color='blue', label=f'Var ≈ {var_b:.2f}')
134
+ ax2.axvline(x=true_grade.amount, color='black', linestyle='--')
135
+ ax2.set_title('Grader B: Low Variance')
136
+ ax2.set_xlabel('Grade')
137
+ ax2.set_ylim(0, max(grader_b)*1.1)
138
+
139
+ ax3.fill_between(_grader_x, grader_c, alpha=0.3, color='purple', label=f'Var ≈ {var_c:.2f}')
140
+ ax3.axvline(x=true_grade.amount, color='black', linestyle='--')
141
+ ax3.set_title('Grader C: Mixed Distribution')
142
+ ax3.set_xlabel('Grade')
143
+ ax3.set_ylim(0, max(grader_c)*1.1)
144
+
145
+ # Add annotations to explain what's happening
146
+ ax1.annotate('Wide spread = high variance',
147
+ xy=(true_grade.amount, max(grader_a)*0.5),
148
+ xytext=(true_grade.amount-15, max(grader_a)*0.7),
149
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
150
+
151
+ ax2.annotate('Narrow spread = low variance',
152
+ xy=(true_grade.amount, max(grader_b)*0.5),
153
+ xytext=(true_grade.amount+8, max(grader_b)*0.7),
154
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
155
+
156
+ ax3.annotate('Mixture creates outliers',
157
+ xy=(true_grade.amount+15, grader_c[np.where(_grader_x >= true_grade.amount+15)[0][0]]),
158
+ xytext=(true_grade.amount+5, max(grader_c)*0.7),
159
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
160
+
161
+ # Add legends and adjust layout
162
+ for _ax in [ax1, ax2, ax3]:
163
+ _ax.legend()
164
+ _ax.grid(alpha=0.2)
165
+
166
+ plt.tight_layout()
167
+ plt.gca()
168
+ return (
169
+ ax1,
170
+ ax2,
171
+ ax3,
172
+ grader_a,
173
+ grader_b,
174
+ grader_c,
175
+ grader_fig,
176
+ var_a,
177
+ var_b,
178
+ var_c,
179
+ )
180
+
181
+
182
+ @app.cell(hide_code=True)
183
+ def _(mo):
184
+ mo.md(
185
+ r"""
186
+ /// note
187
+ All three distributions have the same expected value (the true grade), but they differ significantly in their spread:
188
+
189
+ - **Grader A** has high variance - grades vary widely from the true value
190
+ - **Grader B** has low variance - grades consistently stay close to the true value
191
+ - **Grader C** has a mixture distribution - mostly consistent but with occasional extreme values
192
+
193
+ This illustrates why variance is crucial: two distributions can have the same mean but behave very differently in practice.
194
+ """
195
+ )
196
+ return
197
+
198
+
199
+ @app.cell(hide_code=True)
200
+ def _(mo):
201
+ mo.md(
202
+ r"""
203
+ ## Computing Variance
204
+
205
+ Let's work through some concrete examples to understand how to calculate variance.
206
+
207
+ ### Example 1: Fair Die Roll
208
+
209
+ Consider rolling a fair six-sided die. We'll calculate its variance step by step:
210
+ """
211
+ )
212
+ return
213
+
214
+
215
+ @app.cell
216
+ def _(np):
217
+ # Define the die values and probabilities
218
+ die_values = np.array([1, 2, 3, 4, 5, 6])
219
+ die_probs = np.array([1/6] * 6)
220
+
221
+ # Calculate E[X]
222
+ expected_value = np.sum(die_values * die_probs)
223
+
224
+ # Calculate E[X^2]
225
+ expected_square = np.sum(die_values**2 * die_probs)
226
+
227
+ # Calculate Var(X) = E[X^2] - (E[X])^2
228
+ variance = expected_square - expected_value**2
229
+
230
+ # Calculate standard deviation
231
+ std_dev = np.sqrt(variance)
232
+
233
+ print(f"E[X] = {expected_value:.2f}")
234
+ print(f"E[X^2] = {expected_square:.2f}")
235
+ print(f"Var(X) = {variance:.2f}")
236
+ print(f"Standard Deviation = {std_dev:.2f}")
237
+ return (
238
+ die_probs,
239
+ die_values,
240
+ expected_square,
241
+ expected_value,
242
+ std_dev,
243
+ variance,
244
+ )
245
+
246
+
247
+ @app.cell(hide_code=True)
248
+ def _(mo):
249
+ mo.md(
250
+ r"""
251
+ /// NOTE
252
+ For a fair die:
253
+
254
+ - The expected value (3.50) tells us the average roll
255
+ - The variance (2.92) tells us how much typical rolls deviate from this average
256
+ - The standard deviation (1.71) gives us this spread in the original units
257
+ """
258
+ )
259
+ return
260
+
261
+
262
+ @app.cell(hide_code=True)
263
+ def _(mo):
264
+ mo.md(
265
+ r"""
266
+ ## Properties of Variance
267
+
268
+ Variance has several important properties that make it useful for analyzing random variables:
269
+
270
+ 1. **Non-negativity**: $\text{Var}(X) \geq 0$ for any random variable $X$
271
+ 2. **Variance of a constant**: $\text{Var}(c) = 0$ for any constant $c$
272
+ 3. **Scaling**: $\text{Var}(aX) = a^2\text{Var}(X)$ for any constant $a$
273
+ 4. **Translation**: $\text{Var}(X + b) = \text{Var}(X)$ for any constant $b$
274
+ 5. **Independence**: If $X$ and $Y$ are independent, then $\text{Var}(X + Y) = \text{Var}(X) + \text{Var}(Y)$
275
+
276
+ Let's verify a property with an example.
277
+ """
278
+ )
279
+ return
280
+
281
+
282
+ @app.cell(hide_code=True)
283
+ def _(mo):
284
+ mo.md(
285
+ r"""
286
+ ## Proof of Variance Formula
287
+
288
+ The equivalence of the two variance formulas is a fundamental result in probability theory. Here's the proof:
289
+
290
+ Starting with the definition $\text{Var}(X) = E[(X-\mu)^2]$ where $\mu = E[X]$:
291
+
292
+ \begin{align}
293
+ \text{Var}(X) &= E[(X-\mu)^2] \\
294
+ &= \sum_x(x-\mu)^2P(x) && \text{Definition of Expectation}\\
295
+ &= \sum_x (x^2 -2\mu x + \mu^2)P(x) && \text{Expanding the square}\\
296
+ &= \sum_x x^2P(x)- 2\mu \sum_x xP(x) + \mu^2 \sum_x P(x) && \text{Distributing the sum}\\
297
+ &= E[X^2]- 2\mu E[X] + \mu^2 && \text{Definition of expectation}\\
298
+ &= E[X^2]- 2(E[X])^2 + (E[X])^2 && \text{Since }\mu = E[X]\\
299
+ &= E[X^2]- (E[X])^2 && \text{Simplifying}
300
+ \end{align}
301
+
302
+ /// tip
303
+ This proof shows why the formula $\text{Var}(X) = E[X^2] - (E[X])^2$ is so useful - it's much easier to compute $E[X^2]$ and $E[X]$ separately than to work with deviations directly.
304
+ """
305
+ )
306
+ return
307
+
308
+
309
+ @app.cell
310
+ def _(die_probs, die_values, np):
311
+ # Demonstrate scaling property
312
+ a = 2 # Scale factor
313
+
314
+ # Original variance
315
+ original_var = np.sum(die_values**2 * die_probs) - (np.sum(die_values * die_probs))**2
316
+
317
+ # Scaled random variable variance
318
+ scaled_values = a * die_values
319
+ scaled_var = np.sum(scaled_values**2 * die_probs) - (np.sum(scaled_values * die_probs))**2
320
+
321
+ print(f"Original Variance: {original_var:.2f}")
322
+ print(f"Scaled Variance (a={a}): {scaled_var:.2f}")
323
+ print(f"a^2 * Original Variance: {a**2 * original_var:.2f}")
324
+ print(f"Property holds: {abs(scaled_var - a**2 * original_var) < 1e-10}")
325
+ return a, original_var, scaled_values, scaled_var
326
+
327
+
328
+ @app.cell
329
+ def _():
330
+ # DIY : Prove more properties as shown above
331
+ return
332
+
333
+
334
+ @app.cell(hide_code=True)
335
+ def _(mo):
336
+ mo.md(
337
+ r"""
338
+ ## Standard Deviation
339
+
340
+ While variance is mathematically convenient, it has one practical drawback: its units are squared. For example, if we're measuring grades (0-100), the variance is in "grade points squared." This makes it hard to interpret intuitively.
341
+
342
+ The **standard deviation**, denoted by $\sigma$ or $\text{SD}(X)$, is the square root of variance:
343
+
344
+ $$\sigma = \sqrt{\text{Var}(X)}$$
345
+
346
+ /// tip
347
+ Standard deviation is often more intuitive because it's in the same units as the original data. For a normal distribution, approximately:
348
+ - 68% of values fall within 1 standard deviation of the mean
349
+ - 95% of values fall within 2 standard deviations
350
+ - 99.7% of values fall within 3 standard deviations
351
+ """
352
+ )
353
+ return
354
+
355
+
356
+ @app.cell(hide_code=True)
357
+ def _(controls1):
358
+ controls1
359
+ return
360
+
361
+
362
+ @app.cell(hide_code=True)
363
+ def _(TangleSlider, mo):
364
+ normal_mean = mo.ui.anywidget(TangleSlider(
365
+ amount=0,
366
+ min_value=-5,
367
+ max_value=5,
368
+ step=0.5,
369
+ digits=1,
370
+ suffix=" units"
371
+ ))
372
+
373
+ normal_std = mo.ui.anywidget(TangleSlider(
374
+ amount=1,
375
+ min_value=0.1,
376
+ max_value=3,
377
+ step=0.1,
378
+ digits=1,
379
+ suffix=" units"
380
+ ))
381
+
382
+ # Create a grid layout for the controls
383
+ controls1 = mo.vstack([
384
+ mo.md("### Interactive Normal Distribution"),
385
+ mo.hstack([
386
+ mo.md("Adjust the parameters to see how standard deviation affects the shape of the distribution:"),
387
+ ]),
388
+ mo.hstack([
389
+ mo.md("Mean (μ): "),
390
+ normal_mean,
391
+ mo.md(" Standard deviation (σ): "),
392
+ normal_std
393
+ ], justify="start"),
394
+ ])
395
+ return controls1, normal_mean, normal_std
396
+
397
+
398
+ @app.cell(hide_code=True)
399
+ def _(normal_mean, normal_std, np, plt, stats):
400
+ # data for normal distribution
401
+ _normal_x = np.linspace(-10, 10, 1000)
402
+ _normal_y = stats.norm.pdf(_normal_x, loc=normal_mean.amount, scale=normal_std.amount)
403
+
404
+ # ranges for standard deviation intervals
405
+ one_sigma_left = normal_mean.amount - normal_std.amount
406
+ one_sigma_right = normal_mean.amount + normal_std.amount
407
+ two_sigma_left = normal_mean.amount - 2 * normal_std.amount
408
+ two_sigma_right = normal_mean.amount + 2 * normal_std.amount
409
+ three_sigma_left = normal_mean.amount - 3 * normal_std.amount
410
+ three_sigma_right = normal_mean.amount + 3 * normal_std.amount
411
+
412
+ # Create the plot
413
+ normal_fig, normal_ax = plt.subplots(figsize=(10, 6))
414
+
415
+ # Plot the distribution
416
+ normal_ax.plot(_normal_x, _normal_y, 'b-', linewidth=2)
417
+
418
+ # stdev intervals
419
+ normal_ax.fill_between(_normal_x, 0, _normal_y, where=(_normal_x >= one_sigma_left) & (_normal_x <= one_sigma_right),
420
+ alpha=0.3, color='red', label='68% (±1σ)')
421
+ normal_ax.fill_between(_normal_x, 0, _normal_y, where=(_normal_x >= two_sigma_left) & (_normal_x <= two_sigma_right),
422
+ alpha=0.2, color='green', label='95% (±2σ)')
423
+ normal_ax.fill_between(_normal_x, 0, _normal_y, where=(_normal_x >= three_sigma_left) & (_normal_x <= three_sigma_right),
424
+ alpha=0.1, color='blue', label='99.7% (±3σ)')
425
+
426
+ # vertical lines for the mean and standard deviations
427
+ normal_ax.axvline(x=normal_mean.amount, color='black', linestyle='-', linewidth=1.5, label='Mean (μ)')
428
+ normal_ax.axvline(x=one_sigma_left, color='red', linestyle='--', linewidth=1)
429
+ normal_ax.axvline(x=one_sigma_right, color='red', linestyle='--', linewidth=1)
430
+ normal_ax.axvline(x=two_sigma_left, color='green', linestyle='--', linewidth=1)
431
+ normal_ax.axvline(x=two_sigma_right, color='green', linestyle='--', linewidth=1)
432
+
433
+ # annotations
434
+ normal_ax.annotate(f'μ = {normal_mean.amount:.2f}',
435
+ xy=(normal_mean.amount, max(_normal_y)*0.5),
436
+ xytext=(normal_mean.amount + 0.5, max(_normal_y)*0.8),
437
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
438
+
439
+ normal_ax.annotate(f'σ = {normal_std.amount:.2f}',
440
+ xy=(one_sigma_right, stats.norm.pdf(one_sigma_right, loc=normal_mean.amount, scale=normal_std.amount)),
441
+ xytext=(one_sigma_right + 0.5, max(_normal_y)*0.6),
442
+ arrowprops=dict(facecolor='red', shrink=0.05, width=1))
443
+
444
+ # labels and title
445
+ normal_ax.set_xlabel('Value')
446
+ normal_ax.set_ylabel('Probability Density')
447
+ normal_ax.set_title(f'Normal Distribution with μ = {normal_mean.amount:.2f} and σ = {normal_std.amount:.2f}')
448
+
449
+ # legend and grid
450
+ normal_ax.legend()
451
+ normal_ax.grid(alpha=0.3)
452
+
453
+ plt.tight_layout()
454
+ plt.gca()
455
+ return (
456
+ normal_ax,
457
+ normal_fig,
458
+ one_sigma_left,
459
+ one_sigma_right,
460
+ three_sigma_left,
461
+ three_sigma_right,
462
+ two_sigma_left,
463
+ two_sigma_right,
464
+ )
465
+
466
+
467
+ @app.cell(hide_code=True)
468
+ def _(mo):
469
+ mo.md(
470
+ r"""
471
+ /// tip
472
+ The interactive visualization above demonstrates how standard deviation (σ) affects the shape of a normal distribution:
473
+
474
+ - The **red region** covers μ ± 1σ, containing approximately 68% of the probability
475
+ - The **green region** covers μ ± 2σ, containing approximately 95% of the probability
476
+ - The **blue region** covers μ ± 3σ, containing approximately 99.7% of the probability
477
+
478
+ This is known as the "68-95-99.7 rule" or the "empirical rule" and is a useful heuristic for understanding the spread of data.
479
+ """
480
+ )
481
+ return
482
+
483
+
484
+ @app.cell(hide_code=True)
485
+ def _(mo):
486
+ mo.md(
487
+ r"""
488
+ ## 🤔 Test Your Understanding
489
+
490
+ Choose what you believe are the correct options in the questions below:
491
+
492
+ <details>
493
+ <summary>The variance of a random variable can be negative.</summary>
494
+ ❌ False! Variance is defined as an expected value of squared deviations, and squares are always non-negative.
495
+ </details>
496
+
497
+ <details>
498
+ <summary>If X and Y are independent random variables, then Var(X + Y) = Var(X) + Var(Y).</summary>
499
+ ✅ True! This is one of the key properties of variance for independent random variables.
500
+ </details>
501
+
502
+ <details>
503
+ <summary>Multiplying a random variable by 2 multiplies its variance by 2.</summary>
504
+ ❌ False! Multiplying a random variable by a constant a multiplies its variance by a². So multiplying by 2 multiplies variance by 4.
505
+ </details>
506
+
507
+ <details>
508
+ <summary>Standard deviation is always equal to the square root of variance.</summary>
509
+ ✅ True! By definition, standard deviation σ = √Var(X).
510
+ </details>
511
+
512
+ <details>
513
+ <summary>If Var(X) = 0, then X must be a constant.</summary>
514
+ ✅ True! Zero variance means there is no spread around the mean, so X can only take one value.
515
+ </details>
516
+ """
517
+ )
518
+ return
519
+
520
+
521
+ @app.cell(hide_code=True)
522
+ def _(mo):
523
+ mo.md(
524
+ r"""
525
+ ## Key Takeaways
526
+
527
+ Variance gives us a way to measure how spread out a random variable is around its mean. It's like the "uncertainty" in our expectation - a high variance means individual outcomes can differ widely from what we expect on average.
528
+
529
+ Standard deviation brings this measure back to the original units, making it easier to interpret. For grades, a standard deviation of 10 points means typical grades fall within about 10 points of the average.
530
+
531
+ Variance pops up everywhere - from weather forecasts (how reliable is the predicted temperature?) to financial investments (how risky is this stock?) to quality control (how consistent is our manufacturing process?).
532
+
533
+ In our next notebook, we'll explore more properties of random variables and see how they combine to form more complex distributions.
534
+ """
535
+ )
536
+ return
537
+
538
+
539
+ @app.cell(hide_code=True)
540
+ def _(mo):
541
+ mo.md(r"""Appendix (containing helper code):""")
542
+ return
543
+
544
+
545
+ @app.cell(hide_code=True)
546
+ def _():
547
+ import marimo as mo
548
+ return (mo,)
549
+
550
+
551
+ @app.cell(hide_code=True)
552
+ def _():
553
+ import numpy as np
554
+ import scipy.stats as stats
555
+ import matplotlib.pyplot as plt
556
+ from wigglystuff import TangleSlider
557
+ return TangleSlider, np, plt, stats
558
+
559
+
560
+ @app.cell(hide_code=True)
561
+ def _(TangleSlider, mo):
562
+ # Create interactive elements using TangleSlider for a more inline experience
563
+ true_grade = mo.ui.anywidget(TangleSlider(
564
+ amount=70,
565
+ min_value=50,
566
+ max_value=90,
567
+ step=5,
568
+ digits=0,
569
+ suffix=" points"
570
+ ))
571
+
572
+ grader_a_spread = mo.ui.anywidget(TangleSlider(
573
+ amount=10,
574
+ min_value=5,
575
+ max_value=20,
576
+ step=1,
577
+ digits=0,
578
+ suffix=" points"
579
+ ))
580
+
581
+ grader_b_spread = mo.ui.anywidget(TangleSlider(
582
+ amount=2,
583
+ min_value=1,
584
+ max_value=5,
585
+ step=0.5,
586
+ digits=1,
587
+ suffix=" points"
588
+ ))
589
+
590
+ grader_c_mix = mo.ui.anywidget(TangleSlider(
591
+ amount=0.2,
592
+ min_value=0,
593
+ max_value=1,
594
+ step=0.05,
595
+ digits=2,
596
+ suffix=" proportion"
597
+ ))
598
+ return grader_a_spread, grader_b_spread, grader_c_mix, true_grade
599
+
600
+
601
+ @app.cell(hide_code=True)
602
+ def _(grader_a_spread, grader_b_spread, grader_c_mix, mo, true_grade):
603
+ # Create a grid layout for the interactive controls
604
+ controls = mo.vstack([
605
+ mo.md("### Adjust Parameters to See How Variance Changes"),
606
+ mo.hstack([
607
+ mo.md("**True grade:** The correct score that should be assigned is "),
608
+ true_grade,
609
+ mo.md(" out of 100.")
610
+ ], justify="start"),
611
+ mo.hstack([
612
+ mo.md("**Grader A:** Has a wide spread with standard deviation of "),
613
+ grader_a_spread,
614
+ mo.md(" points.")
615
+ ], justify="start"),
616
+ mo.hstack([
617
+ mo.md("**Grader B:** Has a narrow spread with standard deviation of "),
618
+ grader_b_spread,
619
+ mo.md(" points.")
620
+ ], justify="start"),
621
+ mo.hstack([
622
+ mo.md("**Grader C:** Has a mixture distribution with "),
623
+ grader_c_mix,
624
+ mo.md(" proportion of outliers.")
625
+ ], justify="start"),
626
+ ])
627
+ return (controls,)
628
+
629
+
630
+ if __name__ == "__main__":
631
+ app.run()
probability/13_bernoulli_distribution.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.0",
6
+ # "numpy==2.2.3",
7
+ # "scipy==1.15.2",
8
+ # ]
9
+ # ///
10
+
11
+ import marimo
12
+
13
+ __generated_with = "0.11.22"
14
+ app = marimo.App(width="medium", app_title="Bernoulli Distribution")
15
+
16
+
17
+ @app.cell(hide_code=True)
18
+ def _(mo):
19
+ mo.md(
20
+ r"""
21
+ # Bernoulli Distribution
22
+
23
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/bernoulli/), by Stanford professor Chris Piech._
24
+
25
+ ## Parametric Random Variables
26
+
27
+ There are many classic and commonly-seen random variable abstractions that show up in the world of probability. At this point, we'll learn about several of the most significant parametric discrete distributions.
28
+
29
+ When solving problems, if you can recognize that a random variable fits one of these formats, then you can use its pre-derived Probability Mass Function (PMF), expectation, variance, and other properties. Random variables of this sort are called **parametric random variables**. If you can argue that a random variable falls under one of the studied parametric types, you simply need to provide parameters.
30
+
31
+ > A good analogy is a `class` in programming. Creating a parametric random variable is very similar to calling a constructor with input parameters.
32
+ """
33
+ )
34
+ return
35
+
36
+
37
+ @app.cell(hide_code=True)
38
+ def _(mo):
39
+ mo.md(
40
+ r"""
41
+ ## Bernoulli Random Variables
42
+
43
+ A **Bernoulli random variable** (also called a boolean or indicator random variable) is the simplest kind of parametric random variable. It can take on two values: 1 and 0.
44
+
45
+ It takes on a 1 if an experiment with probability $p$ resulted in success and a 0 otherwise.
46
+
47
+ Some example uses include:
48
+
49
+ - A coin flip (heads = 1, tails = 0)
50
+ - A random binary digit
51
+ - Whether a disk drive crashed
52
+ - Whether someone likes a Netflix movie
53
+
54
+ Here $p$ is the parameter, but different instances of Bernoulli random variables might have different values of $p$.
55
+ """
56
+ )
57
+ return
58
+
59
+
60
+ @app.cell(hide_code=True)
61
+ def _(mo):
62
+ mo.md(
63
+ r"""
64
+ ## Key Properties of a Bernoulli Random Variable
65
+
66
+ If $X$ is declared to be a Bernoulli random variable with parameter $p$, denoted $X \sim \text{Bern}(p)$, it has the following properties:
67
+ """
68
+ )
69
+ return
70
+
71
+
72
+ @app.cell
73
+ def _(stats):
74
+ # Define the Bernoulli distribution function
75
+ def Bern(p):
76
+ return stats.bernoulli(p)
77
+ return (Bern,)
78
+
79
+
80
+ @app.cell(hide_code=True)
81
+ def _(mo):
82
+ mo.md(
83
+ r"""
84
+ ## Bernoulli Distribution Properties
85
+
86
+ $\begin{array}{lll}
87
+ \text{Notation:} & X \sim \text{Bern}(p) \\
88
+ \text{Description:} & \text{A boolean variable that is 1 with probability } p \\
89
+ \text{Parameters:} & p, \text{ the probability that } X = 1 \\
90
+ \text{Support:} & x \text{ is either 0 or 1} \\
91
+ \text{PMF equation:} & P(X = x) =
92
+ \begin{cases}
93
+ p & \text{if }x = 1\\
94
+ 1-p & \text{if }x = 0
95
+ \end{cases} \\
96
+ \text{PMF (smooth):} & P(X = x) = p^x(1-p)^{1-x} \\
97
+ \text{Expectation:} & E[X] = p \\
98
+ \text{Variance:} & \text{Var}(X) = p(1-p) \\
99
+ \end{array}$
100
+ """
101
+ )
102
+ return
103
+
104
+
105
+ @app.cell(hide_code=True)
106
+ def _(mo, p_slider):
107
+ # Visualization of the Bernoulli PMF
108
+ _p = p_slider.value
109
+
110
+ # Values for PMF
111
+ values = [0, 1]
112
+ probabilities = [1 - _p, _p]
113
+
114
+ # Relevant statistics
115
+ expected_value = _p
116
+ variance = _p * (1 - _p)
117
+
118
+ mo.md(f"""
119
+ ## PMF Graph for Bernoulli($p={_p:.2f}$)
120
+
121
+ Parameter $p$: {p_slider}
122
+
123
+ Expected value: $E[X] = {expected_value:.2f}$
124
+
125
+ Variance: $\\text{{Var}}(X) = {variance:.2f}$
126
+ """)
127
+ return expected_value, probabilities, values, variance
128
+
129
+
130
+ @app.cell(hide_code=True)
131
+ def _(expected_value, p_slider, plt, probabilities, values, variance):
132
+ # PMF
133
+ _p = p_slider.value
134
+ fig, ax = plt.subplots(figsize=(10, 6))
135
+
136
+ # Bar plot for PMF
137
+ ax.bar(values, probabilities, width=0.4, color='blue', alpha=0.7)
138
+
139
+ ax.set_xlabel('Values that X can take on')
140
+ ax.set_ylabel('Probability')
141
+ ax.set_title(f'PMF of Bernoulli Distribution with p = {_p:.2f}')
142
+
143
+ # x-axis limit
144
+ ax.set_xticks([0, 1])
145
+ ax.set_xlim(-0.5, 1.5)
146
+
147
+ # y-axis w/ some padding
148
+ ax.set_ylim(0, max(probabilities) * 1.1)
149
+
150
+ # Add expectation as vertical line
151
+ ax.axvline(x=expected_value, color='red', linestyle='--',
152
+ label=f'E[X] = {expected_value:.2f}')
153
+
154
+ # Add variance annotation
155
+ ax.text(0.5, max(probabilities) * 0.8,
156
+ f'Var(X) = {variance:.3f}',
157
+ horizontalalignment='center',
158
+ bbox=dict(facecolor='white', alpha=0.7))
159
+
160
+ ax.legend()
161
+ plt.tight_layout()
162
+ plt.gca()
163
+ return ax, fig
164
+
165
+
166
+ @app.cell(hide_code=True)
167
+ def _(mo):
168
+ mo.md(
169
+ r"""
170
+ ## Proof: Expectation of a Bernoulli
171
+
172
+ If $X$ is a Bernoulli with parameter $p$, $X \sim \text{Bern}(p)$:
173
+
174
+ \begin{align}
175
+ E[X] &= \sum_x x \cdot (X=x) && \text{Definition of expectation} \\
176
+ &= 1 \cdot p + 0 \cdot (1-p) &&
177
+ X \text{ can take on values 0 and 1} \\
178
+ &= p && \text{Remove the 0 term}
179
+ \end{align}
180
+
181
+ ## Proof: Variance of a Bernoulli
182
+
183
+ If $X$ is a Bernoulli with parameter $p$, $X \sim \text{Bern}(p)$:
184
+
185
+ To compute variance, first compute $E[X^2]$:
186
+
187
+ \begin{align}
188
+ E[X^2]
189
+ &= \sum_x x^2 \cdot (X=x) &&\text{LOTUS}\\
190
+ &= 0^2 \cdot (1-p) + 1^2 \cdot p\\
191
+ &= p
192
+ \end{align}
193
+
194
+ \begin{align}
195
+ (X)
196
+ &= E[X^2] - E[X]^2&& \text{Def of variance} \\
197
+ &= p - p^2 && \text{Substitute }E[X^2]=p, E[X] = p \\
198
+ &= p (1-p) && \text{Factor out }p
199
+ \end{align}
200
+ """
201
+ )
202
+ return
203
+
204
+
205
+ @app.cell(hide_code=True)
206
+ def _(mo):
207
+ mo.md(
208
+ r"""
209
+ ## Indicator Random Variable
210
+
211
+ > **Definition**: An indicator variable is a Bernoulli random variable which takes on the value 1 if an **underlying event occurs**, and 0 _otherwise_.
212
+
213
+ Indicator random variables are a convenient way to convert the "true/false" outcome of an event into a number. That number may be easier to incorporate into an equation.
214
+
215
+ A random variable $I$ is an indicator variable for an event $A$ if $I = 1$ when $A$ occurs and $I = 0$ if $A$ does not occur. Indicator random variables are Bernoulli random variables, with $p = P(A)$. $I_A$ is a common choice of name for an indicator random variable.
216
+
217
+ Here are some properties of indicator random variables:
218
+
219
+ - $P(I=1)=P(A)$
220
+ - $E[I]=P(A)$
221
+ """
222
+ )
223
+ return
224
+
225
+
226
+ @app.cell(hide_code=True)
227
+ def _(mo):
228
+ # Simulation of Bernoulli trials
229
+ mo.md(r"""
230
+ ## Simulation of Bernoulli Trials
231
+
232
+ Let's simulate Bernoulli trials to see the law of large numbers in action. We'll flip a biased coin repeatedly and observe how the proportion of successes approaches the true probability $p$.
233
+ """)
234
+
235
+ # UI element for simulation parameters
236
+ num_trials_slider = mo.ui.slider(10, 10000, value=1000, step=10, label="Number of trials")
237
+ p_sim_slider = mo.ui.slider(0.01, 0.99, value=0.65, step=0.01, label="Success probability (p)")
238
+ return num_trials_slider, p_sim_slider
239
+
240
+
241
+ @app.cell(hide_code=True)
242
+ def _(mo):
243
+ mo.md(r"""## Simulation""")
244
+ return
245
+
246
+
247
+ @app.cell(hide_code=True)
248
+ def _(mo, num_trials_slider, p_sim_slider):
249
+ mo.hstack([num_trials_slider, p_sim_slider], justify='space-around')
250
+ return
251
+
252
+
253
+ @app.cell(hide_code=True)
254
+ def _(np, num_trials_slider, p_sim_slider, plt):
255
+ # Bernoulli trials
256
+ _num_trials = num_trials_slider.value
257
+ p = p_sim_slider.value
258
+
259
+ # Random Bernoulli trials
260
+ trials = np.random.binomial(1, p, size=_num_trials)
261
+
262
+ # Cumulative proportion of successes
263
+ cumulative_mean = np.cumsum(trials) / np.arange(1, _num_trials + 1)
264
+
265
+ # Results
266
+ plt.figure(figsize=(10, 6))
267
+ plt.plot(range(1, _num_trials + 1), cumulative_mean, label='Proportion of successes')
268
+ plt.axhline(y=p, color='r', linestyle='--', label=f'True probability (p={p})')
269
+
270
+ plt.xscale('log') # Use log scale for better visualization
271
+ plt.xlabel('Number of trials')
272
+ plt.ylabel('Proportion of successes')
273
+ plt.title('Convergence of Sample Proportion to True Probability')
274
+ plt.legend()
275
+ plt.grid(True, alpha=0.3)
276
+
277
+ # Add annotation
278
+ plt.annotate('As the number of trials increases,\nthe proportion approaches p',
279
+ xy=(_num_trials, cumulative_mean[-1]),
280
+ xytext=(_num_trials/5, p + 0.1),
281
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
282
+
283
+ plt.tight_layout()
284
+ plt.gca()
285
+ return cumulative_mean, p, trials
286
+
287
+
288
+ @app.cell(hide_code=True)
289
+ def _(mo, np, trials):
290
+ # Calculate statistics from the simulation
291
+ num_successes = np.sum(trials)
292
+ num_trials = len(trials)
293
+ proportion = num_successes / num_trials
294
+
295
+ # Display the results
296
+ mo.md(f"""
297
+ ### Simulation Results
298
+
299
+ - Number of trials: {num_trials}
300
+ - Number of successes: {num_successes}
301
+ - Proportion of successes: {proportion:.4f}
302
+
303
+ This demonstrates how the sample proportion approaches the true probability $p$ as the number of trials increases.
304
+ """)
305
+ return num_successes, num_trials, proportion
306
+
307
+
308
+ @app.cell(hide_code=True)
309
+ def _(mo):
310
+ mo.md(
311
+ r"""
312
+ ## 🤔 Test Your Understanding
313
+
314
+ Pick which of these statements about Bernoulli random variables you think are correct:
315
+
316
+ /// details | The variance of a Bernoulli random variable is always less than or equal to 0.25
317
+ ✅ Correct! The variance $p(1-p)$ reaches its maximum value of 0.25 when $p = 0.5$.
318
+ ///
319
+
320
+ /// details | The expected value of a Bernoulli random variable must be either 0 or 1
321
+ ❌ Incorrect! The expected value is $p$, which can be any value between 0 and 1.
322
+ ///
323
+
324
+ /// details | If $X \sim \text{Bern}(0.3)$ and $Y \sim \text{Bern}(0.7)$, then $X$ and $Y$ have the same variance
325
+ ✅ Correct! $\text{Var}(X) = 0.3 \times 0.7 = 0.21$ and $\text{Var}(Y) = 0.7 \times 0.3 = 0.21$.
326
+ ///
327
+
328
+ /// details | Two independent coin flips can be modeled as the sum of two Bernoulli random variables
329
+ ✅ Correct! The sum would follow a Binomial distribution with $n=2$.
330
+ ///
331
+ """
332
+ )
333
+ return
334
+
335
+
336
+ @app.cell(hide_code=True)
337
+ def _(mo):
338
+ mo.md(
339
+ r"""
340
+ ## Applications of Bernoulli Random Variables
341
+
342
+ Bernoulli random variables are used in many real-world scenarios:
343
+
344
+ 1. **Quality Control**: Testing if a manufactured item is defective (1) or not (0)
345
+
346
+ 2. **A/B Testing**: Determining if a user clicks (1) or doesn't click (0) on a website button
347
+
348
+ 3. **Medical Testing**: Checking if a patient tests positive (1) or negative (0) for a disease
349
+
350
+ 4. **Election Modeling**: Modeling if a particular voter votes for candidate A (1) or not (0)
351
+
352
+ 5. **Financial Markets**: Modeling if a stock price goes up (1) or down (0) in a simplified model
353
+
354
+ Because Bernoulli random variables are parametric, as soon as you declare a random variable to be of type Bernoulli, you automatically know all of its pre-derived properties!
355
+ """
356
+ )
357
+ return
358
+
359
+
360
+ @app.cell(hide_code=True)
361
+ def _(mo):
362
+ mo.md(
363
+ r"""
364
+ ## Summary
365
+
366
+ And that's a wrap on Bernoulli distributions! We've learnt the simplest of all probability distributions — the one that only has two possible outcomes. Flip a coin, check if an email is spam, see if your blind date shows up — these are all Bernoulli trials with success probability $p$.
367
+
368
+ The beauty of Bernoulli is in its simplicity: just set $p$ (the probability of success) and you're good to go! The PMF gives us $P(X=1) = p$ and $P(X=0) = 1-p$, while expectation is simply $p$ and variance is $p(1-p)$. Oh, and when you're tracking whether specific events happen or not? That's an indicator random variable — just another Bernoulli in disguise!
369
+
370
+ Two key things to remember:
371
+
372
+ /// note
373
+ 💡 **Maximum Variance**: A Bernoulli's variance $p(1-p)$ reaches its maximum at $p=0.5$, making a fair coin the most "unpredictable" Bernoulli random variable.
374
+
375
+ 💡 **Instant Properties**: When you identify a random variable as Bernoulli, you instantly know all its properties—expectation, variance, PMF—without additional calculations.
376
+ ///
377
+
378
+ Next up: Binomial distribution—where we'll see what happens when we let Bernoulli trials have a party and add themselves together!
379
+ """
380
+ )
381
+ return
382
+
383
+
384
+ @app.cell(hide_code=True)
385
+ def _(mo):
386
+ mo.md(r"""#### Appendix (containing helper code for the notebook)""")
387
+ return
388
+
389
+
390
+ @app.cell
391
+ def _():
392
+ import marimo as mo
393
+ return (mo,)
394
+
395
+
396
+ @app.cell(hide_code=True)
397
+ def _():
398
+ from marimo import Html
399
+ return (Html,)
400
+
401
+
402
+ @app.cell(hide_code=True)
403
+ def _():
404
+ import numpy as np
405
+ import matplotlib.pyplot as plt
406
+ from scipy import stats
407
+ import math
408
+
409
+ # Set style for consistent visualizations
410
+ plt.style.use('seaborn-v0_8-whitegrid')
411
+ plt.rcParams['figure.figsize'] = [10, 6]
412
+ plt.rcParams['font.size'] = 12
413
+
414
+ # Set random seed for reproducibility
415
+ np.random.seed(42)
416
+ return math, np, plt, stats
417
+
418
+
419
+ @app.cell(hide_code=True)
420
+ def _(mo):
421
+ # Create a UI element for the parameter p
422
+ p_slider = mo.ui.slider(0.01, 0.99, value=0.65, step=0.01, label="Parameter p")
423
+ return (p_slider,)
424
+
425
+
426
+ if __name__ == "__main__":
427
+ app.run()
probability/14_binomial_distribution.py ADDED
@@ -0,0 +1,545 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.0",
6
+ # "numpy==2.2.4",
7
+ # "scipy==1.15.2",
8
+ # "altair==5.2.0",
9
+ # "wigglystuff==0.1.10",
10
+ # "pandas==2.2.3",
11
+ # ]
12
+ # ///
13
+
14
+ import marimo
15
+
16
+ __generated_with = "0.11.24"
17
+ app = marimo.App(width="medium", app_title="Binomial Distribution")
18
+
19
+
20
+ @app.cell(hide_code=True)
21
+ def _(mo):
22
+ mo.md(
23
+ r"""
24
+ # Binomial Distribution
25
+
26
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/binomial/), by Stanford professor Chris Piech._
27
+
28
+ In this section, we will discuss the binomial distribution. To start, imagine the following example:
29
+
30
+ Consider $n$ independent trials of an experiment where each trial is a "success" with probability $p$. Let $X$ be the number of successes in $n$ trials.
31
+
32
+ This situation is truly common in the natural world, and as such, there has been a lot of research into such phenomena. Random variables like $X$ are called **binomial random variables**. If you can identify that a process fits this description, you can inherit many already proved properties such as the PMF formula, expectation, and variance!
33
+ """
34
+ )
35
+ return
36
+
37
+
38
+ @app.cell(hide_code=True)
39
+ def _(mo):
40
+ mo.md(
41
+ r"""
42
+ ## Binomial Random Variable Definition
43
+
44
+ $X \sim \text{Bin}(n, p)$ represents a binomial random variable where:
45
+
46
+ - $X$ is our random variable (number of successes)
47
+ - $\text{Bin}$ indicates it follows a binomial distribution
48
+ - $n$ is the number of trials
49
+ - $p$ is the probability of success in each trial
50
+
51
+ ```
52
+ X ~ Bin(n, p)
53
+ ↑ ↑ ↑
54
+ | | +-- Probability of
55
+ | | success on each
56
+ | | trial
57
+ | +-- Number of trials
58
+ |
59
+ Our random variable
60
+ is distributed
61
+ as a Binomial
62
+ ```
63
+
64
+ Here are a few examples of binomial random variables:
65
+
66
+ - Number of heads in $n$ coin flips
67
+ - Number of 1's in randomly generated length $n$ bit string
68
+ - Number of disk drives crashed in 1000 computer cluster, assuming disks crash independently
69
+ """
70
+ )
71
+ return
72
+
73
+
74
+ @app.cell(hide_code=True)
75
+ def _(mo):
76
+ mo.md(
77
+ r"""
78
+ ## Properties of Binomial Distribution
79
+
80
+ | Property | Formula |
81
+ |----------|---------|
82
+ | Notation | $X \sim \text{Bin}(n, p)$ |
83
+ | Description | Number of "successes" in $n$ identical, independent experiments each with probability of success $p$ |
84
+ | Parameters | $n \in \{0, 1, \dots\}$, the number of experiments<br>$p \in [0, 1]$, the probability that a single experiment gives a "success" |
85
+ | Support | $x \in \{0, 1, \dots, n\}$ |
86
+ | PMF equation | $P(X=x) = {n \choose x}p^x(1-p)^{n-x}$ |
87
+ | Expectation | $E[X] = n \cdot p$ |
88
+ | Variance | $\text{Var}(X) = n \cdot p \cdot (1-p)$ |
89
+
90
+ Let's explore how the binomial distribution changes with different parameters.
91
+ """
92
+ )
93
+ return
94
+
95
+
96
+ @app.cell(hide_code=True)
97
+ def _(TangleSlider, mo):
98
+ # Interactive elements using TangleSlider
99
+ n_slider = mo.ui.anywidget(TangleSlider(
100
+ amount=10,
101
+ min_value=1,
102
+ max_value=30,
103
+ step=1,
104
+ digits=0,
105
+ suffix=" trials"
106
+ ))
107
+
108
+ p_slider = mo.ui.anywidget(TangleSlider(
109
+ amount=0.5,
110
+ min_value=0.01,
111
+ max_value=0.99,
112
+ step=0.01,
113
+ digits=2,
114
+ suffix=" probability"
115
+ ))
116
+
117
+ # Grid layout for the interactive controls
118
+ controls = mo.vstack([
119
+ mo.md("### Adjust Parameters to See How Binomial Distribution Changes"),
120
+ mo.hstack([
121
+ mo.md("**Number of trials (n):** "),
122
+ n_slider
123
+ ], justify="start"),
124
+ mo.hstack([
125
+ mo.md("**Probability of success (p):** "),
126
+ p_slider
127
+ ], justify="start"),
128
+ ])
129
+ return controls, n_slider, p_slider
130
+
131
+
132
+ @app.cell(hide_code=True)
133
+ def _(controls):
134
+ controls
135
+ return
136
+
137
+
138
+ @app.cell(hide_code=True)
139
+ def _(n_slider, np, p_slider, plt, stats):
140
+ # Parameters from sliders
141
+ _n = int(n_slider.amount)
142
+ _p = p_slider.amount
143
+
144
+ # Calculate PMF
145
+ _x = np.arange(0, _n + 1)
146
+ _pmf = stats.binom.pmf(_x, _n, _p)
147
+
148
+ # Relevant stats
149
+ _mean = _n * _p
150
+ _variance = _n * _p * (1 - _p)
151
+ _std_dev = np.sqrt(_variance)
152
+
153
+ _fig, _ax = plt.subplots(figsize=(10, 6))
154
+
155
+ # Plot PMF as bars
156
+ _ax.bar(_x, _pmf, color='royalblue', alpha=0.7, label=f'PMF: P(X=k)')
157
+
158
+ # Add a line
159
+ _ax.plot(_x, _pmf, 'ro-', alpha=0.6, label='PMF line')
160
+
161
+ # Add vertical lines
162
+ _ax.axvline(x=_mean, color='green', linestyle='--', linewidth=2,
163
+ label=f'Mean: {_mean:.2f}')
164
+
165
+ # Shade the stdev region
166
+ _ax.axvspan(_mean - _std_dev, _mean + _std_dev, alpha=0.2, color='green',
167
+ label=f'±1 Std Dev: {_std_dev:.2f}')
168
+
169
+ # Add labels and title
170
+ _ax.set_xlabel('Number of Successes (k)')
171
+ _ax.set_ylabel('Probability: P(X=k)')
172
+ _ax.set_title(f'Binomial Distribution with n={_n}, p={_p:.2f}')
173
+
174
+ # Annotations
175
+ _ax.annotate(f'E[X] = {_mean:.2f}',
176
+ xy=(_mean, stats.binom.pmf(int(_mean), _n, _p)),
177
+ xytext=(_mean + 1, max(_pmf) * 0.8),
178
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
179
+
180
+ _ax.annotate(f'Var(X) = {_variance:.2f}',
181
+ xy=(_mean, stats.binom.pmf(int(_mean), _n, _p) / 2),
182
+ xytext=(_mean + 1, max(_pmf) * 0.6),
183
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
184
+
185
+ # Grid and legend
186
+ _ax.grid(alpha=0.3)
187
+ _ax.legend()
188
+
189
+ plt.tight_layout()
190
+ plt.gca()
191
+ return
192
+
193
+
194
+ @app.cell(hide_code=True)
195
+ def _(mo):
196
+ mo.md(
197
+ r"""
198
+ ## Relationship to Bernoulli Random Variables
199
+
200
+ One way to think of the binomial is as the sum of $n$ Bernoulli variables. Say that $Y_i$ is an indicator Bernoulli random variable which is 1 if experiment $i$ is a success. Then if $X$ is the total number of successes in $n$ experiments, $X \sim \text{Bin}(n, p)$:
201
+
202
+ $$X = \sum_{i=1}^n Y_i$$
203
+
204
+ Recall that the outcome of $Y_i$ will be 1 or 0, so one way to think of $X$ is as the sum of those 1s and 0s.
205
+ """
206
+ )
207
+ return
208
+
209
+
210
+ @app.cell(hide_code=True)
211
+ def _(mo):
212
+ mo.md(
213
+ r"""
214
+ ## Binomial Probability Mass Function (PMF)
215
+
216
+ The most important property to know about a binomial is its [Probability Mass Function](https://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/10_probability_mass_function.py):
217
+
218
+ $$P(X=k) = {n \choose k}p^k(1-p)^{n-k}$$
219
+
220
+ ```
221
+ P(X = k) = (n) p^k(1-p)^(n-k)
222
+ ↑ (k)
223
+ | ↑
224
+ | +-- Binomial coefficient:
225
+ | number of ways to choose
226
+ | k successes from n trials
227
+ |
228
+ Probability that our
229
+ variable takes on the
230
+ value k
231
+ ```
232
+
233
+ Recall, we derived this formula in Part 1. There is a complete example on the probability of $k$ heads in $n$ coin flips, where each flip is heads with probability $p$.
234
+
235
+ To briefly review, if you think of each experiment as being distinct, then there are ${n \choose k}$ ways of permuting $k$ successes from $n$ experiments. For any of the mutually exclusive permutations, the probability of that permutation is $p^k \cdot (1-p)^{n-k}$.
236
+
237
+ The name binomial comes from the term ${n \choose k}$ which is formally called the binomial coefficient.
238
+ """
239
+ )
240
+ return
241
+
242
+
243
+ @app.cell(hide_code=True)
244
+ def _(mo):
245
+ mo.md(
246
+ r"""
247
+ ## Expectation of Binomial
248
+
249
+ There is an easy way to calculate the expectation of a binomial and a hard way. The easy way is to leverage the fact that a binomial is the sum of Bernoulli indicator random variables $X = \sum_{i=1}^{n} Y_i$ where $Y_i$ is an indicator of whether the $i$-th experiment was a success: $Y_i \sim \text{Bernoulli}(p)$.
250
+
251
+ Since the [expectation of the sum](http://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/11_expectation.py) of random variables is the sum of expectations, we can add the expectation, $E[Y_i] = p$, of each of the Bernoulli's:
252
+
253
+ \begin{align}
254
+ E[X] &= E\Big[\sum_{i=1}^{n} Y_i\Big] && \text{Since }X = \sum_{i=1}^{n} Y_i \\
255
+ &= \sum_{i=1}^{n}E[ Y_i] && \text{Expectation of sum} \\
256
+ &= \sum_{i=1}^{n}p && \text{Expectation of Bernoulli} \\
257
+ &= n \cdot p && \text{Sum $n$ times}
258
+ \end{align}
259
+
260
+ The hard way is to use the definition of expectation:
261
+
262
+ \begin{align}
263
+ E[X] &= \sum_{i=0}^n i \cdot P(X = i) && \text{Def of expectation} \\
264
+ &= \sum_{i=0}^n i \cdot {n \choose i} p^i(1-p)^{n-i} && \text{Sub in PMF} \\
265
+ & \cdots && \text{Many steps later} \\
266
+ &= n \cdot p
267
+ \end{align}
268
+ """
269
+ )
270
+ return
271
+
272
+
273
+ @app.cell(hide_code=True)
274
+ def _(mo):
275
+ mo.md(
276
+ r"""
277
+ ## Binomial Distribution in Python
278
+
279
+ As you might expect, you can use binomial distributions in code. The standardized library for binomials is `scipy.stats.binom`.
280
+
281
+ One of the most helpful methods that this package provides is a way to calculate the PMF. For example, say $n=5$, $p=0.6$ and you want to find $P(X=2)$, you could use the following code:
282
+ """
283
+ )
284
+ return
285
+
286
+
287
+ @app.cell
288
+ def _(stats):
289
+ # define variables for x, n, and p
290
+ _n = 5 # Integer value for n
291
+ _p = 0.6
292
+ _x = 2
293
+
294
+ # use scipy to compute the pmf
295
+ p_x = stats.binom.pmf(_x, _n, _p)
296
+
297
+ # use the probability for future work
298
+ print(f'P(X = {_x}) = {p_x:.4f}')
299
+ return (p_x,)
300
+
301
+
302
+ @app.cell(hide_code=True)
303
+ def _(mo):
304
+ mo.md(r"""Another particularly helpful function is the ability to generate a random sample from a binomial. For example, say $X$ represents the number of requests to a website. We can draw 100 samples from this distribution using the following code:""")
305
+ return
306
+
307
+
308
+ @app.cell
309
+ def _(n, p, stats):
310
+ n_int = int(n)
311
+
312
+ # samples from the binomial distribution
313
+ samples = stats.binom.rvs(n_int, p, size=100)
314
+
315
+ # Print the samples
316
+ print(samples)
317
+ return n_int, samples
318
+
319
+
320
+ @app.cell(hide_code=True)
321
+ def _(n_int, np, p, plt, samples, stats):
322
+ # Plot histogram of samples
323
+ plt.figure(figsize=(10, 5))
324
+ plt.hist(samples, bins=np.arange(-0.5, n_int+1.5, 1), alpha=0.7, color='royalblue',
325
+ edgecolor='black', density=True)
326
+
327
+ # Overlay the PMF
328
+ x_values = np.arange(0, n_int+1)
329
+ pmf_values = stats.binom.pmf(x_values, n_int, p)
330
+ plt.plot(x_values, pmf_values, 'ro-', ms=8, label='Theoretical PMF')
331
+
332
+ # Add labels and title
333
+ plt.xlabel('Number of Successes')
334
+ plt.ylabel('Relative Frequency / Probability')
335
+ plt.title(f'Histogram of 100 Samples from Bin({n_int}, {p})')
336
+ plt.legend()
337
+ plt.grid(alpha=0.3)
338
+
339
+ # Annotate
340
+ plt.annotate('Sample mean: %.2f' % np.mean(samples),
341
+ xy=(0.7, 0.9), xycoords='axes fraction',
342
+ bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3))
343
+ plt.annotate('Theoretical mean: %.2f' % (n_int*p),
344
+ xy=(0.7, 0.8), xycoords='axes fraction',
345
+ bbox=dict(boxstyle='round,pad=0.5', fc='lightgreen', alpha=0.3))
346
+
347
+ plt.tight_layout()
348
+ plt.gca()
349
+ return pmf_values, x_values
350
+
351
+
352
+ @app.cell(hide_code=True)
353
+ def _(mo):
354
+ mo.md(
355
+ r"""
356
+ You might be wondering what a random sample is! A random sample is a randomly chosen assignment for our random variable. Above we have 100 such assignments. The probability that value $k$ is chosen is given by the PMF: $P(X=k)$.
357
+
358
+ There are also functions for getting the mean, the variance, and more. You can read the [scipy.stats.binom documentation](https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.binom.html), especially the list of methods.
359
+ """
360
+ )
361
+ return
362
+
363
+
364
+ @app.cell(hide_code=True)
365
+ def _(mo):
366
+ mo.md(
367
+ r"""
368
+ ## Interactive Exploration of Binomial vs. Negative Binomial
369
+
370
+ The standard binomial distribution is a special case of a broader family of distributions. One related distribution is the negative binomial, which can model count data with overdispersion (where the variance is larger than the mean).
371
+
372
+ Below, you can explore how the negative binomial distribution compares to a Poisson distribution (which can be seen as a limiting case of the binomial as $n$ gets large and $p$ gets small, with $np$ held constant).
373
+
374
+ Adjust the sliders to see how the parameters affect the distribution:
375
+
376
+ *Note: The interactive visualization in this section was inspired by work from [liquidcarbon on GitHub](https://github.com/liquidcarbon).*
377
+ """
378
+ )
379
+ return
380
+
381
+
382
+ @app.cell(hide_code=True)
383
+ def _(alpha_slider, chart, equation, mo, mu_slider):
384
+ mo.vstack(
385
+ [
386
+ mo.md(f"## Negative Binomial Distribution (Poisson + Overdispersion)\n{equation}"),
387
+ mo.hstack([mu_slider, alpha_slider], justify="start"),
388
+ chart,
389
+ ], justify='space-around'
390
+ ).center()
391
+ return
392
+
393
+
394
+ @app.cell(hide_code=True)
395
+ def _(mo):
396
+ mo.md(
397
+ r"""
398
+ ## 🤔 Test Your Understanding
399
+ Pick which of these statements about binomial distributions you think are correct:
400
+
401
+ /// details | The variance of a binomial distribution is always equal to its mean
402
+ ❌ Incorrect! The variance is $np(1-p)$ while the mean is $np$. They're only equal when $p=1$ (which is a degenerate case).
403
+ ///
404
+
405
+ /// details | If $X \sim \text{Bin}(n, p)$ and $Y \sim \text{Bin}(n, 1-p)$, then $X$ and $Y$ have the same variance
406
+ ✅ Correct! $\text{Var}(X) = np(1-p)$ and $\text{Var}(Y) = n(1-p)p$, which are the same.
407
+ ///
408
+
409
+ /// details | As the number of trials increases, the binomial distribution approaches a normal distribution
410
+ ✅ Correct! For large $n$, the binomial distribution can be approximated by a normal distribution with the same mean and variance.
411
+ ///
412
+
413
+ /// details | The PMF of a binomial distribution is symmetric when $p = 0.5$
414
+ ✅ Correct! When $p = 0.5$, the PMF is symmetric around $n/2$.
415
+ ///
416
+
417
+ /// details | The sum of two independent binomial random variables with the same $p$ is also a binomial random variable
418
+ ✅ Correct! If $X \sim \text{Bin}(n_1, p)$ and $Y \sim \text{Bin}(n_2, p)$ are independent, then $X + Y \sim \text{Bin}(n_1 + n_2, p)$.
419
+ ///
420
+
421
+ /// details | The maximum value of the PMF for $\text{Bin}(n,p)$ always occurs at $k = np$
422
+ ❌ Incorrect! The mode (maximum value of PMF) is either $\lfloor (n+1)p \rfloor$ or $\lceil (n+1)p-1 \rceil$ depending on whether $(n+1)p$ is an integer.
423
+ ///
424
+ """
425
+ )
426
+ return
427
+
428
+
429
+ @app.cell(hide_code=True)
430
+ def _(mo):
431
+ mo.md(
432
+ r"""
433
+ ## Summary
434
+
435
+ So we've explored the binomial distribution, and honestly, it's one of the most practical probability distributions you'll encounter. Think about it — anytime you're counting successes in a fixed number of trials (like those coin flips we discussed), this is your go-to distribution.
436
+
437
+ I find it fascinating how the expectation is simply $np$. Such a clean, intuitive formula! And remember that neat visualization we saw earlier? When we adjusted the parameters, you could actually see how the distribution shape changes—becoming more symmetric as $n$ increases.
438
+
439
+ The key things to take away:
440
+
441
+ - The binomial distribution models the number of successes in $n$ independent trials, each with probability $p$ of success
442
+
443
+ - Its PMF is given by the formula $P(X=k) = {n \choose k}p^k(1-p)^{n-k}$, which lets us calculate exactly how likely any specific number of successes is
444
+
445
+ - The expected value is $E[X] = np$ and the variance is $Var(X) = np(1-p)$
446
+
447
+ - It's related to other distributions: it's essentially a sum of Bernoulli random variables, and connects to both the negative binomial and Poisson distributions
448
+
449
+ - In Python, the `scipy.stats.binom` module makes working with binomial distributions straightforward—you can generate random samples and calculate probabilities with just a few lines of code
450
+
451
+ You'll see the binomial distribution pop up everywhere—from computer science to quality control, epidemiology, and data science. Any time you have scenarios with binary outcomes over multiple trials, this distribution has you covered.
452
+ """
453
+ )
454
+ return
455
+
456
+
457
+ @app.cell(hide_code=True)
458
+ def _(mo):
459
+ mo.md(r"""Appendix code (helper functions, variables, etc.):""")
460
+ return
461
+
462
+
463
+ @app.cell
464
+ def _():
465
+ import marimo as mo
466
+ return (mo,)
467
+
468
+
469
+ @app.cell(hide_code=True)
470
+ def _():
471
+ import numpy as np
472
+ import matplotlib.pyplot as plt
473
+ import scipy.stats as stats
474
+ import pandas as pd
475
+ import altair as alt
476
+ from wigglystuff import TangleSlider
477
+ return TangleSlider, alt, np, pd, plt, stats
478
+
479
+
480
+ @app.cell(hide_code=True)
481
+ def _(mo):
482
+ alpha_slider = mo.ui.slider(
483
+ value=0.1,
484
+ steps=[0, 0.01, 0.02, 0.03, 0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.8, 1],
485
+ label="α (overdispersion)",
486
+ show_value=True,
487
+ )
488
+ mu_slider = mo.ui.slider(
489
+ value=100, start=1, stop=100, step=1, label="μ (mean)", show_value=True
490
+ )
491
+ return alpha_slider, mu_slider
492
+
493
+
494
+ @app.cell(hide_code=True)
495
+ def _():
496
+ equation = """
497
+ $$
498
+ P(X = k) = \\frac{\\Gamma(k + \\frac{1}{\\alpha})}{\\Gamma(k + 1) \\Gamma(\\frac{1}{\\alpha})} \\left( \\frac{1}{\\mu \\alpha + 1} \\right)^{\\frac{1}{\\alpha}} \\left( \\frac{\\mu \\alpha}{\\mu \\alpha + 1} \\right)^k
499
+ $$
500
+
501
+ $$
502
+ \\sigma^2 = \\mu + \\alpha \\mu^2
503
+ $$
504
+ """
505
+ return (equation,)
506
+
507
+
508
+ @app.cell(hide_code=True)
509
+ def _(alpha_slider, alt, mu_slider, np, pd, stats):
510
+ mu = mu_slider.value
511
+ alpha = alpha_slider.value
512
+ n = 1000 - mu if alpha == 0 else 1 / alpha
513
+ p = n / (mu + n)
514
+ x = np.arange(0, mu * 3 + 1, 1)
515
+ df = pd.DataFrame(
516
+ {
517
+ "x": x,
518
+ "y": stats.nbinom.pmf(x, n, p),
519
+ "y_poi": stats.nbinom.pmf(x, 1000 - mu, 1 - mu / 1000),
520
+ }
521
+ )
522
+ r1k = stats.nbinom.rvs(n, p, size=1000)
523
+ df["in 95% CI"] = df["x"].between(*np.percentile(r1k, q=[2.5, 97.5]))
524
+ base = alt.Chart(df)
525
+
526
+ chart_poi = base.mark_bar(
527
+ fillOpacity=0.25, width=100 / mu, fill="magenta"
528
+ ).encode(
529
+ x=alt.X("x").scale(domain=(-0.4, x.max() + 0.4), nice=False),
530
+ y=alt.Y("y_poi").scale(domain=(0, df.y_poi.max() * 1.1)).title(None),
531
+ )
532
+ chart_nb = base.mark_bar(fillOpacity=0.75, width=100 / mu).encode(
533
+ x="x",
534
+ y="y",
535
+ fill=alt.Fill("in 95% CI")
536
+ .scale(domain=[False, True], range=["#aaa", "#7c7"])
537
+ .legend(orient="bottom-right"),
538
+ )
539
+
540
+ chart = (chart_poi + chart_nb).configure_view(continuousWidth=450)
541
+ return alpha, base, chart, chart_nb, chart_poi, df, mu, n, p, r1k, x
542
+
543
+
544
+ if __name__ == "__main__":
545
+ app.run()
probability/15_poisson_distribution.py ADDED
@@ -0,0 +1,805 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # /// script
2
+ # requires-python = ">=3.10"
3
+ # dependencies = [
4
+ # "marimo",
5
+ # "matplotlib==3.10.0",
6
+ # "numpy==2.2.4",
7
+ # "scipy==1.15.2",
8
+ # "altair==5.2.0",
9
+ # "wigglystuff==0.1.10",
10
+ # "pandas==2.2.3",
11
+ # ]
12
+ # ///
13
+
14
+ import marimo
15
+
16
+ __generated_with = "0.11.25"
17
+ app = marimo.App(width="medium", app_title="Poisson Distribution")
18
+
19
+
20
+ @app.cell(hide_code=True)
21
+ def _(mo):
22
+ mo.md(
23
+ r"""
24
+ # Poisson Distribution
25
+
26
+ _This notebook is a computational companion to ["Probability for Computer Scientists"](https://chrispiech.github.io/probabilityForComputerScientists/en/part2/poisson/), by Stanford professor Chris Piech._
27
+
28
+ A Poisson random variable gives the probability of a given number of events in a fixed interval of time (or space). It makes the Poisson assumption that events occur with a known constant mean rate and independently of the time since the last event.
29
+ """
30
+ )
31
+ return
32
+
33
+
34
+ @app.cell(hide_code=True)
35
+ def _(mo):
36
+ mo.md(
37
+ r"""
38
+ ## Poisson Random Variable Definition
39
+
40
+ $X \sim \text{Poisson}(\lambda)$ represents a Poisson random variable where:
41
+
42
+ - $X$ is our random variable (number of events)
43
+ - $\text{Poisson}$ indicates it follows a Poisson distribution
44
+ - $\lambda$ is the rate parameter (average number of events per time interval)
45
+
46
+ ```
47
+ X ~ Poisson(λ)
48
+ ↑ ↑ ↑
49
+ | | +-- Rate parameter:
50
+ | | average number of
51
+ | | events per interval
52
+ | +-- Indicates Poisson
53
+ | distribution
54
+ |
55
+ Our random variable
56
+ counting number of events
57
+ ```
58
+
59
+ The Poisson distribution is particularly useful when:
60
+
61
+ 1. Events occur independently of each other
62
+ 2. The average rate of occurrence is constant
63
+ 3. Two events cannot occur at exactly the same instant
64
+ 4. The probability of an event is proportional to the length of the time interval
65
+ """
66
+ )
67
+ return
68
+
69
+
70
+ @app.cell(hide_code=True)
71
+ def _(mo):
72
+ mo.md(
73
+ r"""
74
+ ## Properties of Poisson Distribution
75
+
76
+ | Property | Formula |
77
+ |----------|---------|
78
+ | Notation | $X \sim \text{Poisson}(\lambda)$ |
79
+ | Description | Number of events in a fixed time frame if (a) events occur with a constant mean rate and (b) they occur independently of time since last event |
80
+ | Parameters | $\lambda \in \mathbb{R}^{+}$, the constant average rate |
81
+ | Support | $x \in \{0, 1, \dots\}$ |
82
+ | PMF equation | $P(X=x) = \frac{\lambda^x e^{-\lambda}}{x!}$ |
83
+ | Expectation | $E[X] = \lambda$ |
84
+ | Variance | $\text{Var}(X) = \lambda$ |
85
+
86
+ Note that unlike many other distributions, the Poisson distribution's mean and variance are equal, both being $\lambda$.
87
+
88
+ Let's explore how the Poisson distribution changes with different rate parameters.
89
+ """
90
+ )
91
+ return
92
+
93
+
94
+ @app.cell(hide_code=True)
95
+ def _(TangleSlider, mo):
96
+ # interactive elements using TangleSlider
97
+ lambda_slider = mo.ui.anywidget(TangleSlider(
98
+ amount=5,
99
+ min_value=0.1,
100
+ max_value=20,
101
+ step=0.1,
102
+ digits=1,
103
+ suffix=" events"
104
+ ))
105
+
106
+ # interactive controls
107
+ _controls = mo.vstack([
108
+ mo.md("### Adjust the Rate Parameter to See How Poisson Distribution Changes"),
109
+ mo.hstack([
110
+ mo.md("**Rate parameter (λ):** "),
111
+ lambda_slider,
112
+ mo.md("**events per interval.** Higher values shift the distribution rightward and make it more spread out.")
113
+ ], justify="start"),
114
+ ])
115
+ _controls
116
+ return (lambda_slider,)
117
+
118
+
119
+ @app.cell(hide_code=True)
120
+ def _(lambda_slider, np, plt, stats):
121
+ def create_poisson_pmf_plot(lambda_value):
122
+ """Create a visualization of Poisson PMF with annotations for mean and variance."""
123
+ # PMF for values
124
+ max_x = max(20, int(lambda_value * 3)) # Show at least up to 3*lambda
125
+ x = np.arange(0, max_x + 1)
126
+ pmf = stats.poisson.pmf(x, lambda_value)
127
+
128
+ # Relevant key statistics
129
+ mean = lambda_value # For Poisson, mean = lambda
130
+ variance = lambda_value # For Poisson, variance = lambda
131
+ std_dev = np.sqrt(variance)
132
+
133
+ # plot
134
+ fig, ax = plt.subplots(figsize=(10, 6))
135
+
136
+ # PMF as bars
137
+ ax.bar(x, pmf, color='royalblue', alpha=0.7, label=f'PMF: P(X=k)')
138
+
139
+ # for the PMF values
140
+ ax.plot(x, pmf, 'ro-', alpha=0.6, label='PMF line')
141
+
142
+ # Vertical lines - mean and key values
143
+ ax.axvline(x=mean, color='green', linestyle='--', linewidth=2,
144
+ label=f'Mean: {mean:.2f}')
145
+
146
+ # Stdev region
147
+ ax.axvspan(mean - std_dev, mean + std_dev, alpha=0.2, color='green',
148
+ label=f'±1 Std Dev: {std_dev:.2f}')
149
+
150
+ ax.set_xlabel('Number of Events (k)')
151
+ ax.set_ylabel('Probability: P(X=k)')
152
+ ax.set_title(f'Poisson Distribution with λ={lambda_value:.1f}')
153
+
154
+ # annotations
155
+ ax.annotate(f'E[X] = {mean:.2f}',
156
+ xy=(mean, stats.poisson.pmf(int(mean), lambda_value)),
157
+ xytext=(mean + 1, max(pmf) * 0.8),
158
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
159
+
160
+ ax.annotate(f'Var(X) = {variance:.2f}',
161
+ xy=(mean, stats.poisson.pmf(int(mean), lambda_value) / 2),
162
+ xytext=(mean + 1, max(pmf) * 0.6),
163
+ arrowprops=dict(facecolor='black', shrink=0.05, width=1))
164
+
165
+ ax.grid(alpha=0.3)
166
+ ax.legend()
167
+
168
+ plt.tight_layout()
169
+ return plt.gca()
170
+
171
+ # Get parameter from slider and create plot
172
+ _lambda = lambda_slider.amount
173
+ create_poisson_pmf_plot(_lambda)
174
+ return (create_poisson_pmf_plot,)
175
+
176
+
177
+ @app.cell(hide_code=True)
178
+ def _(mo):
179
+ mo.md(
180
+ r"""
181
+ ## Poisson Intuition: Relation to Binomial Distribution
182
+
183
+ The Poisson distribution can be derived as a limiting case of the [binomial distribution](http://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/14_binomial_distribution.py).
184
+
185
+ Let's work on a practical example: predicting the number of ride-sharing requests in a specific area over a one-minute interval. From historical data, we know that the average number of requests per minute is $\lambda = 5$.
186
+
187
+ We could approximate this using a binomial distribution by dividing our minute into smaller intervals. For example, we can divide a minute into 60 seconds and treat each second as a [Bernoulli trial](http://marimo.app/https://github.com/marimo-team/learn/blob/main/probability/13_bernoulli_distribution.py) - either there's a request (success) or there isn't (failure).
188
+
189
+ Let's visualize this concept:
190
+ """
191
+ )
192
+ return
193
+
194
+
195
+ @app.cell(hide_code=True)
196
+ def _(fig_to_image, mo, plt):
197
+ def create_time_division_visualization():
198
+ # visualization of dividing a minute into 60 seconds
199
+ fig, ax = plt.subplots(figsize=(12, 2))
200
+
201
+ # Example events hardcoded at 2.75s and 7.12s
202
+ events = [2.75, 7.12]
203
+
204
+ # array of 60 rectangles
205
+ for i in range(60):
206
+ color = 'royalblue' if any(i <= e < i+1 for e in events) else 'lightgray'
207
+ ax.add_patch(plt.Rectangle((i, 0), 0.9, 1, color=color))
208
+
209
+ # markers for events
210
+ for e in events:
211
+ ax.plot(e, 0.5, 'ro', markersize=10)
212
+
213
+ # labels
214
+ ax.set_xlim(0, 60)
215
+ ax.set_ylim(0, 1)
216
+ ax.set_yticks([])
217
+ ax.set_xticks([0, 15, 30, 45, 60])
218
+ ax.set_xticklabels(['0s', '15s', '30s', '45s', '60s'])
219
+ ax.set_xlabel('Time (seconds)')
220
+ ax.set_title('One Minute Divided into 60 Second Intervals')
221
+
222
+ plt.tight_layout()
223
+ plt.gca()
224
+ return fig, events, i
225
+
226
+ # Create visualization and convert to image
227
+ _fig, _events, i = create_time_division_visualization()
228
+ _img = mo.image(fig_to_image(_fig), width="100%")
229
+
230
+ # explanation
231
+ _explanation = mo.md(
232
+ r"""
233
+ In this visualization:
234
+
235
+ - Each rectangle represents a 1-second interval
236
+ - Blue rectangles indicate intervals where an event occurred
237
+ - Red dots show the actual event times (2.75s and 7.12s)
238
+
239
+ If we treat this as a binomial experiment with 60 trials (seconds), we can calculate probabilities using the binomial PMF. But there's a problem: what if multiple events occur within the same second? To address this, we can divide our minute into smaller intervals.
240
+ """
241
+ )
242
+ mo.vstack([_fig, _explanation])
243
+ return create_time_division_visualization, i
244
+
245
+
246
+ @app.cell(hide_code=True)
247
+ def _(mo):
248
+ mo.md(
249
+ r"""
250
+ The total number of requests received over the minute can be approximated as the sum of the sixty indicator variables, which conveniently matches the description of a binomial — a sum of Bernoullis.
251
+
252
+ Specifically, if we define $X$ to be the number of requests in a minute, $X$ is a binomial with $n=60$ trials. What is the probability, $p$, of a success on a single trial? To make the expectation of $X$ equal the observed historical average $\lambda$, we should choose $p$ so that:
253
+
254
+ \begin{align}
255
+ \lambda &= E[X] && \text{Expectation matches historical average} \\
256
+ \lambda &= n \cdot p && \text{Expectation of a Binomial is } n \cdot p \\
257
+ p &= \frac{\lambda}{n} && \text{Solving for $p$}
258
+ \end{align}
259
+
260
+ In this case, since $\lambda=5$ and $n=60$, we should choose $p=\frac{5}{60}=\frac{1}{12}$ and state that $X \sim \text{Bin}(n=60, p=\frac{5}{60})$. Now we can calculate the probability of different numbers of requests using the binomial PMF:
261
+
262
+ $P(X = x) = {n \choose x} p^x (1-p)^{n-x}$
263
+
264
+ For example:
265
+
266
+ \begin{align}
267
+ P(X=1) &= {60 \choose 1} (5/60)^1 (55/60)^{60-1} \approx 0.0295 \\
268
+ P(X=2) &= {60 \choose 2} (5/60)^2 (55/60)^{60-2} \approx 0.0790 \\
269
+ P(X=3) &= {60 \choose 3} (5/60)^3 (55/60)^{60-3} \approx 0.1389
270
+ \end{align}
271
+
272
+ This is a good approximation, but it doesn't account for the possibility of multiple events in a single second. One solution is to divide our minute into even more fine-grained intervals. Let's try 600 deciseconds (tenths of a second):
273
+ """
274
+ )
275
+ return
276
+
277
+
278
+ @app.cell(hide_code=True)
279
+ def _(fig_to_image, mo, plt):
280
+ def create_decisecond_visualization(e_value):
281
+ # (Just showing the first 100 for clarity)
282
+ fig, ax = plt.subplots(figsize=(12, 2))
283
+
284
+ # Example events at 2.75s and 7.12s (convert to deciseconds)
285
+ events = [27.5, 71.2]
286
+
287
+ for i in range(100):
288
+ color = 'royalblue' if any(i <= event_val < i + 1 for event_val in events) else 'lightgray'
289
+ ax.add_patch(plt.Rectangle((i, 0), 0.9, 1, color=color))
290
+
291
+ # Markers for events
292
+ for event in events:
293
+ if event < 100: # Only show events in our visible range
294
+ ax.plot(event/10, 0.5, 'ro', markersize=10) # Divide by 10 to convert to deciseconds
295
+
296
+ # Add labels
297
+ ax.set_xlim(0, 100)
298
+ ax.set_ylim(0, 1)
299
+ ax.set_yticks([])
300
+ ax.set_xticks([0, 20, 40, 60, 80, 100])
301
+ ax.set_xticklabels(['0s', '2s', '4s', '6s', '8s', '10s'])
302
+ ax.set_xlabel('Time (first 10 seconds shown)')
303
+ ax.set_title('One Minute Divided into 600 Decisecond Intervals (first 100 shown)')
304
+
305
+ plt.tight_layout()
306
+ plt.gca()
307
+ return fig
308
+
309
+ # Create viz and convert to image
310
+ _fig = create_decisecond_visualization(e_value=5)
311
+ _img = mo.image(fig_to_image(_fig), width="100%")
312
+
313
+ # Explanation
314
+ _explanation = mo.md(
315
+ r"""
316
+ With $n=600$ and $p=\frac{5}{600}=\frac{1}{120}$, we can recalculate our probabilities:
317
+
318
+ \begin{align}
319
+ P(X=1) &= {600 \choose 1} (5/600)^1 (595/600)^{600-1} \approx 0.0333 \\
320
+ P(X=2) &= {600 \choose 2} (5/600)^2 (595/600)^{600-2} \approx 0.0837 \\
321
+ P(X=3) &= {600 \choose 3} (5/600)^3 (595/600)^{600-3} \approx 0.1402
322
+ \end{align}
323
+
324
+ As we make our intervals smaller (increasing $n$), our approximation becomes more accurate.
325
+ """
326
+ )
327
+ mo.vstack([_fig, _explanation])
328
+ return (create_decisecond_visualization,)
329
+
330
+
331
+ @app.cell(hide_code=True)
332
+ def _(mo):
333
+ mo.md(
334
+ r"""
335
+ ## The Binomial Distribution in the Limit
336
+
337
+ What happens if we continue dividing our time interval into smaller and smaller pieces? Let's explore how the probabilities change as we increase the number of intervals:
338
+ """
339
+ )
340
+ return
341
+
342
+
343
+ @app.cell(hide_code=True)
344
+ def _(mo):
345
+ intervals_slider = mo.ui.slider(
346
+ start = 60,
347
+ stop = 10000,
348
+ step=100,
349
+ value=600,
350
+ label="Number of intervals to divide a minute")
351
+ return (intervals_slider,)
352
+
353
+
354
+ @app.cell(hide_code=True)
355
+ def _(intervals_slider):
356
+ intervals_slider
357
+ return
358
+
359
+
360
+ @app.cell(hide_code=True)
361
+ def _(intervals_slider, np, pd, plt, stats):
362
+ def create_comparison_plot(n, lambda_value):
363
+ # Calculate probability
364
+ p = lambda_value / n
365
+
366
+ # Binomial probabilities
367
+ x_values = np.arange(0, 15)
368
+ binom_pmf = stats.binom.pmf(x_values, n, p)
369
+
370
+ # True Poisson probabilities
371
+ poisson_pmf = stats.poisson.pmf(x_values, lambda_value)
372
+
373
+ # DF for comparison
374
+ df = pd.DataFrame({
375
+ 'Events': x_values,
376
+ f'Binomial(n={n}, p={p:.6f})': binom_pmf,
377
+ f'Poisson(λ=5)': poisson_pmf,
378
+ 'Difference': np.abs(binom_pmf - poisson_pmf)
379
+ })
380
+
381
+ # Plot both PMFs
382
+ fig, ax = plt.subplots(figsize=(10, 6))
383
+
384
+ # Bar plot for the binomial
385
+ ax.bar(x_values - 0.2, binom_pmf, width=0.4, alpha=0.7,
386
+ color='royalblue', label=f'Binomial(n={n}, p={p:.6f})')
387
+
388
+ # Bar plot for the Poisson
389
+ ax.bar(x_values + 0.2, poisson_pmf, width=0.4, alpha=0.7,
390
+ color='crimson', label='Poisson(λ=5)')
391
+
392
+ # Labels and title
393
+ ax.set_xlabel('Number of Events (k)')
394
+ ax.set_ylabel('Probability')
395
+ ax.set_title(f'Comparison of Binomial and Poisson PMFs with n={n}')
396
+ ax.legend()
397
+ ax.set_xticks(x_values)
398
+ ax.grid(alpha=0.3)
399
+
400
+ plt.tight_layout()
401
+ return df, fig, n, p
402
+
403
+ # Number of intervals from the slider
404
+ n = intervals_slider.value
405
+ _lambda = 5 # Fixed lambda for our example
406
+
407
+ # Cromparison plot
408
+ df, fig, n, p = create_comparison_plot(n, _lambda)
409
+ return create_comparison_plot, df, fig, n, p
410
+
411
+
412
+ @app.cell(hide_code=True)
413
+ def _(df, fig, fig_to_image, mo, n, p):
414
+ # table of values
415
+ _styled_df = df.style.format({
416
+ f'Binomial(n={n}, p={p:.6f})': '{:.6f}',
417
+ f'Poisson(λ=5)': '{:.6f}',
418
+ 'Difference': '{:.6f}'
419
+ })
420
+
421
+ # Calculate the max absolute difference
422
+ _max_diff = df['Difference'].max()
423
+
424
+ # output
425
+ _chart = mo.image(fig_to_image(fig), width="100%")
426
+ _explanation = mo.md(f"**Maximum absolute difference between distributions: {_max_diff:.6f}**")
427
+ _table = mo.ui.table(df)
428
+
429
+ mo.vstack([_chart, _explanation, _table])
430
+ return
431
+
432
+
433
+ @app.cell(hide_code=True)
434
+ def _(mo):
435
+ mo.md(
436
+ r"""
437
+ As you can see from the interactive comparison above, as the number of intervals increases, the binomial distribution approaches the Poisson distribution! This is not a coincidence - the Poisson distribution is actually the limiting case of the binomial distribution when:
438
+
439
+ - The number of trials $n$ approaches infinity
440
+ - The probability of success $p$ approaches zero
441
+ - The product $np = \lambda$ remains constant
442
+
443
+ This relationship is why the Poisson distribution is so useful - it's easier to work with than a binomial with a very large number of trials and a very small probability of success.
444
+
445
+ ## Derivation of the Poisson PMF
446
+
447
+ Let's derive the Poisson PMF by taking the limit of the binomial PMF as $n \to \infty$. We start with:
448
+
449
+ $P(X=x) = \lim_{n \rightarrow \infty} {n \choose x} (\lambda / n)^x(1-\lambda/n)^{n-x}$
450
+
451
+ While this expression looks intimidating, it simplifies nicely:
452
+
453
+ \begin{align}
454
+ P(X=x)
455
+ &= \lim_{n \rightarrow \infty} {n \choose x} (\lambda / n)^x(1-\lambda/n)^{n-x}
456
+ && \text{Start: binomial in the limit}\\
457
+ &= \lim_{n \rightarrow \infty}
458
+ {n \choose x} \cdot
459
+ \frac{\lambda^x}{n^x} \cdot
460
+ \frac{(1-\lambda/n)^{n}}{(1-\lambda/n)^{x}}
461
+ && \text{Expanding the power terms} \\
462
+ &= \lim_{n \rightarrow \infty}
463
+ \frac{n!}{(n-x)!x!} \cdot
464
+ \frac{\lambda^x}{n^x} \cdot
465
+ \frac{(1-\lambda/n)^{n}}{(1-\lambda/n)^{x}}
466
+ && \text{Expanding the binomial term} \\
467
+ &= \lim_{n \rightarrow \infty}
468
+ \frac{n!}{(n-x)!x!} \cdot
469
+ \frac{\lambda^x}{n^x} \cdot
470
+ \frac{e^{-\lambda}}{(1-\lambda/n)^{x}}
471
+ && \text{Using limit rule } \lim_{n \rightarrow \infty}(1-\lambda/n)^{n} = e^{-\lambda}\\
472
+ &= \lim_{n \rightarrow \infty}
473
+ \frac{n!}{(n-x)!x!} \cdot
474
+ \frac{\lambda^x}{n^x} \cdot
475
+ \frac{e^{-\lambda}}{1}
476
+ && \text{As } n \to \infty \text{, } \lambda/n \to 0\\
477
+ &= \lim_{n \rightarrow \infty}
478
+ \frac{n!}{(n-x)!} \cdot
479
+ \frac{1}{x!} \cdot
480
+ \frac{\lambda^x}{n^x} \cdot
481
+ e^{-\lambda}
482
+ && \text{Rearranging terms}\\
483
+ &= \lim_{n \rightarrow \infty}
484
+ \frac{n^x}{1} \cdot
485
+ \frac{1}{x!} \cdot
486
+ \frac{\lambda^x}{n^x} \cdot
487
+ e^{-\lambda}
488
+ && \text{As } n \to \infty \text{, } \frac{n!}{(n-x)!} \approx n^x\\
489
+ &= \lim_{n \rightarrow \infty}
490
+ \frac{\lambda^x}{x!} \cdot
491
+ e^{-\lambda}
492
+ && \text{Canceling } n^x\\
493
+ &=
494
+ \frac{\lambda^x \cdot e^{-\lambda}}{x!}
495
+ && \text{Simplifying}\\
496
+ \end{align}
497
+
498
+ This gives us our elegant Poisson PMF formula: $P(X=x) = \frac{\lambda^x \cdot e^{-\lambda}}{x!}$
499
+ """
500
+ )
501
+ return
502
+
503
+
504
+ @app.cell(hide_code=True)
505
+ def _(mo):
506
+ mo.md(
507
+ r"""
508
+ ## Poisson Distribution in Python
509
+
510
+ Python's `scipy.stats` module provides functions to work with the Poisson distribution. Let's see how to calculate probabilities and generate random samples.
511
+
512
+ First, let's calculate some probabilities for our ride-sharing example with $\lambda = 5$:
513
+ """
514
+ )
515
+ return
516
+
517
+
518
+ @app.cell
519
+ def _(stats):
520
+ _lambda = 5
521
+
522
+ # Calculate probabilities for X = 1, 2, 3
523
+ p_1 = stats.poisson.pmf(1, _lambda)
524
+ p_2 = stats.poisson.pmf(2, _lambda)
525
+ p_3 = stats.poisson.pmf(3, _lambda)
526
+
527
+ print(f"P(X=1) = {p_1:.5f}")
528
+ print(f"P(X=2) = {p_2:.5f}")
529
+ print(f"P(X=3) = {p_3:.5f}")
530
+
531
+ # Calculate cumulative probability P(X ≤ 3)
532
+ p_leq_3 = stats.poisson.cdf(3, _lambda)
533
+ print(f"P(X≤3) = {p_leq_3:.5f}")
534
+
535
+ # Calculate probability P(X > 10)
536
+ p_gt_10 = 1 - stats.poisson.cdf(10, _lambda)
537
+ print(f"P(X>10) = {p_gt_10:.5f}")
538
+ return p_1, p_2, p_3, p_gt_10, p_leq_3
539
+
540
+
541
+ @app.cell(hide_code=True)
542
+ def _(mo):
543
+ mo.md(r"""We can also generate random samples from a Poisson distribution and visualize their distribution:""")
544
+ return
545
+
546
+
547
+ @app.cell(hide_code=True)
548
+ def _(np, plt, stats):
549
+ def create_samples_plot(lambda_value, sample_size=1000):
550
+ # Random samples
551
+ samples = stats.poisson.rvs(lambda_value, size=sample_size)
552
+
553
+ # theoretical PMF
554
+ x_values = np.arange(0, max(samples) + 1)
555
+ pmf_values = stats.poisson.pmf(x_values, lambda_value)
556
+
557
+ # histograms to compare
558
+ fig, ax = plt.subplots(figsize=(10, 6))
559
+
560
+ # samples as a histogram
561
+ ax.hist(samples, bins=np.arange(-0.5, max(samples) + 1.5, 1),
562
+ alpha=0.7, density=True, label='Random Samples')
563
+
564
+ # theoretical PMF
565
+ ax.plot(x_values, pmf_values, 'ro-', label='Theoretical PMF')
566
+
567
+ # labels and title
568
+ ax.set_xlabel('Number of Events')
569
+ ax.set_ylabel('Relative Frequency / Probability')
570
+ ax.set_title(f'1000 Random Samples from Poisson(λ={lambda_value})')
571
+ ax.legend()
572
+ ax.grid(alpha=0.3)
573
+
574
+ # annotations
575
+ ax.annotate(f'Sample Mean: {np.mean(samples):.2f}',
576
+ xy=(0.7, 0.9), xycoords='axes fraction',
577
+ bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.3))
578
+ ax.annotate(f'Theoretical Mean: {lambda_value:.2f}',
579
+ xy=(0.7, 0.8), xycoords='axes fraction',
580
+ bbox=dict(boxstyle='round,pad=0.5', fc='lightgreen', alpha=0.3))
581
+
582
+ plt.tight_layout()
583
+ return plt.gca()
584
+
585
+ # Use a lambda value of 5 for this example
586
+ _lambda = 5
587
+ create_samples_plot(_lambda)
588
+ return (create_samples_plot,)
589
+
590
+
591
+ @app.cell(hide_code=True)
592
+ def _(mo):
593
+ mo.md(
594
+ r"""
595
+ ## Changing Time Frames
596
+
597
+ One important property of the Poisson distribution is that the rate parameter $\lambda$ scales linearly with the time interval. If events occur at a rate of $\lambda$ per unit time, then over a period of $t$ units, the rate parameter becomes $\lambda \cdot t$.
598
+
599
+ For example, if a website receives an average of 5 requests per minute, what is the distribution of requests over a 20-minute period?
600
+
601
+ The rate parameter for the 20-minute period would be $\lambda = 5 \cdot 20 = 100$ requests.
602
+ """
603
+ )
604
+ return
605
+
606
+
607
+ @app.cell(hide_code=True)
608
+ def _(mo):
609
+ rate_slider = mo.ui.slider(
610
+ start = 0.1,
611
+ stop = 10,
612
+ step=0.1,
613
+ value=5,
614
+ label="Rate per unit time (λ)"
615
+ )
616
+
617
+ time_slider = mo.ui.slider(
618
+ start = 1,
619
+ stop = 60,
620
+ step=1,
621
+ value=20,
622
+ label="Time period (t units)"
623
+ )
624
+
625
+ controls = mo.vstack([
626
+ mo.md("### Adjust Parameters to See How Time Scaling Works"),
627
+ mo.hstack([rate_slider, time_slider], justify="space-between")
628
+ ])
629
+ return controls, rate_slider, time_slider
630
+
631
+
632
+ @app.cell
633
+ def _(controls):
634
+ controls.center()
635
+ return
636
+
637
+
638
+ @app.cell(hide_code=True)
639
+ def _(mo, np, plt, rate_slider, stats, time_slider):
640
+ def create_time_scaling_plot(rate, time_period):
641
+ # scaled rate parameter
642
+ lambda_value = rate * time_period
643
+
644
+ # PMF for values
645
+ max_x = max(30, int(lambda_value * 1.5))
646
+ x = np.arange(0, max_x + 1)
647
+ pmf = stats.poisson.pmf(x, lambda_value)
648
+
649
+ # plot
650
+ fig, ax = plt.subplots(figsize=(10, 6))
651
+
652
+ # PMF as bars
653
+ ax.bar(x, pmf, color='royalblue', alpha=0.7,
654
+ label=f'PMF: Poisson(λ={lambda_value:.1f})')
655
+
656
+ # vertical line for mean
657
+ ax.axvline(x=lambda_value, color='red', linestyle='--', linewidth=2,
658
+ label=f'Mean = {lambda_value:.1f}')
659
+
660
+ # labels and title
661
+ ax.set_xlabel('Number of Events')
662
+ ax.set_ylabel('Probability')
663
+ ax.set_title(f'Poisson Distribution Over {time_period} Units (Rate = {rate}/unit)')
664
+
665
+ # better visualization if lambda is large
666
+ if lambda_value > 10:
667
+ ax.set_xlim(lambda_value - 4*np.sqrt(lambda_value),
668
+ lambda_value + 4*np.sqrt(lambda_value))
669
+
670
+ ax.legend()
671
+ ax.grid(alpha=0.3)
672
+
673
+ plt.tight_layout()
674
+
675
+ # Create relevant info markdown
676
+ info_text = f"""
677
+ When the rate is **{rate}** events per unit time and we observe for **{time_period}** units:
678
+
679
+ - The expected number of events is **{lambda_value:.1f}**
680
+ - The variance is also **{lambda_value:.1f}**
681
+ - The standard deviation is **{np.sqrt(lambda_value):.2f}**
682
+ - P(X=0) = {stats.poisson.pmf(0, lambda_value):.4f} (probability of no events)
683
+ - P(X≥10) = {1 - stats.poisson.cdf(9, lambda_value):.4f} (probability of 10 or more events)
684
+ """
685
+
686
+ return plt.gca(), info_text
687
+
688
+ # parameters from sliders
689
+ _rate = rate_slider.value
690
+ _time = time_slider.value
691
+
692
+ # store
693
+ _plot, _info_text = create_time_scaling_plot(_rate, _time)
694
+
695
+ # Display info as markdown
696
+ info = mo.md(_info_text)
697
+
698
+ mo.vstack([_plot, info], justify="center")
699
+ return create_time_scaling_plot, info
700
+
701
+
702
+ @app.cell(hide_code=True)
703
+ def _(mo):
704
+ mo.md(
705
+ r"""
706
+ ## 🤔 Test Your Understanding
707
+ Pick which of these statements about Poisson distributions you think are correct:
708
+
709
+ /// details | The variance of a Poisson distribution is always equal to its mean
710
+ ✅ Correct! For a Poisson distribution with parameter $\lambda$, both the mean and variance equal $\lambda$.
711
+ ///
712
+
713
+ /// details | The Poisson distribution can be used to model the number of successes in a fixed number of trials
714
+ ❌ Incorrect! That's the binomial distribution. The Poisson distribution models the number of events in a fixed interval of time or space, not a fixed number of trials.
715
+ ///
716
+
717
+ /// details | If $X \sim \text{Poisson}(\lambda_1)$ and $Y \sim \text{Poisson}(\lambda_2)$ are independent, then $X + Y \sim \text{Poisson}(\lambda_1 + \lambda_2)$
718
+ ✅ Correct! The sum of independent Poisson random variables is also a Poisson random variable with parameter equal to the sum of the individual parameters.
719
+ ///
720
+
721
+ /// details | As $\lambda$ increases, the Poisson distribution approaches a normal distribution
722
+ ✅ Correct! For large values of $\lambda$ (generally $\lambda > 10$), the Poisson distribution is approximately normal with mean $\lambda$ and variance $\lambda$.
723
+ ///
724
+
725
+ /// details | The probability of zero events in a Poisson process is always less than the probability of one event
726
+ ❌ Incorrect! For $\lambda < 1$, the probability of zero events ($e^{-\lambda}$) is actually greater than the probability of one event ($\lambda e^{-\lambda}$).
727
+ ///
728
+
729
+ /// details | The Poisson distribution has a single parameter $\lambda$, which always equals the average number of events per time period
730
+ ✅ Correct! The parameter $\lambda$ represents the average rate of events, and it uniquely defines the distribution.
731
+ ///
732
+ """
733
+ )
734
+ return
735
+
736
+
737
+ @app.cell(hide_code=True)
738
+ def _(mo):
739
+ mo.md(
740
+ r"""
741
+ ## Summary
742
+
743
+ The Poisson distribution is one of those incredibly useful tools that shows up all over the place. I've always found it fascinating how such a simple formula can model so many real-world phenomena - from website traffic to radioactive decay.
744
+
745
+ What makes the Poisson really cool is that it emerges naturally as we try to model rare events occurring over a continuous interval. Remember that visualization where we kept dividing time into smaller and smaller chunks? As we showed, when you take a binomial distribution and let the number of trials approach infinity while keeping the expected value constant, you end up with the elegant Poisson formula.
746
+
747
+ The key things to remember about the Poisson distribution:
748
+
749
+ - It models the number of events occurring in a fixed interval of time or space, assuming events happen at a constant average rate and independently of each other
750
+
751
+ - Its PMF is given by the elegantly simple formula $P(X=k) = \frac{\lambda^k e^{-\lambda}}{k!}$
752
+
753
+ - Both the mean and variance equal the parameter $\lambda$, which represents the average number of events per interval
754
+
755
+ - It's related to the binomial distribution as a limiting case when $n \to \infty$, $p \to 0$, and $np = \lambda$ remains constant
756
+
757
+ - The rate parameter scales linearly with the length of the interval - if events occur at rate $\lambda$ per unit time, then over $t$ units, the parameter becomes $\lambda t$
758
+
759
+ From modeling website traffic and customer arrivals to defects in manufacturing and radioactive decay, the Poisson distribution provides a powerful and mathematically elegant way to understand random occurrences in our world.
760
+ """
761
+ )
762
+ return
763
+
764
+
765
+ @app.cell(hide_code=True)
766
+ def _(mo):
767
+ mo.md(r"""Appendix code (helper functions, variables, etc.):""")
768
+ return
769
+
770
+
771
+ @app.cell
772
+ def _():
773
+ import marimo as mo
774
+ return (mo,)
775
+
776
+
777
+ @app.cell(hide_code=True)
778
+ def _():
779
+ import numpy as np
780
+ import matplotlib.pyplot as plt
781
+ import scipy.stats as stats
782
+ import pandas as pd
783
+ import altair as alt
784
+ from wigglystuff import TangleSlider
785
+ return TangleSlider, alt, np, pd, plt, stats
786
+
787
+
788
+ @app.cell(hide_code=True)
789
+ def _():
790
+ import io
791
+ import base64
792
+ from matplotlib.figure import Figure
793
+
794
+ # Helper function to convert mpl figure to an image format mo.image can hopefully handle
795
+ def fig_to_image(fig):
796
+ buf = io.BytesIO()
797
+ fig.savefig(buf, format='png')
798
+ buf.seek(0)
799
+ data = f"data:image/png;base64,{base64.b64encode(buf.read()).decode('utf-8')}"
800
+ return data
801
+ return Figure, base64, fig_to_image, io
802
+
803
+
804
+ if __name__ == "__main__":
805
+ app.run()
python/006_dictionaries.py CHANGED
@@ -196,13 +196,13 @@ def _():
196
 
197
  @app.cell
198
  def _(mo, nested_data):
199
- mo.md(f"Alice's age: {nested_data["users"]["alice"]["age"]}")
200
  return
201
 
202
 
203
  @app.cell
204
  def _(mo, nested_data):
205
- mo.md(f"Bob's interests: {nested_data["users"]["bob"]["interests"]}")
206
  return
207
 
208
 
 
196
 
197
  @app.cell
198
  def _(mo, nested_data):
199
+ mo.md(f"Alice's age: {nested_data['users']['alice']['age']}")
200
  return
201
 
202
 
203
  @app.cell
204
  def _(mo, nested_data):
205
+ mo.md(f"Bob's interests: {nested_data['users']['bob']['interests']}")
206
  return
207
 
208
 
scripts/build.py ADDED
@@ -0,0 +1,1523 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import subprocess
5
+ import argparse
6
+ import json
7
+ from typing import List, Dict, Any
8
+ from pathlib import Path
9
+ import re
10
+
11
+
12
+ def export_html_wasm(notebook_path: str, output_dir: str, as_app: bool = False) -> bool:
13
+ """Export a single marimo notebook to HTML format.
14
+
15
+ Returns:
16
+ bool: True if export succeeded, False otherwise
17
+ """
18
+ output_path = notebook_path.replace(".py", ".html")
19
+
20
+ cmd = ["marimo", "export", "html-wasm"]
21
+ if as_app:
22
+ print(f"Exporting {notebook_path} to {output_path} as app")
23
+ cmd.extend(["--mode", "run", "--no-show-code"])
24
+ else:
25
+ print(f"Exporting {notebook_path} to {output_path} as notebook")
26
+ cmd.extend(["--mode", "edit"])
27
+
28
+ try:
29
+ output_file = os.path.join(output_dir, output_path)
30
+ os.makedirs(os.path.dirname(output_file), exist_ok=True)
31
+
32
+ cmd.extend([notebook_path, "-o", output_file])
33
+ print(f"Running command: {' '.join(cmd)}")
34
+
35
+ # Use Popen to handle interactive prompts
36
+ process = subprocess.Popen(
37
+ cmd,
38
+ stdin=subprocess.PIPE,
39
+ stdout=subprocess.PIPE,
40
+ stderr=subprocess.PIPE,
41
+ text=True
42
+ )
43
+
44
+ # Send 'Y' to the prompt
45
+ stdout, stderr = process.communicate(input="Y\n", timeout=60)
46
+
47
+ if process.returncode != 0:
48
+ print(f"Error exporting {notebook_path}:")
49
+ print(f"Command: {' '.join(cmd)}")
50
+ print(f"Return code: {process.returncode}")
51
+ print(f"Stdout: {stdout}")
52
+ print(f"Stderr: {stderr}")
53
+ return False
54
+
55
+ print(f"Successfully exported {notebook_path} to {output_file}")
56
+ return True
57
+ except subprocess.TimeoutExpired:
58
+ print(f"Timeout exporting {notebook_path} - command took too long to execute")
59
+ return False
60
+ except subprocess.CalledProcessError as e:
61
+ print(f"Error exporting {notebook_path}:")
62
+ print(e.stderr)
63
+ return False
64
+ except Exception as e:
65
+ print(f"Unexpected error exporting {notebook_path}: {e}")
66
+ return False
67
+
68
+
69
+ def get_course_metadata(course_dir: Path) -> Dict[str, Any]:
70
+ """Extract metadata from a course directory."""
71
+ metadata = {
72
+ "id": course_dir.name,
73
+ "title": course_dir.name.replace("_", " ").title(),
74
+ "description": "",
75
+ "notebooks": []
76
+ }
77
+
78
+ # Try to read README.md for description
79
+ readme_path = course_dir / "README.md"
80
+ if readme_path.exists():
81
+ with open(readme_path, "r", encoding="utf-8") as f:
82
+ content = f.read()
83
+ # Extract first paragraph as description
84
+ if content:
85
+ lines = content.split("\n")
86
+ # Skip title line if it exists
87
+ start_idx = 1 if lines and lines[0].startswith("#") else 0
88
+ description_lines = []
89
+ for line in lines[start_idx:]:
90
+ if line.strip() and not line.startswith("#"):
91
+ description_lines.append(line)
92
+ elif description_lines: # Stop at the next heading
93
+ break
94
+ description = " ".join(description_lines).strip()
95
+ # Clean up the description
96
+ description = description.replace("_", "")
97
+ description = description.replace("[work in progress]", "")
98
+ description = description.replace("(https://github.com/marimo-team/learn/issues/51)", "")
99
+ # Remove any other GitHub issue links
100
+ description = re.sub(r'\[.*?\]\(https://github\.com/.*?/issues/\d+\)', '', description)
101
+ description = re.sub(r'https://github\.com/.*?/issues/\d+', '', description)
102
+ # Clean up any double spaces
103
+ description = re.sub(r'\s+', ' ', description).strip()
104
+ metadata["description"] = description
105
+
106
+ return metadata
107
+
108
+
109
+ def organize_notebooks_by_course(all_notebooks: List[str]) -> Dict[str, Dict[str, Any]]:
110
+ """Organize notebooks by course."""
111
+ courses = {}
112
+
113
+ for notebook_path in all_notebooks:
114
+ path = Path(notebook_path)
115
+ course_id = path.parts[0]
116
+
117
+ if course_id not in courses:
118
+ course_dir = Path(course_id)
119
+ courses[course_id] = get_course_metadata(course_dir)
120
+
121
+ # Extract notebook info
122
+ filename = path.name
123
+ notebook_id = path.stem
124
+
125
+ # Try to extract order from filename (e.g., 001_numbers.py -> 1)
126
+ order = 999
127
+ if "_" in notebook_id:
128
+ try:
129
+ order_str = notebook_id.split("_")[0]
130
+ order = int(order_str)
131
+ except ValueError:
132
+ pass
133
+
134
+ # Create display name by removing order prefix and underscores
135
+ display_name = notebook_id
136
+ if "_" in notebook_id:
137
+ display_name = "_".join(notebook_id.split("_")[1:])
138
+
139
+ # Convert display name to title case, but handle italics properly
140
+ parts = display_name.split("_")
141
+ formatted_parts = []
142
+
143
+ i = 0
144
+ while i < len(parts):
145
+ if i + 1 < len(parts) and parts[i] == "" and parts[i+1] == "":
146
+ # Skip empty parts that might come from consecutive underscores
147
+ i += 2
148
+ continue
149
+
150
+ if i + 1 < len(parts) and (parts[i] == "" or parts[i+1] == ""):
151
+ # This is an italics marker
152
+ if parts[i] == "":
153
+ # Opening italics
154
+ text_part = parts[i+1].replace("_", " ").title()
155
+ formatted_parts.append(f"<em>{text_part}</em>")
156
+ i += 2
157
+ else:
158
+ # Text followed by italics marker
159
+ text_part = parts[i].replace("_", " ").title()
160
+ formatted_parts.append(text_part)
161
+ i += 1
162
+ else:
163
+ # Regular text
164
+ text_part = parts[i].replace("_", " ").title()
165
+ formatted_parts.append(text_part)
166
+ i += 1
167
+
168
+ display_name = " ".join(formatted_parts)
169
+
170
+ courses[course_id]["notebooks"].append({
171
+ "id": notebook_id,
172
+ "path": notebook_path,
173
+ "display_name": display_name,
174
+ "order": order,
175
+ "original_number": notebook_id.split("_")[0] if "_" in notebook_id else ""
176
+ })
177
+
178
+ # Sort notebooks by order
179
+ for course_id in courses:
180
+ courses[course_id]["notebooks"].sort(key=lambda x: x["order"])
181
+
182
+ return courses
183
+
184
+
185
+ def generate_eva_css() -> str:
186
+ """Generate Neon Genesis Evangelion inspired CSS with light/dark mode support."""
187
+ return """
188
+ :root {
189
+ /* Light mode colors (default) */
190
+ --eva-purple: #7209b7;
191
+ --eva-green: #1c7361;
192
+ --eva-orange: #e65100;
193
+ --eva-blue: #0039cb;
194
+ --eva-red: #c62828;
195
+ --eva-black: #f5f5f5;
196
+ --eva-dark: #e0e0e0;
197
+ --eva-terminal-bg: rgba(255, 255, 255, 0.9);
198
+ --eva-text: #333333;
199
+ --eva-border-radius: 4px;
200
+ --eva-transition: all 0.3s ease;
201
+ }
202
+
203
+ /* Dark mode colors */
204
+ [data-theme="dark"] {
205
+ --eva-purple: #9a1eb3;
206
+ --eva-green: #1c7361;
207
+ --eva-orange: #ff6600;
208
+ --eva-blue: #0066ff;
209
+ --eva-red: #ff0000;
210
+ --eva-black: #111111;
211
+ --eva-dark: #222222;
212
+ --eva-terminal-bg: rgba(0, 0, 0, 0.85);
213
+ --eva-text: #e0e0e0;
214
+ }
215
+
216
+ body {
217
+ background-color: var(--eva-black);
218
+ color: var(--eva-text);
219
+ font-family: 'Courier New', monospace;
220
+ margin: 0;
221
+ padding: 0;
222
+ line-height: 1.6;
223
+ transition: background-color 0.3s ease, color 0.3s ease;
224
+ }
225
+
226
+ .eva-container {
227
+ max-width: 1200px;
228
+ margin: 0 auto;
229
+ padding: 2rem;
230
+ }
231
+
232
+ .eva-header {
233
+ border-bottom: 2px solid var(--eva-green);
234
+ padding-bottom: 1rem;
235
+ margin-bottom: 2rem;
236
+ display: flex;
237
+ justify-content: space-between;
238
+ align-items: center;
239
+ position: sticky;
240
+ top: 0;
241
+ background-color: var(--eva-black);
242
+ z-index: 100;
243
+ backdrop-filter: blur(5px);
244
+ padding-top: 1rem;
245
+ transition: background-color 0.3s ease;
246
+ }
247
+
248
+ [data-theme="light"] .eva-header {
249
+ background-color: rgba(245, 245, 245, 0.95);
250
+ }
251
+
252
+ .eva-logo {
253
+ font-size: 2.5rem;
254
+ font-weight: bold;
255
+ color: var(--eva-green);
256
+ text-transform: uppercase;
257
+ letter-spacing: 2px;
258
+ text-shadow: 0 0 10px rgba(28, 115, 97, 0.5);
259
+ }
260
+
261
+ [data-theme="light"] .eva-logo {
262
+ text-shadow: 0 0 10px rgba(28, 115, 97, 0.3);
263
+ }
264
+
265
+ .eva-nav {
266
+ display: flex;
267
+ gap: 1.5rem;
268
+ align-items: center;
269
+ }
270
+
271
+ .eva-nav a {
272
+ color: var(--eva-text);
273
+ text-decoration: none;
274
+ text-transform: uppercase;
275
+ font-size: 0.9rem;
276
+ letter-spacing: 1px;
277
+ transition: color 0.3s;
278
+ position: relative;
279
+ padding: 0.5rem 0;
280
+ }
281
+
282
+ .eva-nav a:hover {
283
+ color: var(--eva-green);
284
+ }
285
+
286
+ .eva-nav a:hover::after {
287
+ content: '';
288
+ position: absolute;
289
+ bottom: -5px;
290
+ left: 0;
291
+ width: 100%;
292
+ height: 2px;
293
+ background-color: var(--eva-green);
294
+ animation: scanline 1.5s linear infinite;
295
+ }
296
+
297
+ .theme-toggle {
298
+ background: none;
299
+ border: none;
300
+ color: var(--eva-text);
301
+ cursor: pointer;
302
+ font-size: 1.2rem;
303
+ padding: 0.5rem;
304
+ margin-left: 1rem;
305
+ transition: color 0.3s;
306
+ }
307
+
308
+ .theme-toggle:hover {
309
+ color: var(--eva-green);
310
+ }
311
+
312
+ .eva-hero {
313
+ background-color: var(--eva-terminal-bg);
314
+ border: 1px solid var(--eva-green);
315
+ padding: 3rem 2rem;
316
+ margin-bottom: 3rem;
317
+ position: relative;
318
+ overflow: hidden;
319
+ border-radius: var(--eva-border-radius);
320
+ display: flex;
321
+ flex-direction: column;
322
+ align-items: flex-start;
323
+ background-image: linear-gradient(45deg, rgba(0, 0, 0, 0.9), rgba(0, 0, 0, 0.7)), url('https://raw.githubusercontent.com/marimo-team/marimo/main/docs/_static/marimo-logotype-thick.svg');
324
+ background-size: cover;
325
+ background-position: center;
326
+ background-blend-mode: overlay;
327
+ transition: background-color 0.3s ease, border-color 0.3s ease;
328
+ }
329
+
330
+ [data-theme="light"] .eva-hero {
331
+ background-image: linear-gradient(45deg, rgba(255, 255, 255, 0.9), rgba(255, 255, 255, 0.7)), url('https://raw.githubusercontent.com/marimo-team/marimo/main/docs/_static/marimo-logotype-thick.svg');
332
+ }
333
+
334
+ .eva-hero::before {
335
+ content: '';
336
+ position: absolute;
337
+ top: 0;
338
+ left: 0;
339
+ width: 100%;
340
+ height: 2px;
341
+ background-color: var(--eva-green);
342
+ animation: scanline 3s linear infinite;
343
+ }
344
+
345
+ .eva-hero h1 {
346
+ font-size: 2.5rem;
347
+ margin-bottom: 1rem;
348
+ color: var(--eva-green);
349
+ text-transform: uppercase;
350
+ letter-spacing: 2px;
351
+ text-shadow: 0 0 10px rgba(28, 115, 97, 0.5);
352
+ }
353
+
354
+ [data-theme="light"] .eva-hero h1 {
355
+ text-shadow: 0 0 10px rgba(28, 115, 97, 0.3);
356
+ }
357
+
358
+ .eva-hero p {
359
+ font-size: 1.1rem;
360
+ max-width: 800px;
361
+ margin-bottom: 2rem;
362
+ line-height: 1.8;
363
+ }
364
+
365
+ .eva-features {
366
+ display: grid;
367
+ grid-template-columns: repeat(auto-fit, minmax(300px, 1fr));
368
+ gap: 2rem;
369
+ margin-bottom: 3rem;
370
+ }
371
+
372
+ .eva-feature {
373
+ background-color: var(--eva-terminal-bg);
374
+ border: 1px solid var(--eva-blue);
375
+ padding: 1.5rem;
376
+ border-radius: var(--eva-border-radius);
377
+ transition: var(--eva-transition);
378
+ position: relative;
379
+ overflow: hidden;
380
+ }
381
+
382
+ .eva-feature:hover {
383
+ transform: translateY(-5px);
384
+ box-shadow: 0 10px 20px rgba(0, 102, 255, 0.2);
385
+ }
386
+
387
+ .eva-feature-icon {
388
+ font-size: 2rem;
389
+ margin-bottom: 1rem;
390
+ color: var(--eva-blue);
391
+ }
392
+
393
+ .eva-feature h3 {
394
+ font-size: 1.3rem;
395
+ margin-bottom: 1rem;
396
+ color: var(--eva-blue);
397
+ }
398
+
399
+ .eva-section-title {
400
+ font-size: 2rem;
401
+ color: var(--eva-green);
402
+ margin-bottom: 2rem;
403
+ text-transform: uppercase;
404
+ letter-spacing: 2px;
405
+ text-align: center;
406
+ position: relative;
407
+ padding-bottom: 1rem;
408
+ }
409
+
410
+ .eva-section-title::after {
411
+ content: '';
412
+ position: absolute;
413
+ bottom: 0;
414
+ left: 50%;
415
+ transform: translateX(-50%);
416
+ width: 100px;
417
+ height: 2px;
418
+ background-color: var(--eva-green);
419
+ }
420
+
421
+ /* Flashcard view for courses */
422
+ .eva-courses {
423
+ display: grid;
424
+ grid-template-columns: repeat(auto-fill, minmax(350px, 1fr));
425
+ gap: 2rem;
426
+ }
427
+
428
+ .eva-course {
429
+ background-color: var(--eva-terminal-bg);
430
+ border: 1px solid var(--eva-purple);
431
+ border-radius: var(--eva-border-radius);
432
+ transition: var(--eva-transition), height 0.4s cubic-bezier(0.19, 1, 0.22, 1);
433
+ position: relative;
434
+ overflow: hidden;
435
+ height: 350px;
436
+ display: flex;
437
+ flex-direction: column;
438
+ }
439
+
440
+ .eva-course:hover {
441
+ transform: translateY(-5px);
442
+ box-shadow: 0 10px 20px rgba(154, 30, 179, 0.3);
443
+ }
444
+
445
+ .eva-course::after {
446
+ content: '';
447
+ position: absolute;
448
+ bottom: 0;
449
+ left: 0;
450
+ width: 100%;
451
+ height: 2px;
452
+ background-color: var(--eva-purple);
453
+ animation: scanline 2s linear infinite;
454
+ }
455
+
456
+ .eva-course-badge {
457
+ position: absolute;
458
+ top: 15px;
459
+ right: -40px;
460
+ background: linear-gradient(135deg, var(--eva-orange) 0%, #ff9500 100%);
461
+ color: var(--eva-black);
462
+ font-size: 0.65rem;
463
+ padding: 0.3rem 2.5rem;
464
+ text-transform: uppercase;
465
+ font-weight: bold;
466
+ z-index: 3;
467
+ letter-spacing: 1px;
468
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.3);
469
+ transform: rotate(45deg);
470
+ text-shadow: 0 1px 1px rgba(255, 255, 255, 0.2);
471
+ border-top: 1px solid rgba(255, 255, 255, 0.3);
472
+ border-bottom: 1px solid rgba(0, 0, 0, 0.2);
473
+ white-space: nowrap;
474
+ overflow: hidden;
475
+ }
476
+
477
+ .eva-course-badge i {
478
+ margin-right: 4px;
479
+ font-size: 0.7rem;
480
+ }
481
+
482
+ [data-theme="light"] .eva-course-badge {
483
+ box-shadow: 0 2px 5px rgba(0, 0, 0, 0.2);
484
+ text-shadow: 0 1px 1px rgba(255, 255, 255, 0.4);
485
+ }
486
+
487
+ .eva-course-badge::before {
488
+ content: '';
489
+ position: absolute;
490
+ left: 0;
491
+ top: 0;
492
+ width: 100%;
493
+ height: 100%;
494
+ background: linear-gradient(to right, transparent, rgba(255, 255, 255, 0.3), transparent);
495
+ animation: scanline 2s linear infinite;
496
+ }
497
+
498
+ .eva-course-header {
499
+ padding: 1rem 1.5rem;
500
+ cursor: pointer;
501
+ display: flex;
502
+ justify-content: space-between;
503
+ align-items: center;
504
+ border-bottom: 1px solid rgba(154, 30, 179, 0.3);
505
+ z-index: 2;
506
+ background-color: var(--eva-terminal-bg);
507
+ position: absolute;
508
+ top: 0;
509
+ left: 0;
510
+ width: 100%;
511
+ height: 3.5rem;
512
+ box-sizing: border-box;
513
+ }
514
+
515
+ .eva-course-title {
516
+ font-size: 1.3rem;
517
+ color: var(--eva-purple);
518
+ text-transform: uppercase;
519
+ letter-spacing: 1px;
520
+ margin: 0;
521
+ }
522
+
523
+ .eva-course-toggle {
524
+ color: var(--eva-purple);
525
+ font-size: 1.5rem;
526
+ transition: transform 0.4s cubic-bezier(0.19, 1, 0.22, 1);
527
+ }
528
+
529
+ .eva-course.active .eva-course-toggle {
530
+ transform: rotate(180deg);
531
+ }
532
+
533
+ .eva-course-front {
534
+ display: flex;
535
+ flex-direction: column;
536
+ justify-content: space-between;
537
+ padding: 1.5rem;
538
+ margin-top: 3.5rem;
539
+ transition: opacity 0.3s ease, transform 0.3s ease;
540
+ position: absolute;
541
+ top: 0;
542
+ left: 0;
543
+ width: 100%;
544
+ height: calc(100% - 3.5rem);
545
+ background-color: var(--eva-terminal-bg);
546
+ z-index: 1;
547
+ box-sizing: border-box;
548
+ }
549
+
550
+ .eva-course.active .eva-course-front {
551
+ opacity: 0;
552
+ transform: translateY(-10px);
553
+ pointer-events: none;
554
+ }
555
+
556
+ .eva-course-description {
557
+ margin-top: 0.5rem;
558
+ margin-bottom: 1.5rem;
559
+ font-size: 0.9rem;
560
+ line-height: 1.6;
561
+ flex-grow: 1;
562
+ overflow: hidden;
563
+ display: -webkit-box;
564
+ -webkit-line-clamp: 4;
565
+ -webkit-box-orient: vertical;
566
+ max-height: 150px;
567
+ }
568
+
569
+ .eva-course-stats {
570
+ display: flex;
571
+ justify-content: space-between;
572
+ font-size: 0.8rem;
573
+ color: var(--eva-text);
574
+ opacity: 0.7;
575
+ }
576
+
577
+ .eva-course-content {
578
+ position: absolute;
579
+ top: 3.5rem;
580
+ left: 0;
581
+ width: 100%;
582
+ height: calc(100% - 3.5rem);
583
+ padding: 1.5rem;
584
+ background-color: var(--eva-terminal-bg);
585
+ transition: opacity 0.3s ease, transform 0.3s ease;
586
+ opacity: 0;
587
+ transform: translateY(10px);
588
+ pointer-events: none;
589
+ overflow-y: auto;
590
+ z-index: 1;
591
+ box-sizing: border-box;
592
+ }
593
+
594
+ .eva-course.active .eva-course-content {
595
+ opacity: 1;
596
+ transform: translateY(0);
597
+ pointer-events: auto;
598
+ }
599
+
600
+ .eva-course.active {
601
+ height: auto;
602
+ min-height: 350px;
603
+ max-height: 800px;
604
+ transition: height 0.4s cubic-bezier(0.19, 1, 0.22, 1), transform 0.3s ease, box-shadow 0.3s ease;
605
+ }
606
+
607
+ .eva-notebooks {
608
+ margin-top: 1rem;
609
+ display: grid;
610
+ grid-template-columns: repeat(auto-fill, minmax(250px, 1fr));
611
+ gap: 0.75rem;
612
+ }
613
+
614
+ .eva-notebook {
615
+ margin-bottom: 0.5rem;
616
+ padding: 0.75rem;
617
+ border-left: 2px solid var(--eva-blue);
618
+ transition: all 0.25s ease;
619
+ display: flex;
620
+ align-items: center;
621
+ background-color: rgba(0, 0, 0, 0.2);
622
+ border-radius: 0 var(--eva-border-radius) var(--eva-border-radius) 0;
623
+ opacity: 1;
624
+ transform: translateX(0);
625
+ }
626
+
627
+ [data-theme="light"] .eva-notebook {
628
+ background-color: rgba(0, 0, 0, 0.05);
629
+ }
630
+
631
+ .eva-notebook:hover {
632
+ background-color: rgba(0, 102, 255, 0.1);
633
+ padding-left: 1rem;
634
+ transform: translateX(3px);
635
+ }
636
+
637
+ .eva-notebook a {
638
+ color: var(--eva-text);
639
+ text-decoration: none;
640
+ display: block;
641
+ font-size: 0.9rem;
642
+ flex-grow: 1;
643
+ }
644
+
645
+ .eva-notebook a:hover {
646
+ color: var(--eva-blue);
647
+ }
648
+
649
+ .eva-notebook-number {
650
+ color: var(--eva-blue);
651
+ font-size: 0.8rem;
652
+ margin-right: 0.75rem;
653
+ opacity: 0.7;
654
+ min-width: 24px;
655
+ font-weight: bold;
656
+ }
657
+
658
+ .eva-button {
659
+ display: inline-block;
660
+ background-color: transparent;
661
+ color: var(--eva-green);
662
+ border: 1px solid var(--eva-green);
663
+ padding: 0.7rem 1.5rem;
664
+ text-decoration: none;
665
+ text-transform: uppercase;
666
+ font-size: 0.9rem;
667
+ letter-spacing: 1px;
668
+ transition: var(--eva-transition);
669
+ cursor: pointer;
670
+ border-radius: var(--eva-border-radius);
671
+ position: relative;
672
+ overflow: hidden;
673
+ }
674
+
675
+ .eva-button:hover {
676
+ background-color: var(--eva-green);
677
+ color: var(--eva-black);
678
+ }
679
+
680
+ .eva-button::after {
681
+ content: '';
682
+ position: absolute;
683
+ top: 0;
684
+ left: -100%;
685
+ width: 100%;
686
+ height: 100%;
687
+ background: linear-gradient(90deg, transparent, rgba(255, 255, 255, 0.2), transparent);
688
+ transition: 0.5s;
689
+ }
690
+
691
+ .eva-button:hover::after {
692
+ left: 100%;
693
+ }
694
+
695
+ .eva-course-button {
696
+ margin-top: 1rem;
697
+ margin-bottom: 1rem;
698
+ align-self: center;
699
+ }
700
+
701
+ .eva-cta {
702
+ background-color: var(--eva-terminal-bg);
703
+ border: 1px solid var(--eva-orange);
704
+ padding: 3rem 2rem;
705
+ margin: 4rem 0;
706
+ text-align: center;
707
+ border-radius: var(--eva-border-radius);
708
+ position: relative;
709
+ overflow: hidden;
710
+ }
711
+
712
+ .eva-cta h2 {
713
+ font-size: 2rem;
714
+ color: var(--eva-orange);
715
+ margin-bottom: 1.5rem;
716
+ text-transform: uppercase;
717
+ }
718
+
719
+ .eva-cta p {
720
+ max-width: 600px;
721
+ margin: 0 auto 2rem;
722
+ font-size: 1.1rem;
723
+ }
724
+
725
+ .eva-cta .eva-button {
726
+ color: var(--eva-orange);
727
+ border-color: var(--eva-orange);
728
+ }
729
+
730
+ .eva-cta .eva-button:hover {
731
+ background-color: var(--eva-orange);
732
+ color: var(--eva-black);
733
+ }
734
+
735
+ .eva-footer {
736
+ margin-top: 4rem;
737
+ padding-top: 2rem;
738
+ border-top: 2px solid var(--eva-green);
739
+ display: flex;
740
+ flex-direction: column;
741
+ align-items: center;
742
+ gap: 2rem;
743
+ }
744
+
745
+ .eva-footer-logo {
746
+ max-width: 200px;
747
+ margin-bottom: 1rem;
748
+ }
749
+
750
+ .eva-footer-links {
751
+ display: flex;
752
+ gap: 1.5rem;
753
+ margin-bottom: 1.5rem;
754
+ }
755
+
756
+ .eva-footer-links a {
757
+ color: var(--eva-text);
758
+ text-decoration: none;
759
+ transition: var(--eva-transition);
760
+ }
761
+
762
+ .eva-footer-links a:hover {
763
+ color: var(--eva-green);
764
+ }
765
+
766
+ .eva-social-links {
767
+ display: flex;
768
+ gap: 1.5rem;
769
+ margin-bottom: 1.5rem;
770
+ }
771
+
772
+ .eva-social-links a {
773
+ color: var(--eva-text);
774
+ font-size: 1.5rem;
775
+ transition: var(--eva-transition);
776
+ }
777
+
778
+ .eva-social-links a:hover {
779
+ color: var(--eva-green);
780
+ transform: translateY(-3px);
781
+ }
782
+
783
+ .eva-footer-copyright {
784
+ font-size: 0.9rem;
785
+ text-align: center;
786
+ }
787
+
788
+ .eva-search {
789
+ position: relative;
790
+ margin-bottom: 3rem;
791
+ }
792
+
793
+ .eva-search input {
794
+ width: 100%;
795
+ padding: 1rem;
796
+ background-color: var(--eva-terminal-bg);
797
+ border: 1px solid var(--eva-green);
798
+ color: var(--eva-text);
799
+ font-family: 'Courier New', monospace;
800
+ font-size: 1rem;
801
+ border-radius: var(--eva-border-radius);
802
+ outline: none;
803
+ transition: var(--eva-transition);
804
+ }
805
+
806
+ .eva-search input:focus {
807
+ box-shadow: 0 0 10px rgba(28, 115, 97, 0.3);
808
+ }
809
+
810
+ .eva-search input::placeholder {
811
+ color: rgba(224, 224, 224, 0.5);
812
+ }
813
+
814
+ [data-theme="light"] .eva-search input::placeholder {
815
+ color: rgba(51, 51, 51, 0.5);
816
+ }
817
+
818
+ .eva-search-icon {
819
+ position: absolute;
820
+ right: 1rem;
821
+ top: 50%;
822
+ transform: translateY(-50%);
823
+ color: var(--eva-green);
824
+ font-size: 1.2rem;
825
+ }
826
+
827
+ @keyframes scanline {
828
+ 0% {
829
+ transform: translateX(-100%);
830
+ }
831
+ 100% {
832
+ transform: translateX(100%);
833
+ }
834
+ }
835
+
836
+ @keyframes blink {
837
+ 0%, 100% {
838
+ opacity: 1;
839
+ }
840
+ 50% {
841
+ opacity: 0;
842
+ }
843
+ }
844
+
845
+ .eva-cursor {
846
+ display: inline-block;
847
+ width: 10px;
848
+ height: 1.2em;
849
+ background-color: var(--eva-green);
850
+ margin-left: 2px;
851
+ animation: blink 1s infinite;
852
+ vertical-align: middle;
853
+ }
854
+
855
+ @media (max-width: 768px) {
856
+ .eva-courses {
857
+ grid-template-columns: 1fr;
858
+ }
859
+
860
+ .eva-header {
861
+ flex-direction: column;
862
+ align-items: flex-start;
863
+ padding: 1rem;
864
+ }
865
+
866
+ .eva-nav {
867
+ margin-top: 1rem;
868
+ flex-wrap: wrap;
869
+ }
870
+
871
+ .eva-hero {
872
+ padding: 2rem 1rem;
873
+ }
874
+
875
+ .eva-hero h1 {
876
+ font-size: 2rem;
877
+ }
878
+
879
+ .eva-features {
880
+ grid-template-columns: 1fr;
881
+ }
882
+
883
+ .eva-footer {
884
+ flex-direction: column;
885
+ align-items: center;
886
+ text-align: center;
887
+ }
888
+
889
+ .eva-notebooks {
890
+ grid-template-columns: 1fr;
891
+ }
892
+ }
893
+
894
+ .eva-course.closing .eva-course-content {
895
+ opacity: 0;
896
+ transform: translateY(10px);
897
+ transition: opacity 0.2s ease, transform 0.2s ease;
898
+ }
899
+
900
+ .eva-course.closing .eva-course-front {
901
+ opacity: 1;
902
+ transform: translateY(0);
903
+ transition: opacity 0.3s ease 0.1s, transform 0.3s ease 0.1s;
904
+ }
905
+ """
906
+
907
+
908
+ def get_html_header():
909
+ """Generate the HTML header with CSS and meta tags."""
910
+ return """<!DOCTYPE html>
911
+ <html lang="en" data-theme="light">
912
+ <head>
913
+ <meta charset="UTF-8">
914
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
915
+ <title>Marimo Learn - Interactive Educational Notebooks</title>
916
+ <link rel="stylesheet" href="https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css">
917
+ <style>
918
+ {css}
919
+ </style>
920
+ </head>
921
+ <body>
922
+ <div class="eva-container">
923
+ <header class="eva-header">
924
+ <div class="eva-logo">MARIMO LEARN</div>
925
+ <nav class="eva-nav">
926
+ <a href="#features">Features</a>
927
+ <a href="#courses">Courses</a>
928
+ <a href="#contribute">Contribute</a>
929
+ <a href="https://docs.marimo.io" target="_blank">Documentation</a>
930
+ <a href="https://github.com/marimo-team/learn" target="_blank">GitHub</a>
931
+ <button id="themeToggle" class="theme-toggle" aria-label="Toggle dark/light mode">
932
+ <i class="fas fa-moon"></i>
933
+ </button>
934
+ </nav>
935
+ </header>"""
936
+
937
+
938
+ def get_html_hero_section():
939
+ """Generate the hero section of the page."""
940
+ return """
941
+ <section class="eva-hero">
942
+ <h1>Interactive Learning with Marimo<span class="eva-cursor"></span></h1>
943
+ <p>
944
+ A curated collection of educational notebooks covering computer science,
945
+ mathematics, data science, and more. Built with marimo - the reactive
946
+ Python notebook that makes data exploration delightful.
947
+ </p>
948
+ <a href="#courses" class="eva-button">Explore Courses</a>
949
+ </section>"""
950
+
951
+
952
+ def get_html_features_section():
953
+ """Generate the features section of the page."""
954
+ return """
955
+ <section id="features">
956
+ <h2 class="eva-section-title">Why Marimo Learn?</h2>
957
+ <div class="eva-features">
958
+ <div class="eva-feature">
959
+ <div class="eva-feature-icon"><i class="fas fa-bolt"></i></div>
960
+ <h3>Reactive Notebooks</h3>
961
+ <p>Experience the power of reactive programming with marimo notebooks that automatically update when dependencies change.</p>
962
+ </div>
963
+ <div class="eva-feature">
964
+ <div class="eva-feature-icon"><i class="fas fa-code"></i></div>
965
+ <h3>Learn by Doing</h3>
966
+ <p>Interactive examples and exercises help you understand concepts through hands-on practice.</p>
967
+ </div>
968
+ <div class="eva-feature">
969
+ <div class="eva-feature-icon"><i class="fas fa-graduation-cap"></i></div>
970
+ <h3>Comprehensive Courses</h3>
971
+ <p>From Python basics to advanced optimization techniques, our courses cover a wide range of topics.</p>
972
+ </div>
973
+ </div>
974
+ </section>"""
975
+
976
+
977
+ def get_html_courses_start():
978
+ """Generate the beginning of the courses section."""
979
+ return """
980
+ <section id="courses">
981
+ <h2 class="eva-section-title">Explore Courses</h2>
982
+ <div class="eva-search">
983
+ <input type="text" id="courseSearch" placeholder="Search courses and notebooks...">
984
+ <span class="eva-search-icon"><i class="fas fa-search"></i></span>
985
+ </div>
986
+ <div class="eva-courses">"""
987
+
988
+
989
+ def generate_course_card(course, notebook_count, is_wip):
990
+ """Generate HTML for a single course card."""
991
+ html = f'<div class="eva-course" data-course-id="{course["id"]}">\n'
992
+
993
+ # Add WIP badge if needed
994
+ if is_wip:
995
+ html += ' <div class="eva-course-badge"><i class="fas fa-code-branch"></i> In Progress</div>\n'
996
+
997
+ html += f''' <div class="eva-course-header">
998
+ <h2 class="eva-course-title">{course["title"]}</h2>
999
+ <span class="eva-course-toggle"><i class="fas fa-chevron-down"></i></span>
1000
+ </div>
1001
+ <div class="eva-course-front">
1002
+ <p class="eva-course-description">{course["description"]}</p>
1003
+ <div class="eva-course-stats">
1004
+ <span><i class="fas fa-book"></i> {notebook_count} notebook{"s" if notebook_count != 1 else ""}</span>
1005
+ </div>
1006
+ <button class="eva-button eva-course-button">View Notebooks</button>
1007
+ </div>
1008
+ <div class="eva-course-content">
1009
+ <div class="eva-notebooks">
1010
+ '''
1011
+
1012
+ # Add notebooks
1013
+ for i, notebook in enumerate(course["notebooks"]):
1014
+ notebook_number = notebook.get("original_number", f"{i+1:02d}")
1015
+ html += f''' <div class="eva-notebook">
1016
+ <span class="eva-notebook-number">{notebook_number}</span>
1017
+ <a href="{notebook["path"].replace(".py", ".html")}" data-notebook-title="{notebook["display_name"]}">{notebook["display_name"]}</a>
1018
+ </div>
1019
+ '''
1020
+
1021
+ html += ''' </div>
1022
+ </div>
1023
+ </div>
1024
+ '''
1025
+ return html
1026
+
1027
+
1028
+ def generate_course_cards(courses):
1029
+ """Generate HTML for all course cards."""
1030
+ html = ""
1031
+
1032
+ # Define the custom order for courses
1033
+ course_order = ["python", "probability", "polars", "optimization", "functional_programming"]
1034
+
1035
+ # Create a dictionary of courses by ID for easy lookup
1036
+ courses_by_id = {course["id"]: course for course in courses.values()}
1037
+
1038
+ # Determine which courses are "work in progress" based on description or notebook count
1039
+ work_in_progress = set()
1040
+ for course_id, course in courses_by_id.items():
1041
+ # Consider a course as "work in progress" if it has few notebooks or contains specific phrases
1042
+ if (len(course["notebooks"]) < 5 or
1043
+ "work in progress" in course["description"].lower() or
1044
+ "help us add" in course["description"].lower() or
1045
+ "check back later" in course["description"].lower()):
1046
+ work_in_progress.add(course_id)
1047
+
1048
+ # First output courses in the specified order
1049
+ for course_id in course_order:
1050
+ if course_id in courses_by_id:
1051
+ course = courses_by_id[course_id]
1052
+
1053
+ # Skip if no notebooks
1054
+ if not course["notebooks"]:
1055
+ continue
1056
+
1057
+ # Count notebooks
1058
+ notebook_count = len(course["notebooks"])
1059
+
1060
+ # Determine if this course is a work in progress
1061
+ is_wip = course_id in work_in_progress
1062
+
1063
+ html += generate_course_card(course, notebook_count, is_wip)
1064
+
1065
+ # Remove from the dictionary so we don't output it again
1066
+ del courses_by_id[course_id]
1067
+
1068
+ # Then output any remaining courses alphabetically
1069
+ sorted_remaining_courses = sorted(courses_by_id.values(), key=lambda x: x["title"])
1070
+
1071
+ for course in sorted_remaining_courses:
1072
+ # Skip if no notebooks
1073
+ if not course["notebooks"]:
1074
+ continue
1075
+
1076
+ # Count notebooks
1077
+ notebook_count = len(course["notebooks"])
1078
+
1079
+ # Determine if this course is a work in progress
1080
+ is_wip = course["id"] in work_in_progress
1081
+
1082
+ html += generate_course_card(course, notebook_count, is_wip)
1083
+
1084
+ return html
1085
+
1086
+
1087
+ def get_html_courses_end():
1088
+ """Generate the end of the courses section."""
1089
+ return """ </div>
1090
+ </section>"""
1091
+
1092
+
1093
+ def get_html_contribute_section():
1094
+ """Generate the contribute section."""
1095
+ return """
1096
+ <section id="contribute" class="eva-cta">
1097
+ <h2>Contribute to Marimo Learn</h2>
1098
+ <p>
1099
+ Help us expand our collection of educational notebooks. Whether you're an expert in machine learning,
1100
+ statistics, or any other field, your contributions are welcome!
1101
+ </p>
1102
+ <a href="https://github.com/marimo-team/learn" target="_blank" class="eva-button">
1103
+ <i class="fab fa-github"></i> Contribute on GitHub
1104
+ </a>
1105
+ </section>"""
1106
+
1107
+
1108
+ def get_html_footer():
1109
+ """Generate the page footer."""
1110
+ return """
1111
+ <footer class="eva-footer">
1112
+ <div class="eva-footer-logo">
1113
+ <a href="https://marimo.io" target="_blank">
1114
+ <img src="https://marimo.io/logotype-wide.svg" alt="Marimo" width="200">
1115
+ </a>
1116
+ </div>
1117
+ <div class="eva-social-links">
1118
+ <a href="https://github.com/marimo-team" target="_blank" aria-label="GitHub"><i class="fab fa-github"></i></a>
1119
+ <a href="https://marimo.io/discord?ref=learn" target="_blank" aria-label="Discord"><i class="fab fa-discord"></i></a>
1120
+ <a href="https://twitter.com/marimo_io" target="_blank" aria-label="Twitter"><i class="fab fa-twitter"></i></a>
1121
+ <a href="https://www.youtube.com/@marimo-team" target="_blank" aria-label="YouTube"><i class="fab fa-youtube"></i></a>
1122
+ <a href="https://www.linkedin.com/company/marimo-io" target="_blank" aria-label="LinkedIn"><i class="fab fa-linkedin"></i></a>
1123
+ </div>
1124
+ <div class="eva-footer-links">
1125
+ <a href="https://marimo.io" target="_blank">Website</a>
1126
+ <a href="https://docs.marimo.io" target="_blank">Documentation</a>
1127
+ <a href="https://github.com/marimo-team/learn" target="_blank">GitHub</a>
1128
+ </div>
1129
+ <div class="eva-footer-copyright">
1130
+ © 2025 Marimo Inc. All rights reserved.
1131
+ </div>
1132
+ </footer>"""
1133
+
1134
+
1135
+ def get_html_scripts():
1136
+ """Generate the JavaScript for the page."""
1137
+ return """
1138
+ <script>
1139
+ // Set light theme as default immediately
1140
+ document.documentElement.setAttribute('data-theme', 'light');
1141
+
1142
+ document.addEventListener('DOMContentLoaded', function() {
1143
+ // Theme toggle functionality
1144
+ const themeToggle = document.getElementById('themeToggle');
1145
+ const themeIcon = themeToggle.querySelector('i');
1146
+
1147
+ // Update theme icon based on current theme
1148
+ updateThemeIcon('light');
1149
+
1150
+ // Check localStorage for saved theme preference
1151
+ const savedTheme = localStorage.getItem('theme');
1152
+ if (savedTheme && savedTheme !== 'light') {
1153
+ document.documentElement.setAttribute('data-theme', savedTheme);
1154
+ updateThemeIcon(savedTheme);
1155
+ }
1156
+
1157
+ // Toggle theme when button is clicked
1158
+ themeToggle.addEventListener('click', () => {
1159
+ const currentTheme = document.documentElement.getAttribute('data-theme');
1160
+ const newTheme = currentTheme === 'dark' ? 'light' : 'dark';
1161
+
1162
+ document.documentElement.setAttribute('data-theme', newTheme);
1163
+ localStorage.setItem('theme', newTheme);
1164
+ updateThemeIcon(newTheme);
1165
+ });
1166
+
1167
+ function updateThemeIcon(theme) {
1168
+ if (theme === 'dark') {
1169
+ themeIcon.className = 'fas fa-sun';
1170
+ } else {
1171
+ themeIcon.className = 'fas fa-moon';
1172
+ }
1173
+ }
1174
+
1175
+ // Terminal typing effect for hero text
1176
+ const heroTitle = document.querySelector('.eva-hero h1');
1177
+ const heroText = document.querySelector('.eva-hero p');
1178
+ const cursor = document.querySelector('.eva-cursor');
1179
+
1180
+ const originalTitle = heroTitle.textContent;
1181
+ const originalText = heroText.textContent.trim();
1182
+
1183
+ heroTitle.textContent = '';
1184
+ heroText.textContent = '';
1185
+
1186
+ let titleIndex = 0;
1187
+ let textIndex = 0;
1188
+
1189
+ function typeTitle() {
1190
+ if (titleIndex < originalTitle.length) {
1191
+ heroTitle.textContent += originalTitle.charAt(titleIndex);
1192
+ titleIndex++;
1193
+ setTimeout(typeTitle, 50);
1194
+ } else {
1195
+ cursor.style.display = 'none';
1196
+ setTimeout(typeText, 500);
1197
+ }
1198
+ }
1199
+
1200
+ function typeText() {
1201
+ if (textIndex < originalText.length) {
1202
+ heroText.textContent += originalText.charAt(textIndex);
1203
+ textIndex++;
1204
+ setTimeout(typeText, 20);
1205
+ }
1206
+ }
1207
+
1208
+ typeTitle();
1209
+
1210
+ // Course toggle functionality - flashcard style
1211
+ const courseHeaders = document.querySelectorAll('.eva-course-header');
1212
+ const courseButtons = document.querySelectorAll('.eva-course-button');
1213
+
1214
+ // Function to toggle course
1215
+ function toggleCourse(course) {
1216
+ const isActive = course.classList.contains('active');
1217
+
1218
+ // First close all courses with a slight delay for better visual effect
1219
+ document.querySelectorAll('.eva-course.active').forEach(c => {
1220
+ if (c !== course) {
1221
+ // Add a closing class for animation
1222
+ c.classList.add('closing');
1223
+ // Remove active class after a short delay
1224
+ setTimeout(() => {
1225
+ c.classList.remove('active');
1226
+ c.classList.remove('closing');
1227
+ }, 300);
1228
+ }
1229
+ });
1230
+
1231
+ // Toggle the clicked course
1232
+ if (!isActive) {
1233
+ // Add a small delay before opening to allow others to close
1234
+ setTimeout(() => {
1235
+ course.classList.add('active');
1236
+
1237
+ // Check if the course has any notebooks
1238
+ const notebooks = course.querySelectorAll('.eva-notebook');
1239
+ const content = course.querySelector('.eva-course-content');
1240
+
1241
+ if (notebooks.length === 0 && !content.querySelector('.eva-empty-message')) {
1242
+ // If no notebooks, show a message
1243
+ const emptyMessage = document.createElement('p');
1244
+ emptyMessage.className = 'eva-empty-message';
1245
+ emptyMessage.textContent = 'No notebooks available in this course yet.';
1246
+ emptyMessage.style.color = 'var(--eva-text)';
1247
+ emptyMessage.style.fontStyle = 'italic';
1248
+ emptyMessage.style.opacity = '0.7';
1249
+ emptyMessage.style.textAlign = 'center';
1250
+ emptyMessage.style.padding = '1rem 0';
1251
+ content.appendChild(emptyMessage);
1252
+ }
1253
+
1254
+ // Animate notebooks to appear sequentially
1255
+ notebooks.forEach((notebook, index) => {
1256
+ notebook.style.opacity = '0';
1257
+ notebook.style.transform = 'translateX(-10px)';
1258
+ setTimeout(() => {
1259
+ notebook.style.opacity = '1';
1260
+ notebook.style.transform = 'translateX(0)';
1261
+ }, 50 + (index * 30)); // Stagger the animations
1262
+ });
1263
+ }, 100);
1264
+ }
1265
+ }
1266
+
1267
+ // Add click event to course headers
1268
+ courseHeaders.forEach(header => {
1269
+ header.addEventListener('click', function(e) {
1270
+ e.preventDefault();
1271
+ e.stopPropagation();
1272
+
1273
+ const currentCourse = this.closest('.eva-course');
1274
+ toggleCourse(currentCourse);
1275
+ });
1276
+ });
1277
+
1278
+ // Add click event to course buttons
1279
+ courseButtons.forEach(button => {
1280
+ button.addEventListener('click', function(e) {
1281
+ e.preventDefault();
1282
+ e.stopPropagation();
1283
+
1284
+ const currentCourse = this.closest('.eva-course');
1285
+ toggleCourse(currentCourse);
1286
+ });
1287
+ });
1288
+
1289
+ // Search functionality with improved matching
1290
+ const searchInput = document.getElementById('courseSearch');
1291
+ const courses = document.querySelectorAll('.eva-course');
1292
+ const notebooks = document.querySelectorAll('.eva-notebook');
1293
+
1294
+ searchInput.addEventListener('input', function() {
1295
+ const searchTerm = this.value.toLowerCase();
1296
+
1297
+ if (searchTerm === '') {
1298
+ // Reset all visibility
1299
+ courses.forEach(course => {
1300
+ course.style.display = 'block';
1301
+ course.classList.remove('active');
1302
+ });
1303
+
1304
+ notebooks.forEach(notebook => {
1305
+ notebook.style.display = 'flex';
1306
+ });
1307
+
1308
+ // Open the first course with notebooks by default when search is cleared
1309
+ for (let i = 0; i < courses.length; i++) {
1310
+ const courseNotebooks = courses[i].querySelectorAll('.eva-notebook');
1311
+ if (courseNotebooks.length > 0) {
1312
+ courses[i].classList.add('active');
1313
+ break;
1314
+ }
1315
+ }
1316
+
1317
+ return;
1318
+ }
1319
+
1320
+ // First hide all courses
1321
+ courses.forEach(course => {
1322
+ course.style.display = 'none';
1323
+ course.classList.remove('active');
1324
+ });
1325
+
1326
+ // Then show courses and notebooks that match the search
1327
+ let hasResults = false;
1328
+
1329
+ // Track which courses have matching notebooks
1330
+ const coursesWithMatchingNotebooks = new Set();
1331
+
1332
+ // First check notebooks
1333
+ notebooks.forEach(notebook => {
1334
+ const notebookTitle = notebook.querySelector('a').getAttribute('data-notebook-title').toLowerCase();
1335
+ const matchesSearch = notebookTitle.includes(searchTerm);
1336
+
1337
+ notebook.style.display = matchesSearch ? 'flex' : 'none';
1338
+
1339
+ if (matchesSearch) {
1340
+ const course = notebook.closest('.eva-course');
1341
+ coursesWithMatchingNotebooks.add(course.getAttribute('data-course-id'));
1342
+ hasResults = true;
1343
+ }
1344
+ });
1345
+
1346
+ // Then check course titles and descriptions
1347
+ courses.forEach(course => {
1348
+ const courseId = course.getAttribute('data-course-id');
1349
+ const courseTitle = course.querySelector('.eva-course-title').textContent.toLowerCase();
1350
+ const courseDescription = course.querySelector('.eva-course-description').textContent.toLowerCase();
1351
+
1352
+ const courseMatches = courseTitle.includes(searchTerm) || courseDescription.includes(searchTerm);
1353
+
1354
+ // Show course if it matches or has matching notebooks
1355
+ if (courseMatches || coursesWithMatchingNotebooks.has(courseId)) {
1356
+ course.style.display = 'block';
1357
+ course.classList.add('active');
1358
+ hasResults = true;
1359
+
1360
+ // If course matches but doesn't have matching notebooks, show all its notebooks
1361
+ if (courseMatches && !coursesWithMatchingNotebooks.has(courseId)) {
1362
+ course.querySelectorAll('.eva-notebook').forEach(nb => {
1363
+ nb.style.display = 'flex';
1364
+ });
1365
+ }
1366
+ }
1367
+ });
1368
+ });
1369
+
1370
+ // Open the first course with notebooks by default
1371
+ let firstCourseWithNotebooks = null;
1372
+ for (let i = 0; i < courses.length; i++) {
1373
+ const courseNotebooks = courses[i].querySelectorAll('.eva-notebook');
1374
+ if (courseNotebooks.length > 0) {
1375
+ firstCourseWithNotebooks = courses[i];
1376
+ break;
1377
+ }
1378
+ }
1379
+
1380
+ if (firstCourseWithNotebooks) {
1381
+ firstCourseWithNotebooks.classList.add('active');
1382
+ } else if (courses.length > 0) {
1383
+ // If no courses have notebooks, just open the first one
1384
+ courses[0].classList.add('active');
1385
+ }
1386
+
1387
+ // Smooth scrolling for anchor links
1388
+ document.querySelectorAll('a[href^="#"]').forEach(anchor => {
1389
+ anchor.addEventListener('click', function(e) {
1390
+ e.preventDefault();
1391
+
1392
+ const targetId = this.getAttribute('href');
1393
+ const targetElement = document.querySelector(targetId);
1394
+
1395
+ if (targetElement) {
1396
+ window.scrollTo({
1397
+ top: targetElement.offsetTop - 100,
1398
+ behavior: 'smooth'
1399
+ });
1400
+ }
1401
+ });
1402
+ });
1403
+ });
1404
+ </script>"""
1405
+
1406
+
1407
+ def get_html_footer_closing():
1408
+ """Generate closing HTML tags."""
1409
+ return """
1410
+ </div>
1411
+ </body>
1412
+ </html>"""
1413
+
1414
+
1415
+ def generate_index(courses: Dict[str, Dict[str, Any]], output_dir: str) -> None:
1416
+ """Generate the index.html file with Neon Genesis Evangelion aesthetics."""
1417
+ print("Generating index.html")
1418
+
1419
+ index_path = os.path.join(output_dir, "index.html")
1420
+ os.makedirs(output_dir, exist_ok=True)
1421
+
1422
+ try:
1423
+ with open(index_path, "w", encoding="utf-8") as f:
1424
+ # Build the page HTML from individual components
1425
+ header = get_html_header().format(css=generate_eva_css())
1426
+ hero = get_html_hero_section()
1427
+ features = get_html_features_section()
1428
+ courses_start = get_html_courses_start()
1429
+ course_cards = generate_course_cards(courses)
1430
+ courses_end = get_html_courses_end()
1431
+ contribute = get_html_contribute_section()
1432
+ footer = get_html_footer()
1433
+ scripts = get_html_scripts()
1434
+ closing = get_html_footer_closing()
1435
+
1436
+ # Write all elements to the file
1437
+ f.write(header)
1438
+ f.write(hero)
1439
+ f.write(features)
1440
+ f.write(courses_start)
1441
+ f.write(course_cards)
1442
+ f.write(courses_end)
1443
+ f.write(contribute)
1444
+ f.write(footer)
1445
+ f.write(scripts)
1446
+ f.write(closing)
1447
+
1448
+ except IOError as e:
1449
+ print(f"Error generating index.html: {e}")
1450
+
1451
+
1452
+ def main() -> None:
1453
+ parser = argparse.ArgumentParser(description="Build marimo notebooks")
1454
+ parser.add_argument(
1455
+ "--output-dir", default="_site", help="Output directory for built files"
1456
+ )
1457
+ parser.add_argument(
1458
+ "--course-dirs", nargs="+", default=None,
1459
+ help="Specific course directories to build (default: all directories with .py files)"
1460
+ )
1461
+ args = parser.parse_args()
1462
+
1463
+ # Find all course directories (directories containing .py files)
1464
+ all_notebooks: List[str] = []
1465
+
1466
+ # Directories to exclude from course detection
1467
+ excluded_dirs = ["scripts", "env", "__pycache__", ".git", ".github", "assets"]
1468
+
1469
+ if args.course_dirs:
1470
+ course_dirs = args.course_dirs
1471
+ else:
1472
+ # Automatically detect course directories (any directory with .py files)
1473
+ course_dirs = []
1474
+ for item in os.listdir("."):
1475
+ if (os.path.isdir(item) and
1476
+ not item.startswith(".") and
1477
+ not item.startswith("_") and
1478
+ item not in excluded_dirs):
1479
+ # Check if directory contains .py files
1480
+ if list(Path(item).glob("*.py")):
1481
+ course_dirs.append(item)
1482
+
1483
+ print(f"Found course directories: {', '.join(course_dirs)}")
1484
+
1485
+ for directory in course_dirs:
1486
+ dir_path = Path(directory)
1487
+ if not dir_path.exists():
1488
+ print(f"Warning: Directory not found: {dir_path}")
1489
+ continue
1490
+
1491
+ notebooks = [str(path) for path in dir_path.rglob("*.py")
1492
+ if not path.name.startswith("_") and "/__pycache__/" not in str(path)]
1493
+ all_notebooks.extend(notebooks)
1494
+
1495
+ if not all_notebooks:
1496
+ print("No notebooks found!")
1497
+ return
1498
+
1499
+ # Export notebooks sequentially
1500
+ successful_notebooks = []
1501
+ for nb in all_notebooks:
1502
+ # Determine if notebook should be exported as app or notebook
1503
+ # For now, export all as notebooks
1504
+ if export_html_wasm(nb, args.output_dir, as_app=False):
1505
+ successful_notebooks.append(nb)
1506
+
1507
+ # Organize notebooks by course (only include successfully exported notebooks)
1508
+ courses = organize_notebooks_by_course(successful_notebooks)
1509
+
1510
+ # Generate index with organized courses
1511
+ generate_index(courses, args.output_dir)
1512
+
1513
+ # Save course data as JSON for potential use by other tools
1514
+ courses_json_path = os.path.join(args.output_dir, "courses.json")
1515
+ with open(courses_json_path, "w", encoding="utf-8") as f:
1516
+ json.dump(courses, f, indent=2)
1517
+
1518
+ print(f"Build complete! Site generated in {args.output_dir}")
1519
+ print(f"Successfully exported {len(successful_notebooks)} out of {len(all_notebooks)} notebooks")
1520
+
1521
+
1522
+ if __name__ == "__main__":
1523
+ main()
scripts/preview.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import subprocess
5
+ import argparse
6
+ import webbrowser
7
+ import time
8
+ import sys
9
+ from pathlib import Path
10
+
11
+ def main():
12
+ parser = argparse.ArgumentParser(description="Build and preview marimo notebooks site")
13
+ parser.add_argument(
14
+ "--port", default=8000, type=int, help="Port to run the server on"
15
+ )
16
+ parser.add_argument(
17
+ "--no-build", action="store_true", help="Skip building the site (just serve existing files)"
18
+ )
19
+ parser.add_argument(
20
+ "--output-dir", default="_site", help="Output directory for built files"
21
+ )
22
+ args = parser.parse_args()
23
+
24
+ # Store the current directory
25
+ original_dir = os.getcwd()
26
+
27
+ try:
28
+ # Build the site if not skipped
29
+ if not args.no_build:
30
+ print("Building site...")
31
+ build_script = Path("scripts/build.py")
32
+ if not build_script.exists():
33
+ print(f"Error: Build script not found at {build_script}")
34
+ return 1
35
+
36
+ result = subprocess.run(
37
+ [sys.executable, str(build_script), "--output-dir", args.output_dir],
38
+ check=False
39
+ )
40
+ if result.returncode != 0:
41
+ print("Warning: Build process completed with errors.")
42
+
43
+ # Check if the output directory exists
44
+ output_dir = Path(args.output_dir)
45
+ if not output_dir.exists():
46
+ print(f"Error: Output directory '{args.output_dir}' does not exist.")
47
+ return 1
48
+
49
+ # Change to the output directory
50
+ os.chdir(args.output_dir)
51
+
52
+ # Open the browser
53
+ url = f"http://localhost:{args.port}"
54
+ print(f"Opening {url} in your browser...")
55
+ webbrowser.open(url)
56
+
57
+ # Start the server
58
+ print(f"Starting server on port {args.port}...")
59
+ print("Press Ctrl+C to stop the server")
60
+
61
+ # Use the appropriate Python executable
62
+ subprocess.run([sys.executable, "-m", "http.server", str(args.port)])
63
+
64
+ return 0
65
+ except KeyboardInterrupt:
66
+ print("\nServer stopped.")
67
+ return 0
68
+ except Exception as e:
69
+ print(f"Error: {e}")
70
+ return 1
71
+ finally:
72
+ # Always return to the original directory
73
+ os.chdir(original_dir)
74
+
75
+ if __name__ == "__main__":
76
+ sys.exit(main())