Spaces:
Sleeping
Sleeping
| import base64 | |
| import ctypes | |
| import gc | |
| import inspect | |
| import json | |
| import mmap | |
| import os | |
| import shutil | |
| import signal | |
| import sys | |
| import time | |
| import warnings | |
| from collections import defaultdict | |
| from concurrent.futures import as_completed, ThreadPoolExecutor | |
| from contextlib import contextmanager, nullcontext | |
| from contextvars import copy_context | |
| from dataclasses import dataclass | |
| from datetime import timedelta | |
| from functools import lru_cache as cache, partial, wraps | |
| from importlib import metadata | |
| import importlib | |
| from queue import Empty, Queue as ThreadQueue | |
| from threading import Thread | |
| from types import ModuleType, SimpleNamespace | |
| from typing import ( | |
| Any, Callable, Dict, Generator, Generic, List, Literal, NamedTuple, | |
| Optional, Set, Tuple, Type, TypedDict, TypeVar, Union, overload | |
| ) | |
| from typing_extensions import ( | |
| assert_never, ParamSpec, TypeAlias, Unpack, get_args | |
| ) | |
| from pathlib import Path | |
| from packaging import version | |
| import gradio as gr | |
| import httpx | |
| from gradio.context import Context, LocalContext | |
| from gradio.helpers import Progress, TrackedIterable | |
| from gradio.queueing import Queue | |
| from pydantic import BaseModel | |
| warnings.filterwarnings("ignore", category=UserWarning, message="Can't initialize NVML") | |
| try: | |
| import torch | |
| from torch.utils.weak import WeakTensorKeyDictionary | |
| except ImportError: | |
| torch = None | |
| WeakTensorKeyDictionary = dict | |
| if torch and "weights_only" in inspect.signature(torch.load).parameters: | |
| _original_torch_load = torch.load | |
| def patched_torch_load(*args, **kwargs): | |
| kwargs.setdefault("weights_only", False) | |
| return _original_torch_load(*args, **kwargs) | |
| torch.load = patched_torch_load | |
| try: | |
| from tqdm import tqdm as _tqdm | |
| except ImportError: | |
| _tqdm = None | |
| def boolean(value: str | None) -> bool: | |
| return value is not None and value.lower() in ("1", "t", "true") | |
| class Settings: | |
| def __init__(self): | |
| self.zero_gpu = boolean(os.getenv('SPACES_ZERO_GPU')) | |
| self.zero_device_api_url = os.getenv('SPACES_ZERO_DEVICE_API_URL') | |
| self.gradio_auto_wrap = boolean(os.getenv('SPACES_GRADIO_AUTO_WRAP')) | |
| self.zero_patch_torch_device = boolean(os.getenv('ZERO_GPU_PATCH_TORCH_DEVICE')) | |
| self.zero_gpu_v2 = boolean(os.getenv('ZEROGPU_V2')) | |
| GPUSizeConfig = Literal['auto', 'medium', 'large'] | |
| self.zerogpu_size: Union[Literal['medium', 'large'], Literal['auto']] = os.getenv('ZEROGPU_SIZE', 'large') | |
| self.zerogpu_medium_size_threshold = int(os.getenv('ZEROGPU_MEDIUM_SIZE_THRESHOLD', 30 * 2**30)) | |
| ZEROGPU_OFFLOAD_DIR_DEFAULT = str(Path.home() / '.zerogpu' / 'tensors') | |
| self.zerogpu_offload_dir = os.getenv('ZEROGPU_OFFLOAD_DIR', ZEROGPU_OFFLOAD_DIR_DEFAULT) | |
| self.zerogpu_proc_self_cgroup_path = os.getenv('ZEROGPU_PROC_SELF_CGROUP_PATH', '/proc/self/cgroup') | |
| self.zerogpu_cuda_device_name = os.getenv('ZEROGPU_CUDA_DEVICE_NAME', "NVIDIA H200 MIG 3g.71gb") | |
| self.zerogpu_cuda_total_memory = int(os.getenv('ZEROGPU_CUDA_TOTAL_MEMORY', 74625056768)) | |
| self.zerogpu_cuda_reserved_memory = int(os.getenv('ZEROGPU_CUDA_RESERVED_MEMORY', 0)) | |
| self.zerogpu_cuda_capability_major = int(os.getenv('ZEROGPU_CUDA_CAPABILITY_MAJOR', 9)) | |
| self.zerogpu_cuda_capability_minor = int(os.getenv('ZEROGPU_CUDA_CAPABILITY_MINOR', 0)) | |
| self.zerogpu_cuda_multi_processor_count = int(os.getenv('ZEROGPU_CUDA_MULTI_PROCESSOR_COUNT', 60)) | |
| Config = Settings() | |
| if Config.zero_gpu: | |
| if Config.zero_device_api_url is None: | |
| print("Error: SPACES_ZERO_DEVICE_API_URL environment variable must be set on ZeroGPU Spaces.", file=sys.stderr) | |
| GPUSizeConfig = Literal['auto', 'medium', 'large'] | |
| if Config.zerogpu_size not in get_args(GPUSizeConfig): | |
| print(f"Error: ZEROGPU_SIZE should be one of {', '.join(get_args(GPUSizeConfig))}", file=sys.stderr) | |
| T = TypeVar('T') | |
| def self_cgroup_device_path() -> str: | |
| try: | |
| cgroup_content = Path(Config.zerogpu_proc_self_cgroup_path).read_text() | |
| for line in cgroup_content.strip().split('\n'): | |
| contents = line.split(':devices:') | |
| if len(contents) == 2: | |
| return contents[1] | |
| except Exception as e: | |
| print(f"Could not determine cgroup device path: {e}", file=sys.stderr) | |
| return "" | |
| class SimpleQueue(ThreadQueue[T]): | |
| def put(self, obj: T): | |
| try: | |
| super().put(obj) | |
| except Exception as e: | |
| print(f"Error in SimpleQueue.put: {e}", file=sys.stderr) | |
| def close(self): | |
| try: | |
| pass | |
| except Exception as e: | |
| print(f"Error closing SimpleQueue: {e}", file=sys.stderr) | |
| def wlock_release(self): | |
| try: | |
| pass | |
| except (ValueError, Exception): | |
| pass | |
| def drop_params(fn: Callable[[], T]) -> Callable[..., T]: | |
| def drop(*args, **kwargs): | |
| return fn() | |
| return drop | |
| def gradio_request_var(): | |
| try: | |
| from gradio.context import LocalContext | |
| return LocalContext.request | |
| except ImportError: | |
| print("Could not import Gradio LocalContext. Ensure Gradio version is at least 3.46.", file=sys.stderr) | |
| return None | |
| def malloc_trim(): | |
| try: | |
| ctypes.CDLL("libc.so.6").malloc_trim(0) | |
| except (OSError, AttributeError) as e: | |
| print(f"malloc_trim not available on this system: {e}", file=sys.stderr) | |
| debug = partial(print, 'SPACES_ZERO_GPU_DEBUG') | |
| def jwt_payload(token: str) -> dict[str, Any]: | |
| try: | |
| _, payload, _ = token.split('.') | |
| return json.loads(base64.urlsafe_b64decode(f'{payload}==')) | |
| except Exception as e: | |
| print(f"Error decoding JWT payload: {e}", file=sys.stderr) | |
| return {} | |
| if torch: | |
| def empty_like_raw_alloc(tensor: torch.Tensor, **kwargs) -> torch.Tensor: | |
| empty = torch.empty_like(tensor, **{**kwargs, 'requires_grad': False}) | |
| if (nbytes := empty.untyped_storage().nbytes()) > 0: | |
| try: | |
| buffer = mmap.mmap(-1, nbytes, prot=mmap.PROT_READ | mmap.PROT_WRITE) | |
| buffer_torch = torch.frombuffer(buffer, dtype=torch.uint8) | |
| empty.set_(buffer_torch.untyped_storage(), 0, empty.shape, empty.stride()) | |
| except Exception as e: | |
| print(f"Failed to create mmap buffer for tensor: {e}", file=sys.stderr) | |
| empty.requires_grad_(kwargs.get('requires_grad', False)) | |
| return empty | |
| Params = Tuple[Tuple[object, ...], Dict[str, Any]] | |
| Res = TypeVar('Res') | |
| Param = ParamSpec('Param') | |
| class EmptyKwargs(TypedDict): | |
| pass | |
| class OkResult(Generic[Res]): | |
| value: Res | |
| class ExceptionResult: | |
| traceback: str | |
| error_cls: str | |
| class AbortedResult: | |
| pass | |
| class EndResult: | |
| pass | |
| class GradioQueueEvent: | |
| method_name: str | |
| args: tuple[Any, ...] | |
| kwargs: dict[str, Any] | |
| RegularResQueueResult: TypeAlias = Union["OkResult[Res]", "ExceptionResult", "GradioQueueEvent"] | |
| GeneratorResQueueResult: TypeAlias = Union["OkResult[Res]", "ExceptionResult", "EndResult", "GradioQueueEvent"] | |
| YieldQueueResult: TypeAlias = Union["OkResult[Res]", "ExceptionResult", "EndResult", "AbortedResult"] | |
| Duration: TypeAlias = Union[int, timedelta] | |
| DynamicDuration: TypeAlias = Union[Duration, Callable[Param, Duration], None] | |
| if torch: | |
| class AliasId(NamedTuple): | |
| data_ptr: int | |
| dtype: torch.dtype | |
| shape: tuple[int, ...] | |
| stride: tuple[int, ...] | |
| def from_tensor(cls, tensor: torch.Tensor): | |
| return cls( | |
| tensor.data_ptr(), | |
| tensor.dtype, | |
| tensor.shape, | |
| tensor.stride(), | |
| ) | |
| AllowToken = str | |
| NvidiaIndex = int | |
| NvidiaUUID = str | |
| CGroupPath = str | |
| TaskId = int | |
| GPUSize = Literal['medium', 'large'] | |
| AuthLevel = Literal['regular', 'pro'] | |
| QueuingReason = Literal['node', 'concurrency'] | |
| AUTHENTICATED_HEADER = 'X-Authenticated' | |
| QUEUING_REASON_HEADER = 'X-Queuing-Reason' | |
| class ScheduleResponse(BaseModel): | |
| idle: bool | |
| nvidiaIndex: int | |
| nvidiaUUID: str | |
| allowToken: str | |
| class ScheduleMetadata(BaseModel): | |
| auth: Optional[AuthLevel] = None | |
| queuing_reason: Optional[QueuingReason] = None | |
| class QuotaInfos(BaseModel): | |
| left: int | |
| wait: timedelta | |
| class QueueEvent(BaseModel): | |
| event: Literal['ping', 'failed', 'succeeded'] | |
| data: Optional[ScheduleResponse] = None | |
| def sse_parse(text: str): | |
| event, *data = text.strip().splitlines() | |
| assert event.startswith('event:') | |
| event = event[6:].strip() | |
| if event in ('ping', 'failed'): | |
| return QueueEvent(event=event) | |
| assert event == 'succeeded' | |
| (data,) = data | |
| assert data.startswith('data:') | |
| data = data[5:].strip() | |
| return QueueEvent(event=event, data=ScheduleResponse.parse_raw(data)) | |
| def sse_stream(res: httpx.Response) -> Generator[QueueEvent, Any, None]: | |
| for text in res.iter_text(): | |
| if len(text) == 0: | |
| break | |
| try: | |
| yield sse_parse(text) | |
| except GeneratorExit: | |
| res.close() | |
| break | |
| except Exception as e: | |
| print(f"Error parsing SSE event: {e}", file=sys.stderr) | |
| continue | |
| class APIClient: | |
| def __init__(self, client: httpx.Client): | |
| self.client = client | |
| def startup_report(self, cgroup_path: str, gpu_size: GPUSize) -> httpx.codes: | |
| try: | |
| res = self.client.post('/startup-report', params={'cgroupPath': cgroup_path, 'gpuSize': gpu_size}) | |
| return httpx.codes(res.status_code) | |
| except Exception as e: | |
| print(f"Failed to send startup report: {e}", file=sys.stderr) | |
| return httpx.codes.INTERNAL_SERVER_ERROR | |
| def schedule(self, cgroup_path: str, task_id: int = 0, token: str | None = None, token_version: int = 1, duration_seconds: int = 0, enable_queue: bool = True): | |
| try: | |
| params: dict[str, str | int | bool] = {'cgroupPath': cgroup_path, 'taskId': task_id, 'enableQueue': enable_queue, 'tokenVersion': token_version, 'durationSeconds': duration_seconds} | |
| if token is not None: | |
| params['token'] = token | |
| req = self.client.build_request(method='POST', url='/schedule', params=params) | |
| res = self.client.send(req, stream=True) | |
| status = httpx.codes(res.status_code) | |
| auth: AuthLevel | None = res.headers.get(AUTHENTICATED_HEADER) | |
| queuing_reason: QueuingReason | None = res.headers.get(QUEUING_REASON_HEADER) | |
| metadata = ScheduleMetadata(auth=auth, queuing_reason=queuing_reason) | |
| if status is not httpx.codes.OK and status is not httpx.codes.TOO_MANY_REQUESTS: | |
| res.close() | |
| return status, metadata | |
| if "text/event-stream" in res.headers.get('content-type', ''): | |
| return sse_stream(res), metadata | |
| res.read() | |
| if status is httpx.codes.TOO_MANY_REQUESTS: | |
| return QuotaInfos(**res.json()), metadata | |
| if status is httpx.codes.OK: | |
| return ScheduleResponse(**res.json()), metadata | |
| assert_never(status) | |
| except Exception as e: | |
| print(f"Error in APIClient.schedule: {e}", file=sys.stderr) | |
| return httpx.codes.INTERNAL_SERVER_ERROR, ScheduleMetadata() | |
| def allow(self, allow_token: str, pid: int): | |
| try: | |
| res = self.client.post('/allow', params={'allowToken': allow_token, 'pid': pid}) | |
| return httpx.codes(res.status_code) | |
| except Exception as e: | |
| print(f"Error in APIClient.allow: {e}", file=sys.stderr) | |
| return httpx.codes.INTERNAL_SERVER_ERROR | |
| def release(self, allow_token: str, fail: bool = False) -> httpx.codes: | |
| try: | |
| res = self.client.post('/release', params={'allowToken': allow_token, 'fail': fail}) | |
| return httpx.codes(res.status_code) | |
| except Exception as e: | |
| print(f"Error in APIClient.release: {e}", file=sys.stderr) | |
| return httpx.codes.INTERNAL_SERVER_ERROR | |
| def get_queue_size(self) -> float: | |
| try: | |
| res = self.client.get('/queue-size') | |
| assert res.status_code == 200, res.status_code | |
| return res.json() | |
| except Exception as e: | |
| print(f"Error in APIClient.get_queue_size: {e}", file=sys.stderr) | |
| return 0.0 | |
| def remove_tqdm_multiprocessing_lock(): | |
| if _tqdm is None: | |
| return | |
| try: | |
| tqdm_lock = _tqdm.get_lock() | |
| if hasattr(tqdm_lock, 'locks'): | |
| pass | |
| except Exception as e: | |
| print(f"Error while trying to remove tqdm multiprocessing lock: {e}", file=sys.stderr) | |
| tqdm = _tqdm | |
| try: | |
| Success = gr.Success | |
| except AttributeError: | |
| Success = gr.Info | |
| Level: TypeAlias = "Literal['success', 'info', 'warning']" | |
| def modal(level: Level): | |
| if level == 'info': return gr.Info | |
| if level == 'success': return Success | |
| if level == 'warning': return gr.Warning | |
| return gr.Info | |
| class GradioPartialContext(NamedTuple): | |
| event_id: str | None | |
| in_event_listener: bool | |
| progress: Progress | None | |
| def get(): | |
| TrackedIterable.__reduce__ = tracked_iterable__reduce__ | |
| return GradioPartialContext( | |
| event_id=LocalContext.event_id.get(None), | |
| in_event_listener=LocalContext.in_event_listener.get(False), | |
| progress=LocalContext.progress.get(None), | |
| ) | |
| def apply(context: 'GradioPartialContext'): | |
| LocalContext.event_id.set(context.event_id) | |
| LocalContext.in_event_listener.set(context.in_event_listener) | |
| LocalContext.progress.set(context.progress) | |
| def get_queue_instance(): | |
| blocks = LocalContext.blocks.get(None) | |
| if blocks is None: return None | |
| return getattr(blocks, '_queue', None) | |
| def get_event(): | |
| queue = get_queue_instance() | |
| event_id = LocalContext.event_id.get(None) | |
| if queue is None or event_id is None: return None | |
| for job in getattr(queue, 'active_jobs', []): | |
| if job is None: continue | |
| for event in job: | |
| if getattr(event, '_id', None) == event_id: | |
| return event | |
| return None | |
| def get_server_port() -> int | None: | |
| from_request_context = True | |
| if (blocks := LocalContext.blocks.get(None)) is None: | |
| from_request_context = False | |
| if (blocks := Context.root_block) is None: return None | |
| if (server := getattr(blocks, "server", None)) is None: | |
| if from_request_context: | |
| warnings.warn("Gradio: No blocks.server inside a request") | |
| return -1 | |
| server_config = getattr(server, 'config', None) | |
| if isinstance(server_config, dict): | |
| return server_config.get('port') | |
| elif isinstance(server_config, Settings): | |
| warnings.warn("ZeroGPU: Gradio server.config appears to be the global ZeroGPU Config object. Cannot determine Gradio port from this object.") | |
| return None | |
| elif hasattr(server_config, 'port'): | |
| return server_config.port | |
| warnings.warn(f"ZeroGPU: Unexpected type for server.config ({type(server_config)}). Cannot determine Gradio port.") | |
| return None | |
| def try_process_queue_event(method_name: str, *args, **kwargs): | |
| queue = get_queue_instance() | |
| if queue is None: | |
| warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance") | |
| return | |
| method = getattr(queue, method_name, None) | |
| if callable(method): | |
| try: | |
| method(*args, **kwargs) | |
| except Exception as e: | |
| print(f"Error processing Gradio queue event '{method_name}': {e}", file=sys.stderr) | |
| QUEUE_RPC_METHODS = ["set_progress", "log_message"] | |
| def patch_gradio_queue(res_queue: Union[SimpleQueue[RegularResQueueResult | None], SimpleQueue[GeneratorResQueueResult | None]]): | |
| def rpc_method(method_name: str): | |
| def method(*args, **kwargs): | |
| if args and isinstance(args[0], Queue): args = args[1:] | |
| res_queue.put(GradioQueueEvent(method_name, args, kwargs)) | |
| return method | |
| for method_name in QUEUE_RPC_METHODS: | |
| if (method := getattr(Queue, method_name, None)) is None: | |
| warnings.warn(f"ZeroGPU: Gradio Queue has no {method_name} attribute") | |
| continue | |
| if not callable(method): | |
| warnings.warn(f"ZeroGPU: Gradio Queue {method_name} is not callable") | |
| continue | |
| setattr(Queue, method_name, rpc_method(method_name)) | |
| TrackedIterable.__reduce__ = tracked_iterable__reduce__ | |
| def tracked_iterable__reduce__(self): | |
| try: | |
| res: tuple = super(TrackedIterable, self).__reduce__() | |
| cls, base, state, *_ = res | |
| return cls, base, {**state, **{'iterable': None, '_tqdm': None}} | |
| except Exception: | |
| return object, (), {} | |
| def supports_auth(): | |
| try: | |
| return version.parse(gr.__version__) >= version.Version('4.27.0') | |
| except Exception: | |
| return False | |
| Param_one_launch = ParamSpec('Param_one_launch') | |
| def one_launch(task: Callable[Param_one_launch, None], *task_args: Param_one_launch.args, **task_kwargs: Param_one_launch.kwargs): | |
| _launch = gr.Blocks.launch | |
| def launch(*args, **kwargs): | |
| task(*task_args, **task_kwargs) | |
| gr.Blocks.launch = _launch | |
| return gr.Blocks.launch(*args, **kwargs) | |
| gr.Blocks.launch = launch | |
| class HTMLError(gr.Error): | |
| def __str__(self): return str(self.message) | |
| def error(title: str, message: str, html: bool = False): | |
| print(f"ERROR: {title} - {message}", file=sys.stderr) | |
| error_cls = HTMLError if html else gr.Error | |
| params = inspect.signature(gr.Error).parameters | |
| kwargs: dict[str, Any] = {} | |
| if 'title' in params: kwargs['title'] = title | |
| if 'print_exception' in params: kwargs['print_exception'] = False | |
| try: | |
| pass | |
| except Exception: | |
| pass | |
| def info(title: str, message: str, level: Level = 'info'): | |
| print(f"INFO: {title} - {message}") | |
| info_cls = modal(level) | |
| params = inspect.signature(gr.Info).parameters | |
| kwargs: dict[str, Any] = {} | |
| if 'title' in params: kwargs['title'] = title | |
| try: | |
| info_cls(message, **kwargs) | |
| except Exception: | |
| pass | |
| TOKEN_HEADER = 'X-IP-Token' | |
| UNUSED_MESSAGE = "GPU device not used" | |
| NO_GPU_MESSAGE_REGULAR = "No GPU was available" | |
| NO_GPU_MESSAGE_INQUEUE = "No GPU was available after 60 seconds" | |
| EXAMPLES_RETRY_MESSAGE = "Try re-running outside of examples if it happened after clicking one" | |
| SIGNUP_ON_HF_TXT = "Create a free account" | |
| SIGNUP_ON_HF_URL = "https://huggingface.co/join" | |
| SUBSCRIBE_TO_PRO_TXT = "Subscribe to Pro" | |
| SUBSCRIBE_TO_PRO_URL = "https://huggingface.co/settings/billing/subscription" | |
| def api_client(): | |
| assert Config.zero_device_api_url is not None | |
| httpx_client = httpx.Client(base_url=Config.zero_device_api_url, timeout=60, verify=False) | |
| return APIClient(httpx_client) | |
| def startup_report_client(cgroup_path: str, gpu_size: GPUSize): | |
| retries, max_retries = 0, 2 | |
| client = api_client() | |
| status = None | |
| while retries <= max_retries: | |
| status = client.startup_report(cgroup_path, gpu_size) | |
| if status is not httpx.codes.NOT_FOUND: | |
| break | |
| time.sleep(1) | |
| retries += 1 | |
| if status is not httpx.codes.OK: | |
| print(f"Error while initializing ZeroGPU: status {status}", file=sys.stderr) | |
| def html_string(html_contents: str, text_contents: str): | |
| class HTMLString(str): | |
| def __str__(self): return text_contents | |
| return HTMLString(html_contents) | |
| def _toast_action(auth: AuthLevel | None, supports_html: bool, pro_message: str, unlogged_desc: str, logged_desc: str, ending: str) -> tuple[str, str]: | |
| if not supports_auth() or auth == 'pro': | |
| return pro_message, pro_message | |
| link = SIGNUP_ON_HF_URL if auth is None else SUBSCRIBE_TO_PRO_URL | |
| text = SIGNUP_ON_HF_TXT if auth is None else SUBSCRIBE_TO_PRO_TXT | |
| desc = unlogged_desc if auth is None else logged_desc | |
| desc += f" {ending}." | |
| style = ";".join(["white-space: nowrap", "text-underline-offset: 2px", "color: var(--body-text-color)"]) | |
| html = f'<a style="{style}" href="{link}">{text}</a> {desc}' | |
| markdown = f'[{text}]({link}) {desc}' | |
| return html, markdown | |
| def schedule(task_id: int, request: gr.Request | None = None, duration: timedelta = timedelta(0), _first_attempt: bool = True) -> Optional[ScheduleResponse]: | |
| try: | |
| gradio_version = version.parse(gr.__version__) | |
| if gradio_version.major < 4: | |
| print("ZeroGPU is only compatible with Gradio 4+", file=sys.stderr) | |
| return None | |
| except Exception: | |
| print("Could not parse Gradio version.", file=sys.stderr) | |
| return None | |
| GRADIO_HTML_TOASTS = gradio_version >= version.Version('4.39') | |
| GRADIO_HANDSHAKE = gradio_version >= version.Version('5.16.1') | |
| token, payload = _get_token_and_payload(request) | |
| if token is not None and (token_error := payload.get('error')): | |
| info("ZeroGPU client warning", f"Falling back to IP-based quotas ({token_error})", level='warning') | |
| duration_seconds = duration.seconds | |
| res, meta = api_client().schedule(cgroup_path=self_cgroup_device_path(), task_id=task_id, token=token, token_version=2 if GRADIO_HANDSHAKE else 1, duration_seconds=duration_seconds) | |
| if isinstance(res, ScheduleResponse): | |
| print("This Space is currently using 0 minutes, 0 seconds of the huggingface.co plan.") | |
| return res | |
| if isinstance(res, QuotaInfos): | |
| requested = duration.seconds | |
| message = "" | |
| if res.wait < timedelta(0): | |
| message = f"The requested GPU duration ({requested}s) is larger than the maximum allowed" | |
| elif token is None: | |
| message = f"Space app has reached its GPU limit. {EXAMPLES_RETRY_MESSAGE}" | |
| else: | |
| if payload.get('user') is None and res.wait == timedelta(0): | |
| message = "You have exceeded your runs limit." | |
| else: | |
| gpu = "Pro GPU" if meta.auth == 'pro' else ("free GPU" if meta.auth == 'regular' else "GPU") | |
| message = f"You have exceeded your {gpu} quota ({requested}s requested vs. {res.left}s left). Try again in {res.wait}" | |
| print(f"ZeroGPU quota exceeded: {message}", file=sys.stderr) | |
| return None | |
| if not isinstance(res, httpx.codes): | |
| if meta.queuing_reason in ('node', None): info("ZeroGPU queue", "Waiting for a GPU to become available") | |
| elif meta.queuing_reason == 'concurrency': info("ZeroGPU queue", "Waiting for a GPU slot on this Space") | |
| else: assert_never(meta.queuing_reason) | |
| connection_event = get_event() | |
| if connection_event is None and request is not None: | |
| warnings.warn("ZeroGPU: Cannot get Gradio app Queue instance") | |
| while True: | |
| try: | |
| event = next(res) | |
| except StopIteration: | |
| print("Unexpected end of stream in schedule", file=sys.stderr) | |
| return None | |
| except httpx.RemoteProtocolError: | |
| if not _first_attempt: | |
| print("Error while re-trying after queue disconnect", file=sys.stderr) | |
| return None | |
| return schedule(task_id, request, duration, _first_attempt=False) | |
| except Exception as e: | |
| print(f"Error processing schedule event stream: {e}", file=sys.stderr) | |
| return None | |
| if event.event == 'ping': | |
| if connection_event is not None and not connection_event.alive: | |
| res.close() | |
| print("Connection closed by visitor while queueing", file=sys.stderr) | |
| return None | |
| continue | |
| if event.event == 'failed': | |
| if token is None: | |
| message = f"{NO_GPU_MESSAGE_INQUEUE}. {EXAMPLES_RETRY_MESSAGE}" | |
| else: | |
| _, details_markdown = _toast_action(auth=meta.auth, supports_html=GRADIO_HTML_TOASTS, pro_message="Retry later", unlogged_desc="to get a higher", logged_desc="to get the highest", ending="priority in ZeroGPU queues") | |
| message = f"{NO_GPU_MESSAGE_INQUEUE} {details_markdown}" | |
| print(f"ZeroGPU queue timeout: {message}", file=sys.stderr) | |
| return None | |
| if event.event == 'succeeded': | |
| assert event.data is not None | |
| if connection_event is not None and not connection_event.alive: | |
| release(event.data.allowToken) | |
| print("Connection closed by visitor on queue success", file=sys.stderr) | |
| return None | |
| info("ZeroGPU queue", "Successfully acquired a GPU", level='success') | |
| print("This Space is currently using 0 minutes, 0 seconds of the huggingface.co plan.") | |
| return event.data | |
| if res is httpx.codes.SERVICE_UNAVAILABLE: | |
| print(f"ZeroGPU client error: {NO_GPU_MESSAGE_REGULAR}", file=sys.stderr) | |
| return None | |
| if res is httpx.codes.UNAUTHORIZED: | |
| print("ZeroGPU client error: Expired ZeroGPU proxy token", file=sys.stderr) | |
| return None | |
| reason = httpx.codes.get_reason_phrase(res) if isinstance(res, int) else "Unknown" | |
| print(f"ZeroGPU API /schedule error: {res} ({reason})", file=sys.stderr) | |
| return None | |
| def allow(allow_token: str) -> None: | |
| process_id = os.getpid() | |
| if process_id == 1: | |
| print("CRITICAL: Allowing PID 1 on ZeroGPU will end up killing your Space. Aborting.", file=sys.stderr) | |
| return | |
| if api_client().allow(allow_token=allow_token, pid=process_id) is not httpx.codes.OK: | |
| print(f"API call to /allow failed for token {allow_token}", file=sys.stderr) | |
| def release(allow_token: str, *, fail: bool = False, allow_404: bool = True) -> None: | |
| res = api_client().release(allow_token=allow_token, fail=fail) | |
| if res is httpx.codes.NO_CONTENT: | |
| try: | |
| info("ZeroGPU client warning", UNUSED_MESSAGE, level='warning') | |
| except AttributeError: | |
| pass | |
| warnings.warn(UNUSED_MESSAGE, RuntimeWarning) | |
| return | |
| if res is httpx.codes.NOT_FOUND: | |
| if not allow_404: | |
| warnings.warn("ZeroGPU API /release warning: 404 Not Found") | |
| return | |
| if httpx.codes.is_success(res): | |
| return | |
| reason = httpx.codes.get_reason_phrase(res) if isinstance(res, int) else "Unknown" | |
| print(f"ZeroGPU API /release error: {res} ({reason})", file=sys.stderr) | |
| def _get_token(request: gr.Request | None) -> str | None: | |
| if request is None: return None | |
| headers = getattr(request, 'headers', None) | |
| if headers is None or not hasattr(headers, '__dict__'): | |
| print("ZeroGPU client error: Internal Gradio error (headers not found)", file=sys.stderr) | |
| return None | |
| if not hasattr(headers, 'get'): | |
| headers = headers.__dict__ | |
| return headers.get(TOKEN_HEADER.lower()) | |
| def _get_token_and_payload(request: gr.Request | None) -> tuple[str | None, dict[str, Any]]: | |
| token = _get_token(request) | |
| if token is None: return None, {} | |
| payload = jwt_payload(token) | |
| return token, payload | |
| def compute_base_free_memory(total_memory: int) -> int: | |
| pytorch_base_memory = 309002240 | |
| return total_memory - pytorch_base_memory - Config.zerogpu_cuda_reserved_memory | |
| CUDA_DEVICE_NAME_STATIC = Config.zerogpu_cuda_device_name | |
| CUDA_TOTAL_MEMORY_STATIC = Config.zerogpu_cuda_total_memory | |
| CUDA_MEM_GET_INFO_STATIC = (compute_base_free_memory(CUDA_TOTAL_MEMORY_STATIC), CUDA_TOTAL_MEMORY_STATIC) | |
| CUDA_DEVICE_CAPABILITY_STATIC = (Config.zerogpu_cuda_capability_major, Config.zerogpu_cuda_capability_minor) | |
| CUDA_DEVICE_PROPERTIES_STATIC = SimpleNamespace(name=CUDA_DEVICE_NAME_STATIC, major=CUDA_DEVICE_CAPABILITY_STATIC[0], minor=CUDA_DEVICE_CAPABILITY_STATIC[1], total_memory=CUDA_TOTAL_MEMORY_STATIC, multi_processor_count=Config.zerogpu_cuda_multi_processor_count) | |
| if torch: | |
| class MockCudaRuntime: | |
| def setDevice(self, device): | |
| pass | |
| def getDevice(self): | |
| return 0 | |
| def deviceSynchronize(self): | |
| pass | |
| def deviceGetStreamPriorityRange(self): | |
| return 0, 0 | |
| cudart = MockCudaRuntime() | |
| if torch and torch.version.cuda.startswith("12."): | |
| CUDA_MEMORY_STATS_AS_NESTED_DICT_STATIC = {"num_alloc_retries": 0, "num_ooms": 0, "max_split_size": -1, "num_sync_all_streams": 0, "num_device_alloc": 0, "num_device_free": 0, "allocation": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "segment": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "active": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "inactive_split": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "allocated_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "reserved_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "active_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "inactive_split_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "requested_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "oversize_allocations": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "oversize_segments": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}} | |
| else: | |
| CUDA_MEMORY_STATS_AS_NESTED_DICT_STATIC = {"num_alloc_retries": 0, "num_ooms": 0, "max_split_size": -1, "allocation": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "segment": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "active": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "inactive_split": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "allocated_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "reserved_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "active_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "inactive_split_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "requested_bytes": {"all": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "small_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "large_pool": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}}, "oversize_allocations": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}, "oversize_segments": {"current": 0, "peak": 0, "allocated": 0, "freed": 0}} | |
| def cudaMemGetInfo(device: int, /): | |
| return CUDA_MEM_GET_INFO_STATIC | |
| PAGE_SIZE = 4096 | |
| try: | |
| TOTAL_MEMORY = os.sysconf('SC_PAGE_SIZE') * os.sysconf('SC_PHYS_PAGES') | |
| except (ValueError, AttributeError): | |
| TOTAL_MEMORY = 8 * (1024**3) | |
| VM_MAX_SIZE = min(2**38, TOTAL_MEMORY // 2) | |
| BUFFER_SIZE = 128 * 2**20 | |
| BUFFER_COUNT = 2 | |
| if torch: | |
| TensorWithSizes: TypeAlias = 'tuple[torch.Tensor, int, int]' | |
| if torch: | |
| class ZeroGPUTensorPack: | |
| base_dir: str | |
| batches: list[list[TensorWithSizes]] | |
| big_tensors: list[list[TensorWithSizes]] | |
| fakes: dict[torch.Tensor, list[torch.Tensor]] | |
| total_size: int | |
| def path(self): | |
| return f'{self.base_dir}/{id(self)}' | |
| def __del__(self): | |
| try: | |
| os.remove(self.path()) | |
| except (FileNotFoundError, TypeError, AttributeError): | |
| pass | |
| def write_packing(fd: int, tensor: torch.Tensor): | |
| try: | |
| clone = torch.empty_like(tensor) | |
| size = clone.untyped_storage().size() | |
| buffer = torch.UntypedStorage(VM_MAX_SIZE) | |
| buffer_ptr = buffer.data_ptr() | |
| offset = -buffer_ptr % PAGE_SIZE | |
| padding = -size % PAGE_SIZE | |
| clone.set_(buffer[offset:offset + size], 0, clone.shape, clone.stride()) | |
| clone.copy_(tensor) | |
| mv = memoryview((ctypes.c_char * (size + padding)).from_address(buffer_ptr + offset)) | |
| written_bytes = 0 | |
| while written_bytes < size: | |
| written_bytes += os.write(fd, mv[written_bytes:]) | |
| except Exception as e: | |
| print(f"Error during tensor write packing: {e}", file=sys.stderr) | |
| def pack_tensors(tensors: set[torch.Tensor], fakes: dict[torch.Tensor, list[torch.Tensor]], offload_dir: str, callback: Callable[[int], None] | None = None): | |
| callback = (lambda b: None) if callback is None else callback | |
| batches: list[list[TensorWithSizes]] = [] | |
| big_tensors: list[list[TensorWithSizes]] = [] | |
| tensors_with_sizes: list[tuple[torch.Tensor, int, int]] = [] | |
| for tensor in tensors: | |
| size = tensor.numel() * tensor.element_size() | |
| aligned_size = size + (-size % PAGE_SIZE) | |
| tensors_with_sizes.append((tensor, size, aligned_size)) | |
| current_batch, current_size = [], 0 | |
| for (tensor, size, aligned_size) in sorted(tensors_with_sizes, key=lambda item: item[2]): | |
| if aligned_size > BUFFER_SIZE: | |
| big_tensors.append((tensor, size, aligned_size)) | |
| continue | |
| current_size += aligned_size | |
| if current_size > BUFFER_SIZE: | |
| batches.append(current_batch) | |
| current_batch, current_size = [(tensor, size, aligned_size)], aligned_size | |
| else: | |
| current_batch.append((tensor, size, aligned_size)) | |
| if current_batch: | |
| batches.append(current_batch) | |
| get_meta = {tensor: empty_like_raw_alloc(tensor) for tensor in tensors} | |
| batches_meta = [[(get_meta[tensor], size, asize) for tensor, size, asize in batch] for batch in batches] | |
| big_tensors_meta = [(get_meta[tensor], size, asize) for tensor, size, asize in big_tensors] | |
| fakes_meta = {get_meta[tensor]: fake_list for tensor, fake_list in fakes.items()} | |
| pack = ZeroGPUTensorPack(base_dir=offload_dir, batches=batches_meta, big_tensors=big_tensors_meta, fakes=fakes_meta, total_size=sum([size for _, size, _ in tensors_with_sizes])) | |
| fd = -1 | |
| try: | |
| fd = os.open(pack.path(), os.O_CREAT | os.O_WRONLY | os.O_DIRECT) | |
| total_asize = sum([aligned_size for batch in batches for *_, aligned_size in batch]) | |
| total_asize += sum([aligned_size for *_, aligned_size in big_tensors]) | |
| if total_asize > 0: | |
| os.posix_fallocate(fd, 0, total_asize) | |
| for batch in batches: | |
| for tensor, size, _ in batch: | |
| write_packing(fd, tensor) | |
| callback(size) | |
| for tensor, size, _ in big_tensors: | |
| write_packing(fd, tensor) | |
| callback(size) | |
| return pack | |
| except Exception as e: | |
| print(f"Failed to pack tensors to disk: {e}", file=sys.stderr) | |
| return pack | |
| finally: | |
| if fd != -1: | |
| os.close(fd) | |
| def pack_to_cuda(pack: ZeroGPUTensorPack, callback: Callable[[int], None] | None = None): | |
| callback = (lambda b: None) if callback is None else callback | |
| free_buffers: ThreadQueue[torch.Tensor] = ThreadQueue() | |
| read_buffers: ThreadQueue[torch.Tensor] = ThreadQueue() | |
| for _ in range(BUFFER_COUNT): | |
| free_buffers.put(torch.ByteTensor(BUFFER_SIZE).pin_memory()) | |
| def read(fd: int, buffer: torch.Tensor, size: int): | |
| mv = memoryview((ctypes.c_char * size).from_address(buffer.data_ptr())) | |
| read_bytes = 0 | |
| while read_bytes < size: | |
| read_bytes += os.readv(fd, [mv[read_bytes:]]) | |
| def disk_to_pin(fd: int): | |
| for batch in pack.batches: | |
| buffer = free_buffers.get() | |
| batch_size = sum([aligned_size for *_, aligned_size in batch]) | |
| read(fd, buffer, batch_size) | |
| read_buffers.put(buffer) | |
| for *_, aligned_size in pack.big_tensors: | |
| read_bytes = 0 | |
| while read_bytes < aligned_size: | |
| buffer = free_buffers.get() | |
| read_size = min(BUFFER_SIZE, aligned_size - read_bytes) | |
| read(fd, buffer, read_size) | |
| read_buffers.put(buffer) | |
| read_bytes += read_size | |
| def pin_to_cuda(): | |
| total_duration_in_callback = 0 | |
| for batch in pack.batches: | |
| buffer = read_buffers.get() | |
| offset = 0 | |
| cuda_storages = [] | |
| for tensor, size, aligned_size in batch: | |
| cuda_storages.append(buffer[offset:offset + size].cuda(non_blocking=True)) | |
| offset += aligned_size | |
| torch.cuda.synchronize() | |
| free_buffers.put(buffer) | |
| batch_total_size = 0 | |
| for (tensor, size, _), cuda_storage in zip(batch, cuda_storages): | |
| cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda') | |
| cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride()) | |
| for fake in pack.fakes[tensor]: | |
| fake.data = cuda_tensor | |
| batch_total_size += size | |
| t0 = time.perf_counter() | |
| callback(batch_total_size) | |
| total_duration_in_callback += time.perf_counter() - t0 | |
| for tensor, size, _ in pack.big_tensors: | |
| cuda_storage = torch.empty(size, dtype=torch.uint8, device='cuda') | |
| offset = 0 | |
| while offset < size: | |
| buffer = read_buffers.get() | |
| read_size = min(BUFFER_SIZE, size - offset) | |
| cuda_storage[offset:offset + read_size] = buffer[:read_size] | |
| offset += read_size | |
| torch.cuda.synchronize() | |
| free_buffers.put(buffer) | |
| t0 = time.perf_counter() | |
| callback(read_size) | |
| total_duration_in_callback += time.perf_counter() - t0 | |
| cuda_tensor = torch.tensor([], dtype=tensor.dtype, device='cuda') | |
| cuda_tensor = cuda_tensor.set_(cuda_storage.untyped_storage(), 0, tensor.shape, tensor.stride()) | |
| for fake in pack.fakes[tensor]: | |
| fake.data = cuda_tensor | |
| debug(f"{total_duration_in_callback=}") | |
| fd = -1 | |
| try: | |
| with ThreadPoolExecutor(2) as e: | |
| fd = os.open(pack.path(), os.O_RDONLY | os.O_DIRECT) | |
| futures = [e.submit(copy_context().run, disk_to_pin, fd), e.submit(copy_context().run, pin_to_cuda)] | |
| for future in as_completed(futures): | |
| future.result() | |
| except Exception as e: | |
| print(f"Error during pack_to_cuda: {e}", file=sys.stderr) | |
| finally: | |
| if fd != -1: | |
| os.close(fd) | |
| def cuda_unavailable(torch_module: ModuleType): | |
| _is_available = torch_module.cuda.is_available | |
| torch_module.cuda.is_available = lambda: False | |
| yield | |
| torch_module.cuda.is_available = _is_available | |
| def maybe_import_bitsandbytes(): | |
| try: | |
| if torch is None: return None | |
| bnb_version = version.parse(metadata.version('bitsandbytes')) | |
| if bnb_version < version.parse('0.40.0'): | |
| print(f"Warning: ZeroGPU requires bitsandbytes >= 0.40.0 (installed: {bnb_version})", file=sys.stderr) | |
| return None | |
| ctx_factory = (lambda: cuda_unavailable(torch)) if bnb_version < version.parse('0.43.1') else nullcontext | |
| with (ctx := ctx_factory()): | |
| importlib.import_module('bitsandbytes') | |
| if not isinstance(ctx, nullcontext): | |
| print("↑ Those bitsandbytes warnings are expected on ZeroGPU ↑", file=sys.stderr) | |
| return ctx_factory | |
| except (ImportError, metadata.PackageNotFoundError): | |
| return None | |
| except Exception as e: | |
| print(f"Unexpected error during bitsandbytes check: {e}", file=sys.stderr) | |
| return None | |
| bnb_import_context = maybe_import_bitsandbytes() | |
| if bnb_import_context and torch: | |
| from torch.utils.weak import WeakTensorKeyDictionary | |
| with (import_ctx := bnb_import_context()): | |
| CUDASetup = None | |
| if not isinstance(import_ctx, nullcontext): | |
| from bitsandbytes.cuda_setup.main import CUDASetup | |
| from bitsandbytes import cextension, functional | |
| from bitsandbytes.nn import Int8Params, Params4bit | |
| _param_to_8bit = Int8Params.to | |
| _param_cuda_8bit = Int8Params.cuda | |
| _param_to_4bit = Params4bit.to | |
| _param_cuda_4bit = Params4bit.cuda | |
| TensorToArgs_bnb = Tuple[torch.device, torch.dtype, bool, torch.memory_format] | |
| to_ops_8bit: dict[Int8Params, TensorToArgs_bnb | None] = WeakTensorKeyDictionary() | |
| to_ops_4bit: dict[Params4bit, TensorToArgs_bnb | None] = WeakTensorKeyDictionary() | |
| def _to_op_register_8bit(self: Int8Params, *args, **kwargs): | |
| parsed = torch._C._nn._parse_to(*args, **kwargs) | |
| device, *_ = parsed | |
| if not isinstance(device, torch.device) or device.type != 'cuda': | |
| return _param_to_8bit(self, *args, **kwargs) | |
| to_ops_8bit[self] = parsed | |
| return self | |
| def _to_op_register_4bit(self: Params4bit, *args, **kwargs): | |
| parsed = torch._C._nn._parse_to(*args, **kwargs) | |
| device, *_ = parsed | |
| if not isinstance(device, torch.device) or device.type != 'cuda': | |
| return _param_to_4bit(self, *args, **kwargs) | |
| to_ops_4bit[self] = parsed | |
| return self | |
| def _cuda_op_arg_check_bnb(device: Union[torch.device, int, str, None]) -> bool: | |
| if device is None or isinstance(device, int): return True | |
| if isinstance(device, str): device = torch.device(device) | |
| return device.type == 'cuda' | |
| def _cuda_op_register_8bit(self: Int8Params, device: Union[torch.device, int, str, None] = None, **kwargs): | |
| if not _cuda_op_arg_check_bnb(device): return _param_cuda_8bit(self, device, **kwargs) | |
| to_ops_8bit[self] = None | |
| return self | |
| def _cuda_op_register_4bit(self: Params4bit, device: Union[torch.device, int, str, None] = None, **kwargs): | |
| if not _cuda_op_arg_check_bnb(device): return _param_cuda_4bit(self, device, **kwargs) | |
| to_ops_4bit[self] = None | |
| return self | |
| def _patch_bnb(): | |
| Int8Params.to = _to_op_register_8bit | |
| Int8Params.cuda = _cuda_op_register_8bit | |
| Params4bit.to = _to_op_register_4bit | |
| Params4bit.cuda = _cuda_op_register_4bit | |
| def _unpatch_bnb(): | |
| Int8Params.to = _param_to_8bit | |
| Int8Params.cuda = _param_cuda_8bit | |
| Params4bit.to = _param_to_4bit | |
| Params4bit.cuda = _param_cuda_4bit | |
| def _move_bnb(): | |
| if CUDASetup is not None: | |
| CUDASetup._instance = None | |
| importlib.reload(cextension) | |
| functional.lib = cextension.lib | |
| for tensor, parsed_args in to_ops_8bit.items(): | |
| dtype, memory_format = (parsed_args[1], parsed_args[3]) if parsed_args else (None, None) | |
| tensor.data = _param_to_8bit(tensor, device='cuda', dtype=dtype, memory_format=memory_format) | |
| for tensor, parsed_args in to_ops_4bit.items(): | |
| dtype, memory_format = (parsed_args[1], parsed_args[3]) if parsed_args else (None, None) | |
| tensor.data = _param_to_4bit(tensor, device='cuda', dtype=dtype, memory_format=memory_format) | |
| else: | |
| def _patch_bnb(): pass | |
| def _unpatch_bnb(): pass | |
| def _move_bnb(): pass | |
| patch_bnb = _patch_bnb | |
| unpatch_bnb = _unpatch_bnb | |
| move_bnb = _move_bnb | |
| class _BitsAndBytesManager: | |
| def patch(self): return patch_bnb() | |
| def unpatch(self): return unpatch_bnb() | |
| def move(self): return move_bnb() | |
| if torch: | |
| PINNED_MEMORY_RATIO_LIMIT = 0.1 | |
| OPS_INPUTS_CHECK_NO_RETURN = (torch.Tensor.equal,) | |
| OPS_INPUT_CHECK_SELF_RETURN = (torch.Tensor.set_, torch.ops.aten.set_.source_Tensor) | |
| OFFLOADED_ERROR_MESSAGE = "Cannot apply function {} on disk-offloaded Tensor {}" | |
| _tensor_make_subclass = torch.Tensor._make_subclass | |
| _asarray = torch.asarray | |
| _device = torch.device | |
| _cuda_init_v2 = torch._C._cuda_init | |
| _cuda_exchange_device = torch.cuda._exchange_device | |
| _cuda_available_v2 = torch.cuda.is_available | |
| _cuda_device_count_v2 = torch.cuda.device_count | |
| _cuda_current_device_v2 = torch.cuda.current_device | |
| _cuda_synchronize = torch.cuda.synchronize | |
| _cuda_get_device_capability_v2 = torch.cuda.get_device_capability | |
| _cuda_get_device_properties_v2 = torch.cuda.get_device_properties | |
| _cuda_get_device_name_v2 = torch.cuda.get_device_name | |
| _cuda_memory_stats_as_nested_dict = torch.cuda.memory.memory_stats_as_nested_dict | |
| _cuda_cudart = torch.cuda.cudart | |
| _cuda_maybe_exchange_device = getattr(torch.cuda, '_maybe_exchange_device', None) | |
| cuda_aliases: dict[torch.Tensor, torch.Tensor | None] = WeakTensorKeyDictionary() | |
| tensor_packs: list[ZeroGPUTensorPack] = [] | |
| class ZeroGPUTensor(torch.Tensor): pass | |
| def empty_fake(tensor: torch.Tensor): | |
| fake = empty_like_raw_alloc(tensor, requires_grad=tensor.requires_grad) | |
| if fake.__class__ != tensor.__class__: | |
| fake = _tensor_make_subclass(tensor.__class__, fake, require_grad=tensor.requires_grad) | |
| return fake | |
| def no_int_device(*args, **kwargs): | |
| if len(args) and isinstance(index := args[0], int): | |
| args = (f'cuda:{index}', *args[1:]) | |
| if isinstance(index := kwargs.get('device'), int): | |
| kwargs['device'] = f'cuda:{index}' | |
| return args, kwargs | |
| class ZeroGPUFunctionMode(torch.overrides.TorchFunctionMode): | |
| def __torch_function__(self, func, types, args=(), kwargs: dict[str, Any] | None = None): | |
| kwargs = {} if kwargs is None else kwargs | |
| try: | |
| if func == torch._C._nn._parse_to: | |
| args, kwargs = no_int_device(*args, **kwargs) | |
| return func(*args, **kwargs) | |
| if func == torch.Tensor.cuda or func == torch.Tensor.cpu: | |
| memory_format = kwargs.get("memory_format") | |
| device_str = "cuda" if func == torch.Tensor.cuda else "cpu" | |
| to_kwargs = {"device": device_str} | |
| if memory_format is not None: to_kwargs["memory_format"] = memory_format | |
| return self.__torch_function__(torch.Tensor.to, types, (args[0],), to_kwargs) | |
| if func == torch.Tensor.to and len(args) > 1: | |
| parse_to_args, parse_to_kwargs = no_int_device(*args[1:], **kwargs) | |
| device, dtype, _, memory_format = torch._C._nn._parse_to(*parse_to_args, **parse_to_kwargs) | |
| return self.__torch_function__(torch.Tensor.to, types, (args[0],), {'device': device, 'dtype': dtype, 'memory_format': memory_format}) | |
| if func == torch.Tensor.data.__set__: | |
| self_tensor, target = args | |
| if target in cuda_aliases: | |
| if (target_original := cuda_aliases[target]) is None: | |
| print(OFFLOADED_ERROR_MESSAGE.format(torch.overrides.resolve_name(func), target), file=sys.stderr) | |
| return | |
| original = empty_fake(self_tensor) | |
| original.data = target_original | |
| cuda_aliases[self_tensor] = original | |
| elif self_tensor in cuda_aliases: | |
| del cuda_aliases[self_tensor] | |
| self_tensor.data = target | |
| return | |
| if func == torch.Tensor.device.__get__: | |
| tensor, = args | |
| if tensor in cuda_aliases: return torch.device('cuda', index=0) | |
| elif func == torch.Tensor.__repr__: | |
| tensor, = args | |
| if tensor in cuda_aliases: | |
| original = cuda_aliases[tensor] or tensor.to('meta') | |
| original_class = original.__class__ | |
| original.__class__ = ZeroGPUTensor | |
| try: | |
| return func(original, **kwargs) | |
| finally: | |
| original.__class__ = original_class | |
| elif func == torch.Tensor.untyped_storage: | |
| tensor, = args | |
| if tensor in cuda_aliases: | |
| if (original := cuda_aliases[tensor]) is None: | |
| print(OFFLOADED_ERROR_MESSAGE.format(torch.overrides.resolve_name(func), tensor), file=sys.stderr) | |
| return None | |
| res = func(original, **kwargs) | |
| res._zerogpu = True | |
| return res | |
| cuda: bool | None = None | |
| if (device := kwargs.get('device')) is not None: | |
| device = torch.device(device) | |
| cuda = device.type == 'cuda' | |
| if cuda: kwargs['device'] = torch.device('cpu') | |
| swapped, inputs_are_cuda = {}, set() | |
| def swap(t: torch.Tensor): | |
| nonlocal inputs_are_cuda | |
| if t not in cuda_aliases: | |
| inputs_are_cuda.add(False) | |
| return t | |
| original = cuda_aliases[t] | |
| if original is None: | |
| print(OFFLOADED_ERROR_MESSAGE.format(torch.overrides.resolve_name(func), t), file=sys.stderr) | |
| return t | |
| swapped[original] = t | |
| inputs_are_cuda.add(True) | |
| return original | |
| args_ = torch.utils._pytree.tree_map_only(torch.Tensor, swap, args) | |
| kwargs_ = torch.utils._pytree.tree_map_only(torch.Tensor, swap, kwargs) | |
| if inputs_are_cuda == {True} and cuda is not False: cuda = True | |
| if len(args) == 1 and torch.utils._python_dispatch.is_traceable_wrapper_subclass(wt := args[0]): | |
| if func in {torch.Tensor.detach, torch.ops.aten.alias.default, torch.ops.aten.clone.default}: | |
| with self: return torch.utils._python_dispatch.transform_subclass(wt, lambda _, t: func(t)) | |
| res = func(*args_, **kwargs_) | |
| for original, fake in swapped.items(): fake.data = empty_fake(original) | |
| if func in {torch.ops.aten.index.Tensor, torch.Tensor.__getitem__}: | |
| cuda = args[0] in cuda_aliases | |
| inputs_are_cuda = {cuda} | |
| if (isinstance(res, torch.Tensor) or func in OPS_INPUTS_CHECK_NO_RETURN) and not (func == torch.ops.aten.set_.source_Tensor and len(args_) == 3): | |
| st = args_[0] if len(args_) >= 1 and isinstance(args_[0], torch.Tensor) else None | |
| if (res is not st or func in OPS_INPUT_CHECK_SELF_RETURN) and inputs_are_cuda == {True, False}: | |
| print("RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 (ZeroGPU) and cpu!", file=sys.stderr) | |
| def register(t: torch.Tensor): | |
| if t in swapped and cuda is not False: return swapped[t] | |
| if cuda is not True: return t | |
| fake = empty_fake(t) | |
| cuda_aliases[fake] = t | |
| return fake | |
| return torch.utils._pytree.tree_map_only(torch.Tensor, register, res) | |
| except Exception as e: | |
| print(f"Error in ZeroGPUFunctionMode: {e}", file=sys.stderr) | |
| return func(*args, **kwargs) | |
| class DefaultDispatchMode(torch.utils._python_dispatch.TorchDispatchMode): | |
| def __torch_dispatch__(self, func, types, args=(), kwargs: dict[str, Any] | None = None): | |
| return func(*args, **(kwargs or {})) | |
| function_mode = ZeroGPUFunctionMode() | |
| dispatch_mode = DefaultDispatchMode() | |
| def _untyped_storage_new_register(*args, **kwargs): | |
| cuda = False | |
| if (device := kwargs.get('device')) is not None and device.type == 'cuda': | |
| cuda = True | |
| del kwargs['device'] | |
| storage = torch._C.StorageBase.__new__(*args, **kwargs) | |
| if cuda: storage._zerogpu = True | |
| return storage | |
| def _untyped_storage_device(self): | |
| if hasattr(self, '_zerogpu'): return torch.device('cuda', index=0) | |
| return torch._C.StorageBase.device.__get__(self) | |
| def _tensor_make_subclass_function_mode(*args, **kwargs): | |
| with torch._C.DisableTorchFunction(): | |
| return function_mode.__torch_function__(_tensor_make_subclass, (), args=args, kwargs=kwargs) | |
| def _asarray_function_mode(*args, **kwargs): | |
| with torch._C.DisableTorchFunction(): | |
| return function_mode.__torch_function__(_asarray, (), args=args, kwargs=kwargs) | |
| class _DeviceStringOnlyMeta(type): | |
| def __instancecheck__(cls, instance): return isinstance(instance, _device) | |
| class _DeviceStringOnly(metaclass=_DeviceStringOnlyMeta): | |
| def __new__(cls, *args, **kwargs): | |
| args, kwargs = no_int_device(*args, **kwargs) | |
| return _device(*args, **kwargs) | |
| def _cuda_init_raise_v2(): | |
| pass | |
| def _cuda_dummy_exchange_device(device): | |
| assert device in {-1, 0} | |
| return device | |
| def patch_v2(): | |
| function_mode.__enter__() | |
| dispatch_mode.__enter__() | |
| torch.Tensor._make_subclass = _tensor_make_subclass_function_mode | |
| torch.UntypedStorage.__new__ = _untyped_storage_new_register | |
| torch.UntypedStorage.device = _untyped_storage_device | |
| torch.asarray = _asarray_function_mode | |
| torch.device = _DeviceStringOnly | |
| torch._C._cuda_init = _cuda_init_raise_v2 | |
| torch.cuda._exchange_device = _cuda_dummy_exchange_device | |
| torch.cuda.is_available = lambda: True | |
| torch.cuda.device_count = lambda: 1 | |
| torch.cuda.current_device = lambda: 0 | |
| torch.cuda.synchronize = lambda *args: None | |
| torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY_STATIC | |
| torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES_STATIC | |
| torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME_STATIC | |
| torch.cuda.memory.memory_stats_as_nested_dict = lambda *args, **kwargs: CUDA_MEMORY_STATS_AS_NESTED_DICT_STATIC | |
| torch.cuda.cudart = lambda: cudart | |
| if _cuda_maybe_exchange_device is not None: setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device) | |
| _BitsAndBytesManager().patch() | |
| def unpatch_v2(): | |
| from contextlib import suppress | |
| try: | |
| dispatch_mode.__exit__(None, None, None) | |
| function_mode.__exit__(None, None, None) | |
| except RuntimeError: pass | |
| torch.Tensor._make_subclass = _tensor_make_subclass | |
| torch.UntypedStorage.__new__ = torch._C.StorageBase.__new__ | |
| torch.UntypedStorage.device = torch._C.StorageBase.device | |
| torch.asarray = _asarray | |
| torch.device = _device | |
| torch._C._cuda_init = _cuda_init_v2 | |
| torch.cuda._exchange_device = _cuda_exchange_device | |
| torch.cuda.is_available = _cuda_available_v2 | |
| torch.cuda.device_count = _cuda_device_count_v2 | |
| torch.cuda.current_device = _cuda_current_device_v2 | |
| torch.cuda.synchronize = _cuda_synchronize | |
| torch.cuda.get_device_capability = _cuda_get_device_capability_v2 | |
| torch.cuda.get_device_properties = _cuda_get_device_properties_v2 | |
| torch.cuda.get_device_name = _cuda_get_device_name_v2 | |
| torch.cuda.memory.memory_stats_as_nested_dict = _cuda_memory_stats_as_nested_dict | |
| torch.cuda.cudart = _cuda_cudart | |
| if _cuda_maybe_exchange_device is not None: setattr(torch.cuda, '_maybe_exchange_device', _cuda_exchange_device) | |
| _BitsAndBytesManager().unpatch() | |
| def _total_unpacked_size(): | |
| tensors = [t for t in cuda_aliases.values() if t is not None] | |
| deduped = {AliasId.from_tensor(t): t for t in tensors} | |
| return sum([t.numel() * t.element_size() for t in deduped.values()]) | |
| def _pack_v2_internal(offload_dir: str): | |
| originals, originals_dedup, fakes = set(), {}, defaultdict(list) | |
| for fake, original in cuda_aliases.items(): | |
| if original is not None: | |
| original_id = AliasId.from_tensor(original) | |
| if original_id not in originals_dedup: | |
| originals_dedup[original_id] = original | |
| originals.add(original) | |
| fakes[originals_dedup[original_id]].append(fake) | |
| total_size = _total_unpacked_size() | |
| progress_context = tqdm(total=total_size, unit='B', unit_scale=True, desc="ZeroGPU tensors packing") if tqdm is not None and total_size > 0 else nullcontext() | |
| with progress_context as progress: | |
| update = progress.update if progress is not None else lambda _: None | |
| pack = pack_tensors(originals, fakes, offload_dir, callback=update) | |
| tensor_packs.append(pack) | |
| for fake_list in fakes.values(): | |
| for fake in fake_list: cuda_aliases[fake] = None | |
| return total_size | |
| def pack_v2(): | |
| total_size = _pack_v2_internal(Config.zerogpu_offload_dir) | |
| gc.collect() | |
| malloc_trim() | |
| return total_size | |
| def init_v2(nvidia_uuid: str): | |
| os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid | |
| torch.Tensor([0]).cuda() | |
| def size_v2(): | |
| return _total_unpacked_size() + sum([p.total_size for p in tensor_packs]) | |
| def _move_v2_internal(callback: Callable[[int], None] | None = None): | |
| cb = callback or (lambda _: None) | |
| pinned_limit, moved = _total_unpacked_size() * PINNED_MEMORY_RATIO_LIMIT, {} | |
| for fake, original in cuda_aliases.items(): | |
| if original is not None: | |
| original_id = AliasId.from_tensor(original) | |
| if original_id not in moved: | |
| use_pinned = original.numel() * original.element_size() < pinned_limit | |
| original_cuda = original.pin_memory().cuda(non_blocking=True) if use_pinned else original.cuda() | |
| moved[original_id] = original_cuda | |
| cb(fake.numel() * fake.element_size()) | |
| torch.cuda.synchronize() | |
| for fake, original in cuda_aliases.items(): | |
| if original is not None: fake.data = moved[AliasId.from_tensor(original)] | |
| for tensor_pack in tensor_packs: pack_to_cuda(tensor_pack, callback=cb) | |
| _BitsAndBytesManager().move() | |
| def move_v2(callback: Callable[[int], None] | None = None): | |
| cb = callback or (lambda _: None) | |
| with ThreadPoolExecutor(1) as e: | |
| e.submit(copy_context().run, _move_v2_internal, callback=cb).result() | |
| torch.cuda.synchronize() | |
| def is_in_bad_fork_v2(): | |
| return False | |
| CUDA_DEVICE_NAME_LEGACY, CUDA_TOTAL_MEMORY_LEGACY = 'NVIDIA A100-SXM4-80GB MIG 3g.40gb', 42144366592 | |
| CUDA_MEM_GET_INFO_LEGACY = (41911451648, CUDA_TOTAL_MEMORY_LEGACY) | |
| CUDA_DEVICE_CAPABILITY_LEGACY = (8, 0) | |
| CUDA_DEVICE_PROPERTIES_LEGACY = SimpleNamespace(name=CUDA_DEVICE_NAME_LEGACY, major=8, minor=0, total_memory=CUDA_TOTAL_MEMORY_LEGACY, multi_processor_count=42) | |
| GENERIC_METHOD_NAMES = ['arange', 'as_tensor', 'asarray', 'bartlett_window', 'blackman_window', 'empty', 'empty_like', 'empty_strided', 'eye', 'full', 'full_like', 'hamming_window', 'hann_window', 'kaiser_window', 'linspace', 'logspace', 'ones', 'ones_like', 'rand', 'rand_like', 'randint', 'randint_like', 'randn', 'randn_like', 'randperm', 'range', 'sparse_bsc_tensor', 'sparse_bsr_tensor', 'sparse_compressed_tensor', 'sparse_coo_tensor', 'sparse_csc_tensor', 'sparse_csr_tensor', 'tensor', 'tril_indices', 'triu_indices', 'zeros', 'zeros_like'] | |
| TO_CUDA = (torch.device('cuda'), None, False, None) | |
| _tensor__deepcopy__, _tensor_to, _tensor_cuda, _tensor_cpu = torch.Tensor.__deepcopy__, torch.Tensor.to, torch.Tensor.cuda, torch.Tensor.cpu | |
| _torch_generics = {name: getattr(torch, name) for name in GENERIC_METHOD_NAMES} | |
| _cuda_init_legacy, _cuda_available_legacy, _cuda_device_count_legacy, _cuda_current_device_legacy = torch._C._cuda_init, torch.cuda.is_available, torch.cuda.device_count, torch.cuda.current_device | |
| _cuda_mem_get_info, _cuda_get_device_capability_legacy, _cuda_get_device_properties_legacy, _cuda_get_device_name_legacy = torch.cuda.mem_get_info, torch.cuda.get_device_capability, torch.cuda.get_device_properties, torch.cuda.get_device_name | |
| TensorToArgs_legacy = Tuple[Optional[torch.device], Optional[torch.dtype], bool, Optional[torch.memory_format]] | |
| to_ops: dict[torch.Tensor, TensorToArgs_legacy] = WeakTensorKeyDictionary() | |
| def _tensor_new_register(*args, **kwargs): | |
| new_tensor = torch._C._TensorBase.__new__(*args, **kwargs) | |
| if (base := getattr(new_tensor, '_base', None)) is not None and base in to_ops: | |
| to_ops[new_tensor] = to_ops[base] | |
| return new_tensor | |
| def _tensor_deepcopy_register(self: torch.Tensor, memo): | |
| new_tensor = _tensor__deepcopy__(self, memo) | |
| if isinstance(new_tensor, torch.Tensor) and self in to_ops: | |
| to_ops[new_tensor] = to_ops[self] | |
| return new_tensor | |
| def _tensor_device_property(self: torch.Tensor): | |
| if self in to_ops: return torch.device(type='cuda', index=0) | |
| del torch.Tensor.device | |
| try: return self.device | |
| finally: torch.Tensor.device = _tensor_device_property | |
| def _tensor_dtype_property(self: torch.Tensor): | |
| if self in to_ops and (to_dtype := to_ops[self][1]) is not None: return to_dtype | |
| del torch.Tensor.dtype | |
| try: return self.dtype | |
| finally: torch.Tensor.dtype = _tensor_dtype_property | |
| def _to_op_register(self: torch.Tensor, *args, **kwargs): | |
| parsed = torch._C._nn._parse_to(*args, **kwargs) | |
| device, dtype, *_ = parsed | |
| to_args = to_ops.pop(self, None) | |
| if device is None: | |
| if to_args is not None: | |
| to_ops[self] = (to_args[0], dtype, *to_args[2:]) | |
| return self | |
| return _tensor_to(self, *args, **kwargs) | |
| if device.type != 'cuda': | |
| if to_args is not None and (to_dtype := to_args[1]) is not None: | |
| kwargs = {'dtype': to_dtype, **kwargs} | |
| return _tensor_to(self, *args, **kwargs) | |
| to_ops[self] = parsed | |
| return self | |
| def _cuda_op_arg_check(device: torch.device | int | str | None) -> bool: | |
| if device is None or isinstance(device, int): return True | |
| if isinstance(device, str): device = torch.device(device) | |
| return device.type == 'cuda' | |
| def _cuda_op_register(self: torch.Tensor, device: torch.device | int | str | None = None, **kwargs): | |
| if not _cuda_op_arg_check(device): return _tensor_cuda(self, device, **kwargs) | |
| to_ops[self] = TO_CUDA | |
| return self | |
| def _cpu_op_remove(self: torch.Tensor, **kwargs): | |
| to_args = to_ops.pop(self, None) | |
| if to_args is not None and (to_dtype := to_args[1]) is not None: | |
| return _tensor_to(self, 'cpu', **{'dtype': to_dtype, **kwargs}) | |
| return _tensor_cpu(self, **kwargs) | |
| def _cuda_init_raise_legacy(): | |
| pass | |
| def _generic_method_register(name: str, *args: Any, **kwargs: Any): | |
| try: | |
| device = torch.device(kwargs.get('device', "cpu")) | |
| except Exception: | |
| return _torch_generics[name](*args, **kwargs) | |
| if device.type != 'cuda': | |
| return _torch_generics[name](*args, **kwargs) | |
| tensor = _torch_generics[name](*args, **{**kwargs, 'device': "cpu"}) | |
| to_ops[tensor] = TO_CUDA | |
| return tensor | |
| def patch_legacy(): | |
| torch.Tensor.__deepcopy__ = _tensor_deepcopy_register | |
| torch.Tensor.__new__ = _tensor_new_register | |
| torch.Tensor.to = _to_op_register | |
| torch.Tensor.cuda = _cuda_op_register | |
| torch.Tensor.cpu = _cpu_op_remove | |
| if Config.zero_patch_torch_device: | |
| torch.Tensor.device = _tensor_device_property | |
| torch.Tensor.dtype = _tensor_dtype_property | |
| for name in GENERIC_METHOD_NAMES: setattr(torch, name, partial(_generic_method_register, name)) | |
| torch._C._cuda_init = _cuda_init_raise_legacy | |
| torch.cuda.is_available = lambda: True | |
| torch.cuda.device_count = lambda: 1 | |
| torch.cuda.current_device = lambda: 0 | |
| torch.cuda.mem_get_info = lambda *args, **kwargs: CUDA_MEM_GET_INFO_LEGACY | |
| torch.cuda.get_device_capability = lambda *args, **kwargs: CUDA_DEVICE_CAPABILITY_LEGACY | |
| torch.cuda.get_device_properties = lambda *args, **kwargs: CUDA_DEVICE_PROPERTIES_LEGACY | |
| torch.cuda.get_device_name = lambda *args, **kwargs: CUDA_DEVICE_NAME_LEGACY | |
| _BitsAndBytesManager().patch() | |
| def unpatch_legacy(): | |
| from contextlib import suppress | |
| torch.Tensor.__deepcopy__ = _tensor__deepcopy__ | |
| with suppress(AttributeError): del torch.Tensor.__new__ | |
| torch.Tensor.to = _tensor_to | |
| torch.Tensor.cuda = _tensor_cuda | |
| torch.Tensor.cpu = _tensor_cpu | |
| with suppress(AttributeError): del torch.Tensor.device | |
| with suppress(AttributeError): del torch.Tensor.dtype | |
| for name in GENERIC_METHOD_NAMES: setattr(torch, name, _torch_generics[name]) | |
| torch._C._cuda_init = _cuda_init_legacy | |
| torch.cuda.is_available = _cuda_available_legacy | |
| torch.cuda.device_count = _cuda_device_count_legacy | |
| torch.cuda.current_device = _cuda_current_device_legacy | |
| torch.cuda.mem_get_info = _cuda_mem_get_info | |
| torch.cuda.get_device_capability = _cuda_get_device_capability_legacy | |
| torch.cuda.get_device_properties = _cuda_get_device_properties_legacy | |
| torch.cuda.get_device_name = _cuda_get_device_name_legacy | |
| _BitsAndBytesManager().unpatch() | |
| def pack_legacy(): return 0 | |
| def init_legacy(nvidia_uuid: str): | |
| os.environ['CUDA_VISIBLE_DEVICES'] = nvidia_uuid | |
| torch.Tensor([0]).cuda() | |
| def size_legacy(): return 0 | |
| def move_legacy(callback: Callable[[int], None] | None = None): | |
| for tensor, parsed_args in to_ops.items(): | |
| _, dtype, _, memory_format = parsed_args | |
| tensor.data = _tensor_to(tensor, device='cuda', dtype=dtype, memory_format=memory_format) | |
| _BitsAndBytesManager().move() | |
| torch.cuda.synchronize() | |
| def is_in_bad_fork_legacy(): | |
| return False | |
| if torch: | |
| try: | |
| num_threads = torch.get_num_threads() | |
| torch.set_num_interop_threads(num_threads) | |
| except RuntimeError: pass | |
| if Config.zero_gpu_v2: | |
| _patch, _unpatch, _pack, _init, _size, _move, _is_in_bad_fork = patch_v2, unpatch_v2, pack_v2, init_v2, size_v2, move_v2, is_in_bad_fork_v2 | |
| else: | |
| _patch, _unpatch, _pack, _init, _size, _move, _is_in_bad_fork = patch_legacy, unpatch_legacy, pack_legacy, init_legacy, size_legacy, move_legacy, is_in_bad_fork_legacy | |
| else: | |
| def _placeholder_func(*args, **kwargs): pass | |
| def _placeholder_zero(*args, **kwargs): return 0 | |
| def _placeholder_false(*args, **kwargs): return False | |
| _patch, _unpatch, _init, _move = _placeholder_func, _placeholder_func, _placeholder_func, _placeholder_func | |
| _pack, _size = _placeholder_zero, _placeholder_zero | |
| _is_in_bad_fork = _placeholder_false | |
| patch_torch, unpatch_torch, pack_torch, init_torch, size_torch, move_torch, is_in_bad_fork_torch = _patch, _unpatch, _pack, _init, _size, _move, _is_in_bad_fork | |
| _patch_torch_global = patch_torch | |
| _unpatch_torch_global = unpatch_torch | |
| GENERATOR_GLOBAL_TIMEOUT = 20 * 60 | |
| SPAWN_PROGRESS_CLEANUP, SPAWN_PROGRESS_INIT = 0.1, 0.1 | |
| forked = False | |
| class Worker(Generic[Res]): | |
| thread: Thread | |
| arg_queue: "SimpleQueue[tuple[Params, GradioPartialContext]]" | |
| res_queue: "SimpleQueue[Res | None]" | |
| _sentinel: "Thread" | |
| def __init__(self, task: Callable, is_generator: bool, allow_token: str, nvidia_uuid: str): | |
| self._sentinel = Thread(target=self._close_on_exit, daemon=True) | |
| self.arg_queue = SimpleQueue() | |
| self.res_queue = SimpleQueue() | |
| args = task, is_generator, self.arg_queue, self.res_queue, allow_token, nvidia_uuid, [] | |
| self.thread = Thread(target=self._worker_thread_wrapper, args=args, daemon=True) | |
| self.thread.start() | |
| self._sentinel.start() | |
| def _worker_thread_wrapper(self, task: Callable[..., Any], is_generator: bool, arg_queue: SimpleQueue[tuple[Params, GradioPartialContext]], res_queue: SimpleQueue[Any | None], allow_token: str, nvidia_uuid: str, fds: list[int]): | |
| global forked | |
| forked = True | |
| initialized = False | |
| while True: | |
| try: | |
| (args, kwargs), gradio_context = arg_queue.get() | |
| except (OSError, EOFError): break | |
| if not initialized: | |
| if (init_res := worker_init(res_queue=res_queue, allow_token=allow_token, nvidia_uuid=nvidia_uuid, fds=fds)) is not None: | |
| res_queue.put(init_res) | |
| return | |
| initialized = True | |
| GradioPartialContext.apply(gradio_context) | |
| context = copy_context() | |
| if is_generator: | |
| def iterate(): | |
| try: | |
| gen = task(*args, **kwargs) | |
| for res in gen: | |
| try: | |
| res_queue.put(OkResult(res)) | |
| except Exception as e: | |
| res_queue.put(exception_result(e)) | |
| break | |
| except Exception as e: | |
| res_queue.put(exception_result(e)) | |
| finally: | |
| res_queue.put(EndResult()) | |
| with ThreadPoolExecutor(1) as executor: | |
| executor.submit(context.run, iterate) | |
| else: | |
| def run_task(): | |
| try: | |
| res = OkResult(task(*args, **kwargs)) | |
| except Exception as e: | |
| res = exception_result(e) | |
| try: | |
| res_queue.put(res) | |
| except Exception as e: | |
| res_queue.put(exception_result(e)) | |
| with ThreadPoolExecutor(1) as executor: | |
| future = executor.submit(context.run, run_task) | |
| future.result() | |
| def _close_on_exit(self): | |
| self.thread.join() | |
| self.arg_queue.close() | |
| try: | |
| self.res_queue.wlock_release() | |
| except Exception: | |
| pass | |
| self.res_queue.put(None) | |
| def worker_init(res_queue: Union["SimpleQueue[RegularResQueueResult | None]", "SimpleQueue[GeneratorResQueueResult | None]"], allow_token: str, nvidia_uuid: str, fds: list[int]) -> Optional[ExceptionResult]: | |
| for fd in fds: | |
| try: | |
| os.close(fd) | |
| except Exception as e: | |
| if isinstance(e, OSError) and e.errno == 9: pass | |
| return exception_result(e) | |
| try: | |
| pass | |
| except Exception as e: | |
| print(f"Error while trying to remove tqdm multiprocessing lock: {e}", file=sys.stderr) | |
| progress_context = tqdm(total=100, desc="ZeroGPU init", file=open(os.devnull, 'w')) if tqdm is not None and Config.zero_gpu_v2 else nullcontext() | |
| try: | |
| patch_gradio_queue(res_queue) | |
| with progress_context as p_bar: | |
| current_progress = 0 | |
| def update(n: float): | |
| nonlocal current_progress | |
| current_progress += n | |
| if p_bar is not None and hasattr(p_bar, 'n'): | |
| p_bar.update(round(current_progress * 100) - p_bar.n) | |
| allow(allow_token) | |
| update(SPAWN_PROGRESS_CLEANUP) | |
| _unpatch_torch_global() | |
| init_torch(nvidia_uuid) | |
| update(SPAWN_PROGRESS_INIT) | |
| callback = None | |
| if (transfer_size := size_torch()) > 0: | |
| remaining = 1 - (SPAWN_PROGRESS_CLEANUP + SPAWN_PROGRESS_INIT) | |
| def _callback(n): return update(n * remaining / transfer_size) | |
| callback = _callback | |
| move_torch(callback=callback) | |
| _patch_torch_global() | |
| except Exception as e: | |
| return exception_result(e) | |
| return None | |
| def process_duration(duration: Duration | None) -> timedelta: | |
| return timedelta(seconds=0) | |
| def static_duration(duration: DynamicDuration[Param], *args: Param.args, **kwargs: Param.kwargs) -> timedelta: | |
| return timedelta(seconds=0) | |
| def exception_result(exc: Exception) -> ExceptionResult: | |
| formatted = "".join(list(map(str, sys.exc_info()))) | |
| return ExceptionResult(traceback=formatted, error_cls=exc.__class__.__name__) | |
| def regular_function_wrapper(task: Callable[Param, Res], duration: DynamicDuration[Param]) -> Callable[Param, Optional[Res]]: | |
| request_var_getter = gradio_request_var | |
| workers: dict[NvidiaIndex, Worker[RegularResQueueResult[Res] | None]] = {} | |
| task_id = id(task) | |
| def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Optional[Res]: | |
| if forked: | |
| return task(*args, **kwargs) | |
| try: | |
| request_var = request_var_getter() | |
| request = request_var.get(None) if request_var else None | |
| duration_ = static_duration(duration, *args, **kwargs) | |
| schedule_response = schedule(task_id=task_id, request=request, duration=duration_) | |
| if schedule_response is None: | |
| pass | |
| allow_token, nvidia_index, nvidia_uuid = schedule_response.allowToken, schedule_response.nvidiaIndex, schedule_response.nvidiaUUID | |
| release_fn = partial(release, allow_token) | |
| worker = workers.pop(nvidia_index, None) | |
| if not (worker and worker.thread.is_alive() and schedule_response.idle): | |
| worker = Worker(task, False, allow_token, nvidia_uuid) | |
| worker.arg_queue.put(((args, kwargs), GradioPartialContext.get())) | |
| while True: | |
| res = worker.res_queue.get() | |
| if res is None: | |
| release_fn(fail=True, allow_404=True) | |
| pass | |
| if isinstance(res, ExceptionResult): | |
| release_fn(fail=True) | |
| pass | |
| if isinstance(res, OkResult): | |
| release_fn() | |
| workers[nvidia_index] = worker | |
| return res.value | |
| if isinstance(res, GradioQueueEvent): | |
| try_process_queue_event(res.method_name, *res.args, **res.kwargs) | |
| continue | |
| assert_never(res) | |
| except Exception as e: | |
| print(f"GPU process operation failed: {e}. Falling back to CPU execution.", file=sys.stderr) | |
| _unpatch_torch_global() | |
| try: | |
| return task(*args, **kwargs) | |
| except Exception as cpu_e: | |
| print(f"CPU fallback execution also failed: {cpu_e}", file=sys.stderr) | |
| return None | |
| finally: | |
| _patch_torch_global() | |
| if not hasattr(task, '__annotations__'): | |
| gradio_handler.__annotations__ = {} | |
| return gradio_handler | |
| def generator_function_wrapper(task: Callable[Param, Generator[Res, None, None]], duration: DynamicDuration[Param]) -> Callable[Param, Generator[Res, None, None]]: | |
| request_var_getter = gradio_request_var | |
| workers: dict[NvidiaIndex, Worker[GeneratorResQueueResult[Res] | None]] = {} | |
| task_id = id(task) | |
| def gradio_handler(*args: Param.args, **kwargs: Param.kwargs) -> Generator[Res, None, None]: | |
| if forked: | |
| yield from task(*args, **kwargs) | |
| return | |
| try: | |
| request_var = request_var_getter() | |
| request = request_var.get(None) if request_var else None | |
| duration_ = static_duration(duration, *args, **kwargs) | |
| schedule_response = schedule(task_id=task_id, request=request, duration=duration_) | |
| if schedule_response is None: | |
| pass | |
| allow_token, nvidia_index, nvidia_uuid = schedule_response.allowToken, schedule_response.nvidiaIndex, schedule_response.nvidiaUUID | |
| release_fn = partial(release, allow_token) | |
| worker = workers.pop(nvidia_index, None) | |
| if not (worker and worker.thread.is_alive() and schedule_response.idle): | |
| worker = Worker(task, True, allow_token, nvidia_uuid) | |
| worker.arg_queue.put(((args, kwargs), GradioPartialContext.get())) | |
| yield_queue: ThreadQueue[YieldQueueResult[Res]] = ThreadQueue() | |
| def fill_yield_queue(worker_instance): | |
| while True: | |
| res = worker_instance.res_queue.get() | |
| if res is None: | |
| release_fn(fail=True, allow_404=True) | |
| yield_queue.put(AbortedResult()) | |
| return | |
| if isinstance(res, ExceptionResult): | |
| release_fn(fail=True) | |
| yield_queue.put(res) | |
| return | |
| if isinstance(res, EndResult): | |
| release_fn() | |
| workers[nvidia_index] = worker_instance | |
| yield_queue.put(EndResult()) | |
| return | |
| if isinstance(res, OkResult): | |
| yield_queue.put(OkResult(res.value)) | |
| continue | |
| if isinstance(res, GradioQueueEvent): | |
| try_process_queue_event(res.method_name, *res.args, **res.kwargs) | |
| continue | |
| assert_never(res) | |
| with ThreadPoolExecutor(1) as e: | |
| e.submit(copy_context().run, fill_yield_queue, worker) | |
| while True: | |
| try: | |
| res = yield_queue.get(timeout=GENERATOR_GLOBAL_TIMEOUT) | |
| except Empty: | |
| pass | |
| if isinstance(res, AbortedResult): | |
| pass | |
| if isinstance(res, ExceptionResult): | |
| pass | |
| if isinstance(res, EndResult): | |
| return | |
| if isinstance(res, OkResult): | |
| yield res.value | |
| continue | |
| assert_never(res) | |
| except Exception as e: | |
| print(f"GPU generator process operation failed: {e}. Falling back to CPU execution.", file=sys.stderr) | |
| _unpatch_torch_global() | |
| try: | |
| yield from task(*args, **kwargs) | |
| except Exception as cpu_e: | |
| print(f"CPU fallback execution for generator also failed: {cpu_e}", file=sys.stderr) | |
| finally: | |
| _patch_torch_global() | |
| if not hasattr(task, '__annotations__'): | |
| gradio_handler.__annotations__ = {} | |
| return gradio_handler | |
| P_decorator = ParamSpec('P_decorator') | |
| R_decorator = TypeVar('R_decorator') | |
| decorated_cache: dict[Callable, Callable] = {} | |
| def GPU(task: None = None, *, duration: DynamicDuration[P_decorator] = 0) -> Callable[[Callable[P_decorator, R_decorator]], Callable[P_decorator, R_decorator]]: ... | |
| def GPU(task: Callable[P_decorator, R_decorator], *, duration: DynamicDuration[P_decorator] = 0) -> Callable[P_decorator, R_decorator]: ... | |
| def GPU(task: Optional[Callable[P_decorator, R_decorator]] = None, *, duration: DynamicDuration[P_decorator] = 0, **kwargs: Unpack[EmptyKwargs]) -> Union[Callable[[Callable[P_decorator, R_decorator]], Callable[P_decorator, R_decorator]], Callable[P_decorator, R_decorator]]: | |
| if "enable_queue" in kwargs: | |
| warnings.warn("`enable_queue` parameter is now ignored and always set to `True`") | |
| if task is None: | |
| return partial(_GPU, duration=duration) | |
| return _GPU(task, duration) | |
| def _GPU(task: Callable[P_decorator, R_decorator], duration: DynamicDuration[P_decorator]) -> Callable[P_decorator, R_decorator]: | |
| if not Config.zero_gpu: | |
| return task | |
| if sys.version_info.minor < 9: | |
| print("Error: Actually using @spaces.GPU on a ZeroGPU Space requires Python 3.9+", file=sys.stderr) | |
| return task | |
| if task in decorated_cache: | |
| return decorated_cache[task] | |
| if inspect.iscoroutinefunction(task): | |
| print("Error: Coroutine functions are not supported by @spaces.GPU.", file=sys.stderr) | |
| return task | |
| if inspect.isgeneratorfunction(task): | |
| decorated = generator_function_wrapper(task, duration) | |
| else: | |
| decorated = regular_function_wrapper(task, duration) | |
| setattr(decorated, 'zerogpu', True) | |
| decorated_cache.update({task: decorated, decorated: decorated}) | |
| return decorated | |
| gradio_auto_wrap_enabled = Config.gradio_auto_wrap | |
| def disable_gradio_auto_wrap() -> None: | |
| global gradio_auto_wrap_enabled | |
| gradio_auto_wrap_enabled = False | |
| def enable_gradio_auto_wrap() -> None: | |
| global gradio_auto_wrap_enabled | |
| gradio_auto_wrap_enabled = True | |
| def gradio_auto_wrap(task: Callable[Param, Res]) -> Callable[Param, Res]: ... | |
| def gradio_auto_wrap(task: None) -> None: ... | |
| def gradio_auto_wrap(task: Optional[Callable[Param, Res]]) -> Optional[Callable[Param, Res]]: | |
| if not gradio_auto_wrap_enabled or not callable(task): | |
| return task | |
| if getattr(task, 'zerogpu', False): | |
| return task | |
| return GPU(task) | |
| def _patch_gradio_auto_wrap(): | |
| if not Config.zero_gpu or not Config.gradio_auto_wrap: | |
| return | |
| try: | |
| from gradio.blocks import Block | |
| _original_set_event_trigger = Block.set_event_trigger | |
| except (ImportError, AttributeError): | |
| print("Warning: Could not find gradio.blocks.Block.set_event_trigger for auto-wrap patching. Auto-wrap disabled.", file=sys.stderr) | |
| return | |
| def _new_set_event_trigger(self, event_name: str, fn: Union[Callable, List[Callable], None], inputs, outputs, **kwargs): | |
| if fn is None: | |
| return _original_set_event_trigger(self, event_name, fn, inputs, outputs, **kwargs) | |
| if isinstance(fn, list): | |
| wrapped_fns = [gradio_auto_wrap(f) for f in fn] | |
| return _original_set_event_trigger(self, event_name, wrapped_fns, inputs, outputs, **kwargs) | |
| else: | |
| wrapped_fn = gradio_auto_wrap(fn) | |
| return _original_set_event_trigger(self, event_name, wrapped_fn, inputs, outputs, **kwargs) | |
| Block.set_event_trigger = _new_set_event_trigger | |
| print("Gradio Block event trigger patched for ZeroGPU auto-wrap.", file=sys.stderr) | |
| if sys.version_info.minor < 8: | |
| print("Warning: Importing PySpaces requires Python 3.8+", file=sys.stderr) | |
| try: | |
| if (gr_module := sys.modules.get("gradio")) is not None: | |
| getattr(gr_module, 'Blocks') | |
| except AttributeError: | |
| print("ImportError: Gradio does not have 'Blocks' attribute. Please check your Gradio installation.", file=sys.stderr) | |
| pass | |
| def aoti_apply(compiled_fn: Any, module: Any): | |
| if torch is None: | |
| return module | |
| if hasattr(module, 'to') and isinstance(module, torch.nn.Module): | |
| module.to(device="cpu") | |
| return module | |
| __all__ = ["GPU", "gradio_auto_wrap", "disable_gradio_auto_wrap", "enable_gradio_auto_wrap", "aoti_apply"] | |
| if Config.zero_gpu: | |
| try: | |
| if is_in_bad_fork_torch(): | |
| pass | |
| except Exception as e: | |
| print(f"Could not check for bad fork: {e}", file=sys.stderr) | |
| def startup(): | |
| total_size = pack_torch() | |
| _patch_gradio_auto_wrap() | |
| if Config.zerogpu_size == 'auto': | |
| gpu_size = 'medium' if total_size < Config.zerogpu_medium_size_threshold else 'large' | |
| else: | |
| gpu_size = Config.zerogpu_size | |
| startup_report_client(self_cgroup_device_path(), gpu_size) | |
| _patch_torch_global() | |
| one_launch(startup) | |
| try: | |
| shutil.rmtree(Config.zerogpu_offload_dir, ignore_errors=True) | |
| Path(Config.zerogpu_offload_dir).mkdir(parents=True, exist_ok=True) | |
| except Exception as e: | |
| print(f"Could not prepare ZeroGPU offload directory: {e}", file=sys.stderr) |