Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	Commit 
							
							ยท
						
						9dcfb8f
	
1
								Parent(s):
							
							8cf952e
								
update logo
Browse files- README.md +6 -2
- assets/logo-sdg.svg +0 -1
- assets/logo.svg +1 -42
- src/synthetic_dataset_generator/_tabbedinterface.py +5 -11
- src/synthetic_dataset_generator/app.py +3 -1
- src/synthetic_dataset_generator/apps/eval.py +35 -15
- src/synthetic_dataset_generator/apps/sft.py +18 -8
- src/synthetic_dataset_generator/apps/textcat.py +21 -5
    	
        README.md
    CHANGED
    
    | @@ -17,8 +17,12 @@ hf_oauth_scopes: | |
| 17 | 
             
            - manage-repos
         | 
| 18 | 
             
            - inference-api
         | 
| 19 | 
             
            ---
         | 
| 20 | 
            -
             | 
| 21 | 
            -
            <img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo | 
|  | |
|  | |
|  | |
|  | |
| 22 |  | 
| 23 | 
             
            
         | 
| 24 |  | 
|  | |
| 17 | 
             
            - manage-repos
         | 
| 18 | 
             
            - inference-api
         | 
| 19 | 
             
            ---
         | 
| 20 | 
            +
            <p align="center">
         | 
| 21 | 
            +
            <img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo.svg" alt="Synthetic Data Generator Logo" style="width: 80%;"/>
         | 
| 22 | 
            +
            </p>
         | 
| 23 | 
            +
            <p align="center">
         | 
| 24 | 
            +
              <h3>Build datasets using natural language</h3>
         | 
| 25 | 
            +
            </p>
         | 
| 26 |  | 
| 27 | 
             
            
         | 
| 28 |  | 
    	
        assets/logo-sdg.svg
    DELETED
    
    
    	
        assets/logo.svg
    CHANGED
    
    |  | 
|  | 
    	
        src/synthetic_dataset_generator/_tabbedinterface.py
    CHANGED
    
    | @@ -8,7 +8,6 @@ from collections.abc import Sequence | |
| 8 |  | 
| 9 | 
             
            import gradio as gr
         | 
| 10 | 
             
            from gradio.blocks import Blocks
         | 
| 11 | 
            -
            from gradio.components import HTML
         | 
| 12 | 
             
            from gradio.layouts import Tab, Tabs
         | 
| 13 | 
             
            from gradio.themes import ThemeClass as Theme
         | 
| 14 | 
             
            from gradio_client.documentation import document
         | 
| @@ -61,16 +60,11 @@ class TabbedInterface(Blocks): | |
| 61 | 
             
                        tab_names = [f"Tab {i}" for i in range(len(interface_list))]
         | 
| 62 | 
             
                    with self:
         | 
| 63 | 
             
                        if title:
         | 
| 64 | 
            -
                            HTML(value=title)
         | 
| 65 | 
            -
             | 
| 66 | 
            -
             | 
| 67 | 
            -
             | 
| 68 | 
            -
             | 
| 69 | 
            -
                                    pass
         | 
| 70 | 
            -
                                with gr.Column(scale=2):
         | 
| 71 | 
            -
                                    gr.LoginButton(
         | 
| 72 | 
            -
                                        value="Sign in", variant="primary", scale=2
         | 
| 73 | 
            -
                                    )
         | 
| 74 | 
             
                        with Tabs():
         | 
| 75 | 
             
                            for interface, tab_name in zip(interface_list, tab_names, strict=False):
         | 
| 76 | 
             
                                with Tab(label=tab_name):
         | 
|  | |
| 8 |  | 
| 9 | 
             
            import gradio as gr
         | 
| 10 | 
             
            from gradio.blocks import Blocks
         | 
|  | |
| 11 | 
             
            from gradio.layouts import Tab, Tabs
         | 
| 12 | 
             
            from gradio.themes import ThemeClass as Theme
         | 
| 13 | 
             
            from gradio_client.documentation import document
         | 
|  | |
| 60 | 
             
                        tab_names = [f"Tab {i}" for i in range(len(interface_list))]
         | 
| 61 | 
             
                    with self:
         | 
| 62 | 
             
                        if title:
         | 
| 63 | 
            +
                            gr.HTML(value=title)
         | 
| 64 | 
            +
                        gr.HTML(
         | 
| 65 | 
            +
                            "<div style='text-align: center;'><h3>Build datasets using natural language</h3></div>"
         | 
| 66 | 
            +
                        )
         | 
| 67 | 
            +
                        gr.LoginButton(value="Sign in", variant="primary", scale=2)
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 68 | 
             
                        with Tabs():
         | 
| 69 | 
             
                            for interface, tab_name in zip(interface_list, tab_names, strict=False):
         | 
| 70 | 
             
                                with Tab(label=tab_name):
         | 
    	
        src/synthetic_dataset_generator/app.py
    CHANGED
    
    | @@ -10,11 +10,13 @@ css = """ | |
| 10 | 
             
            .main_ui_logged_out{opacity: 0.3; pointer-events: none}
         | 
| 11 | 
             
            """
         | 
| 12 |  | 
|  | |
|  | |
| 13 | 
             
            demo = TabbedInterface(
         | 
| 14 | 
             
                [textcat_app, sft_app, eval_app, faq_app],
         | 
| 15 | 
             
                ["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"],
         | 
| 16 | 
             
                css=css,
         | 
| 17 | 
            -
                title= | 
| 18 | 
             
                head="Synthetic Data Generator",
         | 
| 19 | 
             
                theme=theme,
         | 
| 20 | 
             
            )
         | 
|  | |
| 10 | 
             
            .main_ui_logged_out{opacity: 0.3; pointer-events: none}
         | 
| 11 | 
             
            """
         | 
| 12 |  | 
| 13 | 
            +
            image = """<img src="https://raw.githubusercontent.com/argilla-io/synthetic-data-generator/main/assets/logo-sdg.svg" alt="Synthetic Data Generator Logo" style="display: block; margin-left: auto; margin-right: auto; width: 75%; margin-bottom: -400px;"/>"""
         | 
| 14 | 
            +
             | 
| 15 | 
             
            demo = TabbedInterface(
         | 
| 16 | 
             
                [textcat_app, sft_app, eval_app, faq_app],
         | 
| 17 | 
             
                ["Text Classification", "Supervised Fine-Tuning", "Evaluation", "FAQ"],
         | 
| 18 | 
             
                css=css,
         | 
| 19 | 
            +
                title=image,
         | 
| 20 | 
             
                head="Synthetic Data Generator",
         | 
| 21 | 
             
                theme=theme,
         | 
| 22 | 
             
            )
         | 
    	
        src/synthetic_dataset_generator/apps/eval.py
    CHANGED
    
    | @@ -13,8 +13,9 @@ from datasets import ( | |
| 13 | 
             
                load_dataset,
         | 
| 14 | 
             
            )
         | 
| 15 | 
             
            from distilabel.distiset import Distiset
         | 
|  | |
| 16 | 
             
            from gradio_huggingfacehub_search import HuggingfaceHubSearch
         | 
| 17 | 
            -
            from huggingface_hub import HfApi
         | 
| 18 |  | 
| 19 | 
             
            from synthetic_dataset_generator.apps.base import (
         | 
| 20 | 
             
                hide_success_message,
         | 
| @@ -45,7 +46,10 @@ from synthetic_dataset_generator.utils import ( | |
| 45 |  | 
| 46 | 
             
            def get_iframe(hub_repo_id: str) -> str:
         | 
| 47 | 
             
                if not hub_repo_id:
         | 
| 48 | 
            -
                     | 
|  | |
|  | |
|  | |
| 49 |  | 
| 50 | 
             
                url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
         | 
| 51 | 
             
                iframe = f"""
         | 
| @@ -79,12 +83,14 @@ def get_valid_columns(dataframe: pd.DataFrame): | |
| 79 | 
             
                return instruction_valid_columns, response_valid_columns
         | 
| 80 |  | 
| 81 |  | 
| 82 | 
            -
            def load_dataset_from_hub( | 
|  | |
|  | |
| 83 | 
             
                if not repo_id:
         | 
| 84 | 
             
                    raise gr.Error("Hub repo id is required")
         | 
| 85 | 
            -
                subsets = get_dataset_config_names(repo_id)
         | 
| 86 | 
            -
                ds_dict = load_dataset(repo_id, subsets[0])
         | 
| 87 | 
            -
                splits = get_dataset_split_names(repo_id, subsets[0])
         | 
| 88 | 
             
                ds = ds_dict[splits[0]]
         | 
| 89 | 
             
                if num_rows:
         | 
| 90 | 
             
                    ds = ds.select(range(num_rows))
         | 
| @@ -601,7 +607,10 @@ with gr.Blocks() as app: | |
| 601 | 
             
                                search_type="dataset",
         | 
| 602 | 
             
                                sumbit_on_select=True,
         | 
| 603 | 
             
                            )
         | 
| 604 | 
            -
                             | 
|  | |
|  | |
|  | |
| 605 | 
             
                        with gr.Column(scale=3):
         | 
| 606 | 
             
                            search_out = gr.HTML(label="Dataset preview")
         | 
| 607 |  | 
| @@ -666,9 +675,9 @@ with gr.Blocks() as app: | |
| 666 | 
             
                                    inputs=[],
         | 
| 667 | 
             
                                    outputs=[eval_type],
         | 
| 668 | 
             
                                )
         | 
| 669 | 
            -
                             | 
| 670 | 
            -
                                 | 
| 671 | 
            -
             | 
| 672 | 
             
                        with gr.Column(scale=3):
         | 
| 673 | 
             
                            dataframe = gr.Dataframe(
         | 
| 674 | 
             
                                headers=["prompt", "completion", "evaluation"],
         | 
| @@ -724,7 +733,11 @@ with gr.Blocks() as app: | |
| 724 | 
             
                                    label="Distilabel Pipeline Code",
         | 
| 725 | 
             
                                )
         | 
| 726 |  | 
| 727 | 
            -
                search_in.submit(fn=get_iframe, inputs=search_in, outputs=search_out)
         | 
|  | |
|  | |
|  | |
|  | |
| 728 |  | 
| 729 | 
             
                load_btn.click(
         | 
| 730 | 
             
                    fn=load_dataset_from_hub,
         | 
| @@ -793,12 +806,8 @@ with gr.Blocks() as app: | |
| 793 | 
             
                    fn=generate_pipeline_code,
         | 
| 794 | 
             
                    inputs=[
         | 
| 795 | 
             
                        search_in,
         | 
| 796 | 
            -
                        aspects_instruction_response,
         | 
| 797 | 
            -
                        instruction_instruction_response,
         | 
| 798 | 
            -
                        response_instruction_response,
         | 
| 799 | 
             
                        prompt_template,
         | 
| 800 | 
             
                        structured_output,
         | 
| 801 | 
            -
                        num_rows,
         | 
| 802 | 
             
                        eval_type,
         | 
| 803 | 
             
                    ],
         | 
| 804 | 
             
                    outputs=[pipeline_code],
         | 
| @@ -808,5 +817,16 @@ with gr.Blocks() as app: | |
| 808 | 
             
                    outputs=[pipeline_code_ui],
         | 
| 809 | 
             
                )
         | 
| 810 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 811 | 
             
                app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 812 | 
             
                app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
|  | |
| 13 | 
             
                load_dataset,
         | 
| 14 | 
             
            )
         | 
| 15 | 
             
            from distilabel.distiset import Distiset
         | 
| 16 | 
            +
            from gradio.oauth import OAuthToken  #
         | 
| 17 | 
             
            from gradio_huggingfacehub_search import HuggingfaceHubSearch
         | 
| 18 | 
            +
            from huggingface_hub import HfApi, repo_exists
         | 
| 19 |  | 
| 20 | 
             
            from synthetic_dataset_generator.apps.base import (
         | 
| 21 | 
             
                hide_success_message,
         | 
|  | |
| 46 |  | 
| 47 | 
             
            def get_iframe(hub_repo_id: str) -> str:
         | 
| 48 | 
             
                if not hub_repo_id:
         | 
| 49 | 
            +
                    return ""
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                if not repo_exists(repo_id=hub_repo_id, repo_type="dataset"):
         | 
| 52 | 
            +
                    return ""
         | 
| 53 |  | 
| 54 | 
             
                url = f"https://huggingface.co/datasets/{hub_repo_id}/embed/viewer"
         | 
| 55 | 
             
                iframe = f"""
         | 
|  | |
| 83 | 
             
                return instruction_valid_columns, response_valid_columns
         | 
| 84 |  | 
| 85 |  | 
| 86 | 
            +
            def load_dataset_from_hub(
         | 
| 87 | 
            +
                repo_id: str, num_rows: int = 10, token: Union[OAuthToken, None] = None
         | 
| 88 | 
            +
            ):
         | 
| 89 | 
             
                if not repo_id:
         | 
| 90 | 
             
                    raise gr.Error("Hub repo id is required")
         | 
| 91 | 
            +
                subsets = get_dataset_config_names(repo_id, token=token)
         | 
| 92 | 
            +
                ds_dict = load_dataset(repo_id, subsets[0], token=token)
         | 
| 93 | 
            +
                splits = get_dataset_split_names(repo_id, subsets[0], token=token)
         | 
| 94 | 
             
                ds = ds_dict[splits[0]]
         | 
| 95 | 
             
                if num_rows:
         | 
| 96 | 
             
                    ds = ds.select(range(num_rows))
         | 
|  | |
| 607 | 
             
                                search_type="dataset",
         | 
| 608 | 
             
                                sumbit_on_select=True,
         | 
| 609 | 
             
                            )
         | 
| 610 | 
            +
                            with gr.Row():
         | 
| 611 | 
            +
                                load_btn = gr.Button("Load", variant="primary")
         | 
| 612 | 
            +
                                clear_btn_part = gr.Button("Clear", variant="secondary")
         | 
| 613 | 
            +
             | 
| 614 | 
             
                        with gr.Column(scale=3):
         | 
| 615 | 
             
                            search_out = gr.HTML(label="Dataset preview")
         | 
| 616 |  | 
|  | |
| 675 | 
             
                                    inputs=[],
         | 
| 676 | 
             
                                    outputs=[eval_type],
         | 
| 677 | 
             
                                )
         | 
| 678 | 
            +
                            with gr.Row():
         | 
| 679 | 
            +
                                btn_apply_to_sample_dataset = gr.Button("Save", variant="primary")
         | 
| 680 | 
            +
                                clear_btn_full = gr.Button("Clear", variant="secondary")
         | 
| 681 | 
             
                        with gr.Column(scale=3):
         | 
| 682 | 
             
                            dataframe = gr.Dataframe(
         | 
| 683 | 
             
                                headers=["prompt", "completion", "evaluation"],
         | 
|  | |
| 733 | 
             
                                    label="Distilabel Pipeline Code",
         | 
| 734 | 
             
                                )
         | 
| 735 |  | 
| 736 | 
            +
                search_in.submit(fn=get_iframe, inputs=search_in, outputs=search_out).then(
         | 
| 737 | 
            +
                    fn=lambda df: pd.DataFrame(columns=df.columns),
         | 
| 738 | 
            +
                    inputs=[dataframe],
         | 
| 739 | 
            +
                    outputs=[dataframe],
         | 
| 740 | 
            +
                )
         | 
| 741 |  | 
| 742 | 
             
                load_btn.click(
         | 
| 743 | 
             
                    fn=load_dataset_from_hub,
         | 
|  | |
| 806 | 
             
                    fn=generate_pipeline_code,
         | 
| 807 | 
             
                    inputs=[
         | 
| 808 | 
             
                        search_in,
         | 
|  | |
|  | |
|  | |
| 809 | 
             
                        prompt_template,
         | 
| 810 | 
             
                        structured_output,
         | 
|  | |
| 811 | 
             
                        eval_type,
         | 
| 812 | 
             
                    ],
         | 
| 813 | 
             
                    outputs=[pipeline_code],
         | 
|  | |
| 817 | 
             
                    outputs=[pipeline_code_ui],
         | 
| 818 | 
             
                )
         | 
| 819 |  | 
| 820 | 
            +
                clear_btn_part.click(fn=lambda x: "", inputs=[], outputs=[search_in])
         | 
| 821 | 
            +
                clear_btn_full.click(
         | 
| 822 | 
            +
                    fn=lambda df: ("", "", pd.DataFrame(columns=df.columns)),
         | 
| 823 | 
            +
                    inputs=[dataframe],
         | 
| 824 | 
            +
                    outputs=[
         | 
| 825 | 
            +
                        search_in,
         | 
| 826 | 
            +
                        instruction_instruction_response,
         | 
| 827 | 
            +
                        response_instruction_response,
         | 
| 828 | 
            +
                    ],
         | 
| 829 | 
            +
                )
         | 
| 830 | 
            +
             | 
| 831 | 
             
                app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 832 | 
             
                app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
    	
        src/synthetic_dataset_generator/apps/sft.py
    CHANGED
    
    | @@ -78,6 +78,15 @@ def generate_sample_dataset(system_prompt, num_turns, progress=gr.Progress()): | |
| 78 | 
             
                return dataframe
         | 
| 79 |  | 
| 80 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 81 | 
             
            def generate_dataset(
         | 
| 82 | 
             
                system_prompt: str,
         | 
| 83 | 
             
                num_turns: int = 1,
         | 
| @@ -368,7 +377,7 @@ with gr.Blocks() as app: | |
| 368 | 
             
                                        "Create",
         | 
| 369 | 
             
                                        variant="primary",
         | 
| 370 | 
             
                                    )
         | 
| 371 | 
            -
                                     | 
| 372 | 
             
                                        "Clear",
         | 
| 373 | 
             
                                        variant="secondary",
         | 
| 374 | 
             
                                    )
         | 
| @@ -401,17 +410,12 @@ with gr.Blocks() as app: | |
| 401 | 
             
                                    btn_apply_to_sample_dataset = gr.Button(
         | 
| 402 | 
             
                                        "Save", variant="primary"
         | 
| 403 | 
             
                                    )
         | 
| 404 | 
            -
                                     | 
| 405 | 
             
                                        "Clear",
         | 
| 406 | 
             
                                        variant="secondary",
         | 
| 407 | 
             
                                    )
         | 
| 408 | 
             
                            with gr.Column(scale=3):
         | 
| 409 | 
            -
                                dataframe =  | 
| 410 | 
            -
                                    headers=["prompt", "completion"],
         | 
| 411 | 
            -
                                    wrap=True,
         | 
| 412 | 
            -
                                    height=500,
         | 
| 413 | 
            -
                                    interactive=False,
         | 
| 414 | 
            -
                                )
         | 
| 415 |  | 
| 416 | 
             
                        gr.HTML(value="<hr>")
         | 
| 417 | 
             
                        gr.Markdown(value="## 3. Generate your dataset")
         | 
| @@ -527,6 +531,12 @@ with gr.Blocks() as app: | |
| 527 | 
             
                        inputs=[],
         | 
| 528 | 
             
                        outputs=[pipeline_code_ui],
         | 
| 529 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 530 |  | 
| 531 | 
             
                    app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 532 | 
             
                    app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
|  | |
| 78 | 
             
                return dataframe
         | 
| 79 |  | 
| 80 |  | 
| 81 | 
            +
            def _get_dataframe():
         | 
| 82 | 
            +
                return gr.Dataframe(
         | 
| 83 | 
            +
                    headers=["prompt", "completion"],
         | 
| 84 | 
            +
                    wrap=True,
         | 
| 85 | 
            +
                    height=500,
         | 
| 86 | 
            +
                    interactive=False,
         | 
| 87 | 
            +
                )
         | 
| 88 | 
            +
             | 
| 89 | 
            +
             | 
| 90 | 
             
            def generate_dataset(
         | 
| 91 | 
             
                system_prompt: str,
         | 
| 92 | 
             
                num_turns: int = 1,
         | 
|  | |
| 377 | 
             
                                        "Create",
         | 
| 378 | 
             
                                        variant="primary",
         | 
| 379 | 
             
                                    )
         | 
| 380 | 
            +
                                    clear_btn_part = gr.Button(
         | 
| 381 | 
             
                                        "Clear",
         | 
| 382 | 
             
                                        variant="secondary",
         | 
| 383 | 
             
                                    )
         | 
|  | |
| 410 | 
             
                                    btn_apply_to_sample_dataset = gr.Button(
         | 
| 411 | 
             
                                        "Save", variant="primary"
         | 
| 412 | 
             
                                    )
         | 
| 413 | 
            +
                                    clear_btn_full = gr.Button(
         | 
| 414 | 
             
                                        "Clear",
         | 
| 415 | 
             
                                        variant="secondary",
         | 
| 416 | 
             
                                    )
         | 
| 417 | 
             
                            with gr.Column(scale=3):
         | 
| 418 | 
            +
                                dataframe = _get_dataframe()
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 419 |  | 
| 420 | 
             
                        gr.HTML(value="<hr>")
         | 
| 421 | 
             
                        gr.Markdown(value="## 3. Generate your dataset")
         | 
|  | |
| 531 | 
             
                        inputs=[],
         | 
| 532 | 
             
                        outputs=[pipeline_code_ui],
         | 
| 533 | 
             
                    )
         | 
| 534 | 
            +
                    gr.on(
         | 
| 535 | 
            +
                        triggers=[clear_btn_part.click, clear_btn_full.click],
         | 
| 536 | 
            +
                        fn=lambda _: ("", "", 1, _get_dataframe()),
         | 
| 537 | 
            +
                        inputs=[dataframe],
         | 
| 538 | 
            +
                        outputs=[dataset_description, system_prompt, num_turns, dataframe],
         | 
| 539 | 
            +
                    )
         | 
| 540 |  | 
| 541 | 
             
                    app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 542 | 
             
                    app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
    	
        src/synthetic_dataset_generator/apps/textcat.py
    CHANGED
    
    | @@ -35,6 +35,12 @@ from src.synthetic_dataset_generator.utils import ( | |
| 35 | 
             
            from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
         | 
| 36 |  | 
| 37 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 38 | 
             
            def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
         | 
| 39 | 
             
                progress(0.0, desc="Generating text classification task")
         | 
| 40 | 
             
                progress(0.3, desc="Initializing text generation")
         | 
| @@ -345,7 +351,7 @@ with gr.Blocks() as app: | |
| 345 | 
             
                                    "Create",
         | 
| 346 | 
             
                                    variant="primary",
         | 
| 347 | 
             
                                )
         | 
| 348 | 
            -
                                 | 
| 349 | 
             
                                    "Clear",
         | 
| 350 | 
             
                                    variant="secondary",
         | 
| 351 | 
             
                                )
         | 
| @@ -411,11 +417,9 @@ with gr.Blocks() as app: | |
| 411 | 
             
                            )
         | 
| 412 | 
             
                            with gr.Row():
         | 
| 413 | 
             
                                btn_apply_to_sample_dataset = gr.Button("Save", variant="primary")
         | 
| 414 | 
            -
                                 | 
| 415 | 
             
                        with gr.Column(scale=3):
         | 
| 416 | 
            -
                            dataframe =  | 
| 417 | 
            -
                                headers=["labels", "text"], wrap=True, height=500, interactive=False
         | 
| 418 | 
            -
                            )
         | 
| 419 |  | 
| 420 | 
             
                    gr.HTML("<hr>")
         | 
| 421 | 
             
                    gr.Markdown("## 3. Generate your dataset")
         | 
| @@ -553,5 +557,17 @@ with gr.Blocks() as app: | |
| 553 | 
             
                    outputs=[pipeline_code_ui],
         | 
| 554 | 
             
                )
         | 
| 555 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 556 | 
             
                app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 557 | 
             
                app.load(fn=get_org_dropdown, outputs=[org_name])
         | 
|  | |
| 35 | 
             
            from synthetic_dataset_generator.constants import DEFAULT_BATCH_SIZE
         | 
| 36 |  | 
| 37 |  | 
| 38 | 
            +
            def _get_dataframe():
         | 
| 39 | 
            +
                return gr.Dataframe(
         | 
| 40 | 
            +
                    headers=["labels", "text"], wrap=True, height=500, interactive=False
         | 
| 41 | 
            +
                )
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
             
            def generate_system_prompt(dataset_description, temperature, progress=gr.Progress()):
         | 
| 45 | 
             
                progress(0.0, desc="Generating text classification task")
         | 
| 46 | 
             
                progress(0.3, desc="Initializing text generation")
         | 
|  | |
| 351 | 
             
                                    "Create",
         | 
| 352 | 
             
                                    variant="primary",
         | 
| 353 | 
             
                                )
         | 
| 354 | 
            +
                                clear_btn_part = gr.Button(
         | 
| 355 | 
             
                                    "Clear",
         | 
| 356 | 
             
                                    variant="secondary",
         | 
| 357 | 
             
                                )
         | 
|  | |
| 417 | 
             
                            )
         | 
| 418 | 
             
                            with gr.Row():
         | 
| 419 | 
             
                                btn_apply_to_sample_dataset = gr.Button("Save", variant="primary")
         | 
| 420 | 
            +
                                clear_btn_full = gr.Button("Clear", variant="secondary")
         | 
| 421 | 
             
                        with gr.Column(scale=3):
         | 
| 422 | 
            +
                            dataframe = _get_dataframe()
         | 
|  | |
|  | |
| 423 |  | 
| 424 | 
             
                    gr.HTML("<hr>")
         | 
| 425 | 
             
                    gr.Markdown("## 3. Generate your dataset")
         | 
|  | |
| 557 | 
             
                    outputs=[pipeline_code_ui],
         | 
| 558 | 
             
                )
         | 
| 559 |  | 
| 560 | 
            +
                gr.on(
         | 
| 561 | 
            +
                    triggers=[clear_btn_part.click, clear_btn_full.click],
         | 
| 562 | 
            +
                    fn=lambda _: (
         | 
| 563 | 
            +
                        "",
         | 
| 564 | 
            +
                        "",
         | 
| 565 | 
            +
                        [],
         | 
| 566 | 
            +
                        _get_dataframe(),
         | 
| 567 | 
            +
                    ),
         | 
| 568 | 
            +
                    inputs=[dataframe],
         | 
| 569 | 
            +
                    outputs=[dataset_description, system_prompt, labels, dataframe],
         | 
| 570 | 
            +
                )
         | 
| 571 | 
            +
             | 
| 572 | 
             
                app.load(fn=swap_visibility, outputs=main_ui)
         | 
| 573 | 
             
                app.load(fn=get_org_dropdown, outputs=[org_name])
         | 

