YiftachEde commited on
Commit
5c79851
·
1 Parent(s): 0820934
Files changed (1) hide show
  1. torch_patch.py +66 -0
torch_patch.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Patch for torch module to make it compatible with newer diffusers versions
3
+ while using PyTorch 2.0.1
4
+ """
5
+ import torch
6
+ import sys
7
+ import warnings
8
+ import types
9
+ import functools
10
+
11
+ # Check if the attributes already exist
12
+ if not hasattr(torch, 'float8_e4m3fn'):
13
+ # Add missing attributes for compatibility
14
+ # These won't actually function, but they'll allow imports to succeed
15
+ torch.float8_e4m3fn = torch.float16 # Use float16 as a placeholder type
16
+ warnings.warn(
17
+ "Added placeholder for torch.float8_e4m3fn. Actual 8-bit operations won't work, "
18
+ "but imports should succeed. Using PyTorch 2.0.1 with newer diffusers."
19
+ )
20
+
21
+ if not hasattr(torch, 'float8_e5m2'):
22
+ torch.float8_e5m2 = torch.float16 # Use float16 as a placeholder type
23
+
24
+ # Add other missing torch types that might be referenced
25
+ for type_name in ['bfloat16', 'bfloat8', 'float8_e4m3fnuz']:
26
+ if not hasattr(torch, type_name):
27
+ setattr(torch, type_name, torch.float16)
28
+
29
+ # Create a placeholder for torch._dynamo if it doesn't exist
30
+ if not hasattr(torch, '_dynamo'):
31
+ torch._dynamo = types.ModuleType('torch._dynamo')
32
+ sys.modules['torch._dynamo'] = torch._dynamo
33
+
34
+ # Add common attributes/functions used by torch._dynamo
35
+ torch._dynamo.config = types.SimpleNamespace(suppress_errors=True)
36
+ torch._dynamo.optimize = lambda *args, **kwargs: lambda f: f
37
+ torch._dynamo.disable = lambda: None
38
+ torch._dynamo.reset_repro_cache = lambda: None
39
+
40
+ # Add torch.compile if it doesn't exist
41
+ if not hasattr(torch, 'compile'):
42
+ # Just return the function unchanged
43
+ torch.compile = lambda fn, **kwargs: fn
44
+
45
+ # Create a placeholder for torch.cuda.amp if it doesn't exist
46
+ if not hasattr(torch.cuda, 'amp'):
47
+ torch.cuda.amp = types.ModuleType('torch.cuda.amp')
48
+ sys.modules['torch.cuda.amp'] = torch.cuda.amp
49
+
50
+ # Mock autocast
51
+ class MockAutocast:
52
+ def __init__(self, *args, **kwargs):
53
+ pass
54
+ def __enter__(self):
55
+ return self
56
+ def __exit__(self, *args):
57
+ pass
58
+ def __call__(self, func):
59
+ @functools.wraps(func)
60
+ def wrapper(*args, **kwargs):
61
+ return func(*args, **kwargs)
62
+ return wrapper
63
+
64
+ torch.cuda.amp.autocast = MockAutocast
65
+
66
+ print("PyTorch patched for compatibility with newer diffusers - using latest diffusers with PyTorch 2.0.1")