Spaces:
Runtime error
Runtime error
"""Migrate LangChain to the most recent version.""" | |
# Adapted from bump-pydantic | |
# https://github.com/pydantic/bump-pydantic | |
import difflib | |
import functools | |
import multiprocessing | |
import os | |
import time | |
import traceback | |
from pathlib import Path | |
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union | |
import libcst as cst | |
import typer | |
from libcst.codemod import CodemodContext, ContextAwareTransformer | |
from libcst.helpers import calculate_module_and_package | |
from libcst.metadata import FullRepoManager, FullyQualifiedNameProvider, ScopeProvider | |
from rich.console import Console | |
from rich.progress import Progress | |
from typer import Argument, Exit, Option, Typer | |
from typing_extensions import ParamSpec | |
from langchain_cli.namespaces.migrate.codemods import Rule, gather_codemods | |
from langchain_cli.namespaces.migrate.glob_helpers import match_glob | |
app = Typer(invoke_without_command=True, add_completion=False) | |
P = ParamSpec("P") | |
T = TypeVar("T") | |
DEFAULT_IGNORES = [".venv/**"] | |
def main( | |
path: Path = Argument(..., exists=True, dir_okay=True, allow_dash=False), | |
disable: List[Rule] = Option(default=[], help="Disable a rule."), | |
diff: bool = Option(False, help="Show diff instead of applying changes."), | |
ignore: List[str] = Option( | |
default=DEFAULT_IGNORES, help="Ignore a path glob pattern." | |
), | |
log_file: Path = Option("log.txt", help="Log errors to this file."), | |
include_ipynb: bool = Option( | |
False, help="Include Jupyter Notebook files in the migration." | |
), | |
): | |
"""Migrate langchain to the most recent version.""" | |
if not diff: | |
if not typer.confirm( | |
"βοΈ This script will help you migrate to a recent version LangChain. " | |
"This migration script will attempt to replace old imports in the code " | |
"with new ones.\n\n" | |
"π You will need to run the migration script TWICE to migrate (e.g., " | |
"to update llms import from langchain, the script will first move them to " | |
"corresponding imports from the community package, and on the second " | |
"run will migrate from the community package to the partner package " | |
"when possible). \n\n" | |
"π You can pre-view the changes by running with the --diff flag. \n\n" | |
"π« You can disable specific import changes by using the --disable " | |
"flag. \n\n" | |
"π Update your pyproject.toml or requirements.txt file to " | |
"reflect any imports from new packages. For example, if you see new " | |
"imports from langchain_openai, langchain_anthropic or " | |
"langchain_text_splitters you " | |
"should them to your dependencies! \n\n" | |
'β οΈ This script is a "best-effort", and is likely to make some ' | |
"mistakes.\n\n" | |
"π‘οΈ Backup your code prior to running the migration script -- it will " | |
"modify your files!\n\n" | |
"β Do you want to continue?" | |
): | |
raise Exit() | |
console = Console(log_time=True) | |
console.log("Start langchain-cli migrate") | |
# NOTE: LIBCST_PARSER_TYPE=native is required according to https://github.com/Instagram/LibCST/issues/487. | |
os.environ["LIBCST_PARSER_TYPE"] = "native" | |
if os.path.isfile(path): | |
package = path.parent | |
all_files = [path] | |
else: | |
package = path | |
all_files = sorted(package.glob("**/*.py")) | |
if include_ipynb: | |
all_files.extend(sorted(package.glob("**/*.ipynb"))) | |
filtered_files = [ | |
file | |
for file in all_files | |
if not any(match_glob(file, pattern) for pattern in ignore) | |
] | |
files = [str(file.relative_to(".")) for file in filtered_files] | |
if len(files) == 1: | |
console.log("Found 1 file to process.") | |
elif len(files) > 1: | |
console.log(f"Found {len(files)} files to process.") | |
else: | |
console.log("No files to process.") | |
raise Exit() | |
providers = {FullyQualifiedNameProvider, ScopeProvider} | |
metadata_manager = FullRepoManager(".", files, providers=providers) # type: ignore[arg-type] | |
metadata_manager.resolve_cache() | |
scratch: dict[str, Any] = {} | |
start_time = time.time() | |
log_fp = log_file.open("a+", encoding="utf8") | |
partial_run_codemods = functools.partial( | |
get_and_run_codemods, disable, metadata_manager, scratch, package, diff | |
) | |
with Progress(*Progress.get_default_columns(), transient=True) as progress: | |
task = progress.add_task(description="Executing codemods...", total=len(files)) | |
count_errors = 0 | |
difflines: List[List[str]] = [] | |
with multiprocessing.Pool() as pool: | |
for error, _difflines in pool.imap_unordered(partial_run_codemods, files): | |
progress.advance(task) | |
if _difflines is not None: | |
difflines.append(_difflines) | |
if error is not None: | |
count_errors += 1 | |
log_fp.writelines(error) | |
modified = [Path(f) for f in files if os.stat(f).st_mtime > start_time] | |
if not diff: | |
if modified: | |
console.log(f"Refactored {len(modified)} files.") | |
else: | |
console.log("No files were modified.") | |
for _difflines in difflines: | |
color_diff(console, _difflines) | |
if count_errors > 0: | |
console.log(f"Found {count_errors} errors. Please check the {log_file} file.") | |
else: | |
console.log("Run successfully!") | |
if difflines: | |
raise Exit(1) | |
def get_and_run_codemods( | |
disabled_rules: List[Rule], | |
metadata_manager: FullRepoManager, | |
scratch: Dict[str, Any], | |
package: Path, | |
diff: bool, | |
filename: str, | |
) -> Tuple[Union[str, None], Union[List[str], None]]: | |
"""Run codemods from rules. | |
Wrapper around run_codemods to be used with multiprocessing.Pool. | |
""" | |
codemods = gather_codemods(disabled=disabled_rules) | |
return run_codemods(codemods, metadata_manager, scratch, package, diff, filename) | |
def _rewrite_file( | |
filename: str, | |
codemods: List[Type[ContextAwareTransformer]], | |
diff: bool, | |
context: CodemodContext, | |
) -> Tuple[Union[str, None], Union[List[str], None]]: | |
file_path = Path(filename) | |
with file_path.open("r+", encoding="utf-8") as fp: | |
code = fp.read() | |
fp.seek(0) | |
input_tree = cst.parse_module(code) | |
for codemod in codemods: | |
transformer = codemod(context=context) | |
output_tree = transformer.transform_module(input_tree) | |
input_tree = output_tree | |
output_code = input_tree.code | |
if code != output_code: | |
if diff: | |
lines = difflib.unified_diff( | |
code.splitlines(keepends=True), | |
output_code.splitlines(keepends=True), | |
fromfile=filename, | |
tofile=filename, | |
) | |
return None, list(lines) | |
else: | |
fp.write(output_code) | |
fp.truncate() | |
return None, None | |
def _rewrite_notebook( | |
filename: str, | |
codemods: List[Type[ContextAwareTransformer]], | |
diff: bool, | |
context: CodemodContext, | |
) -> Tuple[Optional[str], Optional[List[str]]]: | |
"""Try to rewrite a Jupyter Notebook file.""" | |
import nbformat | |
file_path = Path(filename) | |
if file_path.suffix != ".ipynb": | |
raise ValueError("Only Jupyter Notebook files (.ipynb) are supported.") | |
with file_path.open("r", encoding="utf-8") as fp: | |
notebook = nbformat.read(fp, as_version=4) | |
diffs = [] | |
for cell in notebook.cells: | |
if cell.cell_type == "code": | |
code = "".join(cell.source) | |
# Skip code if any of the lines begin with a magic command or | |
# a ! command. | |
# We can try to handle later. | |
if any( | |
line.startswith("!") or line.startswith("%") | |
for line in code.splitlines() | |
): | |
continue | |
input_tree = cst.parse_module(code) | |
# TODO(Team): Quick hack, need to figure out | |
# how to handle this correctly. | |
# This prevents the code from trying to re-insert the imports | |
# for every cell in the notebook. | |
local_context = CodemodContext() | |
for codemod in codemods: | |
transformer = codemod(context=local_context) | |
output_tree = transformer.transform_module(input_tree) | |
input_tree = output_tree | |
output_code = input_tree.code | |
if code != output_code: | |
cell.source = output_code.splitlines(keepends=True) | |
if diff: | |
cell_diff = difflib.unified_diff( | |
code.splitlines(keepends=True), | |
output_code.splitlines(keepends=True), | |
fromfile=filename, | |
tofile=filename, | |
) | |
diffs.extend(list(cell_diff)) | |
if diff: | |
return None, diffs | |
with file_path.open("w", encoding="utf-8") as fp: | |
nbformat.write(notebook, fp) | |
return None, None | |
def run_codemods( | |
codemods: List[Type[ContextAwareTransformer]], | |
metadata_manager: FullRepoManager, | |
scratch: Dict[str, Any], | |
package: Path, | |
diff: bool, | |
filename: str, | |
) -> Tuple[Union[str, None], Union[List[str], None]]: | |
try: | |
module_and_package = calculate_module_and_package(str(package), filename) | |
context = CodemodContext( | |
metadata_manager=metadata_manager, | |
filename=filename, | |
full_module_name=module_and_package.name, | |
full_package_name=module_and_package.package, | |
) | |
context.scratch.update(scratch) | |
if filename.endswith(".ipynb"): | |
return _rewrite_notebook(filename, codemods, diff, context) | |
else: | |
return _rewrite_file(filename, codemods, diff, context) | |
except cst.ParserSyntaxError as exc: | |
return ( | |
f"A syntax error happened on {filename}. This file cannot be " | |
f"formatted.\n" | |
f"{exc}" | |
), None | |
except Exception: | |
return f"An error happened on {filename}.\n{traceback.format_exc()}", None | |
def color_diff(console: Console, lines: Iterable[str]) -> None: | |
for line in lines: | |
line = line.rstrip("\n") | |
if line.startswith("+"): | |
console.print(line, style="green") | |
elif line.startswith("-"): | |
console.print(line, style="red") | |
elif line.startswith("^"): | |
console.print(line, style="blue") | |
else: | |
console.print(line, style="white") | |