ayousanz commited on
Commit
b35b196
·
verified ·
1 Parent(s): 5a16a79

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Maputo +0 -0
  2. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Maseru +0 -0
  3. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Mbabane +0 -0
  4. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Mogadishu +0 -0
  5. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Monrovia +0 -0
  6. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Nairobi +0 -0
  7. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Ndjamena +0 -0
  8. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Niamey +0 -0
  9. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Nouakchott +0 -0
  10. .venv/Lib/site-packages/tzdata/zoneinfo/Africa/Ouagadougou +0 -0
  11. .venv/Lib/site-packages/unidic_lite/dicdir/char.bin +3 -0
  12. .venv/Lib/site-packages/urllib3/contrib/__init__.py +0 -0
  13. .venv/Lib/site-packages/urllib3/contrib/__pycache__/socks.cpython-39.pyc +0 -0
  14. .venv/Lib/site-packages/urllib3/contrib/emscripten/__init__.py +16 -0
  15. .venv/Lib/site-packages/urllib3/contrib/emscripten/connection.py +254 -0
  16. .venv/Lib/site-packages/urllib3/contrib/emscripten/emscripten_fetch_worker.js +110 -0
  17. .venv/Lib/site-packages/urllib3/contrib/emscripten/fetch.py +418 -0
  18. .venv/Lib/site-packages/urllib3/contrib/emscripten/request.py +22 -0
  19. .venv/Lib/site-packages/urllib3/contrib/emscripten/response.py +285 -0
  20. .venv/Lib/site-packages/urllib3/contrib/socks.py +228 -0
  21. .venv/Lib/site-packages/wasabi-0.10.1.dist-info/INSTALLER +1 -0
  22. .venv/Lib/site-packages/wasabi-0.10.1.dist-info/LICENSE +21 -0
  23. .venv/Lib/site-packages/wasabi-0.10.1.dist-info/METADATA +575 -0
  24. .venv/Lib/site-packages/wasabi-0.10.1.dist-info/RECORD +21 -0
  25. .venv/Lib/site-packages/wasabi-0.10.1.dist-info/REQUESTED +0 -0
  26. .venv/Lib/site-packages/wasabi-0.10.1.dist-info/WHEEL +5 -0
  27. .venv/Lib/site-packages/wasabi-0.10.1.dist-info/top_level.txt +1 -0
  28. .venv/Lib/site-packages/wasabi-0.10.1.dist-info/zip-safe +1 -0
  29. .venv/Lib/site-packages/wasabi/__pycache__/__init__.cpython-39.pyc +0 -0
  30. .venv/Lib/site-packages/wasabi/__pycache__/about.cpython-39.pyc +0 -0
  31. .venv/Lib/site-packages/wasabi/__pycache__/markdown.cpython-39.pyc +0 -0
  32. .venv/Lib/site-packages/wasabi/__pycache__/printer.cpython-39.pyc +0 -0
  33. .venv/Lib/site-packages/wasabi/__pycache__/tables.cpython-39.pyc +0 -0
  34. .venv/Lib/site-packages/wasabi/__pycache__/traceback_printer.cpython-39.pyc +0 -0
  35. .venv/Lib/site-packages/wasabi/__pycache__/util.cpython-39.pyc +0 -0
  36. .venv/Lib/site-packages/wasabi/tests/__init__.py +0 -0
  37. .venv/Lib/site-packages/wasabi/tests/test_util.py +75 -0
  38. .venv/Lib/site-packages/xlstm/__init__.py +9 -0
  39. .venv/Lib/site-packages/xlstm/blocks/__init__.py +0 -0
  40. .venv/Lib/site-packages/xlstm/blocks/mlstm/__init__.py +1 -0
  41. .venv/Lib/site-packages/xlstm/blocks/mlstm/__pycache__/__init__.cpython-39.pyc +0 -0
  42. .venv/Lib/site-packages/xlstm/blocks/mlstm/__pycache__/backends.cpython-39.pyc +0 -0
  43. .venv/Lib/site-packages/xlstm/blocks/mlstm/__pycache__/block.cpython-39.pyc +0 -0
  44. .venv/Lib/site-packages/xlstm/blocks/mlstm/__pycache__/cell.cpython-39.pyc +0 -0
  45. .venv/Lib/site-packages/xlstm/blocks/mlstm/__pycache__/layer.cpython-39.pyc +0 -0
  46. .venv/Lib/site-packages/xlstm/blocks/mlstm/backends.py +145 -0
  47. .venv/Lib/site-packages/xlstm/blocks/mlstm/block.py +34 -0
  48. .venv/Lib/site-packages/xlstm/blocks/mlstm/cell.py +140 -0
  49. .venv/Lib/site-packages/xlstm/blocks/mlstm/layer.py +180 -0
  50. .venv/Lib/site-packages/xlstm/blocks/slstm/__init__.py +0 -0
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Maputo ADDED
Binary file (131 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Maseru ADDED
Binary file (190 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Mbabane ADDED
Binary file (190 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Mogadishu ADDED
Binary file (191 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Monrovia ADDED
Binary file (164 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Nairobi ADDED
Binary file (191 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Ndjamena ADDED
Binary file (160 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Niamey ADDED
Binary file (180 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Nouakchott ADDED
Binary file (130 Bytes). View file
 
.venv/Lib/site-packages/tzdata/zoneinfo/Africa/Ouagadougou ADDED
Binary file (130 Bytes). View file
 
.venv/Lib/site-packages/unidic_lite/dicdir/char.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd31396563d8924645b80fd3c9aa7b13ca089d7748f25553a1d6bc3f9b511ae8
3
+ size 262496
.venv/Lib/site-packages/urllib3/contrib/__init__.py ADDED
File without changes
.venv/Lib/site-packages/urllib3/contrib/__pycache__/socks.cpython-39.pyc ADDED
Binary file (6.11 kB). View file
 
.venv/Lib/site-packages/urllib3/contrib/emscripten/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import urllib3.connection
4
+
5
+ from ...connectionpool import HTTPConnectionPool, HTTPSConnectionPool
6
+ from .connection import EmscriptenHTTPConnection, EmscriptenHTTPSConnection
7
+
8
+
9
+ def inject_into_urllib3() -> None:
10
+ # override connection classes to use emscripten specific classes
11
+ # n.b. mypy complains about the overriding of classes below
12
+ # if it isn't ignored
13
+ HTTPConnectionPool.ConnectionCls = EmscriptenHTTPConnection
14
+ HTTPSConnectionPool.ConnectionCls = EmscriptenHTTPSConnection
15
+ urllib3.connection.HTTPConnection = EmscriptenHTTPConnection # type: ignore[misc,assignment]
16
+ urllib3.connection.HTTPSConnection = EmscriptenHTTPSConnection # type: ignore[misc,assignment]
.venv/Lib/site-packages/urllib3/contrib/emscripten/connection.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import typing
5
+
6
+ # use http.client.HTTPException for consistency with non-emscripten
7
+ from http.client import HTTPException as HTTPException # noqa: F401
8
+ from http.client import ResponseNotReady
9
+
10
+ from ..._base_connection import _TYPE_BODY
11
+ from ...connection import HTTPConnection, ProxyConfig, port_by_scheme
12
+ from ...exceptions import TimeoutError
13
+ from ...response import BaseHTTPResponse
14
+ from ...util.connection import _TYPE_SOCKET_OPTIONS
15
+ from ...util.timeout import _DEFAULT_TIMEOUT, _TYPE_TIMEOUT
16
+ from ...util.url import Url
17
+ from .fetch import _RequestError, _TimeoutError, send_request, send_streaming_request
18
+ from .request import EmscriptenRequest
19
+ from .response import EmscriptenHttpResponseWrapper, EmscriptenResponse
20
+
21
+ if typing.TYPE_CHECKING:
22
+ from ..._base_connection import BaseHTTPConnection, BaseHTTPSConnection
23
+
24
+
25
+ class EmscriptenHTTPConnection:
26
+ default_port: typing.ClassVar[int] = port_by_scheme["http"]
27
+ default_socket_options: typing.ClassVar[_TYPE_SOCKET_OPTIONS]
28
+
29
+ timeout: None | (float)
30
+
31
+ host: str
32
+ port: int
33
+ blocksize: int
34
+ source_address: tuple[str, int] | None
35
+ socket_options: _TYPE_SOCKET_OPTIONS | None
36
+
37
+ proxy: Url | None
38
+ proxy_config: ProxyConfig | None
39
+
40
+ is_verified: bool = False
41
+ proxy_is_verified: bool | None = None
42
+
43
+ _response: EmscriptenResponse | None
44
+
45
+ def __init__(
46
+ self,
47
+ host: str,
48
+ port: int = 0,
49
+ *,
50
+ timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
51
+ source_address: tuple[str, int] | None = None,
52
+ blocksize: int = 8192,
53
+ socket_options: _TYPE_SOCKET_OPTIONS | None = None,
54
+ proxy: Url | None = None,
55
+ proxy_config: ProxyConfig | None = None,
56
+ ) -> None:
57
+ self.host = host
58
+ self.port = port
59
+ self.timeout = timeout if isinstance(timeout, float) else 0.0
60
+ self.scheme = "http"
61
+ self._closed = True
62
+ self._response = None
63
+ # ignore these things because we don't
64
+ # have control over that stuff
65
+ self.proxy = None
66
+ self.proxy_config = None
67
+ self.blocksize = blocksize
68
+ self.source_address = None
69
+ self.socket_options = None
70
+ self.is_verified = False
71
+
72
+ def set_tunnel(
73
+ self,
74
+ host: str,
75
+ port: int | None = 0,
76
+ headers: typing.Mapping[str, str] | None = None,
77
+ scheme: str = "http",
78
+ ) -> None:
79
+ pass
80
+
81
+ def connect(self) -> None:
82
+ pass
83
+
84
+ def request(
85
+ self,
86
+ method: str,
87
+ url: str,
88
+ body: _TYPE_BODY | None = None,
89
+ headers: typing.Mapping[str, str] | None = None,
90
+ # We know *at least* botocore is depending on the order of the
91
+ # first 3 parameters so to be safe we only mark the later ones
92
+ # as keyword-only to ensure we have space to extend.
93
+ *,
94
+ chunked: bool = False,
95
+ preload_content: bool = True,
96
+ decode_content: bool = True,
97
+ enforce_content_length: bool = True,
98
+ ) -> None:
99
+ self._closed = False
100
+ if url.startswith("/"):
101
+ # no scheme / host / port included, make a full url
102
+ url = f"{self.scheme}://{self.host}:{self.port}" + url
103
+ request = EmscriptenRequest(
104
+ url=url,
105
+ method=method,
106
+ timeout=self.timeout if self.timeout else 0,
107
+ decode_content=decode_content,
108
+ )
109
+ request.set_body(body)
110
+ if headers:
111
+ for k, v in headers.items():
112
+ request.set_header(k, v)
113
+ self._response = None
114
+ try:
115
+ if not preload_content:
116
+ self._response = send_streaming_request(request)
117
+ if self._response is None:
118
+ self._response = send_request(request)
119
+ except _TimeoutError as e:
120
+ raise TimeoutError(e.message) from e
121
+ except _RequestError as e:
122
+ raise HTTPException(e.message) from e
123
+
124
+ def getresponse(self) -> BaseHTTPResponse:
125
+ if self._response is not None:
126
+ return EmscriptenHttpResponseWrapper(
127
+ internal_response=self._response,
128
+ url=self._response.request.url,
129
+ connection=self,
130
+ )
131
+ else:
132
+ raise ResponseNotReady()
133
+
134
+ def close(self) -> None:
135
+ self._closed = True
136
+ self._response = None
137
+
138
+ @property
139
+ def is_closed(self) -> bool:
140
+ """Whether the connection either is brand new or has been previously closed.
141
+ If this property is True then both ``is_connected`` and ``has_connected_to_proxy``
142
+ properties must be False.
143
+ """
144
+ return self._closed
145
+
146
+ @property
147
+ def is_connected(self) -> bool:
148
+ """Whether the connection is actively connected to any origin (proxy or target)"""
149
+ return True
150
+
151
+ @property
152
+ def has_connected_to_proxy(self) -> bool:
153
+ """Whether the connection has successfully connected to its proxy.
154
+ This returns False if no proxy is in use. Used to determine whether
155
+ errors are coming from the proxy layer or from tunnelling to the target origin.
156
+ """
157
+ return False
158
+
159
+
160
+ class EmscriptenHTTPSConnection(EmscriptenHTTPConnection):
161
+ default_port = port_by_scheme["https"]
162
+ # all this is basically ignored, as browser handles https
163
+ cert_reqs: int | str | None = None
164
+ ca_certs: str | None = None
165
+ ca_cert_dir: str | None = None
166
+ ca_cert_data: None | str | bytes = None
167
+ cert_file: str | None
168
+ key_file: str | None
169
+ key_password: str | None
170
+ ssl_context: typing.Any | None
171
+ ssl_version: int | str | None = None
172
+ ssl_minimum_version: int | None = None
173
+ ssl_maximum_version: int | None = None
174
+ assert_hostname: None | str | typing.Literal[False]
175
+ assert_fingerprint: str | None = None
176
+
177
+ def __init__(
178
+ self,
179
+ host: str,
180
+ port: int = 0,
181
+ *,
182
+ timeout: _TYPE_TIMEOUT = _DEFAULT_TIMEOUT,
183
+ source_address: tuple[str, int] | None = None,
184
+ blocksize: int = 16384,
185
+ socket_options: None
186
+ | _TYPE_SOCKET_OPTIONS = HTTPConnection.default_socket_options,
187
+ proxy: Url | None = None,
188
+ proxy_config: ProxyConfig | None = None,
189
+ cert_reqs: int | str | None = None,
190
+ assert_hostname: None | str | typing.Literal[False] = None,
191
+ assert_fingerprint: str | None = None,
192
+ server_hostname: str | None = None,
193
+ ssl_context: typing.Any | None = None,
194
+ ca_certs: str | None = None,
195
+ ca_cert_dir: str | None = None,
196
+ ca_cert_data: None | str | bytes = None,
197
+ ssl_minimum_version: int | None = None,
198
+ ssl_maximum_version: int | None = None,
199
+ ssl_version: int | str | None = None, # Deprecated
200
+ cert_file: str | None = None,
201
+ key_file: str | None = None,
202
+ key_password: str | None = None,
203
+ ) -> None:
204
+ super().__init__(
205
+ host,
206
+ port=port,
207
+ timeout=timeout,
208
+ source_address=source_address,
209
+ blocksize=blocksize,
210
+ socket_options=socket_options,
211
+ proxy=proxy,
212
+ proxy_config=proxy_config,
213
+ )
214
+ self.scheme = "https"
215
+
216
+ self.key_file = key_file
217
+ self.cert_file = cert_file
218
+ self.key_password = key_password
219
+ self.ssl_context = ssl_context
220
+ self.server_hostname = server_hostname
221
+ self.assert_hostname = assert_hostname
222
+ self.assert_fingerprint = assert_fingerprint
223
+ self.ssl_version = ssl_version
224
+ self.ssl_minimum_version = ssl_minimum_version
225
+ self.ssl_maximum_version = ssl_maximum_version
226
+ self.ca_certs = ca_certs and os.path.expanduser(ca_certs)
227
+ self.ca_cert_dir = ca_cert_dir and os.path.expanduser(ca_cert_dir)
228
+ self.ca_cert_data = ca_cert_data
229
+
230
+ self.cert_reqs = None
231
+
232
+ # The browser will automatically verify all requests.
233
+ # We have no control over that setting.
234
+ self.is_verified = True
235
+
236
+ def set_cert(
237
+ self,
238
+ key_file: str | None = None,
239
+ cert_file: str | None = None,
240
+ cert_reqs: int | str | None = None,
241
+ key_password: str | None = None,
242
+ ca_certs: str | None = None,
243
+ assert_hostname: None | str | typing.Literal[False] = None,
244
+ assert_fingerprint: str | None = None,
245
+ ca_cert_dir: str | None = None,
246
+ ca_cert_data: None | str | bytes = None,
247
+ ) -> None:
248
+ pass
249
+
250
+
251
+ # verify that this class implements BaseHTTP(s) connection correctly
252
+ if typing.TYPE_CHECKING:
253
+ _supports_http_protocol: BaseHTTPConnection = EmscriptenHTTPConnection("", 0)
254
+ _supports_https_protocol: BaseHTTPSConnection = EmscriptenHTTPSConnection("", 0)
.venv/Lib/site-packages/urllib3/contrib/emscripten/emscripten_fetch_worker.js ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ let Status = {
2
+ SUCCESS_HEADER: -1,
3
+ SUCCESS_EOF: -2,
4
+ ERROR_TIMEOUT: -3,
5
+ ERROR_EXCEPTION: -4,
6
+ };
7
+
8
+ let connections = {};
9
+ let nextConnectionID = 1;
10
+ const encoder = new TextEncoder();
11
+
12
+ self.addEventListener("message", async function (event) {
13
+ if (event.data.close) {
14
+ let connectionID = event.data.close;
15
+ delete connections[connectionID];
16
+ return;
17
+ } else if (event.data.getMore) {
18
+ let connectionID = event.data.getMore;
19
+ let { curOffset, value, reader, intBuffer, byteBuffer } =
20
+ connections[connectionID];
21
+ // if we still have some in buffer, then just send it back straight away
22
+ if (!value || curOffset >= value.length) {
23
+ // read another buffer if required
24
+ try {
25
+ let readResponse = await reader.read();
26
+
27
+ if (readResponse.done) {
28
+ // read everything - clear connection and return
29
+ delete connections[connectionID];
30
+ Atomics.store(intBuffer, 0, Status.SUCCESS_EOF);
31
+ Atomics.notify(intBuffer, 0);
32
+ // finished reading successfully
33
+ // return from event handler
34
+ return;
35
+ }
36
+ curOffset = 0;
37
+ connections[connectionID].value = readResponse.value;
38
+ value = readResponse.value;
39
+ } catch (error) {
40
+ console.log("Request exception:", error);
41
+ let errorBytes = encoder.encode(error.message);
42
+ let written = errorBytes.length;
43
+ byteBuffer.set(errorBytes);
44
+ intBuffer[1] = written;
45
+ Atomics.store(intBuffer, 0, Status.ERROR_EXCEPTION);
46
+ Atomics.notify(intBuffer, 0);
47
+ }
48
+ }
49
+
50
+ // send as much buffer as we can
51
+ let curLen = value.length - curOffset;
52
+ if (curLen > byteBuffer.length) {
53
+ curLen = byteBuffer.length;
54
+ }
55
+ byteBuffer.set(value.subarray(curOffset, curOffset + curLen), 0);
56
+
57
+ Atomics.store(intBuffer, 0, curLen); // store current length in bytes
58
+ Atomics.notify(intBuffer, 0);
59
+ curOffset += curLen;
60
+ connections[connectionID].curOffset = curOffset;
61
+
62
+ return;
63
+ } else {
64
+ // start fetch
65
+ let connectionID = nextConnectionID;
66
+ nextConnectionID += 1;
67
+ const intBuffer = new Int32Array(event.data.buffer);
68
+ const byteBuffer = new Uint8Array(event.data.buffer, 8);
69
+ try {
70
+ const response = await fetch(event.data.url, event.data.fetchParams);
71
+ // return the headers first via textencoder
72
+ var headers = [];
73
+ for (const pair of response.headers.entries()) {
74
+ headers.push([pair[0], pair[1]]);
75
+ }
76
+ let headerObj = {
77
+ headers: headers,
78
+ status: response.status,
79
+ connectionID,
80
+ };
81
+ const headerText = JSON.stringify(headerObj);
82
+ let headerBytes = encoder.encode(headerText);
83
+ let written = headerBytes.length;
84
+ byteBuffer.set(headerBytes);
85
+ intBuffer[1] = written;
86
+ // make a connection
87
+ connections[connectionID] = {
88
+ reader: response.body.getReader(),
89
+ intBuffer: intBuffer,
90
+ byteBuffer: byteBuffer,
91
+ value: undefined,
92
+ curOffset: 0,
93
+ };
94
+ // set header ready
95
+ Atomics.store(intBuffer, 0, Status.SUCCESS_HEADER);
96
+ Atomics.notify(intBuffer, 0);
97
+ // all fetching after this goes through a new postmessage call with getMore
98
+ // this allows for parallel requests
99
+ } catch (error) {
100
+ console.log("Request exception:", error);
101
+ let errorBytes = encoder.encode(error.message);
102
+ let written = errorBytes.length;
103
+ byteBuffer.set(errorBytes);
104
+ intBuffer[1] = written;
105
+ Atomics.store(intBuffer, 0, Status.ERROR_EXCEPTION);
106
+ Atomics.notify(intBuffer, 0);
107
+ }
108
+ }
109
+ });
110
+ self.postMessage({ inited: true });
.venv/Lib/site-packages/urllib3/contrib/emscripten/fetch.py ADDED
@@ -0,0 +1,418 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Support for streaming http requests in emscripten.
3
+
4
+ A few caveats -
5
+
6
+ Firstly, you can't do streaming http in the main UI thread, because atomics.wait isn't allowed.
7
+ Streaming only works if you're running pyodide in a web worker.
8
+
9
+ Secondly, this uses an extra web worker and SharedArrayBuffer to do the asynchronous fetch
10
+ operation, so it requires that you have crossOriginIsolation enabled, by serving over https
11
+ (or from localhost) with the two headers below set:
12
+
13
+ Cross-Origin-Opener-Policy: same-origin
14
+ Cross-Origin-Embedder-Policy: require-corp
15
+
16
+ You can tell if cross origin isolation is successfully enabled by looking at the global crossOriginIsolated variable in
17
+ javascript console. If it isn't, streaming requests will fallback to XMLHttpRequest, i.e. getting the whole
18
+ request into a buffer and then returning it. it shows a warning in the javascript console in this case.
19
+
20
+ Finally, the webworker which does the streaming fetch is created on initial import, but will only be started once
21
+ control is returned to javascript. Call `await wait_for_streaming_ready()` to wait for streaming fetch.
22
+
23
+ NB: in this code, there are a lot of javascript objects. They are named js_*
24
+ to make it clear what type of object they are.
25
+ """
26
+ from __future__ import annotations
27
+
28
+ import io
29
+ import json
30
+ from email.parser import Parser
31
+ from importlib.resources import files
32
+ from typing import TYPE_CHECKING, Any
33
+
34
+ import js # type: ignore[import-not-found]
35
+ from pyodide.ffi import ( # type: ignore[import-not-found]
36
+ JsArray,
37
+ JsException,
38
+ JsProxy,
39
+ to_js,
40
+ )
41
+
42
+ if TYPE_CHECKING:
43
+ from typing_extensions import Buffer
44
+
45
+ from .request import EmscriptenRequest
46
+ from .response import EmscriptenResponse
47
+
48
+ """
49
+ There are some headers that trigger unintended CORS preflight requests.
50
+ See also https://github.com/koenvo/pyodide-http/issues/22
51
+ """
52
+ HEADERS_TO_IGNORE = ("user-agent",)
53
+
54
+ SUCCESS_HEADER = -1
55
+ SUCCESS_EOF = -2
56
+ ERROR_TIMEOUT = -3
57
+ ERROR_EXCEPTION = -4
58
+
59
+ _STREAMING_WORKER_CODE = (
60
+ files(__package__)
61
+ .joinpath("emscripten_fetch_worker.js")
62
+ .read_text(encoding="utf-8")
63
+ )
64
+
65
+
66
+ class _RequestError(Exception):
67
+ def __init__(
68
+ self,
69
+ message: str | None = None,
70
+ *,
71
+ request: EmscriptenRequest | None = None,
72
+ response: EmscriptenResponse | None = None,
73
+ ):
74
+ self.request = request
75
+ self.response = response
76
+ self.message = message
77
+ super().__init__(self.message)
78
+
79
+
80
+ class _StreamingError(_RequestError):
81
+ pass
82
+
83
+
84
+ class _TimeoutError(_RequestError):
85
+ pass
86
+
87
+
88
+ def _obj_from_dict(dict_val: dict[str, Any]) -> JsProxy:
89
+ return to_js(dict_val, dict_converter=js.Object.fromEntries)
90
+
91
+
92
+ class _ReadStream(io.RawIOBase):
93
+ def __init__(
94
+ self,
95
+ int_buffer: JsArray,
96
+ byte_buffer: JsArray,
97
+ timeout: float,
98
+ worker: JsProxy,
99
+ connection_id: int,
100
+ request: EmscriptenRequest,
101
+ ):
102
+ self.int_buffer = int_buffer
103
+ self.byte_buffer = byte_buffer
104
+ self.read_pos = 0
105
+ self.read_len = 0
106
+ self.connection_id = connection_id
107
+ self.worker = worker
108
+ self.timeout = int(1000 * timeout) if timeout > 0 else None
109
+ self.is_live = True
110
+ self._is_closed = False
111
+ self.request: EmscriptenRequest | None = request
112
+
113
+ def __del__(self) -> None:
114
+ self.close()
115
+
116
+ # this is compatible with _base_connection
117
+ def is_closed(self) -> bool:
118
+ return self._is_closed
119
+
120
+ # for compatibility with RawIOBase
121
+ @property
122
+ def closed(self) -> bool:
123
+ return self.is_closed()
124
+
125
+ def close(self) -> None:
126
+ if not self.is_closed():
127
+ self.read_len = 0
128
+ self.read_pos = 0
129
+ self.int_buffer = None
130
+ self.byte_buffer = None
131
+ self._is_closed = True
132
+ self.request = None
133
+ if self.is_live:
134
+ self.worker.postMessage(_obj_from_dict({"close": self.connection_id}))
135
+ self.is_live = False
136
+ super().close()
137
+
138
+ def readable(self) -> bool:
139
+ return True
140
+
141
+ def writable(self) -> bool:
142
+ return False
143
+
144
+ def seekable(self) -> bool:
145
+ return False
146
+
147
+ def readinto(self, byte_obj: Buffer) -> int:
148
+ if not self.int_buffer:
149
+ raise _StreamingError(
150
+ "No buffer for stream in _ReadStream.readinto",
151
+ request=self.request,
152
+ response=None,
153
+ )
154
+ if self.read_len == 0:
155
+ # wait for the worker to send something
156
+ js.Atomics.store(self.int_buffer, 0, ERROR_TIMEOUT)
157
+ self.worker.postMessage(_obj_from_dict({"getMore": self.connection_id}))
158
+ if (
159
+ js.Atomics.wait(self.int_buffer, 0, ERROR_TIMEOUT, self.timeout)
160
+ == "timed-out"
161
+ ):
162
+ raise _TimeoutError
163
+ data_len = self.int_buffer[0]
164
+ if data_len > 0:
165
+ self.read_len = data_len
166
+ self.read_pos = 0
167
+ elif data_len == ERROR_EXCEPTION:
168
+ string_len = self.int_buffer[1]
169
+ # decode the error string
170
+ js_decoder = js.TextDecoder.new()
171
+ json_str = js_decoder.decode(self.byte_buffer.slice(0, string_len))
172
+ raise _StreamingError(
173
+ f"Exception thrown in fetch: {json_str}",
174
+ request=self.request,
175
+ response=None,
176
+ )
177
+ else:
178
+ # EOF, free the buffers and return zero
179
+ # and free the request
180
+ self.is_live = False
181
+ self.close()
182
+ return 0
183
+ # copy from int32array to python bytes
184
+ ret_length = min(self.read_len, len(memoryview(byte_obj)))
185
+ subarray = self.byte_buffer.subarray(
186
+ self.read_pos, self.read_pos + ret_length
187
+ ).to_py()
188
+ memoryview(byte_obj)[0:ret_length] = subarray
189
+ self.read_len -= ret_length
190
+ self.read_pos += ret_length
191
+ return ret_length
192
+
193
+
194
+ class _StreamingFetcher:
195
+ def __init__(self) -> None:
196
+ # make web-worker and data buffer on startup
197
+ self.streaming_ready = False
198
+
199
+ js_data_blob = js.Blob.new(
200
+ [_STREAMING_WORKER_CODE], _obj_from_dict({"type": "application/javascript"})
201
+ )
202
+
203
+ def promise_resolver(js_resolve_fn: JsProxy, js_reject_fn: JsProxy) -> None:
204
+ def onMsg(e: JsProxy) -> None:
205
+ self.streaming_ready = True
206
+ js_resolve_fn(e)
207
+
208
+ def onErr(e: JsProxy) -> None:
209
+ js_reject_fn(e) # Defensive: never happens in ci
210
+
211
+ self.js_worker.onmessage = onMsg
212
+ self.js_worker.onerror = onErr
213
+
214
+ js_data_url = js.URL.createObjectURL(js_data_blob)
215
+ self.js_worker = js.globalThis.Worker.new(js_data_url)
216
+ self.js_worker_ready_promise = js.globalThis.Promise.new(promise_resolver)
217
+
218
+ def send(self, request: EmscriptenRequest) -> EmscriptenResponse:
219
+ headers = {
220
+ k: v for k, v in request.headers.items() if k not in HEADERS_TO_IGNORE
221
+ }
222
+
223
+ body = request.body
224
+ fetch_data = {"headers": headers, "body": to_js(body), "method": request.method}
225
+ # start the request off in the worker
226
+ timeout = int(1000 * request.timeout) if request.timeout > 0 else None
227
+ js_shared_buffer = js.SharedArrayBuffer.new(1048576)
228
+ js_int_buffer = js.Int32Array.new(js_shared_buffer)
229
+ js_byte_buffer = js.Uint8Array.new(js_shared_buffer, 8)
230
+
231
+ js.Atomics.store(js_int_buffer, 0, ERROR_TIMEOUT)
232
+ js.Atomics.notify(js_int_buffer, 0)
233
+ js_absolute_url = js.URL.new(request.url, js.location).href
234
+ self.js_worker.postMessage(
235
+ _obj_from_dict(
236
+ {
237
+ "buffer": js_shared_buffer,
238
+ "url": js_absolute_url,
239
+ "fetchParams": fetch_data,
240
+ }
241
+ )
242
+ )
243
+ # wait for the worker to send something
244
+ js.Atomics.wait(js_int_buffer, 0, ERROR_TIMEOUT, timeout)
245
+ if js_int_buffer[0] == ERROR_TIMEOUT:
246
+ raise _TimeoutError(
247
+ "Timeout connecting to streaming request",
248
+ request=request,
249
+ response=None,
250
+ )
251
+ elif js_int_buffer[0] == SUCCESS_HEADER:
252
+ # got response
253
+ # header length is in second int of intBuffer
254
+ string_len = js_int_buffer[1]
255
+ # decode the rest to a JSON string
256
+ js_decoder = js.TextDecoder.new()
257
+ # this does a copy (the slice) because decode can't work on shared array
258
+ # for some silly reason
259
+ json_str = js_decoder.decode(js_byte_buffer.slice(0, string_len))
260
+ # get it as an object
261
+ response_obj = json.loads(json_str)
262
+ return EmscriptenResponse(
263
+ request=request,
264
+ status_code=response_obj["status"],
265
+ headers=response_obj["headers"],
266
+ body=_ReadStream(
267
+ js_int_buffer,
268
+ js_byte_buffer,
269
+ request.timeout,
270
+ self.js_worker,
271
+ response_obj["connectionID"],
272
+ request,
273
+ ),
274
+ )
275
+ elif js_int_buffer[0] == ERROR_EXCEPTION:
276
+ string_len = js_int_buffer[1]
277
+ # decode the error string
278
+ js_decoder = js.TextDecoder.new()
279
+ json_str = js_decoder.decode(js_byte_buffer.slice(0, string_len))
280
+ raise _StreamingError(
281
+ f"Exception thrown in fetch: {json_str}", request=request, response=None
282
+ )
283
+ else:
284
+ raise _StreamingError(
285
+ f"Unknown status from worker in fetch: {js_int_buffer[0]}",
286
+ request=request,
287
+ response=None,
288
+ )
289
+
290
+
291
+ # check if we are in a worker or not
292
+ def is_in_browser_main_thread() -> bool:
293
+ return hasattr(js, "window") and hasattr(js, "self") and js.self == js.window
294
+
295
+
296
+ def is_cross_origin_isolated() -> bool:
297
+ return hasattr(js, "crossOriginIsolated") and js.crossOriginIsolated
298
+
299
+
300
+ def is_in_node() -> bool:
301
+ return (
302
+ hasattr(js, "process")
303
+ and hasattr(js.process, "release")
304
+ and hasattr(js.process.release, "name")
305
+ and js.process.release.name == "node"
306
+ )
307
+
308
+
309
+ def is_worker_available() -> bool:
310
+ return hasattr(js, "Worker") and hasattr(js, "Blob")
311
+
312
+
313
+ _fetcher: _StreamingFetcher | None = None
314
+
315
+ if is_worker_available() and (
316
+ (is_cross_origin_isolated() and not is_in_browser_main_thread())
317
+ and (not is_in_node())
318
+ ):
319
+ _fetcher = _StreamingFetcher()
320
+ else:
321
+ _fetcher = None
322
+
323
+
324
+ def send_streaming_request(request: EmscriptenRequest) -> EmscriptenResponse | None:
325
+ if _fetcher and streaming_ready():
326
+ return _fetcher.send(request)
327
+ else:
328
+ _show_streaming_warning()
329
+ return None
330
+
331
+
332
+ _SHOWN_TIMEOUT_WARNING = False
333
+
334
+
335
+ def _show_timeout_warning() -> None:
336
+ global _SHOWN_TIMEOUT_WARNING
337
+ if not _SHOWN_TIMEOUT_WARNING:
338
+ _SHOWN_TIMEOUT_WARNING = True
339
+ message = "Warning: Timeout is not available on main browser thread"
340
+ js.console.warn(message)
341
+
342
+
343
+ _SHOWN_STREAMING_WARNING = False
344
+
345
+
346
+ def _show_streaming_warning() -> None:
347
+ global _SHOWN_STREAMING_WARNING
348
+ if not _SHOWN_STREAMING_WARNING:
349
+ _SHOWN_STREAMING_WARNING = True
350
+ message = "Can't stream HTTP requests because: \n"
351
+ if not is_cross_origin_isolated():
352
+ message += " Page is not cross-origin isolated\n"
353
+ if is_in_browser_main_thread():
354
+ message += " Python is running in main browser thread\n"
355
+ if not is_worker_available():
356
+ message += " Worker or Blob classes are not available in this environment." # Defensive: this is always False in browsers that we test in
357
+ if streaming_ready() is False:
358
+ message += """ Streaming fetch worker isn't ready. If you want to be sure that streaming fetch
359
+ is working, you need to call: 'await urllib3.contrib.emscripten.fetch.wait_for_streaming_ready()`"""
360
+ from js import console
361
+
362
+ console.warn(message)
363
+
364
+
365
+ def send_request(request: EmscriptenRequest) -> EmscriptenResponse:
366
+ try:
367
+ js_xhr = js.XMLHttpRequest.new()
368
+
369
+ if not is_in_browser_main_thread():
370
+ js_xhr.responseType = "arraybuffer"
371
+ if request.timeout:
372
+ js_xhr.timeout = int(request.timeout * 1000)
373
+ else:
374
+ js_xhr.overrideMimeType("text/plain; charset=ISO-8859-15")
375
+ if request.timeout:
376
+ # timeout isn't available on the main thread - show a warning in console
377
+ # if it is set
378
+ _show_timeout_warning()
379
+
380
+ js_xhr.open(request.method, request.url, False)
381
+ for name, value in request.headers.items():
382
+ if name.lower() not in HEADERS_TO_IGNORE:
383
+ js_xhr.setRequestHeader(name, value)
384
+
385
+ js_xhr.send(to_js(request.body))
386
+
387
+ headers = dict(Parser().parsestr(js_xhr.getAllResponseHeaders()))
388
+
389
+ if not is_in_browser_main_thread():
390
+ body = js_xhr.response.to_py().tobytes()
391
+ else:
392
+ body = js_xhr.response.encode("ISO-8859-15")
393
+ return EmscriptenResponse(
394
+ status_code=js_xhr.status, headers=headers, body=body, request=request
395
+ )
396
+ except JsException as err:
397
+ if err.name == "TimeoutError":
398
+ raise _TimeoutError(err.message, request=request)
399
+ elif err.name == "NetworkError":
400
+ raise _RequestError(err.message, request=request)
401
+ else:
402
+ # general http error
403
+ raise _RequestError(err.message, request=request)
404
+
405
+
406
+ def streaming_ready() -> bool | None:
407
+ if _fetcher:
408
+ return _fetcher.streaming_ready
409
+ else:
410
+ return None # no fetcher, return None to signify that
411
+
412
+
413
+ async def wait_for_streaming_ready() -> bool:
414
+ if _fetcher:
415
+ await _fetcher.js_worker_ready_promise
416
+ return True
417
+ else:
418
+ return False
.venv/Lib/site-packages/urllib3/contrib/emscripten/request.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+
5
+ from ..._base_connection import _TYPE_BODY
6
+
7
+
8
+ @dataclass
9
+ class EmscriptenRequest:
10
+ method: str
11
+ url: str
12
+ params: dict[str, str] | None = None
13
+ body: _TYPE_BODY | None = None
14
+ headers: dict[str, str] = field(default_factory=dict)
15
+ timeout: float = 0
16
+ decode_content: bool = True
17
+
18
+ def set_header(self, name: str, value: str) -> None:
19
+ self.headers[name.capitalize()] = value
20
+
21
+ def set_body(self, body: _TYPE_BODY | None) -> None:
22
+ self.body = body
.venv/Lib/site-packages/urllib3/contrib/emscripten/response.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json as _json
4
+ import logging
5
+ import typing
6
+ from contextlib import contextmanager
7
+ from dataclasses import dataclass
8
+ from http.client import HTTPException as HTTPException
9
+ from io import BytesIO, IOBase
10
+
11
+ from ...exceptions import InvalidHeader, TimeoutError
12
+ from ...response import BaseHTTPResponse
13
+ from ...util.retry import Retry
14
+ from .request import EmscriptenRequest
15
+
16
+ if typing.TYPE_CHECKING:
17
+ from ..._base_connection import BaseHTTPConnection, BaseHTTPSConnection
18
+
19
+ log = logging.getLogger(__name__)
20
+
21
+
22
+ @dataclass
23
+ class EmscriptenResponse:
24
+ status_code: int
25
+ headers: dict[str, str]
26
+ body: IOBase | bytes
27
+ request: EmscriptenRequest
28
+
29
+
30
+ class EmscriptenHttpResponseWrapper(BaseHTTPResponse):
31
+ def __init__(
32
+ self,
33
+ internal_response: EmscriptenResponse,
34
+ url: str | None = None,
35
+ connection: BaseHTTPConnection | BaseHTTPSConnection | None = None,
36
+ ):
37
+ self._pool = None # set by pool class
38
+ self._body = None
39
+ self._response = internal_response
40
+ self._url = url
41
+ self._connection = connection
42
+ self._closed = False
43
+ super().__init__(
44
+ headers=internal_response.headers,
45
+ status=internal_response.status_code,
46
+ request_url=url,
47
+ version=0,
48
+ version_string="HTTP/?",
49
+ reason="",
50
+ decode_content=True,
51
+ )
52
+ self.length_remaining = self._init_length(self._response.request.method)
53
+ self.length_is_certain = False
54
+
55
+ @property
56
+ def url(self) -> str | None:
57
+ return self._url
58
+
59
+ @url.setter
60
+ def url(self, url: str | None) -> None:
61
+ self._url = url
62
+
63
+ @property
64
+ def connection(self) -> BaseHTTPConnection | BaseHTTPSConnection | None:
65
+ return self._connection
66
+
67
+ @property
68
+ def retries(self) -> Retry | None:
69
+ return self._retries
70
+
71
+ @retries.setter
72
+ def retries(self, retries: Retry | None) -> None:
73
+ # Override the request_url if retries has a redirect location.
74
+ self._retries = retries
75
+
76
+ def stream(
77
+ self, amt: int | None = 2**16, decode_content: bool | None = None
78
+ ) -> typing.Generator[bytes, None, None]:
79
+ """
80
+ A generator wrapper for the read() method. A call will block until
81
+ ``amt`` bytes have been read from the connection or until the
82
+ connection is closed.
83
+
84
+ :param amt:
85
+ How much of the content to read. The generator will return up to
86
+ much data per iteration, but may return less. This is particularly
87
+ likely when using compressed data. However, the empty string will
88
+ never be returned.
89
+
90
+ :param decode_content:
91
+ If True, will attempt to decode the body based on the
92
+ 'content-encoding' header.
93
+ """
94
+ while True:
95
+ data = self.read(amt=amt, decode_content=decode_content)
96
+
97
+ if data:
98
+ yield data
99
+ else:
100
+ break
101
+
102
+ def _init_length(self, request_method: str | None) -> int | None:
103
+ length: int | None
104
+ content_length: str | None = self.headers.get("content-length")
105
+
106
+ if content_length is not None:
107
+ try:
108
+ # RFC 7230 section 3.3.2 specifies multiple content lengths can
109
+ # be sent in a single Content-Length header
110
+ # (e.g. Content-Length: 42, 42). This line ensures the values
111
+ # are all valid ints and that as long as the `set` length is 1,
112
+ # all values are the same. Otherwise, the header is invalid.
113
+ lengths = {int(val) for val in content_length.split(",")}
114
+ if len(lengths) > 1:
115
+ raise InvalidHeader(
116
+ "Content-Length contained multiple "
117
+ "unmatching values (%s)" % content_length
118
+ )
119
+ length = lengths.pop()
120
+ except ValueError:
121
+ length = None
122
+ else:
123
+ if length < 0:
124
+ length = None
125
+
126
+ else: # if content_length is None
127
+ length = None
128
+
129
+ # Check for responses that shouldn't include a body
130
+ if (
131
+ self.status in (204, 304)
132
+ or 100 <= self.status < 200
133
+ or request_method == "HEAD"
134
+ ):
135
+ length = 0
136
+
137
+ return length
138
+
139
+ def read(
140
+ self,
141
+ amt: int | None = None,
142
+ decode_content: bool | None = None, # ignored because browser decodes always
143
+ cache_content: bool = False,
144
+ ) -> bytes:
145
+ if (
146
+ self._closed
147
+ or self._response is None
148
+ or (isinstance(self._response.body, IOBase) and self._response.body.closed)
149
+ ):
150
+ return b""
151
+
152
+ with self._error_catcher():
153
+ # body has been preloaded as a string by XmlHttpRequest
154
+ if not isinstance(self._response.body, IOBase):
155
+ self.length_remaining = len(self._response.body)
156
+ self.length_is_certain = True
157
+ # wrap body in IOStream
158
+ self._response.body = BytesIO(self._response.body)
159
+ if amt is not None and amt >= 0:
160
+ # don't cache partial content
161
+ cache_content = False
162
+ data = self._response.body.read(amt)
163
+ if self.length_remaining is not None:
164
+ self.length_remaining = max(self.length_remaining - len(data), 0)
165
+ if (self.length_is_certain and self.length_remaining == 0) or len(
166
+ data
167
+ ) < amt:
168
+ # definitely finished reading, close response stream
169
+ self._response.body.close()
170
+ return typing.cast(bytes, data)
171
+ else: # read all we can (and cache it)
172
+ data = self._response.body.read()
173
+ if cache_content:
174
+ self._body = data
175
+ if self.length_remaining is not None:
176
+ self.length_remaining = max(self.length_remaining - len(data), 0)
177
+ if len(data) == 0 or (
178
+ self.length_is_certain and self.length_remaining == 0
179
+ ):
180
+ # definitely finished reading, close response stream
181
+ self._response.body.close()
182
+ return typing.cast(bytes, data)
183
+
184
+ def read_chunked(
185
+ self,
186
+ amt: int | None = None,
187
+ decode_content: bool | None = None,
188
+ ) -> typing.Generator[bytes, None, None]:
189
+ # chunked is handled by browser
190
+ while True:
191
+ bytes = self.read(amt, decode_content)
192
+ if not bytes:
193
+ break
194
+ yield bytes
195
+
196
+ def release_conn(self) -> None:
197
+ if not self._pool or not self._connection:
198
+ return None
199
+
200
+ self._pool._put_conn(self._connection)
201
+ self._connection = None
202
+
203
+ def drain_conn(self) -> None:
204
+ self.close()
205
+
206
+ @property
207
+ def data(self) -> bytes:
208
+ if self._body:
209
+ return self._body
210
+ else:
211
+ return self.read(cache_content=True)
212
+
213
+ def json(self) -> typing.Any:
214
+ """
215
+ Deserializes the body of the HTTP response as a Python object.
216
+
217
+ The body of the HTTP response must be encoded using UTF-8, as per
218
+ `RFC 8529 Section 8.1 <https://www.rfc-editor.org/rfc/rfc8259#section-8.1>`_.
219
+
220
+ To use a custom JSON decoder pass the result of :attr:`HTTPResponse.data` to
221
+ your custom decoder instead.
222
+
223
+ If the body of the HTTP response is not decodable to UTF-8, a
224
+ `UnicodeDecodeError` will be raised. If the body of the HTTP response is not a
225
+ valid JSON document, a `json.JSONDecodeError` will be raised.
226
+
227
+ Read more :ref:`here <json_content>`.
228
+
229
+ :returns: The body of the HTTP response as a Python object.
230
+ """
231
+ data = self.data.decode("utf-8")
232
+ return _json.loads(data)
233
+
234
+ def close(self) -> None:
235
+ if not self._closed:
236
+ if isinstance(self._response.body, IOBase):
237
+ self._response.body.close()
238
+ if self._connection:
239
+ self._connection.close()
240
+ self._connection = None
241
+ self._closed = True
242
+
243
+ @contextmanager
244
+ def _error_catcher(self) -> typing.Generator[None, None, None]:
245
+ """
246
+ Catch Emscripten specific exceptions thrown by fetch.py,
247
+ instead re-raising urllib3 variants, so that low-level exceptions
248
+ are not leaked in the high-level api.
249
+
250
+ On exit, release the connection back to the pool.
251
+ """
252
+ from .fetch import _RequestError, _TimeoutError # avoid circular import
253
+
254
+ clean_exit = False
255
+
256
+ try:
257
+ yield
258
+ # If no exception is thrown, we should avoid cleaning up
259
+ # unnecessarily.
260
+ clean_exit = True
261
+ except _TimeoutError as e:
262
+ raise TimeoutError(str(e))
263
+ except _RequestError as e:
264
+ raise HTTPException(str(e))
265
+ finally:
266
+ # If we didn't terminate cleanly, we need to throw away our
267
+ # connection.
268
+ if not clean_exit:
269
+ # The response may not be closed but we're not going to use it
270
+ # anymore so close it now
271
+ if (
272
+ isinstance(self._response.body, IOBase)
273
+ and not self._response.body.closed
274
+ ):
275
+ self._response.body.close()
276
+ # release the connection back to the pool
277
+ self.release_conn()
278
+ else:
279
+ # If we have read everything from the response stream,
280
+ # return the connection back to the pool.
281
+ if (
282
+ isinstance(self._response.body, IOBase)
283
+ and self._response.body.closed
284
+ ):
285
+ self.release_conn()
.venv/Lib/site-packages/urllib3/contrib/socks.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains provisional support for SOCKS proxies from within
3
+ urllib3. This module supports SOCKS4, SOCKS4A (an extension of SOCKS4), and
4
+ SOCKS5. To enable its functionality, either install PySocks or install this
5
+ module with the ``socks`` extra.
6
+
7
+ The SOCKS implementation supports the full range of urllib3 features. It also
8
+ supports the following SOCKS features:
9
+
10
+ - SOCKS4A (``proxy_url='socks4a://...``)
11
+ - SOCKS4 (``proxy_url='socks4://...``)
12
+ - SOCKS5 with remote DNS (``proxy_url='socks5h://...``)
13
+ - SOCKS5 with local DNS (``proxy_url='socks5://...``)
14
+ - Usernames and passwords for the SOCKS proxy
15
+
16
+ .. note::
17
+ It is recommended to use ``socks5h://`` or ``socks4a://`` schemes in
18
+ your ``proxy_url`` to ensure that DNS resolution is done from the remote
19
+ server instead of client-side when connecting to a domain name.
20
+
21
+ SOCKS4 supports IPv4 and domain names with the SOCKS4A extension. SOCKS5
22
+ supports IPv4, IPv6, and domain names.
23
+
24
+ When connecting to a SOCKS4 proxy the ``username`` portion of the ``proxy_url``
25
+ will be sent as the ``userid`` section of the SOCKS request:
26
+
27
+ .. code-block:: python
28
+
29
+ proxy_url="socks4a://<userid>@proxy-host"
30
+
31
+ When connecting to a SOCKS5 proxy the ``username`` and ``password`` portion
32
+ of the ``proxy_url`` will be sent as the username/password to authenticate
33
+ with the proxy:
34
+
35
+ .. code-block:: python
36
+
37
+ proxy_url="socks5h://<username>:<password>@proxy-host"
38
+
39
+ """
40
+
41
+ from __future__ import annotations
42
+
43
+ try:
44
+ import socks # type: ignore[import-not-found]
45
+ except ImportError:
46
+ import warnings
47
+
48
+ from ..exceptions import DependencyWarning
49
+
50
+ warnings.warn(
51
+ (
52
+ "SOCKS support in urllib3 requires the installation of optional "
53
+ "dependencies: specifically, PySocks. For more information, see "
54
+ "https://urllib3.readthedocs.io/en/latest/advanced-usage.html#socks-proxies"
55
+ ),
56
+ DependencyWarning,
57
+ )
58
+ raise
59
+
60
+ import typing
61
+ from socket import timeout as SocketTimeout
62
+
63
+ from ..connection import HTTPConnection, HTTPSConnection
64
+ from ..connectionpool import HTTPConnectionPool, HTTPSConnectionPool
65
+ from ..exceptions import ConnectTimeoutError, NewConnectionError
66
+ from ..poolmanager import PoolManager
67
+ from ..util.url import parse_url
68
+
69
+ try:
70
+ import ssl
71
+ except ImportError:
72
+ ssl = None # type: ignore[assignment]
73
+
74
+
75
+ class _TYPE_SOCKS_OPTIONS(typing.TypedDict):
76
+ socks_version: int
77
+ proxy_host: str | None
78
+ proxy_port: str | None
79
+ username: str | None
80
+ password: str | None
81
+ rdns: bool
82
+
83
+
84
+ class SOCKSConnection(HTTPConnection):
85
+ """
86
+ A plain-text HTTP connection that connects via a SOCKS proxy.
87
+ """
88
+
89
+ def __init__(
90
+ self,
91
+ _socks_options: _TYPE_SOCKS_OPTIONS,
92
+ *args: typing.Any,
93
+ **kwargs: typing.Any,
94
+ ) -> None:
95
+ self._socks_options = _socks_options
96
+ super().__init__(*args, **kwargs)
97
+
98
+ def _new_conn(self) -> socks.socksocket:
99
+ """
100
+ Establish a new connection via the SOCKS proxy.
101
+ """
102
+ extra_kw: dict[str, typing.Any] = {}
103
+ if self.source_address:
104
+ extra_kw["source_address"] = self.source_address
105
+
106
+ if self.socket_options:
107
+ extra_kw["socket_options"] = self.socket_options
108
+
109
+ try:
110
+ conn = socks.create_connection(
111
+ (self.host, self.port),
112
+ proxy_type=self._socks_options["socks_version"],
113
+ proxy_addr=self._socks_options["proxy_host"],
114
+ proxy_port=self._socks_options["proxy_port"],
115
+ proxy_username=self._socks_options["username"],
116
+ proxy_password=self._socks_options["password"],
117
+ proxy_rdns=self._socks_options["rdns"],
118
+ timeout=self.timeout,
119
+ **extra_kw,
120
+ )
121
+
122
+ except SocketTimeout as e:
123
+ raise ConnectTimeoutError(
124
+ self,
125
+ f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
126
+ ) from e
127
+
128
+ except socks.ProxyError as e:
129
+ # This is fragile as hell, but it seems to be the only way to raise
130
+ # useful errors here.
131
+ if e.socket_err:
132
+ error = e.socket_err
133
+ if isinstance(error, SocketTimeout):
134
+ raise ConnectTimeoutError(
135
+ self,
136
+ f"Connection to {self.host} timed out. (connect timeout={self.timeout})",
137
+ ) from e
138
+ else:
139
+ # Adding `from e` messes with coverage somehow, so it's omitted.
140
+ # See #2386.
141
+ raise NewConnectionError(
142
+ self, f"Failed to establish a new connection: {error}"
143
+ )
144
+ else:
145
+ raise NewConnectionError(
146
+ self, f"Failed to establish a new connection: {e}"
147
+ ) from e
148
+
149
+ except OSError as e: # Defensive: PySocks should catch all these.
150
+ raise NewConnectionError(
151
+ self, f"Failed to establish a new connection: {e}"
152
+ ) from e
153
+
154
+ return conn
155
+
156
+
157
+ # We don't need to duplicate the Verified/Unverified distinction from
158
+ # urllib3/connection.py here because the HTTPSConnection will already have been
159
+ # correctly set to either the Verified or Unverified form by that module. This
160
+ # means the SOCKSHTTPSConnection will automatically be the correct type.
161
+ class SOCKSHTTPSConnection(SOCKSConnection, HTTPSConnection):
162
+ pass
163
+
164
+
165
+ class SOCKSHTTPConnectionPool(HTTPConnectionPool):
166
+ ConnectionCls = SOCKSConnection
167
+
168
+
169
+ class SOCKSHTTPSConnectionPool(HTTPSConnectionPool):
170
+ ConnectionCls = SOCKSHTTPSConnection
171
+
172
+
173
+ class SOCKSProxyManager(PoolManager):
174
+ """
175
+ A version of the urllib3 ProxyManager that routes connections via the
176
+ defined SOCKS proxy.
177
+ """
178
+
179
+ pool_classes_by_scheme = {
180
+ "http": SOCKSHTTPConnectionPool,
181
+ "https": SOCKSHTTPSConnectionPool,
182
+ }
183
+
184
+ def __init__(
185
+ self,
186
+ proxy_url: str,
187
+ username: str | None = None,
188
+ password: str | None = None,
189
+ num_pools: int = 10,
190
+ headers: typing.Mapping[str, str] | None = None,
191
+ **connection_pool_kw: typing.Any,
192
+ ):
193
+ parsed = parse_url(proxy_url)
194
+
195
+ if username is None and password is None and parsed.auth is not None:
196
+ split = parsed.auth.split(":")
197
+ if len(split) == 2:
198
+ username, password = split
199
+ if parsed.scheme == "socks5":
200
+ socks_version = socks.PROXY_TYPE_SOCKS5
201
+ rdns = False
202
+ elif parsed.scheme == "socks5h":
203
+ socks_version = socks.PROXY_TYPE_SOCKS5
204
+ rdns = True
205
+ elif parsed.scheme == "socks4":
206
+ socks_version = socks.PROXY_TYPE_SOCKS4
207
+ rdns = False
208
+ elif parsed.scheme == "socks4a":
209
+ socks_version = socks.PROXY_TYPE_SOCKS4
210
+ rdns = True
211
+ else:
212
+ raise ValueError(f"Unable to determine SOCKS version from {proxy_url}")
213
+
214
+ self.proxy_url = proxy_url
215
+
216
+ socks_options = {
217
+ "socks_version": socks_version,
218
+ "proxy_host": parsed.host,
219
+ "proxy_port": parsed.port,
220
+ "username": username,
221
+ "password": password,
222
+ "rdns": rdns,
223
+ }
224
+ connection_pool_kw["_socks_options"] = socks_options
225
+
226
+ super().__init__(num_pools, headers, **connection_pool_kw)
227
+
228
+ self.pool_classes_by_scheme = SOCKSProxyManager.pool_classes_by_scheme
.venv/Lib/site-packages/wasabi-0.10.1.dist-info/INSTALLER ADDED
@@ -0,0 +1 @@
 
 
1
+ uv
.venv/Lib/site-packages/wasabi-0.10.1.dist-info/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ The MIT License (MIT)
2
+
3
+ Copyright (C) 2018 Ines Montani
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in
13
+ all copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21
+ THE SOFTWARE.
.venv/Lib/site-packages/wasabi-0.10.1.dist-info/METADATA ADDED
@@ -0,0 +1,575 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: wasabi
3
+ Version: 0.10.1
4
+ Summary: A lightweight console printing and formatting toolkit
5
+ Home-page: https://ines.io
6
+ Author: Ines Montani
7
+ Author-email: [email protected]
8
+ License: MIT
9
+ Description-Content-Type: text/markdown
10
+ License-File: LICENSE
11
+
12
+ # wasabi: A lightweight console printing and formatting toolkit
13
+
14
+ Over the years, I've written countless implementations of coloring and
15
+ formatting utilities to output messages in our libraries like
16
+ [spaCy](https://spacy.io), [Thinc](https://github.com/explosion/thinc) and
17
+ [Prodigy](https://prodi.gy). While there are many other great open-source
18
+ options, I've always ended up wanting something slightly different or slightly
19
+ custom.
20
+
21
+ This package is still a work in progress and aims to bundle those utilities in a
22
+ standardised way so they can be shared across our other projects. It's super
23
+ lightweight, has zero dependencies and works across Python 2 and 3.
24
+
25
+ [![Azure Pipelines](https://img.shields.io/azure-devops/build/explosion-ai/public/1/master.svg?logo=azure-pipelines&style=flat-square)](https://dev.azure.com/explosion-ai/public/_build?definitionId=1)
26
+ [![PyPi](https://img.shields.io/pypi/v/wasabi.svg?style=flat-square&logo=pypi&logoColor=white)](https://pypi.python.org/pypi/wasabi)
27
+ [![conda](https://img.shields.io/conda/vn/conda-forge/wasabi.svg?style=flat-square&logo=conda-forge/logoColor=white)](https://anaconda.org/conda-forge/wasabi)
28
+ [![GitHub](https://img.shields.io/github/release/ines/wasabi/all.svg?style=flat-square&logo=github)](https://github.com/ines/wasabi)
29
+ [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg?style=flat-square)](https://github.com/ambv/black)
30
+
31
+ <img width="609" src="https://user-images.githubusercontent.com/13643239/48663861-8c9ea000-ea96-11e8-8b04-d120c52276a8.png">
32
+
33
+ ## 💬 FAQ
34
+
35
+ ### Are you going to add more features?
36
+
37
+ Yes, there's still a few of helpers and features to port over. However, the new
38
+ features will be heavily biased by what we (think we) need. I always appreciate
39
+ pull requests to improve the existing functionality – but I want to keep this
40
+ library as simple, lightweight and specific as possible.
41
+
42
+ ### Can I use this for my projects?
43
+
44
+ Sure, if you like it, feel free to adopt it! Just keep in mind that the package
45
+ is very specific and not intended to be a full-featured and fully customisable
46
+ formatting library. If that's what you're looking for, you might want to try
47
+ other packages – for example, [`colored`](https://pypi.org/project/colored/),
48
+ [`crayons`](https://github.com/kennethreitz/crayons),
49
+ [`colorful`](https://github.com/timofurrer/colorful),
50
+ [`tabulate`](https://bitbucket.org/astanin/python-tabulate),
51
+ [`console`](https://github.com/mixmastamyk/console) or
52
+ [`py-term`](https://github.com/gravmatt/py-term), to name a few.
53
+
54
+ ### Why `wasabi`?
55
+
56
+ I was looking for a short and descriptive name, but everything was already
57
+ taken. So I ended up naming this package after one of my rats, Wasabi. 🐀
58
+
59
+ ## ⌛️ Installation
60
+
61
+ ```bash
62
+ pip install wasabi
63
+ ```
64
+
65
+ ## 🎛 API
66
+
67
+ ### <kbd>function</kbd> `msg`
68
+
69
+ An instance of `Printer`, initialized with the default config. Useful as a quick
70
+ shortcut if you don't need to customize initialization.
71
+
72
+ ```python
73
+ from wasabi import msg
74
+
75
+ msg.good("Success!")
76
+ ```
77
+
78
+ ### <kbd>class</kbd> `Printer`
79
+
80
+ #### <kbd>method</kbd> `Printer.__init__`
81
+
82
+ ```python
83
+ from wasabi import Printer
84
+
85
+ msg = Printer()
86
+ ```
87
+
88
+ | Argument | Type | Description | Default |
89
+ | ----------------- | --------- | ------------------------------------------------------------- | ------------- |
90
+ | `pretty` | bool | Pretty-print output with colors and icons. | `True` |
91
+ | `no_print` | bool | Don't actually print, just return. | `False` |
92
+ | `colors` | dict | Add or overwrite color values, names mapped to `0`-`256`. | `None` |
93
+ | `icons` | dict | Add or overwrite icon. Name mapped to unicode. | `None` |
94
+ | `line_max` | int | Maximum line length (for divider). | `80` |
95
+ | `animation` | str | Steps of loading animation for `Printer.loading`. | `"⠙⠹⠸⠼⠴⠦⠧⠇⠏"` |
96
+ | `animation_ascii` | str | Alternative animation for ASCII terminals. | `"\|/-\\"` |
97
+ | `hide_animation` | bool | Don't display animation, e.g. for logs. | `False` |
98
+ | `ignore_warnings` | bool | Don't output messages of type `MESSAGE.WARN`. | `False` |
99
+ | `env_prefix` | str | Prefix for environment variables, e.g. `WASABI_LOG_FRIENDLY`. | `"WASABI"` |
100
+ | `timestamp` | bool | Add timestamp before output. | `False` |
101
+ | **RETURNS** | `Printer` | The initialized printer. | - |
102
+
103
+ #### <kbd>method</kbd> `Printer.text`
104
+
105
+ ```python
106
+ msg = Printer()
107
+ msg.text("Hello world!")
108
+ ```
109
+
110
+ | Argument | Type | Description | Default |
111
+ | ---------- | -------------- | ---------------------------------------------------------------------------------------------------------------------- | ------- |
112
+ | `title` | str | The main text to print. | `""` |
113
+ | `text` | str | Optional additional text to print. | `""` |
114
+ | `color` |  unicode / int | Color name or value. | `None` |
115
+ | `icon` | str | Name of icon to add. | `None` |
116
+ | `show` | bool | Whether to print or not. Can be used to only output messages under certain condition, e.g. if `--verbose` flag is set. | `True` |
117
+ | `spaced` | bool | Whether to add newlines around the output. | `False` |
118
+ | `no_print` | bool | Don't actually print, just return. Overwrites global setting. | `False` |
119
+ | `exits` | int | If set, perform a system exit with the given code after printing. | `None` |
120
+
121
+ #### <kbd>method</kbd> `Printer.good`, `Printer.fail`, `Printer.warn`, `Printer.info`
122
+
123
+ Print special formatted messages.
124
+
125
+ ```python
126
+ msg = Printer()
127
+ msg.good("Success")
128
+ msg.fail("Error")
129
+ msg.warn("Warning")
130
+ msg.info("Info")
131
+ ```
132
+
133
+ | Argument | Type | Description | Default |
134
+ | -------- | ---- | ---------------------------------------------------------------------------------------------------------------------- | ------- |
135
+ | `title` | str | The main text to print. | `""` |
136
+ | `text` | str | Optional additional text to print. | `""` |
137
+ | `show` | bool | Whether to print or not. Can be used to only output messages under certain condition, e.g. if `--verbose` flag is set. | `True` |
138
+ | `exits` | int | If set, perform a system exit with the given code after printing. | `None` |
139
+
140
+ #### <kbd>method</kbd> `Printer.divider`
141
+
142
+ Print a formatted divider.
143
+
144
+ ```python
145
+ msg = Printer()
146
+ msg.divider("Heading")
147
+ ```
148
+
149
+ | Argument | Type | Description | Default |
150
+ | -------- | ---- | ---------------------------------------------------------------------------------------------------------------------- | ------- |
151
+ | `text` | str | Headline text. If empty, only the line is printed. | `""` |
152
+ | `char` | str | Single line character to repeat. | `"="` |
153
+ | `show` | bool | Whether to print or not. Can be used to only output messages under certain condition, e.g. if `--verbose` flag is set. | `True` |
154
+ | `icon` | str | Optional icon to use with title. | `None` |
155
+
156
+ #### <kbd>contextmanager</kbd> `Printer.loading`
157
+
158
+ ```python
159
+ msg = Printer()
160
+ with msg.loading("Loading..."):
161
+ # Do something here that takes longer
162
+ time.sleep(10)
163
+ msg.good("Successfully loaded something!")
164
+ ```
165
+
166
+ | Argument | Type | Description | Default |
167
+ | -------- | ---- | ---------------------------------- | ----------------- |
168
+ | `text` | str | The text to display while loading. | `"Loading..."` |
169
+
170
+ #### <kbd>method</kbd> `Printer.table`, `Printer.row`
171
+
172
+ See [Tables](#tables).
173
+
174
+ #### <kbd>property</kbd> `Printer.counts`
175
+
176
+ Get the counts of how often the special printers were fired, e.g.
177
+ `MESSAGES.GOOD`. Can be used to print an overview like "X warnings"
178
+
179
+ ```python
180
+ msg = Printer()
181
+ msg.good("Success")
182
+ msg.fail("Error")
183
+ msg.warn("Error")
184
+
185
+ print(msg.counts)
186
+ # Counter({'good': 1, 'fail': 2, 'warn': 0, 'info': 0})
187
+ ```
188
+
189
+ | Argument | Type | Description |
190
+ | ----------- | --------- | ---------------------------------------------------- |
191
+ | **RETURNS** | `Counter` | The counts for the individual special message types. |
192
+
193
+ ### Tables
194
+
195
+ #### <kbd>function</kbd> `table`
196
+
197
+ Lightweight helper to format tabular data.
198
+
199
+ ```python
200
+ from wasabi import table
201
+
202
+ data = [("a1", "a2", "a3"), ("b1", "b2", "b3")]
203
+ header = ("Column 1", "Column 2", "Column 3")
204
+ widths = (8, 9, 10)
205
+ aligns = ("r", "c", "l")
206
+ formatted = table(data, header=header, divider=True, widths=widths, aligns=aligns)
207
+ ```
208
+
209
+ ```
210
+ Column 1 Column 2 Column 3
211
+ -------- --------- ----------
212
+ a1 a2 a3
213
+ b1 b2 b3
214
+ ```
215
+
216
+ | Argument | Type | Description | Default |
217
+ | ----------- | ------------------- | ----------------------------------------------------------------------------------------------------------------------------------- | -------- |
218
+ | `data` | iterable / dict | The data to render. Either a list of lists (one per row) or a dict for two-column tables. | |
219
+ | `header` | iterable | Optional header columns. | `None` |
220
+ | `footer` | iterable | Optional footer columns. | `None` |
221
+ | `divider` | bool | Show a divider line between header/footer and body. | `False` |
222
+ | `widths` | iterable / `"auto"` | Column widths in order. If `"auto"`, widths will be calculated automatically based on the largest value. | `"auto"` |
223
+ | `max_col` | int | Maximum column width. | `30` |
224
+ | `spacing` | int | Number of spaces between columns. | `3` |
225
+ | `aligns` | iterable / unicode | Columns alignments in order. `"l"` (left, default), `"r"` (right) or `"c"` (center). If If a string, value is used for all columns. | `None` |
226
+ | `multiline` | bool | If a cell value is a list of a tuple, render it on multiple lines, with one value per line. | `False` |
227
+ | `env_prefix` | unicode | Prefix for environment variables, e.g. WASABI_LOG_FRIENDLY. | `"WASABI"` |
228
+ | `color_values` | dict | Add or overwrite color values, name mapped to value. | `None` |
229
+ | `fg_colors` | iterable | Foreground colors, one per column. None can be specified for individual columns to retain the default background color. | `None` |
230
+ | `bg_colors` | iterable | Background colors, one per column. None can be specified for individual columns to retain the default background color. | `None` |
231
+ | **RETURNS** | str | The formatted table. | |
232
+
233
+ #### <kbd>function</kbd> `row`
234
+
235
+ ```python
236
+ from wasabi import row
237
+
238
+ data = ("a1", "a2", "a3")
239
+ formatted = row(data)
240
+ ```
241
+
242
+ ```
243
+ a1 a2 a3
244
+ ```
245
+
246
+ | Argument | Type | Description | Default |
247
+ | ----------- | ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- | -------- |
248
+ | `data` | iterable | The individual columns to format. | |
249
+ | `widths` | list / int / `"auto"` | Column widths, either one integer for all columns or an iterable of values. If "auto", widths will be calculated automatically based on the largest value. | `"auto"` |
250
+ | `spacing` | int | Number of spaces between columns. | `3` |
251
+ | `aligns` | list | Columns alignments in order. `"l"` (left), `"r"` (right) or `"c"` (center). | `None` |
252
+ | `env_prefix` | unicode | Prefix for environment variables, e.g. WASABI_LOG_FRIENDLY. | `"WASABI"` |
253
+ | `fg_colors` | list | Foreground colors for the columns, in order. None can be specified for individual columns to retain the default foreground color. | `None` |
254
+ | `bg_colors` | list | Background colors for the columns, in order. None can be specified for individual columns to retain the default background color. | `None` |
255
+ | **RETURNS** | str | The formatted row. | |
256
+
257
+ ### <kbd>class</kbd> `TracebackPrinter`
258
+
259
+ Helper to output custom formatted tracebacks and error messages. Currently used
260
+ in [Thinc](https://github.com/explosion/thinc).
261
+
262
+ #### <kbd>method</kbd> `TracebackPrinter.__init__`
263
+
264
+ Initialize a traceback printer.
265
+
266
+ ```python
267
+ from wasabi import TracebackPrinter
268
+
269
+ tb = TracebackPrinter(tb_base="thinc", tb_exclude=("check.py",))
270
+ ```
271
+
272
+ | Argument | Type | Description | Default |
273
+ | ----------------- | ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------- |
274
+ | `color_error` | str / int | Color name or code for errors (passed to `color` helper). | `"red"` |
275
+ | `color_tb` | str / int | Color name or code for traceback headline (passed to `color` helper). | `"blue"` |
276
+ | `color_highlight` | str / int | Color name or code for highlighted text (passed to `color` helper). | `"yellow"` |
277
+ | `indent` | int | Number of spaces to use for indentation. | `2` |
278
+ | `tb_base` | str | Name of directory to use to show relative paths. For example, `"thinc"` will look for the last occurence of `"/thinc/"` in a path and only show path to the right of it. | `None` |
279
+ | `tb_exclude` | tuple | List of filenames to exclude from traceback. | `tuple()` |
280
+ | **RETURNS** | `TracebackPrinter` | The traceback printer. | |
281
+
282
+ #### <kbd>method</kbd> `TracebackPrinter.__call__`
283
+
284
+ Output custom formatted tracebacks and errors.
285
+
286
+ ```python
287
+ from wasabi import TracebackPrinter
288
+ import traceback
289
+
290
+ tb = TracebackPrinter(tb_base="thinc", tb_exclude=("check.py",))
291
+
292
+ error = tb("Some error", "Error description", highlight="kwargs", tb=traceback.extract_stack())
293
+ raise ValueError(error)
294
+ ```
295
+
296
+ ```
297
+ Some error
298
+ Some error description
299
+
300
+ Traceback:
301
+ ├─ <lambda> [61] in .env/lib/python3.6/site-packages/pluggy/manager.py
302
+ ├─── _multicall [187] in .env/lib/python3.6/site-packages/pluggy/callers.py
303
+ └───── pytest_fixture_setup [969] in .env/lib/python3.6/site-packages/_pytest/fixtures.py
304
+ >>> result = call_fixture_func(fixturefunc, request, kwargs)
305
+ ```
306
+
307
+ | Argument | Type | Description | Default |
308
+ | ----------- | -------- | ------------------------------------------------------------------------------------------ | ------- |
309
+ | `title` | str | The message title. | |
310
+ | `*texts` | str | Optional texts to print (one per line). | |
311
+ | `highlight` | str | Optional sequence to highlight in the traceback, e.g. the bad value that caused the error. | `False` |
312
+ | `tb` | iterable | The traceback, e.g. generated by `traceback.extract_stack()`. | `None` |
313
+ | **RETURNS** | str | The formatted traceback. Can be printed or raised by custom exception. | |
314
+
315
+ ### <kbd>class</kbd> `MarkdownRenderer`
316
+
317
+ Helper to create Markdown-formatted content. Will store the blocks added to the Markdown document in order.
318
+
319
+ ```python
320
+ from wasabi import MarkdownRenderer
321
+
322
+ md = MarkdownRenderer()
323
+ md.add(md.title(1, "Hello world"))
324
+ md.add("This is a paragraph")
325
+ print(md.text)
326
+ ```
327
+
328
+ ### <kbd>method</kbd> `MarkdownRenderer.__init__`
329
+
330
+ Initialize a Markdown renderer.
331
+
332
+ ```python
333
+ from wasabi import MarkdownRenderer
334
+
335
+ md = MarkdownRenderer()
336
+ ```
337
+
338
+ | Argument | Type | Description | Default |
339
+ | ----------- | ------------------ | ------------------------------ | ------- |
340
+ | `no_emoji` | bool | Don't include emoji in titles. | `False` |
341
+ | **RETURNS** | `MarkdownRenderer` | The renderer. |
342
+
343
+ ### <kbd>method</kbd> `MarkdownRenderer.add`
344
+
345
+ Add a block to the Markdown document.
346
+
347
+ ```python
348
+ from wasabi import MarkdownRenderer
349
+
350
+ md = MarkdownRenderer()
351
+ md.add("This is a paragraph")
352
+ ```
353
+
354
+ | Argument | Type | Description | Default |
355
+ | -------- | ---- | ------------------- | ------- |
356
+ | `text` | str | The content to add. | |
357
+
358
+ ### <kbd>property</kbd> `MarkdownRenderer.text`
359
+
360
+ The rendered Markdown document.
361
+
362
+ ```python
363
+ md = MarkdownRenderer()
364
+ md.add("This is a paragraph")
365
+ print(md.text)
366
+ ```
367
+
368
+ | Argument | Type | Description | Default |
369
+ | ----------- | ---- | -------------------------------- | ------- |
370
+ | **RETURNS** | str | The document as a single string. | |
371
+
372
+ ### <kbd>method</kbd> `MarkdownRenderer.table`
373
+
374
+ Create a Markdown-formatted table.
375
+
376
+ ```python
377
+ md = MarkdownRenderer()
378
+ table = md.table([("a", "b"), ("c", "d")], ["Column 1", "Column 2"])
379
+ md.add(table)
380
+ ```
381
+
382
+ <!-- prettier-ignore -->
383
+ ```markdown
384
+ | Column 1 | Column 2 |
385
+ | --- | --- |
386
+ | a | b |
387
+ | c | d |
388
+ ```
389
+
390
+ | Argument | Type | Description | Default |
391
+ | ----------- | ----------------------- | ------------------------------------------------------------------------------------ | ------- |
392
+ | `data` | Iterable[Iterable[str]] | The body, one iterable per row, containig an interable of column contents. | |
393
+ | `header` | Iterable[str] | The column names. | |
394
+ | `aligns` | Iterable[str] | Columns alignments in order. `"l"` (left, default), `"r"` (right) or `"c"` (center). | `None` |
395
+ | **RETURNS** | str | The table. | |
396
+
397
+ ### <kbd>method</kbd> `MarkdownRenderer.title`
398
+
399
+ Create a Markdown-formatted heading.
400
+
401
+ ```python
402
+ md = MarkdownRenderer()
403
+ md.add(md.title(1, "Hello world"))
404
+ md.add(md.title(2, "Subheading", "💖"))
405
+ ```
406
+
407
+ ```markdown
408
+ # Hello world
409
+
410
+ ## 💖 Subheading
411
+ ```
412
+
413
+ | Argument | Type | Description | Default |
414
+ | ----------- | ---- | -------------------------------------- | ------- |
415
+ | `level` | int | The heading level, e.g. `3` for `###`. | |
416
+ | `text` | str | The heading text. | |
417
+ | `emoji` | str | Optional emoji to show before heading. | `None` |
418
+ | **RETURNS** | str | The rendered title. | |
419
+
420
+ ### <kbd>method</kbd> `MarkdownRenderer.list`
421
+
422
+ Create a Markdown-formatted non-nested list.
423
+
424
+ ```python
425
+ md = MarkdownRenderer()
426
+ md.add(md.list(["item", "other item"]))
427
+ md.add(md.list(["first item", "second item"], numbered=True))
428
+ ```
429
+
430
+ ```markdown
431
+ - item
432
+ - other item
433
+
434
+ 1. first item
435
+ 2. second item
436
+ ```
437
+
438
+ | Argument | Type | Description | Default |
439
+ | ----------- | ------------- | ------------------------------- | ------- |
440
+ | `items` | Iterable[str] | The list items. | |
441
+ | `numbered` | bool | Whether to use a numbered list. | `False` |
442
+ | **RETURNS** | str | The rendered list. | |
443
+
444
+ ### <kbd>method</kbd> `MarkdownRenderer.link`
445
+
446
+ Create a Markdown-formatted link.
447
+
448
+ ```python
449
+ md = MarkdownRenderer()
450
+ md.add(md.link("Google", "https://google.com"))
451
+ ```
452
+
453
+ ```markdown
454
+ [Google](https://google.com)
455
+ ```
456
+
457
+ | Argument | Type | Description | Default |
458
+ | ----------- | ---- | ------------------ | ------- |
459
+ | `text` | str | The link text. | |
460
+ | `url` | str | The link URL. | |
461
+ | **RETURNS** | str | The rendered link. | |
462
+
463
+ ### <kbd>method</kbd> `MarkdownRenderer.code_block`
464
+
465
+ Create a Markdown-formatted code block.
466
+
467
+ ```python
468
+ md = MarkdownRenderer()
469
+ md.add(md.code_block("import spacy", "python"))
470
+ ```
471
+
472
+ ````markdown
473
+ ```python
474
+ import spacy
475
+ ```
476
+ ````
477
+
478
+ | Argument | Type | Description | Default |
479
+ | ----------- | ---- | ------------------------ | ------- |
480
+ | `text` | str | The code text. | |
481
+ | `lang` | str | Optional code language. | `""` |
482
+ | **RETURNS** | str | The rendered code block. | |
483
+
484
+ ### <kbd>method</kbd> `MarkdownRenderer.code`, `MarkdownRenderer.bold`, `MarkdownRenderer.italic`
485
+
486
+ Create a Markdown-formatted text.
487
+
488
+ ```python
489
+ md = MarkdownRenderer()
490
+ md.add(md.code("import spacy"))
491
+ md.add(md.bold("Hello!"))
492
+ md.add(md.italic("Emphasis"))
493
+ ```
494
+
495
+ ```markdown
496
+ `import spacy`
497
+
498
+ **Hello!**
499
+
500
+ _Emphasis_
501
+ ```
502
+
503
+ ### Utilities
504
+
505
+ #### <kbd>function</kbd> `color`
506
+
507
+ ```python
508
+ from wasabi import color
509
+
510
+ formatted = color("This is a text", fg="white", bg="green", bold=True)
511
+ ```
512
+
513
+ | Argument | Type | Description | Default |
514
+ | ----------- | --------- | --------------------------------------------- | ------- |
515
+ | `text` | str | The text to be formatted. | - |
516
+ | `fg` | str / int | Foreground color. String name or `0` - `256`. | `None` |
517
+ | `bg` | str / int | Background color. String name or `0` - `256`. | `None` |
518
+ | `bold` | bool | Format the text in bold. | `False` |
519
+ | `underline` | bool | Format the text by underlining. | `False` |
520
+ | **RETURNS** | str | The formatted string. | |
521
+
522
+ #### <kbd>function</kbd> `wrap`
523
+
524
+ ```python
525
+ from wasabi import wrap
526
+
527
+ wrapped = wrap("Hello world, this is a text.", indent=2)
528
+ ```
529
+
530
+ | Argument | Type | Description | Default |
531
+ | ----------- | ---- | ------------------------------------------ | ------- |
532
+ | `text` | str | The text to wrap. | - |
533
+ | `wrap_max` | int | Maximum line width, including indentation. | `80` |
534
+ | `indent` | int | Number of spaces used for indentation. | `4` |
535
+ | **RETURNS** | str | The wrapped text with line breaks. | |
536
+
537
+ #### <kbd>function</kbd> `diff_strings`
538
+
539
+ ```python
540
+ from wasabi import diff_strings
541
+
542
+ diff = diff_strings("hello world!", "helloo world")
543
+ ```
544
+
545
+ | Argument | Type | Description | Default |
546
+ | ----------- | --------- | ---------------------------------------------------------------------------- | ------------------ |
547
+ | `a` | str | The first string to diff. |
548
+ | `b` | str | The second string to diff. |
549
+ | `fg` | str / int | Foreground color. String name or `0` - `256`. | `"black"` |
550
+ | `bg` | tuple | Background colors as `(insert, delete)` tuple of string name or `0` - `256`. | `("green", "red")` |
551
+ | **RETURNS** | str | The formatted diff. | |
552
+
553
+ ### Environment variables
554
+
555
+ Wasabi also respects the following environment variables. The prefix can be
556
+ customised on the `Printer` via the `env_prefix` argument. For example, setting
557
+ `env_prefix="SPACY"` will expect the environment variable `SPACY_LOG_FRIENDLY`.
558
+
559
+ | Name | Description |
560
+ | ---------------------- | ------------------------------------------------------ |
561
+ | `ANSI_COLORS_DISABLED` | Disable colors. |
562
+ | `WASABI_LOG_FRIENDLY` | Make output nicer for logs (no colors, no animations). |
563
+ | `WASABI_NO_PRETTY` | Disable pretty printing, e.g. colors and icons. |
564
+
565
+ ## 🔔 Run tests
566
+
567
+ Fork or clone the repo, make sure you have `pytest` installed and then run it on
568
+ the package directory. The tests are located in
569
+ [`/wasabi/tests`](/wasabi/tests).
570
+
571
+ ```bash
572
+ pip install pytest
573
+ cd wasabi
574
+ python -m pytest wasabi
575
+ ```
.venv/Lib/site-packages/wasabi-0.10.1.dist-info/RECORD ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ wasabi-0.10.1.dist-info/LICENSE,sha256=VFEGPy-fp93JVB9xDa69VMP0RZLK8W9vmGJsMhO-yT0,1079
2
+ wasabi-0.10.1.dist-info/METADATA,sha256=PqgL35FWFflAukT8zb2G9h3L0LDdZA-hfhATnYhPT4I,28325
3
+ wasabi-0.10.1.dist-info/RECORD,,
4
+ wasabi-0.10.1.dist-info/WHEEL,sha256=G16H4A3IeoQmnOrYV4ueZGKSjhipXx8zc8nu9FGlvMA,92
5
+ wasabi-0.10.1.dist-info/top_level.txt,sha256=W9sjXfJi_nAtQ-jjHMYsfeSlxTmnidUT6Eg0PqIYxkM,7
6
+ wasabi-0.10.1.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
7
+ wasabi-0.10.1.dist-info\INSTALLER,sha256=5hhM4Q4mYTT9z6QB6PGpUAW81PGNFrYrdXMj4oM_6ak,2
8
+ wasabi-0.10.1.dist-info\REQUESTED,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
+ wasabi/__init__.py,sha256=Y7ESsG4XE5A2C5ZiXh0_tnrUweiZW3yHJzt05ZPSIbo,406
10
+ wasabi/about.py,sha256=h52qkdZ4io5L_AJohzkCwF39SKEN3VZAAQBYHUj_V8E,222
11
+ wasabi/markdown.py,sha256=vjlUPQJ0DwCfJYBuuzP4UdzvLsdBKrPbEtZTtF_qRRE,3430
12
+ wasabi/printer.py,sha256=qxXdc7oxVT_CF03FYOAbBgwWso_HqniggwUUXqudcBU,9114
13
+ wasabi/tables.py,sha256=jATbMi_CUs0UGom5fRlLOh-GMbyhVudZQf20CWkMIHY,5853
14
+ wasabi/tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
15
+ wasabi/tests/test_markdown.py,sha256=RkuBjcWCzK5vZzfIE_TqsnDk23LfyIAA55sr8rlvvjY,1244
16
+ wasabi/tests/test_printer.py,sha256=GlgYIXAmrUMNe3C2zP3ClbDxpq5yzqxhzvGpgNWIuCI,7426
17
+ wasabi/tests/test_tables.py,sha256=qTgQC-FKSYhUFn4OlWeVYO2gmLBdhNV_gHb5EsGAob8,17567
18
+ wasabi/tests/test_traceback.py,sha256=neHeGSCQvuSWOvSnLfqvWuNGe4AKafJ5YNcbn-LBGhU,1692
19
+ wasabi/tests/test_util.py,sha256=YZM3ydI92f3DO-rfeiBe9Q1XlaCC7-y6M40l3iUBHZk,2524
20
+ wasabi/traceback_printer.py,sha256=1lII5g68jVhyEYPdkF_y6IIR0AhKgukKHxouGgaiAIE,4953
21
+ wasabi/util.py,sha256=zmcRGQhoKSUflnnk7aY3ugV8V3bReIunmzPm6tegQr8,7206
.venv/Lib/site-packages/wasabi-0.10.1.dist-info/REQUESTED ADDED
File without changes
.venv/Lib/site-packages/wasabi-0.10.1.dist-info/WHEEL ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ Wheel-Version: 1.0
2
+ Generator: bdist_wheel (0.37.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
.venv/Lib/site-packages/wasabi-0.10.1.dist-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ wasabi
.venv/Lib/site-packages/wasabi-0.10.1.dist-info/zip-safe ADDED
@@ -0,0 +1 @@
 
 
1
+
.venv/Lib/site-packages/wasabi/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (633 Bytes). View file
 
.venv/Lib/site-packages/wasabi/__pycache__/about.cpython-39.pyc ADDED
Binary file (420 Bytes). View file
 
.venv/Lib/site-packages/wasabi/__pycache__/markdown.cpython-39.pyc ADDED
Binary file (4.54 kB). View file
 
.venv/Lib/site-packages/wasabi/__pycache__/printer.cpython-39.pyc ADDED
Binary file (7.98 kB). View file
 
.venv/Lib/site-packages/wasabi/__pycache__/tables.cpython-39.pyc ADDED
Binary file (6.07 kB). View file
 
.venv/Lib/site-packages/wasabi/__pycache__/traceback_printer.cpython-39.pyc ADDED
Binary file (4.83 kB). View file
 
.venv/Lib/site-packages/wasabi/__pycache__/util.cpython-39.pyc ADDED
Binary file (6.91 kB). View file
 
.venv/Lib/site-packages/wasabi/tests/__init__.py ADDED
File without changes
.venv/Lib/site-packages/wasabi/tests/test_util.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf8
2
+ from __future__ import unicode_literals, print_function
3
+
4
+ import pytest
5
+ from wasabi.util import color, wrap, locale_escape, format_repr, diff_strings
6
+
7
+
8
+ def test_color():
9
+ assert color("test", fg="green") == "\x1b[38;5;2mtest\x1b[0m"
10
+ assert color("test", fg=4) == "\x1b[38;5;4mtest\x1b[0m"
11
+ assert color("test", bold=True) == "\x1b[1mtest\x1b[0m"
12
+ assert color("test", fg="red", underline=True) == "\x1b[4;38;5;1mtest\x1b[0m"
13
+ assert (
14
+ color("test", fg=7, bg="red", bold=True) == "\x1b[1;38;5;7;48;5;1mtest\x1b[0m"
15
+ )
16
+
17
+
18
+ def test_wrap():
19
+ text = "Hello world, this is a test."
20
+ assert wrap(text, indent=0) == text
21
+ assert wrap(text, indent=4) == " Hello world, this is a test."
22
+ assert wrap(text, wrap_max=10, indent=0) == "Hello\nworld,\nthis is a\ntest."
23
+ assert (
24
+ wrap(text, wrap_max=5, indent=2)
25
+ == " Hello\n world,\n this\n is\n a\n test."
26
+ )
27
+
28
+
29
+ def test_format_repr():
30
+ obj = {"hello": "world", "test": 123}
31
+ formatted = format_repr(obj)
32
+ assert formatted.replace("u'", "'") in [
33
+ "{'hello': 'world', 'test': 123}",
34
+ "{'test': 123, 'hello': 'world'}",
35
+ ]
36
+ formatted = format_repr(obj, max_len=10)
37
+ assert formatted.replace("u'", "'") in [
38
+ "{'hel ... 123}",
39
+ "{'tes ... rld'}",
40
+ "{'te ... rld'}",
41
+ ]
42
+ formatted = format_repr(obj, max_len=10, ellipsis="[...]")
43
+ assert formatted.replace("u'", "'") in [
44
+ "{'hel [...] 123}",
45
+ "{'tes [...] rld'}",
46
+ "{'te [...] rld'}",
47
+ ]
48
+
49
+
50
+ @pytest.mark.parametrize(
51
+ "text,non_ascii",
52
+ [
53
+ ("abc", ["abc"]),
54
+ ("\u2714 abc", ["? abc"]),
55
+ ("👻", ["??", "?"]), # On Python 3 windows, this becomes "?" instead of "??"
56
+ ],
57
+ )
58
+ def test_locale_escape(text, non_ascii):
59
+ result = locale_escape(text)
60
+ assert result == text or result in non_ascii
61
+ print(result)
62
+
63
+
64
+ def test_diff_strings():
65
+ a = "hello\nworld\nwide\nweb"
66
+ b = "yo\nwide\nworld\nweb"
67
+ expected = "\x1b[38;5;16;48;5;2myo\x1b[0m\n\x1b[38;5;16;48;5;2mwide\x1b[0m\n\x1b[38;5;16;48;5;1mhello\x1b[0m\nworld\n\x1b[38;5;16;48;5;1mwide\x1b[0m\nweb"
68
+ assert diff_strings(a, b) == expected
69
+
70
+
71
+ def test_diff_strings_with_symbols():
72
+ a = "hello\nworld\nwide\nweb"
73
+ b = "yo\nwide\nworld\nweb"
74
+ expected = "\x1b[38;5;16;48;5;2m+ yo\x1b[0m\n\x1b[38;5;16;48;5;2m+ wide\x1b[0m\n\x1b[38;5;16;48;5;1m- hello\x1b[0m\nworld\n\x1b[38;5;16;48;5;1m- wide\x1b[0m\nweb"
75
+ assert diff_strings(a, b, add_symbols=True) == expected
.venv/Lib/site-packages/xlstm/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "1.0.8"
2
+
3
+ from .blocks.mlstm.block import mLSTMBlock, mLSTMBlockConfig
4
+ from .blocks.mlstm.layer import mLSTMLayer, mLSTMLayerConfig
5
+ from .blocks.slstm.block import sLSTMBlock, sLSTMBlockConfig
6
+ from .blocks.slstm.layer import sLSTMLayer, sLSTMLayerConfig
7
+ from .components.feedforward import FeedForwardConfig, GatedFeedForward
8
+ from .xlstm_block_stack import xLSTMBlockStack, xLSTMBlockStackConfig
9
+ from .xlstm_lm_model import xLSTMLMModel, xLSTMLMModelConfig
.venv/Lib/site-packages/xlstm/blocks/__init__.py ADDED
File without changes
.venv/Lib/site-packages/xlstm/blocks/mlstm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
.venv/Lib/site-packages/xlstm/blocks/mlstm/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (190 Bytes). View file
 
.venv/Lib/site-packages/xlstm/blocks/mlstm/__pycache__/backends.cpython-39.pyc ADDED
Binary file (3.98 kB). View file
 
.venv/Lib/site-packages/xlstm/blocks/mlstm/__pycache__/block.cpython-39.pyc ADDED
Binary file (1.33 kB). View file
 
.venv/Lib/site-packages/xlstm/blocks/mlstm/__pycache__/cell.cpython-39.pyc ADDED
Binary file (3.95 kB). View file
 
.venv/Lib/site-packages/xlstm/blocks/mlstm/__pycache__/layer.cpython-39.pyc ADDED
Binary file (4.6 kB). View file
 
.venv/Lib/site-packages/xlstm/blocks/mlstm/backends.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH and its affiliates 2024
2
+ # Maximilian Beck
3
+ import math
4
+
5
+ import torch
6
+
7
+
8
+ def parallel_stabilized_simple(
9
+ queries: torch.Tensor,
10
+ keys: torch.Tensor,
11
+ values: torch.Tensor,
12
+ igate_preact: torch.Tensor,
13
+ fgate_preact: torch.Tensor,
14
+ lower_triangular_matrix: torch.Tensor = None,
15
+ stabilize_rowwise: bool = True,
16
+ eps: float = 1e-6,
17
+ **kwargs,
18
+ ) -> torch.Tensor:
19
+ """This is the mLSTM cell in parallel form.
20
+ This version is stabilized. We control the range of exp() arguments by
21
+ ensuring that they are always smaller than 0.0 by subtracting the maximum.
22
+
23
+ Args:
24
+ queries (torch.Tensor): (B, NH, S, DH)
25
+ keys (torch.Tensor): (B, NH, S, DH)
26
+ values (torch.Tensor): (B, NH, S, DH)
27
+ igate_preact (torch.Tensor): (B, NH, S, 1)
28
+ fgate_preact (torch.Tensor): (B, NH, S, 1)
29
+ lower_triangular_matrix (torch.Tensor, optional): (S,S). Defaults to None.
30
+ stabilize_rowwise (bool, optional): Wether to stabilize the combination matrix C rowwise (take maximum per row).
31
+ Alternative: Subtract the maximum over all rows. Defaults to True.
32
+
33
+ Returns:
34
+ torch.Tensor: (B, NH, S, DH), h_tilde_state
35
+ """
36
+
37
+ B, NH, S, DH = queries.shape
38
+ _dtype, _device = queries.dtype, queries.device
39
+
40
+ # forget gate matrix
41
+ log_fgates = torch.nn.functional.logsigmoid(fgate_preact) # (B, NH, S, 1)
42
+ if lower_triangular_matrix is None or S < lower_triangular_matrix.size(-1):
43
+ ltr = torch.tril(torch.ones((S, S), dtype=torch.bool, device=_device))
44
+ else:
45
+ ltr = lower_triangular_matrix
46
+ assert ltr.dtype == torch.bool, f"lower_triangular_matrix must be of dtype bool, got {ltr.dtype}"
47
+
48
+ log_fgates_cumsum = torch.cat(
49
+ [
50
+ torch.zeros((B, NH, 1, 1), dtype=_dtype, device=_device),
51
+ torch.cumsum(log_fgates, dim=-2),
52
+ ],
53
+ dim=-2,
54
+ ) # (B, NH, S+1, 1)
55
+ # for each batch/head this is a matrix of shape (S+1, S+1) containing the cumsum of the log forget gate values
56
+ # in the second dimension (colum dimension). Each row has the same is a copy of the first row.
57
+ # First entry of each row is zero.
58
+ rep_log_fgates_cumsum = log_fgates_cumsum.repeat(1, 1, 1, S + 1) # (B, NH, S+1, S+1)
59
+ # Now in each row cut off / subtract the forgetgate values of the later timesteps
60
+ # where col j > row i
61
+ _log_fg_matrix = rep_log_fgates_cumsum - rep_log_fgates_cumsum.transpose(-2, -1) # (B, NH, S+1, S+1)
62
+ # Causal masking & selection of the correct submatrix, such that forgetgate at timestep t is not applied
63
+ # to the input at timestep t
64
+ log_fg_matrix = torch.where(ltr, _log_fg_matrix[:, :, 1:, 1:], -float("inf")) # (B, NH, S, S)
65
+
66
+ # gate decay matrix D (combination of forget gate and input gate)
67
+ log_D_matrix = log_fg_matrix + igate_preact.transpose(-2, -1) # (B, NH, S, S)
68
+ # D matrix stabilization
69
+ if stabilize_rowwise:
70
+ max_log_D, _ = torch.max(log_D_matrix, dim=-1, keepdim=True) # (B, NH, S, 1)
71
+ else:
72
+ max_log_D = torch.max(log_D_matrix.view(B, NH, -1), dim=-1, keepdim=True)[0].unsqueeze(-1)
73
+ # (B, NH, 1, 1)
74
+ log_D_matrix_stabilized = log_D_matrix - max_log_D # (B, NH, S, S)
75
+ D_matrix = torch.exp(log_D_matrix_stabilized) # (B, NH, S, S)
76
+
77
+ keys_scaled = keys / math.sqrt(DH)
78
+
79
+ # combination matrix C
80
+ qk_matrix = queries @ keys_scaled.transpose(-2, -1) # (B, NH, S, S)
81
+ C_matrix = qk_matrix * D_matrix # (B, NH, S, S)
82
+ normalizer = torch.maximum(C_matrix.sum(dim=-1, keepdim=True).abs(), torch.exp(-max_log_D)) # (B, NH, S, 1)
83
+ # (B, NH, S, S)
84
+ C_matrix_normalized = C_matrix / (normalizer + eps)
85
+
86
+ # retrieved values
87
+ h_tilde_state = C_matrix_normalized @ values # (B, NH, S, DH)
88
+
89
+ return h_tilde_state
90
+
91
+
92
+ def recurrent_step_stabilized_simple(
93
+ c_state: torch.Tensor,
94
+ n_state: torch.Tensor,
95
+ m_state: torch.Tensor,
96
+ q: torch.Tensor,
97
+ k: torch.Tensor,
98
+ v: torch.Tensor,
99
+ igate_preact: torch.Tensor,
100
+ fgate_preact: torch.Tensor,
101
+ eps: float = 1e-6,
102
+ **kwargs,
103
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
104
+ """This is a single step of the mLSTM operation in recurrent form.
105
+
106
+ Args:
107
+ c_state (torch.Tensor): (B, NH, DH, DH)
108
+ n_state (torch.Tensor): (B, NH, DH, 1)
109
+ m_state (torch.Tensor): (B, NH, 1, 1)
110
+ q (torch.Tensor): (B, NH, 1, DH)
111
+ k (torch.Tensor): (B, NH, 1, DH)
112
+ v (torch.Tensor): (B, NH, 1, DH)
113
+ igate_preact (torch.Tensor): (B, NH, 1, 1)
114
+ fgate_preact (torch.Tensor): (B, NH, 1, 1)
115
+
116
+ Returns:
117
+ tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
118
+ (hidden_state [B, NH, DH], (c_state_new [B, NH, DH, DH], n_state_new [B, NH, DH, 1]], m_state_new [B, NH, 1, 1]))
119
+ """
120
+ B, NH, S, DH = q.shape
121
+ # projections
122
+ q, k, v = q.squeeze_(2).unsqueeze(-1), k.squeeze_(2).unsqueeze(-1), v.squeeze_(2).unsqueeze(-1) # (B, NH, DH, 1)
123
+
124
+ # gates
125
+ log_fg_act = torch.nn.functional.logsigmoid(fgate_preact) # (B, NH, 1, 1)
126
+
127
+ # update rule
128
+ m_state_new = torch.max(log_fg_act + m_state, igate_preact) # (B, NH, 1, 1)
129
+
130
+ fg_act = torch.exp(log_fg_act + m_state - m_state_new) # (B, NH, 1, 1)
131
+ ig_act = torch.exp(igate_preact - m_state_new) # (B, NH, 1, 1)
132
+
133
+ k_scaled = k / math.sqrt(DH)
134
+
135
+ c_state_new = fg_act * c_state + ig_act * (k_scaled @ v.transpose(-1, -2)) # (B, NH, DH, DH)
136
+ n_state_new = fg_act * n_state + ig_act * k_scaled # (B, NH, DH, 1)
137
+
138
+ h_num = q.transpose(-1, -2) @ c_state_new # (B, NH, 1, DH)
139
+
140
+ qn_dotproduct = q.transpose(-1, -2) @ n_state_new # (B, NH, 1, 1)
141
+ max_val = torch.exp(-m_state_new) # (B, NH, 1, 1)
142
+ h_denom = torch.maximum(qn_dotproduct.abs(), max_val) + eps
143
+ h = h_num / h_denom # (B, NH, 1, DH) / (B, NH, 1, 1) = (B, NH, 1, DH)
144
+
145
+ return h, (c_state_new, n_state_new, m_state_new)
.venv/Lib/site-packages/xlstm/blocks/mlstm/block.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH and its affiliates 2024
2
+ # Maximilian Beck
3
+ from dataclasses import dataclass, field
4
+
5
+ from ..xlstm_block import xLSTMBlock, xLSTMBlockConfig
6
+ from .layer import mLSTMLayerConfig
7
+
8
+
9
+ @dataclass
10
+ class mLSTMBlockConfig:
11
+ mlstm: mLSTMLayerConfig = field(default_factory=mLSTMLayerConfig)
12
+
13
+ # we initialize these with None to catch the case where they are not set
14
+ _num_blocks: int = None
15
+ _block_idx: int = None
16
+
17
+ def __post_init__(self):
18
+ self.mlstm._num_blocks = self._num_blocks
19
+ self.mlstm.__post_init__()
20
+
21
+
22
+ class mLSTMBlock(xLSTMBlock):
23
+ config_class = mLSTMBlockConfig
24
+
25
+ def __init__(self, config: mLSTMBlockConfig) -> None:
26
+ super().__init__(
27
+ config=xLSTMBlockConfig(
28
+ mlstm=config.mlstm,
29
+ slstm=None,
30
+ feedforward=None,
31
+ _num_blocks=config._num_blocks,
32
+ _block_idx=config._block_idx,
33
+ )
34
+ )
.venv/Lib/site-packages/xlstm/blocks/mlstm/cell.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH and its affiliates 2024
2
+ # Maximilian Beck
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from ...components.init import bias_linspace_init_
9
+ from ...components.ln import MultiHeadLayerNorm
10
+ from .backends import parallel_stabilized_simple, recurrent_step_stabilized_simple
11
+
12
+
13
+ @dataclass
14
+ class mLSTMCellConfig:
15
+ context_length: int = -1
16
+ embedding_dim: int = -1
17
+ num_heads: int = -1
18
+
19
+
20
+ class mLSTMCell(nn.Module):
21
+ config_class = mLSTMCellConfig
22
+
23
+ def __init__(self, config: mLSTMCellConfig):
24
+ super().__init__()
25
+ self.config = config
26
+
27
+ self.backend_fn = parallel_stabilized_simple
28
+ self.backend_fn_step = recurrent_step_stabilized_simple
29
+
30
+ self.igate = nn.Linear(3 * config.embedding_dim, config.num_heads)
31
+ self.fgate = nn.Linear(3 * config.embedding_dim, config.num_heads)
32
+
33
+ self.outnorm = MultiHeadLayerNorm(ndim=config.embedding_dim, weight=True, bias=False)
34
+
35
+ self.register_buffer(
36
+ "causal_mask",
37
+ torch.tril(torch.ones(config.context_length, config.context_length, dtype=torch.bool)),
38
+ persistent=False,
39
+ )
40
+
41
+ self.reset_parameters()
42
+
43
+ def forward(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, **kwargs) -> torch.Tensor:
44
+ B, S, _ = q.shape # (B, S, H)
45
+
46
+ if_gate_input = torch.cat([q, k, v], dim=-1)
47
+ q = q.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
48
+ k = k.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
49
+ v = v.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
50
+
51
+ q = q.transpose(1, 2) # (B, NH, S, DH)
52
+ k = k.transpose(1, 2) # (B, NH, S, DH)
53
+ v = v.transpose(1, 2) # (B, NH, S, DH)
54
+
55
+ # compute input and forget gate pre-activations
56
+ igate_preact = self.igate(if_gate_input) # (B, S, NH)
57
+ igate_preact = igate_preact.transpose(-1, -2).unsqueeze(-1) # (B, NH, S, 1)
58
+ fgate_preact = self.fgate(if_gate_input) # (B, S, NH)
59
+ fgate_preact = fgate_preact.transpose(-1, -2).unsqueeze(-1) # (B, NH, S, 1)#
60
+
61
+ h_state = self.backend_fn(
62
+ queries=q,
63
+ keys=k,
64
+ values=v,
65
+ igate_preact=igate_preact,
66
+ fgate_preact=fgate_preact,
67
+ lower_triangular_matrix=self.causal_mask,
68
+ ) # (B, NH, S, DH)
69
+
70
+ h_state_norm = self.outnorm(h_state) # (B, NH, S, DH)
71
+ h_state_norm = h_state_norm.transpose(1, 2).reshape(B, S, -1) # (B, NH, S, DH) -> (B, S, NH, DH) -> (B, S, H)
72
+
73
+ return h_state_norm
74
+
75
+ def step(
76
+ self,
77
+ q: torch.Tensor,
78
+ k: torch.Tensor,
79
+ v: torch.Tensor,
80
+ mlstm_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None,
81
+ **kwargs,
82
+ ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
83
+ B, S, _ = q.shape # (B, S, H)
84
+ assert S == 1, f"mLSTMCell.step only supports sequence length S=1, but got S={S}."
85
+
86
+ if_gate_input = torch.cat([q, k, v], dim=-1)
87
+ q = q.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
88
+ k = k.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
89
+ v = v.view(B, S, self.config.num_heads, -1) # (B, S, NH, DH)
90
+
91
+ _, _, NH, DH = q.shape
92
+
93
+ q = q.transpose(1, 2) # (B, NH, S, DH)
94
+ k = k.transpose(1, 2) # (B, NH, S, DH)
95
+ v = v.transpose(1, 2) # (B, NH, S, DH)
96
+
97
+ # compute input and forget gate pre-activations
98
+ igate_preact = self.igate(if_gate_input) # (B, S, NH)
99
+ igate_preact = igate_preact.transpose(-1, -2).unsqueeze(-1) # (B, NH, S, 1)
100
+ fgate_preact = self.fgate(if_gate_input) # (B, S, NH)
101
+ fgate_preact = fgate_preact.transpose(-1, -2).unsqueeze(-1) # (B, NH, S, 1)
102
+
103
+ if mlstm_state is None:
104
+ c_state = torch.zeros(size=(B, NH, DH, DH), device=q.device, dtype=q.dtype)
105
+ n_state = torch.zeros(size=(B, NH, DH, 1), device=q.device, dtype=q.dtype)
106
+ m_state = torch.zeros(size=(B, NH, 1, 1), device=q.device, dtype=q.dtype)
107
+ else:
108
+ c_state, n_state, m_state = mlstm_state
109
+ c_state = c_state.to(device=q.device, dtype=q.dtype)
110
+ n_state = n_state.to(device=q.device, dtype=q.dtype)
111
+ m_state = m_state.to(device=q.device, dtype=q.dtype)
112
+
113
+ assert c_state.shape == (B, NH, DH, DH), f"Expected c_state shape {(B, NH, DH, DH)}, but got {c_state.shape}."
114
+ assert n_state.shape == (B, NH, DH, 1), f"Expected n_state shape {(B, NH, DH, 1)}, but got {n_state.shape}."
115
+ assert m_state.shape == (B, NH, 1, 1), f"Expected m_state shape {(B, NH, 1, 1)}, but got {m_state.shape}."
116
+
117
+ h_state, mlstm_state = self.backend_fn_step(
118
+ c_state=c_state,
119
+ n_state=n_state,
120
+ m_state=m_state,
121
+ q=q,
122
+ k=k,
123
+ v=v,
124
+ igate_preact=igate_preact,
125
+ fgate_preact=fgate_preact,
126
+ ) # (B, NH, 1 DH), ((B, NH, DH, DH), (B, NH, DH, 1), (B, NH, 1, 1))
127
+
128
+ h_state_norm = self.outnorm(h_state) # (B, NH, S, DH)
129
+ h_state_norm = h_state_norm.transpose(1, 2).reshape(B, S, -1) # (B, NH, S, DH) -> (B, S, NH, DH) -> (B, S, H)
130
+
131
+ return h_state_norm, mlstm_state
132
+
133
+ def reset_parameters(self):
134
+ self.outnorm.reset_parameters()
135
+ # forget gate initialization
136
+ torch.nn.init.zeros_(self.fgate.weight)
137
+ bias_linspace_init_(self.fgate.bias, start=3.0, end=6.0)
138
+ # input gate initialization
139
+ torch.nn.init.zeros_(self.igate.weight)
140
+ torch.nn.init.normal_(self.igate.bias, mean=0.0, std=0.1)
.venv/Lib/site-packages/xlstm/blocks/mlstm/layer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) NXAI GmbH and its affiliates 2024
2
+ # Maximilian Beck
3
+ from dataclasses import dataclass
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from ...components.conv import CausalConv1d, CausalConv1dConfig
9
+ from ...components.init import small_init_init_, wang_init_
10
+ from ...components.linear_headwise import (
11
+ LinearHeadwiseExpand,
12
+ LinearHeadwiseExpandConfig,
13
+ )
14
+ from ...utils import UpProjConfigMixin
15
+ from .cell import mLSTMCell, mLSTMCellConfig
16
+
17
+
18
+ @dataclass
19
+ class mLSTMLayerConfig(UpProjConfigMixin):
20
+ conv1d_kernel_size: int = 4
21
+ qkv_proj_blocksize: int = 4
22
+ num_heads: int = 4
23
+ proj_factor: float = 2.0
24
+
25
+ # will be set toplevel config
26
+ embedding_dim: int = -1
27
+ bias: bool = False
28
+ dropout: float = 0.0
29
+ context_length: int = -1
30
+
31
+ _num_blocks: int = 1
32
+ _inner_embedding_dim: int = None
33
+
34
+ def __post_init__(self):
35
+ self._set_proj_up_dim(embedding_dim=self.embedding_dim)
36
+ self._inner_embedding_dim = self._proj_up_dim
37
+
38
+
39
+ class mLSTMLayer(nn.Module):
40
+ config_class = mLSTMLayerConfig
41
+
42
+ def __init__(self, config: mLSTMLayerConfig):
43
+ super().__init__()
44
+ self.config = config
45
+
46
+ self.proj_up = nn.Linear(
47
+ in_features=self.config.embedding_dim,
48
+ out_features=2 * self.config._inner_embedding_dim,
49
+ bias=self.config.bias,
50
+ )
51
+
52
+ num_proj_heads = round(self.config._inner_embedding_dim // self.config.qkv_proj_blocksize)
53
+ self.q_proj = LinearHeadwiseExpand(
54
+ config=LinearHeadwiseExpandConfig(
55
+ in_features=self.config._inner_embedding_dim,
56
+ num_heads=num_proj_heads,
57
+ bias=self.config.bias,
58
+ )
59
+ )
60
+ self.k_proj = LinearHeadwiseExpand(
61
+ config=LinearHeadwiseExpandConfig(
62
+ in_features=self.config._inner_embedding_dim,
63
+ num_heads=num_proj_heads,
64
+ bias=self.config.bias,
65
+ )
66
+ )
67
+ self.v_proj = LinearHeadwiseExpand(
68
+ config=LinearHeadwiseExpandConfig(
69
+ in_features=self.config._inner_embedding_dim,
70
+ num_heads=num_proj_heads,
71
+ bias=self.config.bias,
72
+ )
73
+ )
74
+
75
+ self.conv1d = CausalConv1d(
76
+ config=CausalConv1dConfig(
77
+ feature_dim=self.config._inner_embedding_dim,
78
+ kernel_size=self.config.conv1d_kernel_size,
79
+ )
80
+ )
81
+ self.conv_act_fn = nn.SiLU()
82
+ self.mlstm_cell = mLSTMCell(
83
+ config=mLSTMCellConfig(
84
+ context_length=self.config.context_length,
85
+ embedding_dim=self.config._inner_embedding_dim,
86
+ num_heads=self.config.num_heads,
87
+ )
88
+ )
89
+ self.ogate_act_fn = nn.SiLU()
90
+
91
+ self.learnable_skip = nn.Parameter(torch.ones(self.config._inner_embedding_dim, requires_grad=True))
92
+
93
+ self.proj_down = nn.Linear(
94
+ in_features=self.config._inner_embedding_dim,
95
+ out_features=self.config.embedding_dim,
96
+ bias=self.config.bias,
97
+ )
98
+ self.dropout = nn.Dropout(self.config.dropout)
99
+ self.reset_parameters()
100
+
101
+ def forward(self, x: torch.Tensor, **kwargs) -> torch.Tensor:
102
+ B, S, _ = x.shape
103
+
104
+ # up-projection
105
+ x_inner = self.proj_up(x)
106
+ x_mlstm, z = torch.split(x_inner, split_size_or_sections=self.config._inner_embedding_dim, dim=-1)
107
+
108
+ # mlstm branch
109
+ x_mlstm_conv = self.conv1d(x_mlstm)
110
+ x_mlstm_conv_act = self.conv_act_fn(x_mlstm_conv)
111
+
112
+ q = self.q_proj(x_mlstm_conv_act)
113
+ k = self.k_proj(x_mlstm_conv_act)
114
+ v = self.v_proj(x_mlstm)
115
+
116
+ h_tilde_state = self.mlstm_cell(q=q, k=k, v=v)
117
+
118
+ h_tilde_state_skip = h_tilde_state + (self.learnable_skip * x_mlstm_conv_act)
119
+
120
+ # output / z branch
121
+ h_state = h_tilde_state_skip * self.ogate_act_fn(z)
122
+
123
+ # down-projection
124
+ y = self.dropout(self.proj_down(h_state))
125
+ return y
126
+
127
+ def step(
128
+ self,
129
+ x: torch.Tensor,
130
+ mlstm_state: tuple[torch.Tensor, torch.Tensor, torch.Tensor] = None,
131
+ conv_state: tuple[torch.Tensor] = None,
132
+ ) -> tuple[torch.Tensor, dict[str, tuple[torch.Tensor, ...]]]:
133
+ B, S, _ = x.shape
134
+
135
+ # up-projection
136
+ x_inner = self.proj_up(x)
137
+ x_mlstm, z = torch.split(x_inner, split_size_or_sections=self.config._inner_embedding_dim, dim=-1)
138
+
139
+ # mlstm branch
140
+ x_mlstm_conv, conv_state = self.conv1d.step(x_mlstm, conv_state=conv_state)
141
+ x_mlstm_conv_act = self.conv_act_fn(x_mlstm_conv)
142
+
143
+ q = self.q_proj(x_mlstm_conv_act)
144
+ k = self.k_proj(x_mlstm_conv_act)
145
+ v = self.v_proj(x_mlstm)
146
+
147
+ h_tilde_state, mlstm_state = self.mlstm_cell.step(q=q, k=k, v=v, mlstm_state=mlstm_state)
148
+
149
+ h_tilde_state_skip = h_tilde_state + (self.learnable_skip * x_mlstm_conv_act)
150
+
151
+ # output / z branch
152
+ h_state = h_tilde_state_skip * self.ogate_act_fn(z)
153
+
154
+ # down-projection
155
+ y = self.dropout(self.proj_down(h_state))
156
+ return y, {"mlstm_state": mlstm_state, "conv_state": conv_state}
157
+
158
+ def reset_parameters(self):
159
+ # init inproj
160
+ small_init_init_(self.proj_up.weight, dim=self.config.embedding_dim)
161
+ if self.proj_up.bias is not None:
162
+ nn.init.zeros_(self.proj_up.bias)
163
+ # init outproj
164
+ wang_init_(self.proj_down.weight, dim=self.config.embedding_dim, num_blocks=self.config._num_blocks)
165
+ if self.proj_down.bias is not None:
166
+ nn.init.zeros_(self.proj_down.bias)
167
+
168
+ nn.init.ones_(self.learnable_skip)
169
+
170
+ def _init_qkv_proj(qkv_proj: LinearHeadwiseExpand):
171
+ # use the embedding dim instead of the inner embedding dim
172
+ small_init_init_(qkv_proj.weight, dim=self.config.embedding_dim)
173
+ if qkv_proj.bias is not None:
174
+ nn.init.zeros_(qkv_proj.bias)
175
+
176
+ _init_qkv_proj(self.q_proj)
177
+ _init_qkv_proj(self.k_proj)
178
+ _init_qkv_proj(self.v_proj)
179
+
180
+ self.mlstm_cell.reset_parameters()
.venv/Lib/site-packages/xlstm/blocks/slstm/__init__.py ADDED
File without changes