TechDev commited on
Commit
c80821d
·
verified ·
1 Parent(s): 2851c63

Upload FastTelethon.py

Browse files
Files changed (1) hide show
  1. FastTelethon.py +308 -0
FastTelethon.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copied from https://github.com/tulir/mautrix-telegram/blob/master/mautrix_telegram/util/parallel_file_transfer.py
2
+ # Copyright (C) 2021 Tulir Asokan
3
+ import asyncio
4
+ import hashlib
5
+ import inspect
6
+ import logging
7
+ import math
8
+ import os
9
+ from collections import defaultdict
10
+ from typing import Optional, List, AsyncGenerator, Union, Awaitable, DefaultDict, Tuple, BinaryIO
11
+
12
+ from telethon import utils, helpers, TelegramClient
13
+ from telethon.crypto import AuthKey
14
+ from telethon.network import MTProtoSender
15
+ from telethon.tl.alltlobjects import LAYER
16
+ from telethon.tl.functions import InvokeWithLayerRequest
17
+ from telethon.tl.functions.auth import ExportAuthorizationRequest, ImportAuthorizationRequest
18
+ from telethon.tl.functions.upload import (GetFileRequest, SaveFilePartRequest,
19
+ SaveBigFilePartRequest)
20
+ from telethon.tl.types import (Document, InputFileLocation, InputDocumentFileLocation,
21
+ InputPhotoFileLocation, InputPeerPhotoFileLocation, TypeInputFile,
22
+ InputFileBig, InputFile)
23
+
24
+ try:
25
+ from mautrix.crypto.attachments import async_encrypt_attachment
26
+ except ImportError:
27
+ async_encrypt_attachment = None
28
+
29
+ log: logging.Logger = logging.getLogger("telethon")
30
+
31
+ TypeLocation = Union[Document, InputDocumentFileLocation, InputPeerPhotoFileLocation,
32
+ InputFileLocation, InputPhotoFileLocation]
33
+
34
+
35
+ class DownloadSender:
36
+ client: TelegramClient
37
+ sender: MTProtoSender
38
+ request: GetFileRequest
39
+ remaining: int
40
+ stride: int
41
+
42
+ def __init__(self, client: TelegramClient, sender: MTProtoSender, file: TypeLocation, offset: int, limit: int,
43
+ stride: int, count: int) -> None:
44
+ self.sender = sender
45
+ self.client = client
46
+ self.request = GetFileRequest(file, offset=offset, limit=limit)
47
+ self.stride = stride
48
+ self.remaining = count
49
+
50
+ async def next(self) -> Optional[bytes]:
51
+ if not self.remaining:
52
+ return None
53
+ result = await self.client._call(self.sender, self.request)
54
+ self.remaining -= 1
55
+ self.request.offset += self.stride
56
+ return result.bytes
57
+
58
+ def disconnect(self) -> Awaitable[None]:
59
+ return self.sender.disconnect()
60
+
61
+
62
+ class UploadSender:
63
+ client: TelegramClient
64
+ sender: MTProtoSender
65
+ request: Union[SaveFilePartRequest, SaveBigFilePartRequest]
66
+ part_count: int
67
+ stride: int
68
+ previous: Optional[asyncio.Task]
69
+ loop: asyncio.AbstractEventLoop
70
+
71
+ def __init__(self, client: TelegramClient, sender: MTProtoSender, file_id: int, part_count: int, big: bool,
72
+ index: int,
73
+ stride: int, loop: asyncio.AbstractEventLoop) -> None:
74
+ self.client = client
75
+ self.sender = sender
76
+ self.part_count = part_count
77
+ if big:
78
+ self.request = SaveBigFilePartRequest(file_id, index, part_count, b"")
79
+ else:
80
+ self.request = SaveFilePartRequest(file_id, index, b"")
81
+ self.stride = stride
82
+ self.previous = None
83
+ self.loop = loop
84
+
85
+ async def next(self, data: bytes) -> None:
86
+ if self.previous:
87
+ await self.previous
88
+ self.previous = self.loop.create_task(self._next(data))
89
+
90
+ async def _next(self, data: bytes) -> None:
91
+ self.request.bytes = data
92
+ log.debug(f"Sending file part {self.request.file_part}/{self.part_count}"
93
+ f" with {len(data)} bytes")
94
+ await self.client._call(self.sender, self.request)
95
+ self.request.file_part += self.stride
96
+
97
+ async def disconnect(self) -> None:
98
+ if self.previous:
99
+ await self.previous
100
+ return await self.sender.disconnect()
101
+
102
+
103
+ class ParallelTransferrer:
104
+ client: TelegramClient
105
+ loop: asyncio.AbstractEventLoop
106
+ dc_id: int
107
+ senders: Optional[List[Union[DownloadSender, UploadSender]]]
108
+ auth_key: AuthKey
109
+ upload_ticker: int
110
+
111
+ def __init__(self, client: TelegramClient, dc_id: Optional[int] = None) -> None:
112
+ self.client = client
113
+ self.loop = self.client.loop
114
+ self.dc_id = dc_id or self.client.session.dc_id
115
+ self.auth_key = (None if dc_id and self.client.session.dc_id != dc_id
116
+ else self.client.session.auth_key)
117
+ self.senders = None
118
+ self.upload_ticker = 0
119
+
120
+ async def _cleanup(self) -> None:
121
+ await asyncio.gather(*[sender.disconnect() for sender in self.senders])
122
+ self.senders = None
123
+
124
+ @staticmethod
125
+ def _get_connection_count(file_size: int, max_count: int = 20,
126
+ full_size: int = 100 * 1024 * 1024) -> int:
127
+ if file_size > full_size:
128
+ return max_count
129
+ return math.ceil((file_size / full_size) * max_count)
130
+
131
+ async def _init_download(self, connections: int, file: TypeLocation, part_count: int,
132
+ part_size: int) -> None:
133
+ minimum, remainder = divmod(part_count, connections)
134
+
135
+ def get_part_count() -> int:
136
+ nonlocal remainder
137
+ if remainder > 0:
138
+ remainder -= 1
139
+ return minimum + 1
140
+ return minimum
141
+
142
+ # The first cross-DC sender will export+import the authorization, so we always create it
143
+ # before creating any other senders.
144
+ self.senders = [
145
+ await self._create_download_sender(file, 0, part_size, connections * part_size,
146
+ get_part_count()),
147
+ *await asyncio.gather(
148
+ *[self._create_download_sender(file, i, part_size, connections * part_size,
149
+ get_part_count())
150
+ for i in range(1, connections)])
151
+ ]
152
+
153
+ async def _create_download_sender(self, file: TypeLocation, index: int, part_size: int,
154
+ stride: int,
155
+ part_count: int) -> DownloadSender:
156
+ return DownloadSender(self.client, await self._create_sender(), file, index * part_size, part_size,
157
+ stride, part_count)
158
+
159
+ async def _init_upload(self, connections: int, file_id: int, part_count: int, big: bool
160
+ ) -> None:
161
+ self.senders = [
162
+ await self._create_upload_sender(file_id, part_count, big, 0, connections),
163
+ *await asyncio.gather(
164
+ *[self._create_upload_sender(file_id, part_count, big, i, connections)
165
+ for i in range(1, connections)])
166
+ ]
167
+
168
+ async def _create_upload_sender(self, file_id: int, part_count: int, big: bool, index: int,
169
+ stride: int) -> UploadSender:
170
+ return UploadSender(self.client, await self._create_sender(), file_id, part_count, big, index, stride,
171
+ loop=self.loop)
172
+
173
+ async def _create_sender(self) -> MTProtoSender:
174
+ dc = await self.client._get_dc(self.dc_id)
175
+ sender = MTProtoSender(self.auth_key, loggers=self.client._log)
176
+ await sender.connect(self.client._connection(dc.ip_address, dc.port, dc.id,
177
+ loggers=self.client._log,
178
+ proxy=self.client._proxy))
179
+ if not self.auth_key:
180
+ log.debug(f"Exporting auth to DC {self.dc_id}")
181
+ auth = await self.client(ExportAuthorizationRequest(self.dc_id))
182
+ self.client._init_request.query = ImportAuthorizationRequest(id=auth.id,
183
+ bytes=auth.bytes)
184
+ req = InvokeWithLayerRequest(LAYER, self.client._init_request)
185
+ await sender.send(req)
186
+ self.auth_key = sender.auth_key
187
+ return sender
188
+
189
+ async def init_upload(self, file_id: int, file_size: int, part_size_kb: Optional[float] = None,
190
+ connection_count: Optional[int] = None) -> Tuple[int, int, bool]:
191
+ connection_count = connection_count or self._get_connection_count(file_size)
192
+ part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
193
+ part_count = (file_size + part_size - 1) // part_size
194
+ is_large = file_size > 10 * 1024 * 1024
195
+ await self._init_upload(connection_count, file_id, part_count, is_large)
196
+ return part_size, part_count, is_large
197
+
198
+ async def upload(self, part: bytes) -> None:
199
+ await self.senders[self.upload_ticker].next(part)
200
+ self.upload_ticker = (self.upload_ticker + 1) % len(self.senders)
201
+
202
+ async def finish_upload(self) -> None:
203
+ await self._cleanup()
204
+
205
+ async def download(self, file: TypeLocation, file_size: int,
206
+ part_size_kb: Optional[float] = None,
207
+ connection_count: Optional[int] = None) -> AsyncGenerator[bytes, None]:
208
+ connection_count = connection_count or self._get_connection_count(file_size)
209
+ part_size = (part_size_kb or utils.get_appropriated_part_size(file_size)) * 1024
210
+ part_count = math.ceil(file_size / part_size)
211
+ log.debug("Starting parallel download: "
212
+ f"{connection_count} {part_size} {part_count} {file!s}")
213
+ await self._init_download(connection_count, file, part_count, part_size)
214
+
215
+ part = 0
216
+ while part < part_count:
217
+ tasks = []
218
+ for sender in self.senders:
219
+ tasks.append(self.loop.create_task(sender.next()))
220
+ for task in tasks:
221
+ data = await task
222
+ if not data:
223
+ break
224
+ yield data
225
+ part += 1
226
+ log.debug(f"Part {part} downloaded")
227
+
228
+ log.debug("Parallel download finished, cleaning up connections")
229
+ await self._cleanup()
230
+
231
+
232
+ parallel_transfer_locks: DefaultDict[int, asyncio.Lock] = defaultdict(lambda: asyncio.Lock())
233
+
234
+
235
+ def stream_file(file_to_stream: BinaryIO, chunk_size=1024):
236
+ while True:
237
+ data_read = file_to_stream.read(chunk_size)
238
+ if not data_read:
239
+ break
240
+ yield data_read
241
+
242
+
243
+ async def _internal_transfer_to_telegram(client: TelegramClient,
244
+ response: BinaryIO,
245
+ progress_callback: callable
246
+ ) -> Tuple[TypeInputFile, int]:
247
+ file_id = helpers.generate_random_long()
248
+ file_size = os.path.getsize(response.name)
249
+
250
+ hash_md5 = hashlib.md5()
251
+ uploader = ParallelTransferrer(client)
252
+ part_size, part_count, is_large = await uploader.init_upload(file_id, file_size)
253
+ buffer = bytearray()
254
+ for data in stream_file(response):
255
+ if progress_callback:
256
+ r = progress_callback(response.tell(), file_size)
257
+ if inspect.isawaitable(r):
258
+ await r
259
+ if not is_large:
260
+ hash_md5.update(data)
261
+ if len(buffer) == 0 and len(data) == part_size:
262
+ await uploader.upload(data)
263
+ continue
264
+ new_len = len(buffer) + len(data)
265
+ if new_len >= part_size:
266
+ cutoff = part_size - len(buffer)
267
+ buffer.extend(data[:cutoff])
268
+ await uploader.upload(bytes(buffer))
269
+ buffer.clear()
270
+ buffer.extend(data[cutoff:])
271
+ else:
272
+ buffer.extend(data)
273
+ if len(buffer) > 0:
274
+ await uploader.upload(bytes(buffer))
275
+ await uploader.finish_upload()
276
+ if is_large:
277
+ return InputFileBig(file_id, part_count, "upload"), file_size
278
+ else:
279
+ return InputFile(file_id, part_count, "upload", hash_md5.hexdigest()), file_size
280
+
281
+
282
+ async def download_file(client: TelegramClient,
283
+ location: TypeLocation,
284
+ out: BinaryIO,
285
+ progress_callback: callable = None
286
+ ) -> BinaryIO:
287
+ size = location.size
288
+ dc_id, location = utils.get_input_location(location)
289
+ # We lock the transfers because telegram has connection count limits
290
+ downloader = ParallelTransferrer(client, dc_id)
291
+ downloaded = downloader.download(location, size)
292
+ async for x in downloaded:
293
+ out.write(x)
294
+ if progress_callback:
295
+ r = progress_callback(out.tell(), size)
296
+ if inspect.isawaitable(r):
297
+ await r
298
+
299
+ return out
300
+
301
+
302
+ async def upload_file(client: TelegramClient,
303
+ file: BinaryIO,
304
+ progress_callback: callable = None,
305
+
306
+ ) -> TypeInputFile:
307
+ res = (await _internal_transfer_to_telegram(client, file, progress_callback))[0]
308
+ return res