File size: 16,002 Bytes
be4a029
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
import hashlib
import json
import logging
import os
import warnings
import webbrowser
from pathlib import Path
from typing import Any

from gradio.blocks import BUILT_IN_THEMES
from gradio.themes import Default as DefaultTheme
from gradio.themes import ThemeClass
from gradio.utils import TupleNoPrint
from gradio_client import Client
from huggingface_hub import SpaceStorage

from trackio import context_vars, deploy, utils
from trackio.deploy import sync
from trackio.histogram import Histogram
from trackio.imports import import_csv, import_tf_events
from trackio.media import TrackioAudio, TrackioImage, TrackioVideo
from trackio.run import Run
from trackio.sqlite_storage import SQLiteStorage
from trackio.table import Table
from trackio.ui.main import demo
from trackio.utils import TRACKIO_DIR, TRACKIO_LOGO_DIR

logging.getLogger("httpx").setLevel(logging.WARNING)

warnings.filterwarnings(
    "ignore",
    message="Empty session being created. Install gradio\\[oauth\\]",
    category=UserWarning,
    module="gradio.helpers",
)

__version__ = json.loads(Path(__file__).parent.joinpath("package.json").read_text())[
    "version"
]

__all__ = [
    "init",
    "log",
    "finish",
    "show",
    "sync",
    "delete_project",
    "import_csv",
    "import_tf_events",
    "Image",
    "Video",
    "Audio",
    "Table",
    "Histogram",
]

Image = TrackioImage
Video = TrackioVideo
Audio = TrackioAudio


config = {}

DEFAULT_THEME = "default"


def init(
    project: str,
    name: str | None = None,
    group: str | None = None,
    space_id: str | None = None,
    space_storage: SpaceStorage | None = None,
    dataset_id: str | None = None,
    config: dict | None = None,
    resume: str = "never",
    settings: Any = None,
    private: bool | None = None,
    embed: bool = True,
) -> Run:
    """
    Creates a new Trackio project and returns a [`Run`] object.

    Args:
        project (`str`):
            The name of the project (can be an existing project to continue tracking or
            a new project to start tracking from scratch).
        name (`str`, *optional*):
            The name of the run (if not provided, a default name will be generated).
        group (`str`, *optional*):
            The name of the group which this run belongs to in order to help organize
            related runs together. You can toggle the entire group's visibilitiy in the
            dashboard.
        space_id (`str`, *optional*):
            If provided, the project will be logged to a Hugging Face Space instead of
            a local directory. Should be a complete Space name like
            `"username/reponame"` or `"orgname/reponame"`, or just `"reponame"` in which
            case the Space will be created in the currently-logged-in Hugging Face
            user's namespace. If the Space does not exist, it will be created. If the
            Space already exists, the project will be logged to it.
        space_storage ([`~huggingface_hub.SpaceStorage`], *optional*):
            Choice of persistent storage tier.
        dataset_id (`str`, *optional*):
            If a `space_id` is provided, a persistent Hugging Face Dataset will be
            created and the metrics will be synced to it every 5 minutes. Specify a
            Dataset with name like `"username/datasetname"` or `"orgname/datasetname"`,
            or `"datasetname"` (uses currently-logged-in Hugging Face user's namespace),
            or `None` (uses the same name as the Space but with the `"_dataset"`
            suffix). If the Dataset does not exist, it will be created. If the Dataset
            already exists, the project will be appended to it.
        config (`dict`, *optional*):
            A dictionary of configuration options. Provided for compatibility with
            `wandb.init()`.
        resume (`str`, *optional*, defaults to `"never"`):
            Controls how to handle resuming a run. Can be one of:

            - `"must"`: Must resume the run with the given name, raises error if run
              doesn't exist
            - `"allow"`: Resume the run if it exists, otherwise create a new run
            - `"never"`: Never resume a run, always create a new one
        private (`bool`, *optional*):
            Whether to make the Space private. If None (default), the repo will be
            public unless the organization's default is private. This value is ignored
            if the repo already exists.
        settings (`Any`, *optional*):
            Not used. Provided for compatibility with `wandb.init()`.
        embed (`bool`, *optional*, defaults to `True`):
            If running inside a jupyter/Colab notebook, whether the dashboard should
            automatically be embedded in the cell when trackio.init() is called.

    Returns:
        `Run`: A [`Run`] object that can be used to log metrics and finish the run.
    """
    if settings is not None:
        warnings.warn(
            "* Warning: settings is not used. Provided for compatibility with wandb.init(). Please create an issue at: https://github.com/gradio-app/trackio/issues if you need a specific feature implemented."
        )

    if space_id is None and dataset_id is not None:
        raise ValueError("Must provide a `space_id` when `dataset_id` is provided.")
    space_id, dataset_id = utils.preprocess_space_and_dataset_ids(space_id, dataset_id)
    url = context_vars.current_server.get()
    share_url = context_vars.current_share_server.get()

    if url is None:
        if space_id is None:
            _, url, share_url = demo.launch(
                show_api=False,
                inline=False,
                quiet=True,
                prevent_thread_lock=True,
                show_error=True,
                favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
                allowed_paths=[TRACKIO_LOGO_DIR, TRACKIO_DIR],
            )
        else:
            url = space_id
            share_url = None
        context_vars.current_server.set(url)
        context_vars.current_share_server.set(share_url)
    if (
        context_vars.current_project.get() is None
        or context_vars.current_project.get() != project
    ):
        print(f"* Trackio project initialized: {project}")

        if dataset_id is not None:
            os.environ["TRACKIO_DATASET_ID"] = dataset_id
            print(
                f"* Trackio metrics will be synced to Hugging Face Dataset: {dataset_id}"
            )
        if space_id is None:
            print(f"* Trackio metrics logged to: {TRACKIO_DIR}")
            if utils.is_in_notebook() and embed:
                base_url = share_url + "/" if share_url else url
                full_url = utils.get_full_url(
                    base_url, project=project, write_token=demo.write_token
                )
                utils.embed_url_in_notebook(full_url)
            else:
                utils.print_dashboard_instructions(project)
        else:
            deploy.create_space_if_not_exists(
                space_id, space_storage, dataset_id, private
            )
            user_name, space_name = space_id.split("/")
            space_url = deploy.SPACE_HOST_URL.format(
                user_name=user_name, space_name=space_name
            )
            print(f"* View dashboard by going to: {space_url}")
            if utils.is_in_notebook() and embed:
                utils.embed_url_in_notebook(space_url)
    context_vars.current_project.set(project)

    client = None
    if not space_id:
        client = Client(url, verbose=False)

    if resume == "must":
        if name is None:
            raise ValueError("Must provide a run name when resume='must'")
        if name not in SQLiteStorage.get_runs(project):
            raise ValueError(f"Run '{name}' does not exist in project '{project}'")
        resumed = True
    elif resume == "allow":
        resumed = name is not None and name in SQLiteStorage.get_runs(project)
    elif resume == "never":
        if name is not None and name in SQLiteStorage.get_runs(project):
            warnings.warn(
                f"* Warning: resume='never' but a run '{name}' already exists in "
                f"project '{project}'. Generating a new name and instead. If you want "
                "to resume this run, call init() with resume='must' or resume='allow'."
            )
            name = None
        resumed = False
    else:
        raise ValueError("resume must be one of: 'must', 'allow', or 'never'")

    run = Run(
        url=url,
        project=project,
        client=client,
        name=name,
        group=group,
        config=config,
        space_id=space_id,
    )

    if resumed:
        print(f"* Resumed existing run: {run.name}")
    else:
        print(f"* Created new run: {run.name}")

    context_vars.current_run.set(run)
    globals()["config"] = run.config
    return run


def log(metrics: dict, step: int | None = None) -> None:
    """
    Logs metrics to the current run.

    Args:
        metrics (`dict`):
            A dictionary of metrics to log.
        step (`int`, *optional*):
            The step number. If not provided, the step will be incremented
            automatically.
    """
    run = context_vars.current_run.get()
    if run is None:
        raise RuntimeError("Call trackio.init() before trackio.log().")
    run.log(
        metrics=metrics,
        step=step,
    )


def finish():
    """
    Finishes the current run.
    """
    run = context_vars.current_run.get()
    if run is None:
        raise RuntimeError("Call trackio.init() before trackio.finish().")
    run.finish()


def delete_project(project: str, force: bool = False) -> bool:
    """
    Deletes a project by removing its local SQLite database.

    Args:
        project (`str`):
            The name of the project to delete.
        force (`bool`, *optional*, defaults to `False`):
            If `True`, deletes the project without prompting for confirmation.
            If `False`, prompts the user to confirm before deleting.

    Returns:
        `bool`: `True` if the project was deleted, `False` otherwise.
    """
    db_path = SQLiteStorage.get_project_db_path(project)

    if not db_path.exists():
        print(f"* Project '{project}' does not exist.")
        return False

    if not force:
        response = input(
            f"Are you sure you want to delete project '{project}'? "
            f"This will permanently delete all runs and metrics. (y/N): "
        )
        if response.lower() not in ["y", "yes"]:
            print("* Deletion cancelled.")
            return False

    try:
        db_path.unlink()

        for suffix in ("-wal", "-shm"):
            sidecar = Path(str(db_path) + suffix)
            if sidecar.exists():
                sidecar.unlink()

        print(f"* Project '{project}' has been deleted.")
        return True
    except Exception as e:
        print(f"* Error deleting project '{project}': {e}")
        return False


def show(
    project: str | None = None,
    theme: str | ThemeClass | None = None,
    mcp_server: bool | None = None,
    color_palette: list[str] | None = None,
    *,
    open_browser: bool = True,
    block_thread: bool | None = None,
):
    """
    Launches the Trackio dashboard.

    Args:
        project (`str`, *optional*):
            The name of the project whose runs to show. If not provided, all projects
            will be shown and the user can select one.
        theme (`str` or `ThemeClass`, *optional*):
            A Gradio Theme to use for the dashboard instead of the default Gradio theme,
            can be a built-in theme (e.g. `'soft'`, `'citrus'`), a theme from the Hub
            (e.g. `"gstaff/xkcd"`), or a custom Theme class. If not provided, the
            `TRACKIO_THEME` environment variable will be used, or if that is not set, the
            default Gradio theme will be used.
        mcp_server (`bool`, *optional*):
            If `True`, the Trackio dashboard will be set up as an MCP server and certain
            functions will be added as MCP tools. If `None` (default behavior), then the
            `GRADIO_MCP_SERVER` environment variable will be used to determine if the
            MCP server should be enabled (which is `"True"` on Hugging Face Spaces).
        color_palette (`list[str]`, *optional*):
            A list of hex color codes to use for plot lines. If not provided, the
            `TRACKIO_COLOR_PALETTE` environment variable will be used (comma-separated
            hex codes), or if that is not set, the default color palette will be used.
            Example: `['#FF0000', '#00FF00', '#0000FF']`
        open_browser (`bool`, *optional*, defaults to `True`):
            If `True` and not in a notebook, a new browser tab will be opened with the dashboard.
            If `False`, the browser will not be opened.
        block_thread (`bool`, *optional*):
            If `True`, the main thread will be blocked until the dashboard is closed.
            If `None` (default behavior), then the main thread will not be blocked if the
            dashboard is launched in a notebook, otherwise the main thread will be blocked.

        Returns:
            `app`: The Gradio app object corresponding to the dashboard launched by Trackio.
            `url`: The local URL of the dashboard.
            `share_url`: The public share URL of the dashboard.
            `full_url`: The full URL of the dashboard including the write token (will use the public share URL if launched publicly, otherwise the local URL).
    """
    if color_palette is not None:
        os.environ["TRACKIO_COLOR_PALETTE"] = ",".join(color_palette)

    theme = theme or os.environ.get("TRACKIO_THEME", DEFAULT_THEME)

    if theme != DEFAULT_THEME:
        # TODO: It's a little hacky to reproduce this theme-setting logic from Gradio Blocks,
        # but in Gradio 6.0, the theme will be set in `launch()` instead, which means that we
        # will be able to remove this code.
        if isinstance(theme, str):
            if theme.lower() in BUILT_IN_THEMES:
                theme = BUILT_IN_THEMES[theme.lower()]
            else:
                try:
                    theme = ThemeClass.from_hub(theme)
                except Exception as e:
                    warnings.warn(f"Cannot load {theme}. Caught Exception: {str(e)}")
                    theme = DefaultTheme()
        if not isinstance(theme, ThemeClass):
            warnings.warn("Theme should be a class loaded from gradio.themes")
            theme = DefaultTheme()
        demo.theme: ThemeClass = theme
        demo.theme_css = theme._get_theme_css()
        demo.stylesheets = theme._stylesheets
        theme_hasher = hashlib.sha256()
        theme_hasher.update(demo.theme_css.encode("utf-8"))
        demo.theme_hash = theme_hasher.hexdigest()

    _mcp_server = (
        mcp_server
        if mcp_server is not None
        else os.environ.get("GRADIO_MCP_SERVER", "False") == "True"
    )

    app, url, share_url = demo.launch(
        show_api=_mcp_server,
        quiet=True,
        inline=False,
        prevent_thread_lock=True,
        favicon_path=TRACKIO_LOGO_DIR / "trackio_logo_light.png",
        allowed_paths=[TRACKIO_LOGO_DIR, TRACKIO_DIR],
        mcp_server=_mcp_server,
    )

    base_url = share_url + "/" if share_url else url
    full_url = utils.get_full_url(
        base_url, project=project, write_token=demo.write_token
    )

    if not utils.is_in_notebook():
        print(f"* Trackio UI launched at: {full_url}")
        if open_browser:
            webbrowser.open(full_url)
        block_thread = block_thread if block_thread is not None else True
    else:
        utils.embed_url_in_notebook(full_url)
        block_thread = block_thread if block_thread is not None else False

    if block_thread:
        utils.block_main_thread_until_keyboard_interrupt()
    return TupleNoPrint((demo, url, share_url, full_url))